diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..c38f02e033fbb14a604d7a132dbaf0bb73abb36f 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,12 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text +third_party/demucs/ckpt/htdemucs.pth filter=lfs diff=lfs merge=lfs -text +ckpt/100000_dpo.pt filter=lfs diff=lfs merge=lfs -text *.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text *.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +ckpt/vae/autoencoder_music_1320k.ckpt filter=lfs diff=lfs merge=lfs -text +ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e filter=lfs diff=lfs merge=lfs -text +codeclm/tokenizer/Flow1dVAE/third_party/wespeaker/voxceleb_resnet34_LM/voxceleb_resnet34_LM.onnx filter=lfs diff=lfs merge=lfs -text +codeclm/tokenizer/Flow1dVAE/third_party/wespeaker/voxceleb_resnet34_LM/voxceleb_resnet34_LM.pt filter=lfs diff=lfs merge=lfs -text +third_party/flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp39-cp39-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text +ckpt/60000_alnew.pt filter=lfs diff=lfs merge=lfs -text +*.wav filter=lfs diff=lfs merge=lfs -text +*.mp3 filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..154875adba2aab3d526c6ba8a1ffe0904fedf869 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +launchs/ +**/__pycache__ +sample/generated/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..61058fbc8ea5dcb8462b43b04298077f349be570 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,13 @@ +FROM juhayna/song-generation-levo:v0.1 + +RUN useradd -m -u 1000 user +USER user +ENV PATH="/home/user/.local/bin:$PATH" + +WORKDIR /app + +COPY --chown=user ./requirements.txt requirements.txt +RUN pip install --no-cache-dir --upgrade -r requirements.txt + +COPY --chown=user . /app +CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"] \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b8ec4f1e531f3026bf31ba6529a632d9a327eb3a --- /dev/null +++ b/LICENSE @@ -0,0 +1,211 @@ +Tencent is pleased to support the open source community by making SongGeneration available. + +Copyright (C) 2025 THL A29 Limited, a Tencent company. All rights reserved. + +SongGeneration is licensed under the License Terms of SongGeneration except for the third-party components listed below, which is licensed under different terms. SongGeneration does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations. + + +License Terms of SongGeneration: +-------------------------------------------------------------------- + +Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +- You agree to use the SongGeneration only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances. + +- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +For avoidance of doubts, "Software" means the SongGeneration inference-enabling code and the weights made available under this license excluding any pre-trained data and other AI components. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +Other dependencies and licenses: + + +Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein: +-------------------------------------------------------------------- +1. stable_audio_tools +Copyright (c) 2023 Stability AI + + +Terms of the MIT: +-------------------------------------------------------------------- +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +For the license of other third party components, please refer to the following URL: +https://github.com/Stability-AI/stable-audio-tools/tree/main/LICENSES + + +Open Source Software Licensed under the MIT License: +-------------------------------------------------------------------- +1. demucs +Copyright (c) Meta Platforms, Inc. and affiliates. + + +A copy of the MIT is included in this file. + + + +Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein: +-------------------------------------------------------------------- +1. torch +From PyTorch: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +From Caffe2: + +Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All contributions by Cruise LLC: +Copyright (c) 2022 Cruise LLC. +All rights reserved. + +All contributions from Caffe: +Copyright(c) 2013, 2014, 2015, the respective contributors +All rights reserved. + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Caffe2 uses a copyright model similar to Caffe: each contributor holds +copyright over their contributions to Caffe2. The project versioning records +all such contribution and copyright details. If a contributor wants to further +mark their specific copyright on a particular contribution, they should +indicate their copyright solely in the commit message of the change when it is +committed. + +All rights reserved. + + +Terms of the BSD 3-Clause: +-------------------------------------------------------------------- +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +For the license of other third party components, please refer to the following URL: +https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE + + +Open Source Software Licensed under the BSD 2-Clause License and Other Licenses of the Third-Party Components therein: +-------------------------------------------------------------------- +1. torchaudio +Copyright (c) 2017 Facebook Inc. (Soumith Chintala), +All rights reserved. + + +Terms of the BSD 2-Clause: +-------------------------------------------------------------------- +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +For the license of other third party components, please refer to the following URL: +https://github.com/pytorch/audio/blob/v2.0.2/LICENSE + + +Open Source Software License under the Apache License Version 2.0: +-------------------------------------------------------------------- +1. huggingface-hub +Copyright (c) huggingface-hub original author and authors + +2. transformers +Copyright 2018- The Hugging Face team. All rights reserved. + + +Terms of the Apache License Version 2.0: +-------------------------------------------------------------------- +Apache License + +Version 2.0, January 2004 + +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of this License; and + +You must cause any modified files to carry prominent notices stating that You changed the files; and + +You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + +If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS \ No newline at end of file diff --git a/README.md b/README.md index 670b58d1a0b3167a5fa6db705f7e49d3198506f9..8a20699a5e8ac87c2537a711fc1c47283d3081da 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,68 @@ --- -title: SongGeneration LeVo -emoji: 🏃 +title: LeVo Song Generation +emoji: 🎵 colorFrom: purple -colorTo: blue +colorTo: gray sdk: docker -pinned: false -short_description: Demo interface for the LeVo song generation model. +app_port: 7860 --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + +# SongGeration: + +This repository is the official code repository for LeVo: High-Quality Song Generation with Multi-Preference Alignment. You can find our paper on [here](https://arxiv.org/). The demo page is available [here](https://levo-demo.github.io/). + +In this repository, we provide the SongGeration model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset. Specifically, we have released the model and inference code corresponding to the SFT + auto-DPO version. + +## Installation + +## Start from scatch +You can install the necessary dependencies using the `requirements.txt` file with Python 3.8.12: + +```bash +pip install -r requirements.txt +``` + +then install flash attention from wget + +```bash +wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl -P /home/ +pip install /home/flash_attn-2.7.4.post1+cu12torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl +``` + +## Start with docker +```bash +docker pull juhayna/song-generation-levo:v0.1 +docker run -it --gpus all --network=host juhayna/song-generation-levo:v0.1 /bin/bash +``` + +## Inference + +Please note that all the two folder below must be downloaded completely for the model to load correctly, which is sourced from [here](https://huggingface.co/waytan22/SongGeneration) + +- Save `ckpt` to the root directory +- Save `third_party` to the root directory + +Then run inference, use the following command: + +```bash +sh generate.sh sample/lyric.jsonl sample/generate +``` +- Input keys in the `sample/lyric.jsonl` + - `idx`: name of the generate song file + - `descriptions`: text description, can be None or specified gender, timbre, genre, mood, instrument and BPM + - `prompt_audio_path`: reference audio path, can be None or 10s song audio path + - `gt_lyric`: lyrics, it needs to follow the format of '\[Structure\] Text', supported structures can be found in `conf/vocab.yaml` + +- Outputs of the loader `sample/generate`: + - `audio`: generated audio files + - `jsonl`: output jsonls + - `token`: Token corresponding to the generated audio files + +## Note + +Since the model is trained based on data longer than 1 minute, if the given lyrics are too short, the model will automatically fill in the lyrics to extend the duration. + +## License + +The code and weights in this repository is released under the MIT license as found in the [LICENSE](LICENSE) file. diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..af1c1eff471f6323c51f737735e0dc2649f2f1e7 --- /dev/null +++ b/app.py @@ -0,0 +1,140 @@ +import os +import gradio as gr +import json +import numpy as np +from datetime import datetime +import os +import sys +import librosa + +EXAMPLE_DESC = """female, dark, pop, sad, piano and drums, the bpm is 125.""" +EXAMPLE_LYRICS = """ +[intro-short] + +[verse] +夜晚的街灯闪烁. +我漫步在熟悉的角落. +回忆像潮水般涌来. +你的笑容如此清晰. +在心头无法抹去. +那些曾经的甜蜜. +如今只剩我独自回忆. + +[bridge] +手机屏幕亮起. +是你发来的消息. +简单的几个字. +却让我泪流满面. +曾经的拥抱温暖. +如今却变得遥远. +我多想回到从前. +重新拥有你的陪伴. + +[chorus] +回忆的温度还在. +你却已不在. +我的心被爱填满. +却又被思念刺痛. +R&B的节奏奏响. +我的心却在流浪. +没有你的日子. +我该如何继续向前. + +[outro-short] +""".strip() + + +# 模拟歌曲生成函数 +def generate_song(description, lyric, prompt_audio=None): + # 这里模拟生成过程 - 实际应用中替换为你的模型调用 + print(f"Generating song with description: {description}") + print(f"Lyrics provided: {lyric}") + if prompt_audio is not None: + print("Using prompt audio for generation") + + # 从文件中加载示例音频 + audio_path = "./sample/example.mp3" + audio_data, sample_rate = librosa.load(audio_path, sr=None) # 保持原始采样率 + + + # 创建输入配置的JSON + input_config = { + "description": description, + "lyric": lyric, + "has_prompt_audio": prompt_audio is not None, + "timestamp": datetime.now().isoformat(), + } + + return (sample_rate, audio_data), json.dumps(input_config, indent=2) + +# 创建Gradio界面 +with gr.Blocks(title="LeVo Demo Space") as demo: + gr.Markdown("# 🎵 LeVo Demo Space") + gr.Markdown("Demo interface for the LeVo song generation model. Provide a description, lyrics, and optionally an audio prompt to generate a custom song.") + + with gr.Row(): + with gr.Column(): + description = gr.Textbox( + label="Song Description", + placeholder="Describe the style, mood, and characteristics of the song...", + lines=1, + max_lines=2, + value=EXAMPLE_DESC, + ) + lyric = gr.Textbox( + label="Lyrics", + placeholder="Enter the lyrics for the song...", + lines=5, + max_lines=8, + value=EXAMPLE_LYRICS, + ) + + with gr.Tabs(elem_id="extra-tabs"): + with gr.Tab("Audio Prompt"): + prompt_audio = gr.Audio( + label="Prompt Audio (Optional)", + type="filepath", + elem_id="audio-prompt" + ) + with gr.Tab("Advanced Config"): + text_prompt = gr.Textbox( + label="Text Prompt", + placeholder="Enter the Text Prompt, eg: emotional piano pop", + ) + + generate_btn = gr.Button("Generate Song", variant="primary") + + with gr.Column(): + output_audio = gr.Audio(label="Generated Song", type="numpy") + output_json = gr.JSON(label="Input Configuration") + + # 示例按钮 + examples = gr.Examples( + examples=[ + ["An uplifting pop song with catchy melodies"], + ["Melancholic piano ballad"], + ], + inputs=[description], + label="Description examples" + ) + + examples = gr.Examples( + examples=[ + ["Shine bright like the stars above\nYou're the one that I'm dreaming of"], + ["The rain keeps falling on my window pane\nReminding me of love that's gone away"], + ], + inputs=[lyric], + label="Lyrics examples" + ) + + # 生成按钮点击事件 + generate_btn.click( + fn=generate_song, + inputs=[description, lyric, prompt_audio], + outputs=[output_audio, output_json] + ) + + +# 启动应用 +if __name__ == "__main__": + demo.launch() \ No newline at end of file diff --git a/codeclm/models/__init__.py b/codeclm/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3038dffb738a15e64217735a811506e0ba981dc --- /dev/null +++ b/codeclm/models/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +""" +Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel. +""" +# flake8: noqa +from . import builders +from .codeclm import CodecLM diff --git a/codeclm/models/builders.py b/codeclm/models/builders.py new file mode 100755 index 0000000000000000000000000000000000000000..09f49edef8d0a869fdd0f7fb84f5871bdf4bf4c6 --- /dev/null +++ b/codeclm/models/builders.py @@ -0,0 +1,139 @@ +""" +All the functions to build the relevant models and modules +from the Hydra config. +""" + +import typing as tp + +import omegaconf +import torch +from codeclm.utils.utils import dict_from_config +from codeclm.modules.pattern import ( + CodebooksPatternProvider, + DelayedPatternProvider, +) +from codeclm.modules.conditioners import ( + BaseConditioner, + QwTokenizerConditioner, + QwTextConditioner, + PhonemeTokenizerConditioner, + QuantizedEmbeddingConditioner, + ConditionerProvider, + ConditionFuser, +) + + +def get_audio_tokenizer_model(checkpoint_path: str, cfg: omegaconf.DictConfig): + from codeclm.tokenizer.audio_tokenizer import AudioTokenizer + """Instantiate a compression model.""" + if checkpoint_path is None: + return None + if checkpoint_path.startswith('//pretrained/'): + name = checkpoint_path.split('/', 3)[-1] + return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode) + elif checkpoint_path == "": + return None + else: + name = checkpoint_path + return AudioTokenizer.get_pretrained(name, cfg.vae_config, cfg.vae_model, 'cpu', mode=cfg.mode) + +def get_lm_model(cfg: omegaconf.DictConfig): #-> LMModel: + """Instantiate a LM.""" + lm_kwargs = dict_from_config(getattr(cfg, 'lm')) + + # n_q: number of RVQ + code_depth = lm_kwargs['code_depth'] + q_modeling = lm_kwargs.pop('q_modeling', None) + + # conditioner + condition_provider = get_conditioner_provider(lm_kwargs["dim"], cfg) + + # codebook pattern: delay + codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern') + if codebooks_pattern_cfg.modeling is None: + assert q_modeling is not None, \ + "LM model should either have a codebook pattern defined or transformer_lm.q_modeling" + codebooks_pattern_cfg = omegaconf.OmegaConf.create( + {'modeling': q_modeling, 'delay': {'delays': list(range(code_depth))}} + ) + pattern_provider = get_codebooks_pattern_provider(code_depth, codebooks_pattern_cfg) + + # condition dropout + attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout')) + cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance')) + cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef'] + + # condition fuser + fuser = get_condition_fuser(cfg) + lm_type = lm_kwargs['lm_type'] # YCY: For consistency, choose different lm.py based on lm_type + if lm_type == 'Llama': + from .lm_levo import LmModel + return LmModel( + pattern_provider=pattern_provider, + condition_provider=condition_provider, + fuser=fuser, + cfg_dropout=cfg_prob, + cfg_coef=cfg_coef, + attribute_dropout=attribute_dropout, + cfg=cfg, + **lm_kwargs + ).to('cpu') + else: + raise KeyError(f"Unexpected LM model {lm_type}") + + +def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditionerProvider: + """Instantiate a conditioning model.""" + cfg = getattr(cfg, 'conditioners') + dict_cfg = {} if cfg is None else dict_from_config(cfg) + conditioners: tp.Dict[str, BaseConditioner] = {} + condition_provider_args = dict_cfg.pop('args', {}) + + for cond, cond_cfg in dict_cfg.items(): + model_type = cond_cfg['model'] + model_args = cond_cfg[model_type] + if model_type == 'QwTokenizer': + conditioners[str(cond)] = QwTokenizerConditioner( + output_dim=output_dim, + **model_args + ) + elif model_type == "QwTextTokenizer": + conditioners[str(cond)] = QwTextConditioner( + output_dim=output_dim, + **model_args + ) + elif model_type == 'PhonemeTokenizer': + conditioners[str(cond)] = PhonemeTokenizerConditioner( + output_dim=output_dim, + **model_args + ) + elif model_type == "qt_embedding": + conditioners[str(cond)] = QuantizedEmbeddingConditioner( + dim=output_dim, + **model_args + ) + else: + raise ValueError(f"Unrecognized conditioning model: {model_type}") + conditioner = ConditionerProvider(conditioners, **condition_provider_args) + return conditioner + + +def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser: + """Instantiate a condition fuser object.""" + fuser_cfg = getattr(cfg, 'fuser') + fuser_methods = ['sum', 'prepend'] + fuse2cond = {k: fuser_cfg[k] for k in fuser_methods} + kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods} + fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs) + return fuser + + +def get_codebooks_pattern_provider(code_depth: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider: + """Instantiate a codebooks pattern provider object.""" + pattern_providers = { + 'delay': DelayedPatternProvider, + } + name = cfg.modeling + kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {} + klass = pattern_providers[name] + return klass(code_depth, **kwargs) diff --git a/codeclm/models/codeclm.py b/codeclm/models/codeclm.py new file mode 100644 index 0000000000000000000000000000000000000000..f5928fdf6ed0652224d3b689eade16b1717d22ac --- /dev/null +++ b/codeclm/models/codeclm.py @@ -0,0 +1,303 @@ +""" +Main model for using CodecLM. This will combine all the required components +and provide easy access to the generation API. +""" + +import typing as tp +import warnings + +import torch + +from codeclm.tokenizer.audio_tokenizer import AudioTokenizer +from .lm_levo import LmModel +from ..modules.conditioners import ConditioningAttributes, AudioCondition +from ..utils.autocast import TorchAutocast +import torch +from torch.nn import functional as F +import torchaudio +# from optim.ema import EMA + + +MelodyList = tp.List[tp.Optional[torch.Tensor]] +MelodyType = tp.Union[torch.Tensor, MelodyList] + +class CodecLM: + """CodecLM main model with convenient generation API. + + Args: + name (str): name of the model. + compression_model (CompressionModel): Compression model + used to map audio to invertible discrete representations. + lm (LMModel): Language model over discrete representations. + max_duration (float, optional): maximum duration the model can produce, + otherwise, inferred from the training params. + """ + def __init__(self, name: str, audiotokenizer: AudioTokenizer, lm: LmModel, + max_duration: tp.Optional[float] = None, seperate_tokenizer: AudioTokenizer = None): + self.name = name + self.audiotokenizer = audiotokenizer + self.lm = lm + self.seperate_tokenizer = seperate_tokenizer + # import pdb; pdb.set_trace() + if max_duration is None: + if hasattr(lm, 'cfg'): + max_duration = lm.cfg.dataset.segment_duration # type: ignore + else: + raise ValueError("You must provide max_duration when building directly CodecLM") + assert max_duration is not None + + self.max_duration: float = max_duration + self.device = next(iter(lm.parameters())).device + self.generation_params: dict = {} + # self.set_generation_params(duration=15) # 15 seconds by default + self.set_generation_params(duration=15, extend_stride=self.max_duration // 2) + self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None + if self.device.type == 'cpu': + self.autocast = TorchAutocast(enabled=False) + else: + self.autocast = TorchAutocast(enabled=False) + + + + @property + def frame_rate(self) -> float: + """Roughly the number of AR steps per seconds.""" + return self.audiotokenizer.frame_rate + + @property + def sample_rate(self) -> int: + """Sample rate of the generated audio.""" + return self.audiotokenizer.sample_rate + + @property + def audio_channels(self) -> int: + """Audio channels of the generated audio.""" + return self.audiotokenizer.channels + + def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, + top_p: float = 0.0, temperature: float = 1.0, + duration: float = 30.0, cfg_coef: float = 3.0, + extend_stride: float = 18, record_tokens: bool = False, + record_window: int = 50): + """Set the generation parameters for CodecLM. + + Args: + use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True. + top_k (int, optional): top_k used for sampling. Defaults to 250. + top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0. + temperature (float, optional): Softmax temperature parameter. Defaults to 1.0. + duration (float, optional): Duration of the generated waveform. Defaults to 30.0. + cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0. + two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance, + instead of batching together the two. This has some impact on how things + are padded but seems to have little impact in practice. + extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much + should we extend the audio each time. Larger values will mean less context is + preserved, and shorter value will require extra computations. + """ + assert extend_stride <= self.max_duration, "Cannot stride by more than max generation duration." + self.extend_stride = extend_stride + self.duration = duration + self.generation_params = { + 'use_sampling': use_sampling, + 'temp': temperature, + 'top_k': top_k, + 'top_p': top_p, + 'cfg_coef': cfg_coef, + 'record_tokens': record_tokens, + 'record_window': record_window, + } + + def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None): + """Override the default progress callback.""" + self._progress_callback = progress_callback + + # Inference + def generate(self, lyrics: tp.List[str], + descriptions: tp.List[str], + melody_wavs: torch.Tensor = None, + melody_is_wav: bool = True, + vocal_wavs: torch.Tensor = None, + bgm_wavs: torch.Tensor = None, + return_tokens: bool = False, + ) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]: + """Generate samples conditioned on text and melody. + + Args: + descriptions (list of str): A list of strings used as text conditioning. + melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as + melody conditioning. Should have shape [B, C, T] with B matching the description length, + C=1 or 2. It can be [C, T] if there is a single description. It can also be + a list of [C, T] tensors. + melody_sample_rate: (int): Sample rate of the melody waveforms. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + """ + if melody_wavs is not None: + if melody_wavs.dim() == 2: + melody_wavs = melody_wavs[None] + if melody_wavs.dim() != 3: + raise ValueError("Melody wavs should have a shape [B, C, T].") + melody_wavs = list(melody_wavs) + if vocal_wavs is not None: + if vocal_wavs.dim() == 2: + vocal_wavs = vocal_wavs[None] + if vocal_wavs.dim() != 3: + raise ValueError("Vocal wavs should have a shape [B, C, T].") + vocal_wavs = list(vocal_wavs) + if bgm_wavs is not None: + if bgm_wavs.dim() == 2: + bgm_wavs = bgm_wavs[None] + if bgm_wavs.dim() != 3: + raise ValueError("BGM wavs should have a shape [B, C, T].") + bgm_wavs = list(bgm_wavs) + + texts, audio_qt_embs = self._prepare_tokens_and_attributes(lyrics=lyrics, melody_wavs=melody_wavs, vocal_wavs=vocal_wavs, bgm_wavs=bgm_wavs, melody_is_wav=melody_is_wav) + tokens = self._generate_tokens(texts, descriptions, audio_qt_embs) + + if (tokens == self.lm.eos_token_id).any(): + length = torch.nonzero(torch.eq(tokens, self.lm.eos_token_id))[:,-1].min() + tokens = tokens[...,:length] + + if return_tokens: + return tokens + else: + out = self.generate_audio(tokens) + return out + + + @torch.no_grad() + def _prepare_tokens_and_attributes( + self, + lyrics: tp.Sequence[tp.Optional[str]], + melody_wavs: tp.Optional[MelodyList] = None, + vocal_wavs: tp.Optional[MelodyList] = None, + bgm_wavs: tp.Optional[MelodyList] = None, + melody_is_wav = True + ) -> tp.Tuple[tp.List[str], tp.List[torch.Tensor]]: + """Prepare model inputs. + + Args: + descriptions (list of str): A list of strings used as text conditioning. + prompt (torch.Tensor): A batch of waveforms used for continuation. + melody_wavs (torch.Tensor, optional): A batch of waveforms + used as melody conditioning. Defaults to None. + """ + assert len(lyrics) == 1 + texts = [lyric for lyric in lyrics] + audio_qt_embs = [] + target_melody_token_len = self.lm.cfg.prompt_len * self.audiotokenizer.frame_rate + # import pdb; pdb.set_trace() + if melody_wavs is None: + melody_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long() + elif melody_wavs is not None: + if 'prompt_audio' not in self.lm.condition_provider.conditioners: + raise RuntimeError("This model doesn't support melody conditioning. " + "Use the `melody` model.") + assert len(melody_wavs) == len(texts), \ + f"number of melody wavs must match number of descriptions! " \ + f"got melody len={len(melody_wavs)}, and descriptions len={len(texts)}" + if type(melody_wavs) == list: + melody_wavs = torch.stack(melody_wavs, dim=0) + melody_wavs = melody_wavs.to(self.device) + if melody_is_wav: + melody_tokens, scale = self.audiotokenizer.encode(melody_wavs) + else: + melody_tokens = melody_wavs + if melody_tokens.shape[-1] > target_melody_token_len: + melody_tokens = melody_tokens[...,:target_melody_token_len] + elif melody_tokens.shape[-1] < target_melody_token_len: + melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1) + if self.seperate_tokenizer is not None: + if vocal_wavs is not None: + if type(vocal_wavs) == list: + vocal_wavs = torch.stack(vocal_wavs, dim=0) + if bgm_wavs is None: + use_bgm = False + bgm_wavs = torch.zeros_like(vocal_wavs) + bgm_wavs[:, 0] = 1.0 + bgm_wavs[:, 1:] = torch.randn_like(bgm_wavs[:, 1:])* 0.0003 + else: + use_bgm = True + if type(bgm_wavs) == list: + bgm_wavs = torch.stack(bgm_wavs, dim=0) + vocal_wavs = vocal_wavs.to(self.device) + bgm_wavs = bgm_wavs.to(self.device) + vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs) + assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \ + f"vocal and bgm tokens should have a shape [B, C, T]! " \ + f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}" + assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \ + f"vocal and bgm tokens should have the same length! " \ + f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}" + if not use_bgm: + bgm_tokens = torch.full_like(bgm_tokens, 16385) + if bgm_tokens.shape[-1] > target_melody_token_len: + bgm_tokens = bgm_tokens[...,:target_melody_token_len] + elif bgm_tokens.shape[-1] < target_melody_token_len: + bgm_tokens = torch.cat([bgm_tokens, torch.full((1,1,target_melody_token_len - bgm_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1) + if vocal_tokens.shape[-1] > target_melody_token_len: + vocal_tokens = vocal_tokens[...,:target_melody_token_len] + elif vocal_tokens.shape[-1] < target_melody_token_len: + vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1) + else: + bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long() + vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long() + + melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1) + assert melody_tokens.shape[-1] == target_melody_token_len + audio_qt_embs = melody_tokens.long() + return texts, audio_qt_embs + + + + def _generate_tokens(self, + texts: tp.Optional[tp.List[str]] = None, + descriptions: tp.Optional[tp.List[str]] = None, + audio_qt_embs: tp.Optional[tp.List[torch.Tensor]] = None) -> torch.Tensor: + """Generate discrete audio tokens given audio prompt and/or conditions. + + Args: + attributes (list of ConditioningAttributes): Conditions used for generation (text/melody). + prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation. + progress (bool, optional): Flag to display progress of the generation process. Defaults to False. + Returns: + torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params. + """ + total_gen_len = int(self.duration * self.frame_rate) + current_gen_offset: int = 0 + + def _progress_callback(generated_tokens: int, tokens_to_generate: int): + generated_tokens += current_gen_offset + if self._progress_callback is not None: + # Note that total_gen_len might be quite wrong depending on the + # codebook pattern used, but with delay it is almost accurate. + self._progress_callback(generated_tokens, total_gen_len) + else: + print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r') + + if self.duration <= self.max_duration: + # generate by sampling from LM, simple case. + with self.autocast: + gen_tokens = self.lm.generate(texts=texts, + descriptions=descriptions, + audio_qt_embs=audio_qt_embs, + max_gen_len=total_gen_len, + **self.generation_params) + else: + raise NotImplementedError(f"duration {self.duration} < max duration {self.max_duration}") + return gen_tokens + + @torch.no_grad() + def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None): + """Generate Audio from tokens""" + assert gen_tokens.dim() == 3 + if self.seperate_tokenizer is not None: + gen_tokens_song = gen_tokens[:, [0], :] + gen_tokens_vocal = gen_tokens[:, [1], :] + gen_tokens_bgm = gen_tokens[:, [2], :] + # gen_audio_song = self.audiotokenizer.decode(gen_tokens_song, prompt) + gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt) + return gen_audio_seperate + else: + gen_audio = self.audiotokenizer.decode(gen_tokens, prompt) + return gen_audio diff --git a/codeclm/models/levo.py b/codeclm/models/levo.py new file mode 100755 index 0000000000000000000000000000000000000000..2514a55c623eb36f2afff230ffd2ff6a870b52a3 --- /dev/null +++ b/codeclm/models/levo.py @@ -0,0 +1,224 @@ + +from .llama.modeling_llama import LlamaConfig, CausalLMOutputWithPast, BaseModelOutputWithPast, LlamaDecoderLayer, LlamaRMSNorm +from .llama.modeling_llama import LlamaForCausalLM as LlamaForCausalLM_base +from .llama.modeling_llama import LlamaModel as LlamaModel_base +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Union, Optional, Tuple, List +from packaging import version +import transformers +""" +Wrap the original Llama model for potential customized changes. +""" + +"""main class""" +class CausalLM(LlamaForCausalLM_base): + def __init__(self, config): + super().__init__(config) + self.model = LmModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=hidden_states, + attentions=outputs.attentions, + ) + + +"""Submodel class""" +class LmModel(LlamaModel_base): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + layer_cls = LlamaDecoderLayer # cross attention decoder layer can be overwritten here + + assert version.parse(transformers.__version__) < version.parse("4.40") + + self.layers = nn.ModuleList([layer_cls(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + self.gradient_checkpointing_disable() + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + layer_args = (hidden_states, attention_mask, position_ids,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), *layer_args + ) + else: + + layer_outputs = decoder_layer(*layer_args, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + \ No newline at end of file diff --git a/codeclm/models/llama/__init__.py b/codeclm/models/llama/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cfeaa6b692d8b27331fad2c290eed89ba12c94bf --- /dev/null +++ b/codeclm/models/llama/__init__.py @@ -0,0 +1,90 @@ +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +from transformers.utils import ( + OptionalDependencyNotAvailable, + _LazyModule, + is_sentencepiece_available, + is_tokenizers_available, + is_torch_available, +) + + +_import_structure = { + "configuration_llama": ["LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "LlamaConfig"], +} + +try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_llama"] = ["LlamaTokenizer"] + +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["tokenization_llama_fast"] = ["LlamaTokenizerFast"] + +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + pass +else: + _import_structure["modeling_llama"] = [ + "LlamaForCausalLM", + "LlamaModel", + "LlamaPreTrainedModel", + "LlamaForSequenceClassification", + ] + + +if TYPE_CHECKING: + from .configuration_llama import LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP, LlamaConfig + + try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_llama import LlamaTokenizer + + try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .tokenization_llama_fast import LlamaTokenizerFast + + try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + pass + else: + from .modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaPreTrainedModel + + +else: + import sys + + sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) diff --git a/codeclm/models/llama/configuration_llama.py b/codeclm/models/llama/configuration_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..6660316493d3b95112d1d2d083260c871920709e --- /dev/null +++ b/codeclm/models/llama/configuration_llama.py @@ -0,0 +1,182 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" LLaMA model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class LlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + pretraining_tp (`int`, *optional*, defaults to `1`): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + tie_word_embeddings(`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format + is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + + Example: + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}") diff --git a/codeclm/models/llama/convert_llama_weights_to_hf.py b/codeclm/models/llama/convert_llama_weights_to_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..acc49884ebdb290622e0df239d4c00c1b59de708 --- /dev/null +++ b/codeclm/models/llama/convert_llama_weights_to_hf.py @@ -0,0 +1,318 @@ +# Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os +import shutil +import warnings + +import torch + +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer + + +try: + from transformers import LlamaTokenizerFast +except ImportError as e: + warnings.warn(e) + warnings.warn( + "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion" + ) + LlamaTokenizerFast = None + +""" +Sample usage: + +``` +python src/transformers/models/llama/convert_llama_weights_to_hf.py \ + --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir /output/path +``` + +Thereafter, models can be loaded via: + +```py +from transformers import LlamaForCausalLM, LlamaTokenizer + +model = LlamaForCausalLM.from_pretrained("/output/path") +tokenizer = LlamaTokenizer.from_pretrained("/output/path") +``` + +Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions +come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM). +""" + +NUM_SHARDS = { + "7B": 1, + "7Bf": 1, + "13B": 2, + "13Bf": 2, + "34B": 4, + "30B": 4, + "65B": 8, + "70B": 8, + "70Bf": 8, +} + + +def compute_intermediate_size(n, ffn_dim_multiplier=1, multiple_of=256): + return multiple_of * ((int(ffn_dim_multiplier * int(8 * n / 3)) + multiple_of - 1) // multiple_of) + + +def read_json(path): + with open(path, "r") as f: + return json.load(f) + + +def write_json(text, path): + with open(path, "w") as f: + json.dump(text, f) + + +def write_model(model_path, input_base_path, model_size, tokenizer_path=None, safe_serialization=True): + # for backward compatibility, before you needed the repo to be called `my_repo/model_size` + if not os.path.isfile(os.path.join(input_base_path, "params.json")): + input_base_path = os.path.join(input_base_path, model_size) + + os.makedirs(model_path, exist_ok=True) + tmp_model_path = os.path.join(model_path, "tmp") + os.makedirs(tmp_model_path, exist_ok=True) + + params = read_json(os.path.join(input_base_path, "params.json")) + num_shards = NUM_SHARDS[model_size] + n_layers = params["n_layers"] + n_heads = params["n_heads"] + n_heads_per_shard = n_heads // num_shards + dim = params["dim"] + dims_per_head = dim // n_heads + base = params.get("rope_theta", 10000.0) + inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) + if base > 10000.0: + max_position_embeddings = 16384 + else: + max_position_embeddings = 2048 + + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + if tokenizer_path is not None: + tokenizer = tokenizer_class(tokenizer_path) + tokenizer.save_pretrained(model_path) + vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000 + + if "n_kv_heads" in params: + num_key_value_heads = params["n_kv_heads"] # for GQA / MQA + num_local_key_value_heads = n_heads_per_shard // num_key_value_heads + key_value_dim = dim // num_key_value_heads + else: # compatibility with other checkpoints + num_key_value_heads = n_heads + num_local_key_value_heads = n_heads_per_shard + key_value_dim = dim + + # permute for sliced rotary + def permute(w, n_heads=n_heads, dim1=dim, dim2=dim): + return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + print(f"Fetching all parameters from the checkpoint at {input_base_path}.") + # Load weights + if model_size == "7B": + # Not sharded + # (The sharded implementation would also work, but this is simpler.) + loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu") + else: + # Sharded + loaded = [ + torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu") + for i in range(num_shards) + ] + param_count = 0 + index_dict = {"weight_map": {}} + for layer_i in range(n_layers): + filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin" + if model_size == "7B": + # Unsharded + state_dict = { + f"model.layers.{layer_i}.self_attn.q_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wq.weight"] + ), + f"model.layers.{layer_i}.self_attn.k_proj.weight": permute( + loaded[f"layers.{layer_i}.attention.wk.weight"] + ), + f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"layers.{layer_i}.attention.wv.weight"], + f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"layers.{layer_i}.attention.wo.weight"], + f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w1.weight"], + f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w2.weight"], + f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"layers.{layer_i}.feed_forward.w3.weight"], + f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"layers.{layer_i}.attention_norm.weight"], + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"layers.{layer_i}.ffn_norm.weight"], + } + else: + # Sharded + # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share + # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is + # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned. + + state_dict = { + f"model.layers.{layer_i}.input_layernorm.weight": loaded[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wq.weight"].view(n_heads_per_shard, dims_per_head, dim) + for i in range(num_shards) + ], + dim=0, + ).reshape(dim, dim) + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute( + torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wk.weight"].view( + num_local_key_value_heads, dims_per_head, dim + ) + for i in range(num_shards) + ], + dim=0, + ).reshape(key_value_dim, dim), + num_key_value_heads, + key_value_dim, + dim, + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [ + loaded[i][f"layers.{layer_i}.attention.wv.weight"].view( + num_local_key_value_heads, dims_per_head, dim + ) + for i in range(num_shards) + ], + dim=0, + ).reshape(key_value_dim, dim) + + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=1 + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [loaded[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + + state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin" + if model_size == "7B": + # Unsharded + state_dict = { + "model.embed_tokens.weight": loaded["tok_embeddings.weight"], + "model.norm.weight": loaded["norm.weight"], + "lm_head.weight": loaded["output.weight"], + } + else: + state_dict = { + "model.norm.weight": loaded[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat( + [loaded[i]["tok_embeddings.weight"] for i in range(num_shards)], dim=1 + ), + "lm_head.weight": torch.cat([loaded[i]["output.weight"] for i in range(num_shards)], dim=0), + } + + for k, v in state_dict.items(): + index_dict["weight_map"][k] = filename + param_count += v.numel() + torch.save(state_dict, os.path.join(tmp_model_path, filename)) + + # Write configs + index_dict["metadata"] = {"total_size": param_count * 2} + write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json")) + ffn_dim_multiplier = params["ffn_dim_multiplier"] if "ffn_dim_multiplier" in params else 1 + multiple_of = params["multiple_of"] if "multiple_of" in params else 256 + config = LlamaConfig( + hidden_size=dim, + intermediate_size=compute_intermediate_size(dim, ffn_dim_multiplier, multiple_of), + num_attention_heads=params["n_heads"], + num_hidden_layers=params["n_layers"], + rms_norm_eps=params["norm_eps"], + num_key_value_heads=num_key_value_heads, + vocab_size=vocab_size, + rope_theta=base, + max_position_embeddings=max_position_embeddings, + ) + config.save_pretrained(tmp_model_path) + + # Make space so we can load the model properly now. + del state_dict + del loaded + gc.collect() + + print("Loading the checkpoint in a Llama model.") + model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True) + # Avoid saving this as part of the config. + del model.config._name_or_path + model.config.torch_dtype = torch.float16 + print("Saving in the Transformers format.") + model.save_pretrained(model_path, safe_serialization=safe_serialization) + shutil.rmtree(tmp_model_path) + + +def write_tokenizer(tokenizer_path, input_tokenizer_path): + # Initialize the tokenizer based on the `spm` model + tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast + print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.") + tokenizer = tokenizer_class(input_tokenizer_path) + tokenizer.save_pretrained(tokenizer_path) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--input_dir", + help="Location of LLaMA weights, which contains tokenizer.model and model folders", + ) + parser.add_argument( + "--model_size", + choices=["7B", "7Bf", "13B", "13Bf", "30B", "34B", "65B", "70B", "70Bf", "tokenizer_only"], + help="'f' models correspond to the finetuned versions, and are specific to the Llama2 official release. For more details on Llama2, checkout the original repo: https://huggingface.co/meta-llama", + ) + parser.add_argument( + "--output_dir", + help="Location to write HF model and tokenizer", + ) + parser.add_argument("--safe_serialization", type=bool, help="Whether or not to save using `safetensors`.") + args = parser.parse_args() + spm_path = os.path.join(args.input_dir, "tokenizer.model") + if args.model_size != "tokenizer_only": + write_model( + model_path=args.output_dir, + input_base_path=args.input_dir, + model_size=args.model_size, + safe_serialization=args.safe_serialization, + tokenizer_path=spm_path, + ) + else: + write_tokenizer(args.output_dir, spm_path) + + +if __name__ == "__main__": + main() diff --git a/codeclm/models/llama/modeling_llama.py b/codeclm/models/llama/modeling_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..0e7d4f86a77b393d3e83a4668e30b1329c392787 --- /dev/null +++ b/codeclm/models/llama/modeling_llama.py @@ -0,0 +1,1243 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch LLaMA model.""" +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_available, + logging, + replace_return_docstrings, +) +from .configuration_llama import LlamaConfig + + +if is_flash_attn_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(padding_mask): + seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.bart.modeling_bart._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 +): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + +# Copied from transformers.models.bart.modeling_bart._expand_mask +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + ) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dime x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # TODO: llama does not have dropout in the config?? + # It is recommended to use dropout with FA according to the docs + # when training. + dropout_rate = 0.0 # if not self.training else self.attn_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to" + " the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + " float16." + ) + + query_states = query_states.to(torch.float16) + key_states = key_states.to(torch.float16) + value_states = value_states.to(torch.float16) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + padding_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + # Contains at least one padding token in the sequence + if padding_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, padding_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=True, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=True + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + padding_mask = padding_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = ( + LlamaAttention(config=config) + if not getattr(config, "_flash_attn_2_enabled", False) + else LlamaFlashAttention2(config=config) + ) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, LlamaModel): + module.gradient_checkpointing = value + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = _make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + padding_mask = None + else: + if 0 in attention_mask: + padding_mask = attention_mask + else: + padding_mask = None + + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to( + logits.device + ) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/codeclm/models/llama/tokenization_llama.py b/codeclm/models/llama/tokenization_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..bf105b7a58a2084a98ea4d77f50e44d67b52aac7 --- /dev/null +++ b/codeclm/models/llama/tokenization_llama.py @@ -0,0 +1,426 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tokenization classes for LLaMA.""" +import os +from shutil import copyfile +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import sentencepiece as spm + +from transformers.convert_slow_tokenizer import import_protobuf +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + + +if TYPE_CHECKING: + from transformers.tokenization_utils_base import TextInput + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { + "hf-internal-testing/llama-tokenizer": 2048, +} +SPIECE_UNDERLINE = "▁" + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class LlamaTokenizer(PreTrainedTokenizer): + """ + Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. The default padding token is unset as there is + no padding token in the original model. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + legacy (`bool`, *optional*): + Whether or not the `legacy` behavior of the tokenizer should be used. Legacy is before the merge of #24622 + and #25224 which includes fixes to properly handle tokens that appear after special tokens. A simple + example: + + - `legacy=True`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=True) + >>> tokenizer.encode("Hello .") + [8774, 32099, 3, 5, 1] + ``` + - `legacy=False`: + ```python + >>> from transformers import T5Tokenizer + + >>> tokenizer = T5Tokenizer.from_pretrained("t5-base", legacy=False) + >>> tokenizer.encode("Hello .") # the extra space `[3]` is no longer here + [8774, 32099, 5, 1] + ``` + Checkout the [pull request](https://github.com/huggingface/transformers/pull/24565) for more details. + + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + use_default_system_prompt=True, + spaces_between_special_tokens=False, + legacy=None, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token + eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token + unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token + pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token + + if legacy is None: + logger.warning_once( + f"You are using the default legacy behaviour of the {self.__class__}. This is" + " expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you." + " If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it" + " means, and thouroughly read the reason why this was added as explained in" + " https://github.com/huggingface/transformers/pull/24565" + ) + legacy = True + + self.legacy = legacy + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.use_default_system_prompt = use_default_system_prompt + self.sp_model = self.get_spm_processor(kwargs.pop("from_slow", False)) + + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + use_default_system_prompt=use_default_system_prompt, + spaces_between_special_tokens=spaces_between_special_tokens, + legacy=legacy, + **kwargs, + ) + + @property + def unk_token_length(self): + return len(self.sp_model.encode(str(self.unk_token))) + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.get_spm_processor + def get_spm_processor(self, from_slow=False): + tokenizer = spm.SentencePieceProcessor(**self.sp_model_kwargs) + if self.legacy or from_slow: # no dependency on protobuf + tokenizer.Load(self.vocab_file) + return tokenizer + + with open(self.vocab_file, "rb") as f: + sp_model = f.read() + model_pb2 = import_protobuf(f"The new behaviour of {self.__class__.__name__} (with `self.legacy = False`)") + model = model_pb2.ModelProto.FromString(sp_model) + normalizer_spec = model_pb2.NormalizerSpec() + normalizer_spec.add_dummy_prefix = False + model.normalizer_spec.MergeFrom(normalizer_spec) + sp_model = model.SerializeToString() + tokenizer.LoadFromSerializedProto(sp_model) + return tokenizer + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + state["sp_model_proto"] = self.sp_model.serialized_model_proto() + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.LoadFromSerializedProto(self.sp_model_proto) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer.tokenize + def tokenize(self, text: "TextInput", add_special_tokens=False, **kwargs) -> List[str]: + """ + Converts a string to a list of tokens. If `self.legacy` is set to `False`, a prefix token is added unless the + first token is special. + """ + if self.legacy or len(text) == 0: + return super().tokenize(text, **kwargs) + + tokens = super().tokenize(SPIECE_UNDERLINE + text.replace(SPIECE_UNDERLINE, " "), **kwargs) + + if len(tokens) > 1 and tokens[0] == SPIECE_UNDERLINE and tokens[1] in self.all_special_tokens: + tokens = tokens[1:] + return tokens + + # Copied from transformers.models.t5.tokenization_t5.T5Tokenizer._tokenize + def _tokenize(self, text, **kwargs): + """ + Returns a tokenized string. + + We de-activated the `add_dummy_prefix` option, thus the sentencepiece internals will always strip any + SPIECE_UNDERLINE. For example: `self.sp_model.encode(f"{SPIECE_UNDERLINE}Hey", out_type = str)` will give + `['H', 'e', 'y']` instead of `['▁He', 'y']`. Thus we always encode `f"{unk_token}text"` and strip the + `unk_token`. Here is an example with `unk_token = ""` and `unk_token_length = 4`. + `self.tokenizer.sp_model.encode(" Hey", out_type = str)[4:]`. + """ + tokens = self.sp_model.encode(text, out_type=str) + if self.legacy or not text.startswith((SPIECE_UNDERLINE, " ")): + return tokens + + # 1. Encode string + prefix ex: " Hey" + tokens = self.sp_model.encode(self.unk_token + text, out_type=str) + # 2. Remove self.unk_token from ['<','unk','>', '▁Hey'] + return tokens[self.unk_token_length :] if len(tokens) >= self.unk_token_length else tokens + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + # since we manually add the prefix space, we have to remove it when decoding + if tokens[0].startswith(SPIECE_UNDERLINE): + tokens[0] = tokens[0][1:] + + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0 and self.legacy: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output + + @property + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! + + The output should look something like: + + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] + """ + + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template diff --git a/codeclm/models/llama/tokenization_llama_fast.py b/codeclm/models/llama/tokenization_llama_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..46e8e911252899f337b2e600abe4bc2d810efb4f --- /dev/null +++ b/codeclm/models/llama/tokenization_llama_fast.py @@ -0,0 +1,264 @@ +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from shutil import copyfile +from typing import Optional, Tuple + +from tokenizers import processors + +from transformers.tokenization_utils_fast import PreTrainedTokenizerFast +from transformers.utils import is_sentencepiece_available, logging +from transformers.utils.versions import require_version + + +require_version("tokenizers>=0.13.3") + +if is_sentencepiece_available(): + from .tokenization_llama import LlamaTokenizer +else: + LlamaTokenizer = None + +logger = logging.get_logger(__name__) +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model", "tokenizer_file": "tokenizer.json"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + }, + "tokenizer_file": { + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + }, +} +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" + +# fmt: off +DEFAULT_SYSTEM_PROMPT = """You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your \ +answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure\ + that your responses are socially unbiased and positive in nature. + +If a question does not make any sense, or is not factually coherent, explain why instead of answering something not \ +correct. If you don't know the answer to a question, please don't share false information.""" +# fmt: on + + +class LlamaTokenizerFast(PreTrainedTokenizerFast): + """ + Construct a Llama tokenizer. Based on byte-level Byte-Pair-Encoding. + + This uses notably ByteFallback and no normalization. + + ``` + from transformers import LlamaTokenizerFast + + tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") + tokenizer.encode("Hello this is a test") + >>> [1, 15043, 445, 338, 263, 1243] + ``` + + If you want to change the `bos_token` or the `eos_token`, make sure to specify them when initializing the model, or + call `tokenizer.update_post_processor()` to make sure that the post-processing is correctly done (otherwise the + values of the first token and final token of an encoded sequence will not be correct). For more details, checkout + [post-processors] (https://huggingface.co/docs/tokenizers/api/post-processors) documentation. + + + This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + vocab_file (`str`): + [SentencePiece](https://github.com/google/sentencepiece) file (generally has a .model extension) that + contains the vocabulary necessary to instantiate a tokenizer. + tokenizer_file (`str`): + [tokenizers](https://github.com/huggingface/tokenizers) file (generally has a .json extension) that + contains everything needed to load the tokenizer. + + clean_up_tokenization_spaces (`str`, *optional*, defaults to `False`): + Wether to cleanup spaces after decoding, cleanup consists in removing potential artifacts like extra + spaces. + + bos_token (`str`, *optional*, defaults to `""`): + The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. + + eos_token (`str`, *optional*, defaults to `""`): + The end of sequence token. + + unk_token (`str`, *optional*, defaults to `""`): + The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this + token instead. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + slow_tokenizer_class = LlamaTokenizer + padding_side = "left" + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file=None, + tokenizer_file=None, + clean_up_tokenization_spaces=False, + unk_token="", + bos_token="", + eos_token="", + add_bos_token=True, + add_eos_token=False, + use_default_system_prompt=True, + **kwargs, + ): + super().__init__( + vocab_file=vocab_file, + tokenizer_file=tokenizer_file, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + unk_token=unk_token, + bos_token=bos_token, + eos_token=eos_token, + use_default_system_prompt=use_default_system_prompt, + **kwargs, + ) + self._add_bos_token = add_bos_token + self._add_eos_token = add_eos_token + self.update_post_processor() + self.use_default_system_prompt = use_default_system_prompt + self.vocab_file = vocab_file + + @property + def can_save_slow_tokenizer(self) -> bool: + return os.path.isfile(self.vocab_file) if self.vocab_file else False + + def update_post_processor(self): + """ + Updates the underlying post processor with the current `bos_token` and `eos_token`. + """ + bos = self.bos_token + bos_token_id = self.bos_token_id + + eos = self.eos_token + eos_token_id = self.eos_token_id + + single = f"{(bos+':0 ') * self.add_bos_token}$A:0{(' '+eos+':0') if self.add_eos_token else ''}" + pair = f"{single}{(' '+bos+':1') * self.add_bos_token} $B:1{(' '+eos+':1') if self.add_eos_token else ''}" + + special_tokens = [] + if self.add_bos_token: + special_tokens.append((bos, bos_token_id)) + if self.add_eos_token: + special_tokens.append((eos, eos_token_id)) + self._tokenizer.post_processor = processors.TemplateProcessing( + single=single, pair=pair, special_tokens=special_tokens + ) + + @property + def add_eos_token(self): + return self._add_eos_token + + @property + def add_bos_token(self): + return self._add_bos_token + + @add_eos_token.setter + def add_eos_token(self, value): + self._add_eos_token = value + self.update_post_processor() + + @add_bos_token.setter + def add_bos_token(self, value): + self._add_bos_token = value + self.update_post_processor() + + def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + copyfile(self.vocab_file, out_vocab_file) + + return (out_vocab_file,) + + @property + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.default_chat_template + def default_chat_template(self): + """ + LLaMA uses [INST] and [/INST] to indicate user messages, and <> and <> to indicate system messages. + Assistant messages do not have special tokens, because LLaMA chat models are generally trained with strict + user/assistant/user/assistant message ordering, and so assistant messages can be identified from the ordering + rather than needing special tokens. The system message is partly 'embedded' in the first user message, which + results in an unusual token ordering when it is present. This template should definitely be changed if you wish + to fine-tune a model with more flexible role ordering! + + The output should look something like: + + [INST] B_SYS SystemPrompt E_SYS Prompt [/INST] Answer [INST] Prompt [/INST] Answer + [INST] Prompt [/INST] + """ + + template = ( + "{% if messages[0]['role'] == 'system' %}" + "{% set loop_messages = messages[1:] %}" # Extract system message if it's present + "{% set system_message = messages[0]['content'] %}" + "{% elif USE_DEFAULT_PROMPT == true and not '<>' in messages[0]['content'] %}" + "{% set loop_messages = messages %}" # Or use the default system message if the flag is set + "{% set system_message = 'DEFAULT_SYSTEM_MESSAGE' %}" + "{% else %}" + "{% set loop_messages = messages %}" + "{% set system_message = false %}" + "{% endif %}" + "{% for message in loop_messages %}" # Loop over all non-system messages + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + "{% endif %}" + "{% if loop.index0 == 0 and system_message != false %}" # Embed system message in first message + "{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}" + "{% else %}" + "{% set content = message['content'] %}" + "{% endif %}" + "{% if message['role'] == 'user' %}" # After all of that, handle messages/roles in a fairly normal way + "{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}" + "{% elif message['role'] == 'system' %}" + "{{ '<>\\n' + content.strip() + '\\n<>\\n\\n' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ ' ' + content.strip() + ' ' + eos_token }}" + "{% endif %}" + "{% endfor %}" + ) + template = template.replace("USE_DEFAULT_PROMPT", "true" if self.use_default_system_prompt else "false") + default_message = DEFAULT_SYSTEM_PROMPT.replace("\n", "\\n").replace("'", "\\'") + template = template.replace("DEFAULT_SYSTEM_MESSAGE", default_message) + + return template + + # TODO ArthurZ let's rely on the template processor instead, refactor all fast tokenizers + # Copied from transformers.models.llama.tokenization_llama.LlamaTokenizer.build_inputs_with_special_tokens + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output diff --git a/codeclm/models/lm_levo.py b/codeclm/models/lm_levo.py new file mode 100755 index 0000000000000000000000000000000000000000..42a5b91c2319c5cbab7ea82459ac350b0d3379cd --- /dev/null +++ b/codeclm/models/lm_levo.py @@ -0,0 +1,546 @@ + +import torch +import math +import random +import torch.nn as nn +import typing as tp +import torch.nn.functional as F +from dataclasses import dataclass +from codeclm.models.levo import CausalLM, LlamaConfig +from codeclm.modules.streaming import StreamingModule +from codeclm.modules.conditioners import ( + ConditioningAttributes, + AudioCondition, + ConditionType, + ConditionerProvider, + ConditionFuser, + ClassifierFreeGuidanceDropoutInference, + ClassifierFreeGuidanceDropout, + AttributeDropout, +) +from codeclm.utils.utils import create_norm_fn, init_layer, sample_top_k, sample_top_p, multinomial +from codeclm.modules.pattern import CodebooksPatternProvider +ConditionTensors = tp.Dict[str, ConditionType] + +@dataclass +class LMOutput: + # The logits are already re-aligned with the input codes + # hence no extra shift is required, e.g. when computing CE + logits: torch.Tensor # [B, K, T, card] + mask: torch.Tensor # [B, K, T] + + +class LmModel(StreamingModule): + """Transformer-based language model on multiple streams of codes. + + Args: + pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving. + condition_provider (ConditioningProvider): Conditioning provider from metadata. + fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input. + code_depth (int): Number of parallel streams to model. + code_size (int): Cardinality, vocabulary size. + dim (int): Dimension of the transformer encoder. + num_heads (int): Number of heads for the transformer encoder. + hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. + norm (str): Normalization method. + norm_first (bool): Use pre-norm instead of post-norm. + emb_lr (float, optional): Embedding-specific learning rate. + bias_proj (bool): Use bias for output projections. + weight_init (str, optional): Method for weight initialization. + depthwise_init (str, optional): Method for depthwise weight initialization. + zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. + cfg_dropout (float): Classifier-free guidance dropout. + cfg_coef (float): Classifier-free guidance coefficient. + attribute_dropout (dict): Attribute dropout probabilities. + two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps. + **kwargs: Additional parameters for the transformer encoder. + """ + def __init__(self, + pattern_provider: CodebooksPatternProvider, + condition_provider: ConditionerProvider, + fuser: ConditionFuser, + code_depth: int = 8, + code_size: int = 1024, + dim: int = 128, + intermediate_size: int = 4096, + num_heads: int = 8, + norm: str = 'layer_norm', norm_first: bool = False, + bias_proj: bool = True, + weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None, + zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0, + attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, + lm_type = 'Llama', + num_layers=16, + cfg = None, + **kwargs): + super().__init__() + + self.cfg_coef = cfg_coef + + self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout,seed=random.randint(0, 9999)) + self.att_dropout = AttributeDropout(p=attribute_dropout,seed=random.randint(0, 9999)) + self.condition_provider = condition_provider + self.fuser = fuser + self.code_size = code_size + 1 # + EOS + input_emb_dim = code_size + 2 # EOP + self.code_depth = code_depth + self.dim = dim + self.cfg = cfg + self.pattern_provider = pattern_provider + self.emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim)]) + # if 'activation' in kwargs: + # kwargs['activation'] = get_activation_fn(kwargs['activation']) + + model_cfg = LlamaConfig( + hidden_size=dim, + intermediate_size = intermediate_size, + num_attention_heads = num_heads, + num_hidden_layers = num_layers, + num_key_value_heads = num_heads, + vocab_size = self.code_size, + use_cache=False, + max_position_embeddings=8196, + _flash_attn_2_enabled=True, + rms_norm_eps= 1e-5, + rope_theta= 100000.0, + use_flash_attn_2=True, + attn_implementation="flash_attention_2" + ) + + self.transformer = CausalLM(model_cfg) + self.mlp = nn.Sequential( + nn.Linear(dim * 2, dim), + nn.GELU(), + nn.Linear(dim, dim) + ) + self.layer2_emb = nn.ModuleList([nn.Embedding(input_emb_dim, dim) #, lr=emb_lr) + for _ in range(self.code_depth)]) + sub_model_cfg = LlamaConfig( + hidden_size=dim, + intermediate_size = intermediate_size, + num_attention_heads = num_heads, + num_hidden_layers = 12, + num_key_value_heads = num_heads, + vocab_size = self.code_size, + use_cache=False, + max_position_embeddings=10000, + rms_norm_eps= 1e-5, + rope_theta= 500000.0, + _flash_attn_2_enabled=True, + use_flash_attn_2=True, + attn_implementation="flash_attention_2" + ) + self.transformer2 = CausalLM(sub_model_cfg) + self.out_norm: tp.Optional[nn.Module] = None + if norm_first: + self.out_norm = create_norm_fn(norm, dim) + # enable EOS prediction + if code_depth > 1: + self.linears = nn.ModuleList([nn.Linear(dim, self.code_size, bias=False) + for _ in range(code_depth - 1)]) + + self._init_weights(weight_init, depthwise_init, zero_bias_init) + self._fsdp: tp.Optional[nn.Module] + self.__dict__['_fsdp'] = None + + self.reset_streaming() + + def _init_weights(self, weight_init: tp.Optional[str], + depthwise_init: tp.Optional[str], zero_bias_init: bool): + """Initialization of the transformer module weights. + + Args: + weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. + depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: + 'current' where the depth corresponds to the current layer index or 'global' where the total number + of layer is used as depth. If not set, no depthwise initialization strategy is used. + zero_bias_init (bool): Whether to initialize bias to zero or not. + """ + assert depthwise_init is None or depthwise_init in ['current', 'global'] + assert depthwise_init is None or weight_init is not None, \ + "If 'depthwise_init' is defined, a 'weight_init' method should be provided." + assert not zero_bias_init or weight_init is not None, \ + "If 'zero_bias_init', a 'weight_init' method should be provided" + + if weight_init is None: + return + + for emb_layer in self.emb: + init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) + + + @property + def special_token_id(self) -> int: + return self.code_size # 10001 + + @property + def eos_token_id(self) -> int: + return self.code_size-1 # 10000 + + @torch.no_grad() + def prepare_condition_tensors(self, + batch_size = 1, + text: tp.Optional[tp.List[str]] = None, + descriptions: tp.Optional[tp.List[str]] = None, + audio_qt_emb: tp.Optional[tp.List[torch.Tensor]] = None, + prepare_null_condition = False, + ): + if self.training: + attributes = [] + for i in range(batch_size): + attr = ConditioningAttributes() + if 'description' in self.condition_provider.conditioners: + attr["text"]["description"] = "" + if text is not None: + attr["text"]["description"] = text[i] + if 'prompt_audio' in self.condition_provider.conditioners: + mask = (audio_qt_emb[[i], :, 0] == 16385).bool().unsqueeze(-1) + audio_qt_seq = torch.cat([torch.full_like(audio_qt_emb[i][None][:,:,0], self.eos_token_id).unsqueeze(-1), audio_qt_emb[i][None]], dim=-1) + mask = mask.repeat(1, 1, audio_qt_seq.shape[-1]) + audio_qt_seq[mask] = 16385 + attr["audio"]['prompt_audio'] = AudioCondition( + wav=audio_qt_seq.long(), + length=torch.Tensor([audio_qt_seq.shape[-1]]).long(), + sample_rate=[self.cfg.sample_rate],) + if 'type_info' in self.condition_provider.conditioners: + attr["text"]["type_info"] = "" + if descriptions is not None: + attr["text"]["type_info"] = descriptions[i] + attributes.append(attr) + # print("before cfg dropout", attributes) + attributes = self.cfg_dropout(attributes) # drop ALL conditions + # print("after cfg dropout", attributes) + attributes = self.att_dropout(attributes) # selectively drop some attributes (text, wav, or more fine-grained) + # print("after attribute dropout", attributes) + # attribute to discrete tokenized ids + tokenized = self.condition_provider.tokenize(attributes) + # print("after tokenize", attributes) + # discrete tokenized ids to continuous embeddings + condition_tensors = self.condition_provider(tokenized) + else: + conditions = [] + for i in range(batch_size): + attr = ConditioningAttributes() + if 'description' in self.condition_provider.conditioners: + attr["text"]["description"] = "" + if text is not None: + attr["text"]["description"] = text[i] + if 'prompt_audio' in self.condition_provider.conditioners: + mask = (audio_qt_emb[[i], :, 0] == 16385).bool().unsqueeze(-1) + audio_qt_seq = torch.cat([torch.full_like(audio_qt_emb[i][None][:,:,0], self.eos_token_id).unsqueeze(-1), audio_qt_emb[i][None]], dim=-1) + mask = mask.repeat(1, 1, audio_qt_seq.shape[-1]) + audio_qt_seq[mask] = 16385 + attr["audio"]['prompt_audio'] = AudioCondition( + wav=audio_qt_seq.long().cuda(), + length=torch.Tensor([audio_qt_seq.shape[-1]]).long(), + sample_rate=[self.cfg.sample_rate],) + if 'type_info' in self.condition_provider.conditioners: + attr["text"]["type_info"] = "" + if descriptions is not None: + attr["text"]["type_info"] = descriptions[i] + conditions.append(attr) + print("conditions", conditions) + if prepare_null_condition: + cfg_inference = ClassifierFreeGuidanceDropoutInference() + null_conditions = cfg_inference(conditions, condition_types=["audio", "text"], + customized=None) + conditions = conditions + null_conditions + tokenized_conditions = self.condition_provider.tokenize(conditions) + condition_tensors = self.condition_provider(tokenized_conditions) + return condition_tensors + + def forward(self, + sequence: torch.Tensor, + condition_tensors: ConditionTensors) -> torch.Tensor: + """Apply language model on sequence and conditions. + Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and + S the sequence steps, return the logits with shape [B, card, K, S]. + + Args: + indices (torch.Tensor): Indices of the codes to model. + condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning + tensors, see `conditions`. + Returns: + torch.Tensor: Logits. + """ + + # import pdb; pdb.set_trace() + B, K, S = sequence.shape + assert K == self.code_depth, "Sequence shape must match the specified number of codebooks" + input_1 = self.emb[0](sequence[:, 0]) + input_2 = sum([self.layer2_emb[k](sequence[:, k]) for k in range(1, K)]) + fused_input1, fused_input2 = self.fuser(input_1, input_2, condition_tensors) + output = self.transformer(inputs_embeds=fused_input1, + use_cache=self._is_streaming, + past_key_values=self._streaming_state.get('past_key_values_1', None)) + if self._is_streaming: + self._streaming_state['past_key_values_1'] = output.past_key_values + logits = output.logits # [B, S, card] + logits = logits.unsqueeze(1) # [B, 1, S, card] + + # if self.out_norm: + # out = self.out_norm(out.to(self.out_norm.weight.data.dtype)) + if K > 1: + fused_input2 = torch.cat([fused_input2, output.hidden_states], dim=-1) + fused_input2 = self.mlp(fused_input2) + output2 = self.transformer2(inputs_embeds=fused_input2, + use_cache=self._is_streaming, + past_key_values=self._streaming_state.get('past_key_values_2', None)) + if self._is_streaming: + self._streaming_state['past_key_values_2'] = output2.past_key_values + + res_logits = torch.stack([self.linears[k](output2.hidden_states) for k in range(K - 1)], dim=1) # [B, K, S, card] # [B, K, S, card] + logits = torch.cat([logits, res_logits], dim=1) # [B, K, S, card] + + # remove the prefix from the model outputs + if len(self.fuser.fuse2cond['prepend']) > 0: + logits = logits[:, :, -S:, :] + + return logits # [B, K, S, card] + + def compute_predictions(self, + codes: torch.Tensor, + condition_tensors: tp.Optional[ConditionTensors] = None, + **kwargs, + ): # this function is called during training + """Given an input tensor of codes [B, K, T] and list of conditions, runs the model + forward using the specified codes interleaving pattern. + + Args: + codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size, + K the number of codebooks and T the number of timesteps. + condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning + tensors, see `conditions`. + Returns: + LMOutput: Language model outputs + logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, + i.e. the first item corresponds to logits to predict the first code, meaning that + no additional shifting of codes and logits is required. + mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions. + Given the specified interleaving strategies, parts of the logits and codes should + not be considered as valid predictions because of invalid context. + """ + B, K, T = codes.shape + codes = codes.contiguous() + # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens + pattern = self.pattern_provider.get_pattern(T) + sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence( + codes, self.special_token_id, keep_only_valid_steps=False + ) + model = self if self._fsdp is None else self._fsdp + logits = model(sequence_codes, condition_tensors) # [B, K, S, card] + # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card] + # and provide the corresponding mask over invalid positions of tokens + logits = logits.permute(0, 3, 1, 2) # [B, card, K, S] + # note: we use nans as special token to make it obvious if we feed unexpected logits + logits, logits_indexes, logits_mask = pattern.revert_pattern_logits( + logits, float('nan'), keep_only_valid_steps=False + ) + logits = logits.permute(0, 2, 3, 1) # [B, K, T, card] + logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T] + + return LMOutput(logits, logits_mask) + + @torch.no_grad() + def generate(self, # + # conditions: tp.List[ConditioningAttributes] = [], + texts = None, + descriptions = None, + audio_qt_embs = None, + num_samples: tp.Optional[int] = None, + max_gen_len: int = 256, + use_sampling: bool = True, + temp: float = 1.0, + top_k: int = 250, + top_p: float = 0.0, + cfg_coef: tp.Optional[float] = None, + check: bool = False, + record_tokens: bool = True, + record_window: int = 150 + ) -> torch.Tensor: + """Generate tokens sampling from the model given a prompt or unconditionally. Generation can + be perform in a greedy fashion or using sampling with top K and top P strategies. + + Args: + prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T]. + conditions_tensors (list of ConditioningAttributes, optional): List of conditions. + num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given. + max_gen_len (int): Maximum generation length. + use_sampling (bool): Whether to use a sampling strategy or not. + temp (float): Sampling temperature. + top_k (int): K for "top-k" sampling. + top_p (float): P for "top-p" sampling. + cfg_coeff (float, optional): Classifier-free guidance coefficient. + check (bool): Whether to apply further checks on generated sequence. + callback (Callback, optional): Callback function to report generation progress. + Returns: + torch.Tensor: Generated tokens. + """ + assert not self.training, "generation shouldn't be used in training mode." + first_param = next(iter(self.parameters())) + device = first_param.device + + # 1) Check input shapes are consistent + possible_num_samples = [] + if num_samples is not None: + possible_num_samples.append(num_samples) + elif texts: + possible_num_samples.append(len(texts)) + elif audio_qt_embs: + possible_num_samples.append(len(audio_qt_embs)) + else: + possible_num_samples.append(1) + assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" + num_samples = possible_num_samples[0] + condition_tensors = self.prepare_condition_tensors(batch_size=1, text=texts, descriptions=descriptions, audio_qt_emb=audio_qt_embs, prepare_null_condition=True) + # 3) Prepare token pool + record_token_pool = None + if record_tokens: + record_token_pool = [] + + # 4) set up startoff patterns + start_offset = 0 + assert start_offset < max_gen_len, f"{start_offset}, {max_gen_len}" + pattern = self.pattern_provider.get_pattern(max_gen_len) + # this token is used as default value for codes that are not generated yet + unknown_token = -1 + # we generate codes up to the max_gen_len that will be mapped to the pattern sequence + B = num_samples + gen_codes = torch.full((B, self.code_depth, max_gen_len), + unknown_token, dtype=torch.long, device=device) + # create the gen_sequence with proper interleaving from the pattern: [B, K, S] + gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id) + output_codes = torch.full_like(gen_sequence, self.code_size) + # retrieve the start_offset in the sequence: + # it is the first sequence step that contains the `start_offset` timestep + start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) + assert start_offset_sequence is not None + is_end = torch.zeros((B, self.code_depth, 1)).bool().to(device) + ignore_tokens = audio_qt_embs[0][0] + # 5) auto-regressive sampling + with self.streaming(): + gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S] + prev_offset = 0 + for offset in range(start_offset_sequence, gen_sequence_len): + # get current sequence (note that the streaming API is providing the caching over previous offsets) + curr_sequence = gen_sequence[..., prev_offset:offset] + curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1) + if check: + # check coherence between mask and sequence + assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all() + # should never happen as gen_sequence is filled progressively + assert not (curr_sequence == unknown_token).any() + # sample next token from the model, next token shape is [B, K, 1] + next_token = self._sample_next_token( + curr_sequence, condition_tensors, use_sampling, temp, top_k, top_p, + cfg_coef=cfg_coef, + sampled_token_pool=record_token_pool[-record_window:] if record_tokens else None, + ignore_tokens = ignore_tokens + ) + # ensure the tokens that should be masked are properly set to special_token_id + # as the model never output special_token_id + valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) + next_token[~valid_mask] = self.special_token_id + # 检查eos id + next_token[is_end] = self.special_token_id + is_end = is_end | (next_token == self.eos_token_id) + # ensure we don't overwrite prompt tokens, we only write over unknown tokens + # (then mask tokens should be left as is as well, which is correct) + gen_sequence[..., offset:offset+1] = torch.where( + gen_sequence[..., offset:offset+1] == unknown_token, + next_token, gen_sequence[..., offset:offset+1]) + + # record sampled tokens in a window + if record_tokens: + record_token_pool.append(next_token.squeeze()) + if torch.all(is_end): + gen_sequence = gen_sequence[..., :offset+1] + break + + prev_offset = offset + + # ensure sequence has been entirely filled + assert not (gen_sequence == unknown_token).any() + max_gen_len = gen_sequence.shape[-1] + output_codes[..., :max_gen_len] = gen_sequence + # ensure gen_sequence pattern and mask are matching + # which means the gen_sequence is valid according to the pattern + # assert (gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, + # self.special_token_id) + # ).all() + # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps + out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(output_codes, special_token=unknown_token) + # sanity checks over the returned codes and corresponding masks + assert (out_codes != unknown_token).all() + assert (out_mask == 1).all() + # ensure the returned codes are all valid + assert (out_codes >= 0).all() and (out_codes <= self.code_size).all() + return out_codes + + def _sample_next_token(self, + sequence: torch.Tensor, + condition_tensors: ConditionTensors, + use_sampling: bool = False, + temp: float = 1.0, + top_k: int = 0, + top_p: float = 0.0, + cfg_coef: tp.Optional[float] = None, + sampled_token_pool: tp.Optional[list] = None, + ignore_tokens: tp.Optional[torch.tensor] = torch.tensor([])) -> torch.Tensor: + """Sample next token from the model given a sequence and a set of conditions. The model supports + multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). + + Args: + sequence (torch.Tensor): Current sequence of shape [B, K, S] + with K corresponding to the number of codebooks and S the number of sequence steps. + S = 1 in streaming mode, except for the first step that contains a bigger prompt. + condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used, + should be twice the batch size, being the concatenation of the conditions + null conditions. + use_sampling (bool): Whether to use a sampling strategy or not. + temp (float): Sampling temperature. + top_k (int): K for "top-k" sampling. + top_p (float): P for "top-p" sampling. + cfg_coef (float, optional): classifier free guidance coefficient + Returns: + next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. + """ + # import pdb; pdb.set_trace() + B = sequence.shape[0] + cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef + model = self if self._fsdp is None else self._fsdp + + # Preparing for CFG, predicting both conditional and unconditional logits. + sequence = torch.cat([sequence, sequence], dim=0) + all_logits = model(sequence, condition_tensors=condition_tensors) + cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card] + logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef + + logits = logits.permute(0, 1, 3, 2) # [B, K, card, T] + logits = logits[..., -1] # [B x K x card] + + # add punishment to pre-sampled tokens + if sampled_token_pool is not None and len(sampled_token_pool) > 0: + sampled_token_pool = torch.stack(sampled_token_pool, -1) # [K, T] + for q in range(self.code_depth): + # q_count = torch.bincount(sampled_token_pool) + q_count = torch.bincount(torch.unique(sampled_token_pool[q])) + tmp = min(q_count.shape[-1], self.code_size - 1) + logits[:, q, :tmp] /= (1.1 ** q_count[:tmp]) + + # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error. + if(ignore_tokens is not None): + logits[0][0][ignore_tokens.to(torch.int)] = float('-inf') + if use_sampling and temp > 0.0: + probs = torch.softmax(logits / temp, dim=-1) + if top_p > 0.0: + next_token = sample_top_p(probs, p=top_p) + elif top_k > 0: + next_token_first = sample_top_k(probs[:,[0],:], k=top_k) + next_token_res = sample_top_k(probs[:,1:,:], k=1) + next_token = torch.cat([next_token_first,next_token_res], dim = 1) + else: + next_token = multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(logits, dim=-1, keepdim=True) + + return next_token diff --git a/codeclm/modules/conditioners.py b/codeclm/modules/conditioners.py new file mode 100755 index 0000000000000000000000000000000000000000..cbf2ccf7097f0f5ba17746150028fa741ca6d95f --- /dev/null +++ b/codeclm/modules/conditioners.py @@ -0,0 +1,883 @@ + +import typing as tp +import torch +import torch.nn as nn +from dataclasses import dataclass, field, fields +from itertools import chain +import warnings +import torch.nn.functional as F +from torch.nn.utils.rnn import pad_sequence +from codeclm.utils.utils import length_to_mask, collate +from codeclm.modules.streaming import StreamingModule +from collections import defaultdict +from copy import deepcopy +ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask + +# ================================================================ +# Condition and Condition attributes definitions +# ================================================================ +class AudioCondition(tp.NamedTuple): + wav: torch.Tensor + length: torch.Tensor + sample_rate: tp.List[int] + path: tp.List[tp.Optional[str]] = [] + seek_time: tp.List[tp.Optional[float]] = [] + +@dataclass +class ConditioningAttributes: + text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) + audio: tp.Dict[str, AudioCondition] = field(default_factory=dict) + + def __getitem__(self, item): + return getattr(self, item) + + @property + def text_attributes(self): + return self.text.keys() + + @property + def audio_attributes(self): + return self.audio.keys() + + @property + def attributes(self): + return { + "text": self.text_attributes, + "audio": self.audio_attributes, + } + + def to_flat_dict(self): + return { + **{f"text.{k}": v for k, v in self.text.items()}, + **{f"audio.{k}": v for k, v in self.audio.items()}, + } + + @classmethod + def from_flat_dict(cls, x): + out = cls() + for k, v in x.items(): + kind, att = k.split(".") + out[kind][att] = v + return out + +# ================================================================ +# Conditioner (tokenize and encode raw conditions) definitions +# ================================================================ + +class BaseConditioner(nn.Module): + """Base model for all conditioner modules. + We allow the output dim to be different than the hidden dim for two reasons: + 1) keep our LUTs small when the vocab is large; + 2) make all condition dims consistent. + + Args: + dim (int): Hidden dim of the model. + output_dim (int): Output dim of the conditioner. + """ + def __init__(self, dim: int, output_dim: int, input_token = False, padding_idx=0): + super().__init__() + self.dim = dim + self.output_dim = output_dim + if input_token: + self.output_proj = nn.Embedding(dim, output_dim, padding_idx) + else: + self.output_proj = nn.Linear(dim, output_dim) + + def tokenize(self, *args, **kwargs) -> tp.Any: + """Should be any part of the processing that will lead to a synchronization + point, e.g. BPE tokenization with transfer to the GPU. + + The returned value will be saved and return later when calling forward(). + """ + raise NotImplementedError() + + def forward(self, inputs: tp.Any) -> ConditionType: + """Gets input that should be used as conditioning (e.g, genre, description or a waveform). + Outputs a ConditionType, after the input data was embedded as a dense vector. + + Returns: + ConditionType: + - A tensor of size [B, T, D] where B is the batch size, T is the length of the + output embedding and D is the dimension of the embedding. + - And a mask indicating where the padding tokens. + """ + raise NotImplementedError() + +class TextConditioner(BaseConditioner): + ... + + +class PhonemeTokenizerConditioner(TextConditioner): + def __init__(self, + output_dim: int, + vocab_list, + max_len = 600, + max_sentence_per_structure = 50, + structure_tokens=None, + structure_split_tokens=[','], + sentence_split_tokens=['.'], + mode='sum', + structure_output_dim = 64, + sentence_output_dim = 64, + max_duration = 120, + ): + + self.vocab_list = vocab_list + self.max_len = max_len + self.mode = mode + self.max_sentence_per_structure = max_sentence_per_structure + voc_size = len(self.vocab_list) + + if structure_tokens is None: + structure_tokens = [i for i in vocab_list if len(i) > 1 and i[0] == '[' and i[-1] == ']'] + self.structure_token_ids = [vocab_list.index(i) for i in structure_tokens if i in vocab_list] + self.structure_split_token_ids = [vocab_list.index(i) for i in structure_split_tokens] + self.sentence_split_token_ids = [vocab_list.index(i) for i in sentence_split_tokens] + + # here initialize a output_proj (nn.Embedding) layer + # By default the first vocab is "" (null) + if mode == 'sum': + content_output_dim = output_dim + sentence_output_dim = output_dim + structure_output_dim = output_dim + else: # concat' + raise NotImplementedError("concat 模式还未实现") + # content_output_dim = output_dim - sentence_output_dim - structure_output_dim # by default + + super().__init__(voc_size, content_output_dim, input_token=True, padding_idx=0) + self.special_emb = nn.Embedding(voc_size, structure_output_dim, padding_idx=0) + + self.blank_emb = nn.Parameter(torch.zeros(1, output_dim), requires_grad=False) + + # the first index is "empty structure" token + self.sentence_idx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim) + self.sentence_reidx_in_structure_emb = nn.Embedding(max_sentence_per_structure, sentence_output_dim) + + print("max_len", self.max_len) + print(self.structure_token_ids) + + self.resolution = max_duration / max_len # e.g., 120 / 600 = 0.2s + print(self.__class__, f"resolution = {self.resolution}") + + def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: + inputs = [] + for xx in x: + xx = '' if xx is None else xx + vocab_id = [self.vocab_list.index(item) for item in xx.split(" ") if item in self.vocab_list] + inputs.append(torch.tensor(vocab_id).long()) # [T] + return inputs + + + def forward(self, batch_tokens: tp.List, structure_dur = None) -> ConditionType: + """ + Encode token_id into three types of embeddings: + 1) content embedding: phoneme only (or meaningful contents to be sung out) + 2) structure embedding: structure / separation embeddings, including structures (verse/chorus/...), separators (. / ,) + The two above share the same embedding layer, can be changed to separate embedding layers. + 3) sentence_idx embedding (per structure): + """ + embeds_batch = [] + for b in range(len(batch_tokens)): + tokens = batch_tokens[b] + content_tokens = torch.zeros_like(tokens) + special_tokens = torch.zeros_like(tokens) + sentence_idx_in_structure_tokens = torch.zeros_like(tokens) + sentence_reidx_in_structure_tokens = torch.zeros_like(tokens) + + current_sentence_in_structure_idx = 1 + current_structure = 0 + for i in range(tokens.shape[-1]): + token = tokens[i] + if token in self.structure_token_ids: # structure token + # only update structure token, leave content and sentence index token null (default 0) + special_tokens[i] = token + content_tokens[i] = token + current_structure = token + current_sentence_in_structure_idx = 1 + sentence_idx_in_structure_tokens[i] = 0 + + elif token in self.sentence_split_token_ids: # utterance split token + # only update structure token, leave content and sentence index token null (default 0) + # add up sentence index + special_tokens[i] = current_structure + content_tokens[i] = token + sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1) + current_sentence_in_structure_idx += 1 + + elif token in self.structure_split_token_ids: # structure split token + # update structure token (current structure), content token (current token), + # blank index token + content_tokens[i] = token + special_tokens[i] = current_structure + sentence_idx_in_structure_tokens[i] = sentence_idx_in_structure_tokens[i-1] + else: # content tokens + content_tokens[i] = token + special_tokens[i] = current_structure + sentence_idx_in_structure_tokens[i] = min(current_sentence_in_structure_idx, self.max_sentence_per_structure - 1) + # 反推 + current_sentence_num = sentence_idx_in_structure_tokens[-1] + for i in range(tokens.shape[-1]-1,-1,-1): + if current_sentence_num != 0: + sentence_reidx_in_structure_tokens[i] = min(current_sentence_num + 1 - sentence_idx_in_structure_tokens[i], self.max_sentence_per_structure - 1) + if sentence_idx_in_structure_tokens[i] == 0 and i > 0: + current_sentence_num = sentence_idx_in_structure_tokens[i-1] + + # print("tokens", tokens.max(), tokens.min()) + # print("special tokens", special_tokens.max(), special_tokens.min()) + # print("sentence idx in structure", sentence_idx_in_structure_tokens.max(), sentence_idx_in_structure_tokens.min()) + device = self.output_proj.weight.device + + # import pdb; pdb.set_trace() + content_embeds = self.output_proj(content_tokens.to(device)) # [T, N] + structure_embeds = self.output_proj(special_tokens.to(device)) + # sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device)) + sentence_idx_embeds = self.sentence_idx_in_structure_emb(sentence_idx_in_structure_tokens.to(device)) + self.sentence_reidx_in_structure_emb(sentence_reidx_in_structure_tokens.to(device)) + + if self.mode == 'sum': + embeds = content_embeds + structure_embeds + sentence_idx_embeds + else: + embeds = torch.cat((content_embeds, structure_embeds, sentence_idx_embeds), -1) # [T, N] + embeds_batch.append(embeds) + + # set batch_size = 1, [B, T, N] + if self.max_len is not None: + max_len = self.max_len + else: + max_len = max([e.shape[0] for e in embeds_batch]) + embeds, mask = self.pad_2d_tensor(embeds_batch, max_len) + + return embeds, embeds, mask + + + def pad_2d_tensor(self, xs, max_len): + new_tensor = [] + new_mask = [] + for x in xs: + seq_len, dim = x.size() + pad_len = max_len - seq_len + + if pad_len > 0: + pad_tensor = self.blank_emb.repeat(pad_len, 1).to(x.device) # T, D + padded_tensor = torch.cat([x, pad_tensor], dim=0) + mask = torch.cat((torch.ones_like(x[:, 0]), + torch.zeros_like(pad_tensor[:, 0])), 0) # T + elif pad_len < 0: + padded_tensor = x[:max_len] + mask = torch.ones_like(padded_tensor[:, 0]) + else: + padded_tensor = x + mask = torch.ones_like(x[:, 0]) + + new_tensor.append(padded_tensor) + new_mask.append(mask) + # [B, T, D] & [B, T] + return torch.stack(new_tensor, 0), torch.stack(new_mask, 0) + + +class QwTokenizerConditioner(TextConditioner): + def __init__(self, output_dim: int, + token_path = "", + max_len = 300, + add_token_list=[]): #"" + from transformers import Qwen2Tokenizer + self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path) + if add_token_list != []: + self.text_tokenizer.add_tokens(add_token_list, special_tokens=True) + voc_size = len(self.text_tokenizer.get_vocab()) + # here initialize a output_proj (nn.Embedding) layer + super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643) + self.max_len = max_len + self.padding_idx =' <|endoftext|>' + + vocab = self.text_tokenizer.get_vocab() + # struct是全部的结构 + struct_tokens = [i for i in add_token_list if i[0]=='[' and i[-1]==']'] + self.struct_token_ids = [vocab[i] for i in struct_tokens] + self.pad_token_idx = 151643 + + self.structure_emb = nn.Embedding(200, output_dim, padding_idx=0) + # self.split_token_id = vocab["."] + print("all structure tokens: ", {self.text_tokenizer.convert_ids_to_tokens(i):i for i in self.struct_token_ids}) + + def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: + x = ['<|im_start|>' + xi if xi is not None else "<|im_start|>" for xi in x] + # x = [xi if xi is not None else "" for xi in x] + inputs = self.text_tokenizer(x, return_tensors="pt", padding=True) + return inputs + + def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: + """ + Add structure embeddings of {verse, chorus, bridge} to text/lyric tokens that + belong to these structures accordingly, + Then delete or keep these structure embeddings. + """ + mask = inputs['attention_mask'] + tokens = inputs['input_ids'] + B = tokens.shape[0] + is_sp_embed = torch.any(torch.stack([tokens == i for i in self.struct_token_ids], dim=-1),dim=-1) + + tp_cover_range = torch.zeros_like(tokens) + for b, is_sp in enumerate(is_sp_embed): + sp_list = torch.where(is_sp)[0].tolist() + sp_list.append(mask[b].sum()) + for i, st in enumerate(sp_list[:-1]): + tp_cover_range[b, st: sp_list[i+1]] = tokens[b, st] - 151645 + + if self.max_len is not None: + if inputs['input_ids'].shape[-1] > self.max_len: + warnings.warn(f"Max len limit ({self.max_len}) Exceed! \ + {[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!") + tokens = self.pad_2d_tensor(tokens, self.max_len, self.pad_token_idx).to(self.output_proj.weight.device) + mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device) + tp_cover_range = self.pad_2d_tensor(tp_cover_range, self.max_len, 0).to(self.output_proj.weight.device) + device = self.output_proj.weight.device + content_embeds = self.output_proj(tokens.to(device)) + structure_embeds = self.structure_emb(tp_cover_range.to(device)) + + embeds = content_embeds + structure_embeds + return embeds, embeds, mask + + def pad_2d_tensor(self, x, max_len, pad_id): + batch_size, seq_len = x.size() + pad_len = max_len - seq_len + + if pad_len > 0: + pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device) + padded_tensor = torch.cat([x, pad_tensor], dim=1) + elif pad_len < 0: + padded_tensor = x[:, :max_len] + else: + padded_tensor = x + + return padded_tensor + + +class QwTextConditioner(TextConditioner): + def __init__(self, output_dim: int, + token_path = "", + max_len = 300): #"" + + from transformers import Qwen2Tokenizer + self.text_tokenizer = Qwen2Tokenizer.from_pretrained(token_path) + voc_size = len(self.text_tokenizer.get_vocab()) + # here initialize a output_proj (nn.Embedding) layer + super().__init__(voc_size, output_dim, input_token=True, padding_idx=151643) + + self.max_len = max_len + + def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: + x = ['<|im_start|>' + xi if xi is not None else "<|im_start|>" for xi in x] + inputs = self.text_tokenizer(x, return_tensors="pt", padding=True) + return inputs + + def forward(self, inputs: tp.Dict[str, torch.Tensor], structure_dur = None) -> ConditionType: + """ + Add structure embeddings of {verse, chorus, bridge} to text/lyric tokens that + belong to these structures accordingly, + Then delete or keep these structure embeddings. + """ + mask = inputs['attention_mask'] + tokens = inputs['input_ids'] + + if self.max_len is not None: + if inputs['input_ids'].shape[-1] > self.max_len: + warnings.warn(f"Max len limit ({self.max_len}) Exceed! \ + {[self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in tokens]} will be cut!") + tokens = self.pad_2d_tensor(tokens, self.max_len, 151643).to(self.output_proj.weight.device) + mask = self.pad_2d_tensor(mask, self.max_len, 0).to(self.output_proj.weight.device) + + embeds = self.output_proj(tokens) + return embeds, embeds, mask + + def pad_2d_tensor(self, x, max_len, pad_id): + batch_size, seq_len = x.size() + pad_len = max_len - seq_len + + if pad_len > 0: + pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device) + padded_tensor = torch.cat([x, pad_tensor], dim=1) + elif pad_len < 0: + padded_tensor = x[:, :max_len] + else: + padded_tensor = x + + return padded_tensor + + +class AudioConditioner(BaseConditioner): + ... + +class QuantizedEmbeddingConditioner(AudioConditioner): + def __init__(self, dim: int, + code_size: int, + code_depth: int, + max_len: int, + **kwargs): + super().__init__(dim, dim, input_token=True) + self.code_depth = code_depth + # add 1 for token + self.emb = nn.ModuleList([nn.Embedding(code_size+2, dim, padding_idx=code_size+1) for _ in range(code_depth)]) + # add End-Of-Text embedding + self.EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True) + self.layer2_EOT_emb = nn.Parameter(torch.randn(1, dim), requires_grad=True) + self.output_proj = None + self.max_len = max_len + self.vocab_size = code_size + + def tokenize(self, x: AudioCondition) -> AudioCondition: + """no extra ops""" + # wav, length, sample_rate, path, seek_time = x + # assert length is not None + return x #AudioCondition(wav, length, sample_rate, path, seek_time) + + def forward(self, x: AudioCondition): + wav, lengths, *_ = x + B = wav.shape[0] + wav = wav.reshape(B, self.code_depth, -1).long() + if wav.shape[2] < self.max_len - 1: + wav = F.pad(wav, [0, self.max_len - 1 - wav.shape[2]], value=self.vocab_size+1) + else: + wav = wav[:, :, :self.max_len-1] + embeds1 = self.emb[0](wav[:, 0]) + embeds1 = torch.cat((self.EOT_emb.unsqueeze(0).repeat(B, 1, 1), + embeds1), dim=1) + embeds2 = sum([self.emb[k](wav[:, k]) for k in range(1, self.code_depth)]) # B,T,D + embeds2 = torch.cat((self.layer2_EOT_emb.unsqueeze(0).repeat(B, 1, 1), + embeds2), dim=1) + lengths = lengths + 1 + lengths = torch.clamp(lengths, max=self.max_len) + + if lengths is not None: + mask = length_to_mask(lengths, max_len=embeds1.shape[1]).int() # type: ignore + else: + mask = torch.ones((B, self.code_depth), device=embeds1.device, dtype=torch.int) + return embeds1, embeds2, mask + + +# ================================================================ +# Aggregate all conditions and corresponding conditioners +# ================================================================ +class ConditionerProvider(nn.Module): + """Prepare and provide conditions given all the supported conditioners. + + Args: + conditioners (dict): Dictionary of conditioners. + device (torch.device or str, optional): Device for conditioners and output condition types. + """ + def __init__(self, conditioners: tp.Dict[str, BaseConditioner]): + super().__init__() + self.conditioners = nn.ModuleDict(conditioners) + + @property + def text_conditions(self): + return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] + + @property + def audio_conditions(self): + return [k for k, v in self.conditioners.items() if isinstance(v, AudioConditioner)] + + @property + def has_audio_condition(self): + return len(self.audio_conditions) > 0 + + def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: + """Match attributes/audios with existing conditioners in self, and compute tokenize them accordingly. + This should be called before starting any real GPU work to avoid synchronization points. + This will return a dict matching conditioner names to their arbitrary tokenized representations. + + Args: + inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing + text and audio conditions. + """ + assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( + "Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", + f" but types were {set([type(x) for x in inputs])}") + + output = {} + text = self._collate_text(inputs) + audios = self._collate_audios(inputs) + + assert set(text.keys() | audios.keys()).issubset(set(self.conditioners.keys())), ( + f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", + f"got {text.keys(), audios.keys()}") + + for attribute, batch in chain(text.items(), audios.items()): + output[attribute] = self.conditioners[attribute].tokenize(batch) + return output + + def forward(self, tokenized: tp.Dict[str, tp.Any], structure_dur = None) -> tp.Dict[str, ConditionType]: + """Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. + The output is for example: + { + "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), + "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), + ... + } + + Args: + tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. + """ + output = {} + for attribute, inputs in tokenized.items(): + if attribute == 'description' and structure_dur is not None: + condition1, condition2, mask = self.conditioners[attribute](inputs, structure_dur = structure_dur) + else: + condition1, condition2, mask = self.conditioners[attribute](inputs) + output[attribute] = (condition1, condition2, mask) + return output + + def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: + """Given a list of ConditioningAttributes objects, compile a dictionary where the keys + are the attributes and the values are the aggregated input per attribute. + For example: + Input: + [ + ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), + ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, audio=...), + ] + Output: + { + "genre": ["Rock", "Hip-hop"], + "description": ["A rock song with a guitar solo", "A hip-hop verse"] + } + + Args: + samples (list of ConditioningAttributes): List of ConditioningAttributes samples. + Returns: + dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. + """ + out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) + texts = [x.text for x in samples] + for text in texts: + for condition in self.text_conditions: + out[condition].append(text[condition]) + return out + + def _collate_audios(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, AudioCondition]: + """Generate a dict where the keys are attributes by which we fetch similar audios, + and the values are Tensors of audios according to said attributes. + + *Note*: by the time the samples reach this function, each sample should have some audios + inside the "audio" attribute. It should be either: + 1. A real audio + 2. A null audio due to the sample having no similar audios (nullified by the dataset) + 3. A null audio due to it being dropped in a dropout module (nullified by dropout) + + Args: + samples (list of ConditioningAttributes): List of ConditioningAttributes samples. + Returns: + dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. + """ + # import pdb; pdb.set_trace() + wavs = defaultdict(list) + lengths = defaultdict(list) + sample_rates = defaultdict(list) + paths = defaultdict(list) + seek_times = defaultdict(list) + out: tp.Dict[str, AudioCondition] = {} + + for sample in samples: + for attribute in self.audio_conditions: + wav, length, sample_rate, path, seek_time = sample.audio[attribute] + assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" + assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" + wavs[attribute].append(wav.flatten()) # [C*T] + lengths[attribute].append(length) + sample_rates[attribute].extend(sample_rate) + paths[attribute].extend(path) + seek_times[attribute].extend(seek_time) + + # stack all wavs to a single tensor + for attribute in self.audio_conditions: + stacked_wav, _ = collate(wavs[attribute], dim=0) + out[attribute] = AudioCondition( + stacked_wav.unsqueeze(1), + torch.cat(lengths[attribute]), sample_rates[attribute], + paths[attribute], seek_times[attribute]) + + return out + + +class ConditionFuser(StreamingModule): + """Condition fuser handles the logic to combine the different conditions + to the actual model input. + + Args: + fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse + each condition. For example: + { + "prepend": ["description"], + "sum": ["genre", "bpm"], + } + """ + FUSING_METHODS = ["sum", "prepend"] #, "cross", "input_interpolate"] (not support in this simplest version) + + def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]]): + super().__init__() + assert all([k in self.FUSING_METHODS for k in fuse2cond.keys()] + ), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" + self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond + self.cond2fuse: tp.Dict[str, str] = {} + for fuse_method, conditions in fuse2cond.items(): + for condition in conditions: + self.cond2fuse[condition] = fuse_method + + def forward( + self, + input1: torch.Tensor, + input2: torch.Tensor, + conditions: tp.Dict[str, ConditionType] + ) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """Fuse the conditions to the provided model input. + + Args: + input (torch.Tensor): Transformer input. + conditions (dict[str, ConditionType]): Dict of conditions. + Returns: + tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input + after the conditions have been fused. The second output tensor is the tensor + used for cross-attention or None if no cross attention inputs exist. + """ + #import pdb; pdb.set_trace() + B, T, _ = input1.shape + + if 'offsets' in self._streaming_state: + first_step = False + offsets = self._streaming_state['offsets'] + else: + first_step = True + offsets = torch.zeros(input1.shape[0], dtype=torch.long, device=input1.device) + + assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \ + f"given conditions contain unknown attributes for fuser, " \ + f"expected {self.cond2fuse.keys()}, got {conditions.keys()}" + + # if 'prepend' mode is used, + # the concatenation order will be the SAME with the conditions in config: + # prepend: ['description', 'prompt_audio'] (then goes the input) + fused_input_1 = input1 + fused_input_2 = input2 + for fuse_op in self.fuse2cond.keys(): + fuse_op_conditions = self.fuse2cond[fuse_op] + if fuse_op == 'sum' and len(fuse_op_conditions) > 0: + for cond in fuse_op_conditions: + this_cond_1, this_cond_2, cond_mask = conditions[cond] + fused_input_1 += this_cond_1 + fused_input_2 += this_cond_2 + elif fuse_op == 'prepend' and len(fuse_op_conditions) > 0: + if not first_step: + continue + reverse_list = deepcopy(fuse_op_conditions) + reverse_list.reverse() + for cond in reverse_list: + this_cond_1, this_cond_2, cond_mask = conditions[cond] + fused_input_1 = torch.cat((this_cond_1, fused_input_1), dim=1) # concat along T dim + fused_input_2 = torch.cat((this_cond_2, fused_input_2), dim=1) # concat along T dim + elif fuse_op not in self.FUSING_METHODS: + raise ValueError(f"unknown op ({fuse_op})") + + if self._is_streaming: + self._streaming_state['offsets'] = offsets + T + + return fused_input_1, fused_input_2 + + + +# ================================================================ +# Condition Dropout +# ================================================================ +class DropoutModule(nn.Module): + """Base module for all dropout modules.""" + def __init__(self, seed: int = 1234): + super().__init__() + self.rng = torch.Generator() + self.rng.manual_seed(seed) + + + +class ClassifierFreeGuidanceDropout(DropoutModule): + """Classifier Free Guidance dropout. + All attributes are dropped with the same probability. + + Args: + p (float): Probability to apply condition dropout during training. + seed (int): Random seed. + """ + def __init__(self, p: float, seed: int = 1234): + super().__init__(seed=seed) + self.p = p + + def check(self, sample, condition_type, condition): + + if condition_type not in ['text', 'audio']: + raise ValueError("dropout_condition got an unexpected condition type!" + f" expected 'text', 'audio' but got '{condition_type}'") + + if condition not in getattr(sample, condition_type): + raise ValueError( + "dropout_condition received an unexpected condition!" + f" expected audio={sample.audio.keys()} and text={sample.text.keys()}" + f" but got '{condition}' of type '{condition_type}'!") + + + def get_null_wav(self, wav, sr=48000) -> AudioCondition: + out = wav * 0 + 16385 + return AudioCondition( + wav=out, + length=torch.Tensor([0]).long(), + sample_rate=[sr],) + + def dropout_condition(self, + sample: ConditioningAttributes, + condition_type: str, + condition: str) -> ConditioningAttributes: + """Utility function for nullifying an attribute inside an ConditioningAttributes object. + If the condition is of type "wav", then nullify it using `nullify_condition` function. + If the condition is of any other type, set its value to None. + Works in-place. + """ + self.check(sample, condition_type, condition) + + if condition_type == 'audio': + audio_cond = sample.audio[condition] + depth = audio_cond.wav.shape[1] + sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0]) + else: + sample.text[condition] = None + + return sample + + def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: + """ + Args: + samples (list[ConditioningAttributes]): List of conditions. + Returns: + list[ConditioningAttributes]: List of conditions after all attributes were set to None. + """ + # decide on which attributes to drop in a batched fashion + # drop = torch.rand(1, generator=self.rng).item() < self.p + # if not drop: + # return samples + + # nullify conditions of all attributes + samples = deepcopy(samples) + + for sample in samples: + drop = torch.rand(1, generator=self.rng).item() + if drop ConditioningAttributes: + """Utility function for nullifying an attribute inside an ConditioningAttributes object. + If the condition is of type "audio", then nullify it using `nullify_condition` function. + If the condition is of any other type, set its value to None. + Works in-place. + """ + self.check(sample, condition_type, condition) + + if condition_type == 'audio': + audio_cond = sample.audio[condition] + depth = audio_cond.wav.shape[1] + sample.audio[condition] = self.get_null_wav(audio_cond.wav, sr=audio_cond.sample_rate[0]) + else: + if customized is None: + sample.text[condition] = None + else: + text_cond = deepcopy(sample.text[condition]) + if "structure" in customized: + for _s in ['[inst]', '[outro]', '[intro]', '[verse]', '[chorus]', '[bridge]']: + text_cond = text_cond.replace(_s, "") + text_cond = text_cond.replace(' , ', '') + text_cond = text_cond.replace(" ", " ") + if '.' in customized: + text_cond = text_cond.replace(" . ", " ") + text_cond = text_cond.replace(".", " ") + + sample.text[condition] = text_cond + + return sample + + def forward(self, samples: tp.List[ConditioningAttributes], + condition_types=["wav", "text"], + customized=None, + ) -> tp.List[ConditioningAttributes]: + """ + 100% dropout some condition attributes (description, prompt_wav) or types (text, wav) of + samples during inference. + + Args: + samples (list[ConditioningAttributes]): List of conditions. + Returns: + list[ConditioningAttributes]: List of conditions after all attributes were set to None. + """ + new_samples = deepcopy(samples) + for condition_type in condition_types: + for sample in new_samples: + for condition in sample.attributes[condition_type]: + self.dropout_condition_customized(sample, condition_type, condition, customized) + return new_samples + +class AttributeDropout(ClassifierFreeGuidanceDropout): + """Dropout with a given probability per attribute. + This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes + to be dropped out separately. For example, "artist" can be dropped while "genre" remains. + This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" + must also be dropped. + + Args: + p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: + ... + "genre": 0.1, + "artist": 0.5, + "audio": 0.25, + ... + active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False. + seed (int, optional): Random seed. + """ + def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234): + super().__init__(p=p, seed=seed) + self.active_on_eval = active_on_eval + # construct dict that return the values from p otherwise 0 + self.p = {} + for condition_type, probs in p.items(): + self.p[condition_type] = defaultdict(lambda: 0, probs) + + def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: + """ + Args: + samples (list[ConditioningAttributes]): List of conditions. + Returns: + list[ConditioningAttributes]: List of conditions after certain attributes were set to None. + """ + if not self.training and not self.active_on_eval: + return samples + + samples = deepcopy(samples) + for condition_type, ps in self.p.items(): # for condition types [text, wav] + for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre]) + if torch.rand(1, generator=self.rng).item() < p: + for sample in samples: + self.dropout_condition(sample, condition_type, condition) + return samples diff --git a/codeclm/modules/pattern.py b/codeclm/modules/pattern.py new file mode 100755 index 0000000000000000000000000000000000000000..bea6f6a7b70d075b382ecf7e693739387aa591b0 --- /dev/null +++ b/codeclm/modules/pattern.py @@ -0,0 +1,351 @@ +from collections import namedtuple +from dataclasses import dataclass +from functools import lru_cache +import logging +import typing as tp + +from abc import ABC, abstractmethod +import torch + +LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index) +PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates +logger = logging.getLogger(__name__) + + +@dataclass +class Pattern: + """Base implementation of a pattern over a sequence with multiple codebooks. + + The codebook pattern consists in a layout, defining for each sequence step + the list of coordinates of each codebook timestep in the resulting interleaved sequence. + The first item of the pattern is always an empty list in order to properly insert a special token + to start with. For convenience, we also keep track of ``code_depth`` the number of codebooks used for the pattern + and ``timesteps`` the number of timesteps corresponding to the original sequence. + + The pattern provides convenient methods to build and revert interleaved sequences from it: + ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T] + to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size, + K being the number of codebooks, T the number of original timesteps and S the number of sequence steps + for the output sequence. The unfilled positions are replaced with a special token and the built sequence + is returned along with a mask indicating valid tokens. + ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment + of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask + to fill and specify invalid positions if needed. + See the dedicated methods for more details. + """ + # Pattern layout, for each sequence step, we have a list of coordinates + # corresponding to the original codebook timestep and position. + # The first list is always an empty list in order to properly insert + # a special token to start with. + layout: PatternLayout + timesteps: int + code_depth: int + + def __post_init__(self): + assert len(self.layout) > 0 + assert self.layout[0] == [] + self._validate_layout() + self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes) + self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes) + logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout)) + + def _validate_layout(self): + """Runs checks on the layout to ensure a valid pattern is defined. + A pattern is considered invalid if: + - Multiple timesteps for a same codebook are defined in the same sequence step + - The timesteps for a given codebook are not in ascending order as we advance in the sequence + (this would mean that we have future timesteps before past timesteps). + """ + q_timesteps = {q: 0 for q in range(self.code_depth)} + for s, seq_coords in enumerate(self.layout): + if len(seq_coords) > 0: + qs = set() + for coord in seq_coords: + qs.add(coord.q) + last_q_timestep = q_timesteps[coord.q] + # assert coord.t >= last_q_timestep, \ + # f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}" + q_timesteps[coord.q] = coord.t + # each sequence step contains at max 1 coordinate per codebook + assert len(qs) == len(seq_coords), \ + f"Multiple entries for a same codebook are found at step {s}" + + @property + def num_sequence_steps(self): + return len(self.layout) - 1 + + @property + def max_delay(self): + max_t_in_seq_coords = 0 + for seq_coords in self.layout[1:]: + for coords in seq_coords: + max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1) + return max_t_in_seq_coords - self.timesteps + + @property + def valid_layout(self): + valid_step = len(self.layout) - self.max_delay + return self.layout[:valid_step] + + def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None): + """Get codebook coordinates in the layout that corresponds to the specified timestep t + and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step + and the actual codebook coordinates. + """ + assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps" + if q is not None: + assert q <= self.code_depth, "provided number of codebooks is greater than the pattern's number of codebooks" + coords = [] + for s, seq_codes in enumerate(self.layout): + for code in seq_codes: + if code.t == t and (q is None or code.q == q): + coords.append((s, code)) + return coords + + def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]: + return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)] + + def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]: + steps_with_timesteps = self.get_steps_with_timestep(t, q) + return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None + + def _build_pattern_sequence_scatter_indexes(self, timesteps: int, + code_depth: int, + keep_only_valid_steps: bool, + device: tp.Union[torch.device, str] = 'cpu'): + """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps. + + Args: + timesteps (int): Maximum number of timesteps steps to consider. + keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps. + device (torch.device or str): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S]. + """ + assert code_depth == self.code_depth, f"invalid number of codebooks for the sequence and the pattern: {code_depth} != {self.code_depth}" + assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern" + # use the proper layout based on whether we limit ourselves to valid steps only or not, + # note that using the valid_layout will result in a truncated sequence up to the valid steps + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(code_depth, len(ref_layout), dtype=torch.long).numpy() + mask = torch.zeros(code_depth, len(ref_layout), dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + # the last value is code_depth * timesteps as we have flattened z and append special token as the last token + # which will correspond to the index: code_depth * timesteps + indexes[:] = code_depth * timesteps + # iterate over the pattern and fill scattered indexes and mask + for s, sequence_coords in enumerate(ref_layout): + for coords in sequence_coords: + if coords.t < timesteps: + indexes[coords.q, s] = coords.t + coords.q * timesteps + mask[coords.q, s] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Build sequence corresponding to the pattern from the input tensor z. + The sequence is built using up to sequence_steps if specified, and non-pattern + coordinates are filled with the special token. + + Args: + z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T]. + special_token (int): Special token used to fill non-pattern coordinates in the new sequence. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S + corresponding either to the sequence_steps if provided, otherwise to the length of the pattern. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S]. + """ + B, K, T = z.shape + indexes, mask = self._build_pattern_sequence_scatter_indexes( + T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device) + ) + z = z.reshape(B, -1) + # we append the special token as the last index of our flattened z tensor + z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1) + values = z[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + # import pdb; pdb.set_trace() + return values, indexes, mask + + def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, code_depth: int, + keep_only_valid_steps: bool = False, + is_model_output: bool = False, + device: tp.Union[torch.device, str] = 'cpu'): + """Builds scatter indexes required to retrieve the original multi-codebook sequence + from interleaving pattern. + + Args: + sequence_steps (int): Sequence steps. + code_depth (int): Number of codebooks. + keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. + Steps that are beyond valid steps will be replaced by the special_token in that case. + is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not. + device (torch.device or str): Device for created tensors. + Returns: + indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + ref_layout = self.valid_layout if keep_only_valid_steps else self.layout + timesteps = self.timesteps + assert code_depth == self.code_depth, f"invalid number of codebooks for the sequence and the pattern: {code_depth} != {self.code_depth}" + assert sequence_steps <= len(ref_layout), \ + f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}" + + # ensure we take the appropriate indexes to keep the model output from the first special token as well + if is_model_output: + ref_layout = ref_layout[1:] + + # single item indexing being super slow with pytorch vs. numpy, so we use numpy here + indexes = torch.zeros(code_depth, timesteps, dtype=torch.long).numpy() + mask = torch.zeros(code_depth, timesteps, dtype=torch.bool).numpy() + # fill indexes with last sequence step value that will correspond to our special token + indexes[:] = code_depth * sequence_steps + for s, sequence_codes in enumerate(ref_layout): + if s < sequence_steps: + for code in sequence_codes: + if code.t < timesteps: + indexes[code.q, code.t] = s + code.q * sequence_steps + mask[code.q, code.t] = 1 + indexes = torch.from_numpy(indexes).to(device) + mask = torch.from_numpy(mask).to(device) + return indexes, mask + + def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False): + """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving. + The sequence is reverted using up to timesteps if specified, and non-pattern coordinates + are filled with the special token. + + Args: + s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S]. + special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence. + Returns: + values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T + corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise. + indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T]. + mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T]. + """ + B, K, S = s.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device) + ) + s = s.view(B, -1) + # we append the special token as the last index of our flattened z tensor + s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1) + values = s[:, indexes.view(-1)] + values = values.view(B, K, indexes.shape[-1]) + return values, indexes, mask + + def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False): + """Revert model logits obtained on a sequence built from the pattern + back to a tensor matching the original sequence. + + This method is similar to ``revert_pattern_sequence`` with the following specificities: + 1. It is designed to work with the extra cardinality dimension + 2. We return the logits for the first sequence item that matches the special_token and + which matching target in the original sequence is the first item of the sequence, + while we skip the last logits as there is no matching target + """ + B, card, K, S = logits.shape + indexes, mask = self._build_reverted_sequence_scatter_indexes( + S, K, keep_only_valid_steps, is_model_output=True, device=logits.device + ) + logits = logits.reshape(B, card, -1) + # we append the special token as the last index of our flattened z tensor + logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S] + values = logits[:, :, indexes.view(-1)] + + values = values.view(B, card, K, indexes.shape[-1]) + return values, indexes, mask + + + +class CodebooksPatternProvider(ABC): + """Abstraction around providing pattern for interleaving codebooks. + + The CodebooksPatternProvider abstraction allows to implement various strategies to + define interleaving pattern of sequences composed of multiple codebooks. For a given + number of codebooks `code_depth`, the pattern provider can generate a specified pattern + corresponding to a sequence of `T` timesteps with `code_depth` parallel codebooks. This pattern + can be used to construct a new sequence from the original codes respecting the specified + pattern. The pattern is defined as a list of list of code coordinates, code coordinate + being a tuple with the original timestep and codebook to build the new sequence. + Note that all patterns must start with an empty list that is then used to insert a first + sequence step of special tokens in the newly generated sequence. + + Args: + code_depth (int): number of codebooks. + cached (bool): if True, patterns for a given length are cached. In general + that should be true for efficiency reason to avoid synchronization points. + """ + def __init__(self, code_depth: int, cached: bool = True): + assert code_depth > 0 + self.code_depth = code_depth + self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore + + @abstractmethod + def get_pattern(self, timesteps: int) -> Pattern: + """Builds pattern with specific interleaving between codebooks. + + Args: + timesteps (int): Total number of timesteps. + """ + raise NotImplementedError() + + +class DelayedPatternProvider(CodebooksPatternProvider): + """Provider for delayed pattern across delayed codebooks. + Codebooks are delayed in the sequence and sequence steps will contain codebooks + from different timesteps. + + Example: + Taking timesteps=4 and code_depth=3, delays=None, the multi-codebook sequence: + [[1, 2, 3, 4], + [1, 2, 3, 4], + [1, 2, 3, 4]] + The resulting sequence obtained from the returned pattern is: + [[S, 1, 2, 3, 4], + [S, S, 1, 2, 3], + [S, S, S, 1, 2]] + (with S being a special token) + + Args: + code_depth (int): Number of codebooks. + delays (list of int, optional): Delay for each of the codebooks. + If delays not defined, each codebook is delayed by 1 compared to the previous one. + flatten_first (int): Flatten the first N timesteps. + empty_initial (int): Prepend with N empty list of coordinates. + """ + def __init__(self, code_depth: int, delays: tp.Optional[tp.List[int]] = None, + flatten_first: int = 0, empty_initial: int = 0): + super().__init__(code_depth) + if delays is None: + delays = list(range(code_depth)) + self.delays = delays + self.flatten_first = flatten_first + self.empty_initial = empty_initial + assert len(self.delays) == self.code_depth + assert sorted(self.delays) == self.delays + + def get_pattern(self, timesteps: int) -> Pattern: + out: PatternLayout = [[]] + max_delay = max(self.delays) + if self.empty_initial: + out += [[] for _ in range(self.empty_initial)] + if self.flatten_first: + for t in range(min(timesteps, self.flatten_first)): + for q in range(self.code_depth): + out.append([LayoutCoord(t, q)]) + for t in range(self.flatten_first, timesteps + max_delay): + v = [] + for q, delay in enumerate(self.delays): + t_for_q = t - delay + if t_for_q >= self.flatten_first: + v.append(LayoutCoord(t_for_q, q)) + out.append(v) + return Pattern(out, code_depth=self.code_depth, timesteps=timesteps) diff --git a/codeclm/modules/streaming.py b/codeclm/modules/streaming.py new file mode 100755 index 0000000000000000000000000000000000000000..738b4611a48b8786d777218793648c4d87565d79 --- /dev/null +++ b/codeclm/modules/streaming.py @@ -0,0 +1,112 @@ +""" +Streaming module API that should be implemented by all Streaming components, +""" + +from contextlib import contextmanager +import typing as tp +from torch import nn +import torch + + +State = tp.Dict[str, torch.Tensor] + +class StreamingModule(nn.Module): + """Common API for streaming components. + + Each streaming component has a streaming state, which is just a dict[str, Tensor]. + By convention, the first dim of each tensor must be the batch size. + Don't use dots in the key names, as this would clash with submodules + (like in state_dict). + + If `self._is_streaming` is True, the component should use and remember + the proper state inside `self._streaming_state`. + + To set a streaming component in streaming state, use + + with module.streaming(): + ... + + This will automatically reset the streaming state when exiting the context manager. + This also automatically propagates to all streaming children module. + + Some module might also implement the `StreamingModule.flush` method, although + this one is trickier, as all parents module must be StreamingModule and implement + it as well for it to work properly. See `StreamingSequential` after. + """ + def __init__(self) -> None: + super().__init__() + self._streaming_state: State = {} + self._is_streaming = False + + def _apply_named_streaming(self, fn: tp.Any): + for name, module in self.named_modules(): + if isinstance(module, StreamingModule): + fn(name, module) + + def _set_streaming(self, streaming: bool): + def _set_streaming(name, module): + module._is_streaming = streaming + self._apply_named_streaming(_set_streaming) + + @contextmanager + def streaming(self): + """Context manager to enter streaming mode. Reset streaming state on exit.""" + self._set_streaming(True) + try: + yield + finally: + self._set_streaming(False) + self.reset_streaming() + + def reset_streaming(self): + """Reset the streaming state.""" + def _reset(name: str, module: StreamingModule): + module._streaming_state.clear() + + self._apply_named_streaming(_reset) + + def get_streaming_state(self) -> State: + """Return the streaming state, including that of sub-modules.""" + state: State = {} + + def _add(name: str, module: StreamingModule): + if name: + name += "." + for key, value in module._streaming_state.items(): + state[name + key] = value + + self._apply_named_streaming(_add) + return state + + def set_streaming_state(self, state: State): + """Set the streaming state, including that of sub-modules.""" + state = dict(state) + + def _set(name: str, module: StreamingModule): + if name: + name += "." + module._streaming_state.clear() + for key, value in list(state.items()): + # complexity is not ideal here, but probably fine. + if key.startswith(name): + local_key = key[len(name):] + if '.' not in local_key: + module._streaming_state[local_key] = value + del state[key] + + self._apply_named_streaming(_set) + assert len(state) == 0, list(state.keys()) + + def flush(self, x: tp.Optional[torch.Tensor] = None): + """Flush any remaining outputs that were waiting for completion. + Typically, for convolutions, this will add the final padding + and process the last buffer. + + This should take an optional argument `x`, which will be provided + if a module before this one in the streaming pipeline has already + spitted out a flushed out buffer. + """ + if x is None: + return None + else: + return self(x) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/audio.py b/codeclm/tokenizer/Flow1dVAE/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..29c21532455ef812eec2951653c87eda15d485cd --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/audio.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@File : audio.py +@Time : 2023/8/8 下午7:18 +@Author : waytan +@Contact : waytan@tencent.com +@License : (C)Copyright 2023, Tencent +@Desc : Audio +""" +import json +import subprocess as sp +import typing as tp +from pathlib import Path + +import lameenc +import julius +import torch +import numpy as np +import torchaudio as ta +from contextlib import contextmanager +import tempfile +import os + +@contextmanager +def temp_filenames(count: int, delete=True): + names = [] + try: + for _ in range(count): + names.append(tempfile.NamedTemporaryFile(delete=False).name) + yield names + finally: + if delete: + for name in names: + os.unlink(name) + + +def _read_info(path): + stdout_data = sp.check_output([ + 'ffprobe', "-loglevel", "panic", + str(path), '-print_format', 'json', '-show_format', '-show_streams' + ]) + return json.loads(stdout_data.decode('utf-8')) + + +class AudioFile: + """ + Allows to read audio from any format supported by ffmpeg, as well as resampling or + converting to mono on the fly. See :method:`read` for more details. + """ + def __init__(self, path: Path): + self.path = Path(path) + self._info = None + + def __repr__(self): + features = [("path", self.path)] + features.append(("samplerate", self.samplerate())) + features.append(("channels", self.channels())) + features.append(("streams", len(self))) + features_str = ", ".join(f"{name}={value}" for name, value in features) + return f"AudioFile({features_str})" + + @property + def info(self): + if self._info is None: + self._info = _read_info(self.path) + return self._info + + @property + def duration(self): + return float(self.info['format']['duration']) + + @property + def _audio_streams(self): + return [ + index for index, stream in enumerate(self.info["streams"]) + if stream["codec_type"] == "audio" + ] + + def __len__(self): + return len(self._audio_streams) + + def channels(self, stream=0): + return int(self.info['streams'][self._audio_streams[stream]]['channels']) + + def samplerate(self, stream=0): + return int(self.info['streams'][self._audio_streams[stream]]['sample_rate']) + + def read(self, + seek_time=None, + duration=None, + streams=slice(None), + samplerate=None, + channels=None): + """ + Slightly more efficient implementation than stempeg, + in particular, this will extract all stems at once + rather than having to loop over one file multiple times + for each stream. + + Args: + seek_time (float): seek time in seconds or None if no seeking is needed. + duration (float): duration in seconds to extract or None to extract until the end. + streams (slice, int or list): streams to extract, can be a single int, a list or + a slice. If it is a slice or list, the output will be of size [S, C, T] + with S the number of streams, C the number of channels and T the number of samples. + If it is an int, the output will be [C, T]. + samplerate (int): if provided, will resample on the fly. If None, no resampling will + be done. Original sampling rate can be obtained with :method:`samplerate`. + channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that + as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers. + See https://sound.stackexchange.com/a/42710. + Our definition of mono is simply the average of the two channels. Any other + value will be ignored. + """ + streams = np.array(range(len(self)))[streams] + single = not isinstance(streams, np.ndarray) + if single: + streams = [streams] + + if duration is None: + target_size = None + query_duration = None + else: + target_size = int((samplerate or self.samplerate()) * duration) + query_duration = float((target_size + 1) / (samplerate or self.samplerate())) + + with temp_filenames(len(streams)) as filenames: + command = ['ffmpeg', '-y'] + command += ['-loglevel', 'panic'] + if seek_time: + command += ['-ss', str(seek_time)] + command += ['-i', str(self.path)] + for stream, filename in zip(streams, filenames): + command += ['-map', f'0:{self._audio_streams[stream]}'] + if query_duration is not None: + command += ['-t', str(query_duration)] + command += ['-threads', '1'] + command += ['-f', 'f32le'] + if samplerate is not None: + command += ['-ar', str(samplerate)] + command += [filename] + + sp.run(command, check=True) + wavs = [] + for filename in filenames: + wav = np.fromfile(filename, dtype=np.float32) + wav = torch.from_numpy(wav) + wav = wav.view(-1, self.channels()).t() + if channels is not None: + wav = convert_audio_channels(wav, channels) + if target_size is not None: + wav = wav[..., :target_size] + wavs.append(wav) + wav = torch.stack(wavs, dim=0) + if single: + wav = wav[0] + return wav + + +def convert_audio_channels(wav, channels=2): + """Convert audio to the given number of channels.""" + *shape, src_channels, length = wav.shape + if src_channels == channels: + pass + elif channels == 1: + # Case 1: + # The caller asked 1-channel audio, but the stream have multiple + # channels, downmix all channels. + wav = wav.mean(dim=-2, keepdim=True) + elif src_channels == 1: + # Case 2: + # The caller asked for multiple channels, but the input file have + # one single channel, replicate the audio over all channels. + wav = wav.expand(*shape, channels, length) + elif src_channels >= channels: + # Case 3: + # The caller asked for multiple channels, and the input file have + # more channels than requested. In that case return the first channels. + wav = wav[..., :channels, :] + else: + # Case 4: What is a reasonable choice here? + raise ValueError('The audio file has less channels than requested but is not mono.') + return wav + + +def convert_audio(wav, from_samplerate, to_samplerate, channels): + """Convert audio from a given samplerate to a target one and target number of channels.""" + wav = convert_audio_channels(wav, channels) + return julius.resample_frac(wav, from_samplerate, to_samplerate) + + +def i16_pcm(wav): + """Convert audio to 16 bits integer PCM format.""" + if wav.dtype.is_floating_point: + return (wav.clamp_(-1, 1) * (2**15 - 1)).short() + else: + return wav + + +def f32_pcm(wav): + """Convert audio to float 32 bits PCM format.""" + if wav.dtype.is_floating_point: + return wav + else: + return wav.float() / (2**15 - 1) + + +def as_dtype_pcm(wav): + """Convert audio to either f32 pcm or i16 pcm depending on the given dtype.""" + if wav.dtype.is_floating_point: + return f32_pcm(wav) + else: + return i16_pcm(wav) + + +def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False): + """Save given audio as mp3. This should work on all OSes.""" + c, _ = wav.shape + wav = i16_pcm(wav) + encoder = lameenc.Encoder() + encoder.set_bit_rate(bitrate) + encoder.set_in_sample_rate(samplerate) + encoder.set_channels(c) + encoder.set_quality(2) # 2-highest, 7-fastest + if not verbose: + encoder.silence() + wav = wav.data.cpu() + wav = wav.transpose(0, 1).numpy() + mp3_data = encoder.encode(wav.tobytes()) + mp3_data += encoder.flush() + with open(path, "wb") as f: + f.write(mp3_data) + + +def prevent_clip(wav, mode='rescale'): + """ + different strategies for avoiding raw clipping. + """ + if mode is None or mode == 'none': + return wav + assert wav.dtype.is_floating_point, "too late for clipping" + if mode == 'rescale': + wav = wav / max(1.01 * wav.abs().max(), 1) + elif mode == 'clamp': + wav = wav.clamp(-0.99, 0.99) + elif mode == 'tanh': + wav = torch.tanh(wav) + else: + raise ValueError(f"Invalid mode {mode}") + return wav + + +def save_audio(wav: torch.Tensor, + path: tp.Union[str, Path], + samplerate: int, + bitrate: int = 320, + clip: tp.Union[str] = 'rescale', + bits_per_sample: tp.Union[int] = 16, + as_float: bool = False): + """Save audio file, automatically preventing clipping if necessary + based on the given `clip` strategy. If the path ends in `.mp3`, this + will save as mp3 with the given `bitrate`. + """ + wav = prevent_clip(wav, mode=clip) + path = Path(path) + suffix = path.suffix.lower() + if suffix == ".mp3": + encode_mp3(wav, path, samplerate, bitrate, verbose=True) + elif suffix == ".wav": + if as_float: + bits_per_sample = 32 + encoding = 'PCM_F' + else: + encoding = 'PCM_S' + ta.save(str(path), wav, sample_rate=samplerate, + encoding=encoding, bits_per_sample=bits_per_sample) + elif suffix == ".flac": + ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample) + else: + raise ValueError(f"Invalid suffix for path: {suffix}") + + +def load_track(track, audio_channels, samplerate): + errors = {} + wav = None + + try: + wav = AudioFile(track).read( + streams=0, + samplerate=samplerate, + channels=audio_channels) + except sp.CalledProcessError: + errors['ffmpeg'] = 'FFmpeg could not read the file.' + + if wav is None: + try: + wav, sr = ta.load(str(track)) + except RuntimeError as err: + errors['torchaudio'] = err.args[0] + else: + wav = convert_audio(wav, sr, samplerate, audio_channels) + + return wav, errors \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/cal_token_stat.py b/codeclm/tokenizer/Flow1dVAE/cal_token_stat.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c9f2ad6bcff850bdec1a53c7299bad63a63ab2 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/cal_token_stat.py @@ -0,0 +1,19 @@ +import kaldiio +from tqdm import tqdm +import torch + +if __name__ == "__main__": + bar = torch.zeros(1, 16384) + with open('token.scp', 'r') as f: + for item_idx, line in tqdm(enumerate(f)): + idx, pos = line.strip().split() + codes = kaldiio.load_mat(pos) + for i0 in range(codes.shape[-1]): + bar[0, codes[0, 0, i0]] += 1 + if(item_idx % 1000 == 0): + print("=========") + print(1 - (bar[0]==0).sum() / bar.shape[-1]) + print("=========") + print("=========") + print(1 - (bar[0]==0).sum() / bar.shape[-1]) + print("=========") \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/compare_model_weight.py b/codeclm/tokenizer/Flow1dVAE/compare_model_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..af9c77e6c3b0d5adc6f56bb19d0ecea966acecfd --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/compare_model_weight.py @@ -0,0 +1,13 @@ +import torch +import sys +from safetensors.torch import load_file + +if __name__ == "__main__": + m0, m1 = sys.argv[1], sys.argv[2] + m0 = load_file(m0) + m1 = load_file(m1) + + ks = [k for k in m0.keys() if 'bestrq' in k] + for k in ks: + print(k, (m0[k] - m1[k]).abs().sum()) + \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/configs/models/transformer2D_wocross_inch112_1x4_multi_large.json b/codeclm/tokenizer/Flow1dVAE/configs/models/transformer2D_wocross_inch112_1x4_multi_large.json new file mode 100644 index 0000000000000000000000000000000000000000..611d38d607b77190943ba607fc446a46758a291c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/configs/models/transformer2D_wocross_inch112_1x4_multi_large.json @@ -0,0 +1,26 @@ +{ + "_class_name": "Transformer2DModel", + "_diffusers_version": "0.22.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": true, + "attention_head_dim": 72, + "attention_type": "default", + "cross_attention_dim": null, + "double_self_attention": false, + "dropout": 0.0, + "in_channels": 96, + "norm_elementwise_affine": false, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "norm_type": "ada_norm_single", + "num_attention_heads": 22, + "num_embeds_ada_norm": 1000, + "num_layers": 24, + "num_vector_embeds": null, + "only_cross_attention": false, + "out_channels": 32, + "patch_size": 2, + "sample_size": 384, + "upcast_attention": false, + "use_linear_projection": false +} \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json b/codeclm/tokenizer/Flow1dVAE/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json new file mode 100644 index 0000000000000000000000000000000000000000..ae79af91c539f872ad60a239528e7e32b6903bbb --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/configs/scheduler/stable_diffusion_2.1_largenoise_sample.json @@ -0,0 +1,14 @@ +{ + "_class_name": "DDIMScheduler", + "_diffusers_version": "0.8.0", + "beta_end": 0.02, + "beta_schedule": "scaled_linear", + "beta_start": 0.0015, + "clip_sample": false, + "num_train_timesteps": 1000, + "prediction_type": "sample", + "set_alpha_to_one": false, + "skip_prk_steps": true, + "steps_offset": 1, + "trained_betas": null +} diff --git a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py new file mode 100644 index 0000000000000000000000000000000000000000..6922a42d8016477448e92890f9eec84a8eb8d9d9 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_and_sep_npy.py @@ -0,0 +1,121 @@ +import torch,torchaudio +import os,sys,json +from tqdm import tqdm +import numpy as np + +#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango +from generate_septoken import Tango as Tango_sep +from generate_2rvq import Tango as Tango_1x2 +import kaldiio +from kaldiio import WriteHelper +from audio import AudioFile + +from demucs.models.pretrained import get_model_from_yaml +from filelock import FileLock + +# os.path.join(args.model_dir, "htdemucs.pth"), os.path.join(args.model_dir, "htdemucs.yaml") +class Separator: + def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None: + if torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): + self.device = torch.device(f"cuda:{gpu_id}") + else: + self.device = torch.device("cpu") + self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path) + + def init_demucs_model(self, model_path, config_path): + model = get_model_from_yaml(config_path, model_path) + model.to(self.device) + model.eval() + return model + + def load_audio(self, f): + a, fs = torchaudio.load(f) + if (fs != 48000): + a = torchaudio.functional.resample(a, fs, 48000) + # if a.shape[-1] >= 48000*10: + # a = a[..., :48000*10] + # else: + # a = torch.cat([a, a], -1) + # return a[:, 0:48000*10] + return a + + def run(self, audio_path, output_dir='demucs/test_output', ext=".flac"): + name, _ = os.path.splitext(os.path.split(audio_path)[-1]) + output_paths = [] + # lock_path = os.path.join(output_dir, f"{name}.lock") + # with FileLock(lock_path): # 加一个避免多卡访问时死锁 + for stem in self.demucs_model.sources: + output_path = os.path.join(output_dir, f"{name}_{stem}{ext}") + if os.path.exists(output_path): + output_paths.append(output_path) + if len(output_paths) == 1: # 4 + # drums_path, bass_path, other_path, vocal_path = output_paths + vocal_path = output_paths[0] + else: + lock_path = os.path.join(output_dir, f"{name}_separate.lock") + with FileLock(lock_path): + drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device) + full_audio = self.load_audio(audio_path) + vocal_audio = self.load_audio(vocal_path) + minlen = min(full_audio.shape[-1], vocal_audio.shape[-1]) + # bgm_audio = full_audio[:, 0:minlen] - vocal_audio[:, 0:minlen] + bgm_audio = self.load_audio(drums_path) + self.load_audio(bass_path) + self.load_audio(other_path) + for path in [drums_path, bass_path, other_path, vocal_path]: + os.remove(path) + return full_audio, vocal_audio, bgm_audio + +def read_wav(fname, sample_rate=48_000): + try: + orig_samples, fs = torchaudio.load(fname) + except: + af = AudioFile(fname) + orig_samples = af.read() + fs = af.samplerate() + orig_samples = orig_samples[0] + if(fs!=sample_rate): + orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate) + fs = sample_rate + if orig_samples.shape[0] == 1: + orig_samples = torch.cat([orig_samples, orig_samples], 0) + return orig_samples + +if __name__ == "__main__": + # Define Model + json_path = sys.argv[1] + + mus_infos = [] + with open(json_path) as f: + for line in f: + item = json.loads(line) + mus_infos.append(item) + + tango_sep = Tango_sep(model_path="./saved/model_septoken/model_2.safetensors") + tango_1x2 = Tango_1x2(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2) + separator = Separator() + + # Feature extraction loop + # for i in tqdm(range(2000)): + first_time = True + for item in tqdm(mus_infos): + if(os.path.exists(item['path'])): + full_path = item['path'] + else: + full_path = '/mnt/share/' + item['path'] + + full_tensor, vocal_tensor, bgm_tensor = separator.run(full_path) + + # full_tensor = read_wav(full_path) + # vocal_tensor = read_wav(vocal_path) + # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1]) + # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length] + # bgm_tensor = full_tensor - vocal_tensor + codes_1x2 = tango_1x2.sound2code(full_tensor) + codes_vocal, codes_bgm = tango_sep.sound2code(vocal_tensor, bgm_tensor) + codes = torch.cat([codes_1x2[:,[0],:], codes_vocal, codes_bgm], 1).cpu().numpy() + save_path = full_path.replace('.wav', '.1x1_and_sep.npy').replace('.mp3', '.1x1_and_sep.npy').replace('.flac', '.1x1_and_sep.npy').replace('.ogg', '.1x1_and_sep.npy') + assert save_path != full_path, (save_path, full_path) + np.save(save_path, codes) + + if(first_time): + first_time = False + print(codes_vocal.shape, codes_bgm.shape) diff --git a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py new file mode 100644 index 0000000000000000000000000000000000000000..b46d6afef8297764b7f3ca9b3b652202b12dfc7c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x1_sep.py @@ -0,0 +1,94 @@ +import torch,torchaudio +import os,sys,json +from tqdm import tqdm + +#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango +from generate_septoken import Tango +import kaldiio +from kaldiio import WriteHelper +from audio import AudioFile + +def read_wav(fname, sample_rate=48_000): + try: + orig_samples, fs = torchaudio.load(fname) + except: + af = AudioFile(fname) + orig_samples = af.read() + fs = af.samplerate() + orig_samples = orig_samples[0] + if(fs!=sample_rate): + orig_samples = torchaudio.functional.resample(orig_samples, fs, sample_rate) + fs = sample_rate + if orig_samples.shape[0] == 1: + orig_samples = torch.cat([orig_samples, orig_samples], 0) + return orig_samples + +if __name__ == "__main__": + # Define Model + json_path = sys.argv[1] + outdir = sys.argv[2] + + mus_infos = [] + with open(json_path) as f: + for line in f: + item = json.loads(line) + mus_infos.append(item) + + tango = Tango(model_path="./saved/model_septoken/model_2.safetensors") + + + # Feature extraction loop + # for i in tqdm(range(2000)): + first_time = True + with WriteHelper('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir), write_function="pickle") as writer_vocal, WriteHelper('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir), write_function="pickle") as writer_bgm: + print('ark,scp:{}/token_vocal.ark,{}/token_vocal.scp'.format(outdir, outdir)) + print('ark,scp:{}/token_bgm.ark,{}/token_bgm.scp'.format(outdir, outdir)) + for item in tqdm(mus_infos): + try: + # if True: + idx = item['idx'] + # print(idx) + if(os.path.exists(item['path'])): + full_path = item['path'] + else: + full_path = '/mnt/share/' + item['path'] + if(os.path.exists(item['vocal_path'])): + vocal_path = item['vocal_path'] + bgm_paths = item['bgm_path'] + else: + vocal_path = '/mnt/share/' + item['vocal_path'] + bgm_paths = ['/mnt/share/' + p for p in item['bgm_path']] + vocal_tensor = read_wav(vocal_path) + # full_tensor = read_wav(full_path) + # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1]) + # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length] + # bgm_tensor = full_tensor - vocal_tensor + bgm_tensor = sum([read_wav(p) for p in bgm_paths]) + codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor) + writer_vocal(str(idx), codes_vocal.cpu()) + writer_bgm(str(idx), codes_bgm.cpu()) + if(first_time): + first_time = False + print(codes_vocal.shape, codes_bgm.shape) + except: + print(item['vocal_path']) + print(item['bgm_path']) + continue + + # idx = item['idx'] + # # print(idx) + # full_path = item['path'] + # vocal_path = item['vocal_path'] + # bgm_paths = item['bgm_path'] + # full_tensor = read_wav(full_path) + # vocal_tensor = read_wav(vocal_path) + # length = min(full_tensor.shape[-1], vocal_tensor.shape[-1]) + # full_tensor, vocal_tensor = full_tensor[:, 0:length], vocal_tensor[:, 0:length] + # bgm_tensor = full_tensor - vocal_tensor + # codes_vocal, codes_bgm = tango.sound2code(vocal_tensor, bgm_tensor) + # writer_vocal(str(idx), codes_vocal.cpu()) + # writer_bgm(str(idx), codes_bgm.cpu()) + # if(first_time): + # first_time = False + # print(codes_vocal.shape, codes_bgm.shape) + diff --git a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py new file mode 100644 index 0000000000000000000000000000000000000000..4069277c87dd58a05ea9fdf964af9713cfab205c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x2.py @@ -0,0 +1,70 @@ +import torch,torchaudio +import os,sys,json +from tqdm import tqdm + +#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango +from generate_2rvq import Tango +import kaldiio +from kaldiio import WriteHelper +import torch +import subprocess +import time +import sys + +def get_gpu_memory(): + _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1] + + ACCEPTABLE_AVAILABLE_MEMORY = 1024 + COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv" + memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:] + memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] + return memory_free_values + +if __name__ == "__main__": + # Define Model + json_path = sys.argv[1] + outdir = sys.argv[2] + + gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES']) + while True: + free_mem = get_gpu_memory() + free_mem = free_mem[gpu_idx] + if(free_mem > 25_000): + print("GPU memory {}, run matrix cal".format(free_mem)) + break + else: + print("GPU memory {}, sleep 1min".format(free_mem)) + time.sleep(60) + + mus_infos = [] + with open(json_path) as f: + for line in f: + item = json.loads(line) + mus_infos.append(item) + + tango = Tango(model_path = './saved/model_2rvq/model_2_fixed.safetensors', rvq_num=2) + + + # Feature extraction loop + # for i in tqdm(range(2000)): + with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer: + print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir)) + for item in tqdm(mus_infos): + try: + # if True: + idx = item['idx'] + # print(idx) + with torch.autocast(device_type="cuda", dtype=torch.float16): + if(os.path.exists(item['path'])): + codes = tango.file2code(item['path']) + else: + codes = tango.file2code('/mnt/share/' + item['path']) + writer(str(idx), codes.cpu()) + except: + print(item['path']) + continue + # idx = item['idx'] + # # print(idx) + # with torch.autocast(device_type="cuda", dtype=torch.float16): + # codes = tango.file2code(item['path']) + # writer(str(idx), codes.cpu()) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py new file mode 100644 index 0000000000000000000000000000000000000000..5116d4bc1e3bf5ef7c553d348992e0dfc119f303 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4.py @@ -0,0 +1,46 @@ +import torch,torchaudio +import os,sys,json +from tqdm import tqdm + +#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango +from generate_4rvq import Tango +import kaldiio +from kaldiio import WriteHelper + +if __name__ == "__main__": + # Define Model + json_path = sys.argv[1] + outdir = sys.argv[2] + + mus_infos = [] + with open(json_path) as f: + for line in f: + item = json.loads(line) + mus_infos.append(item) + + tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4) + + + # Feature extraction loop + # for i in tqdm(range(2000)): + with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer: + print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir)) + for item in tqdm(mus_infos): + try: + # if True: + idx = item['idx'] + # print(idx) + with torch.autocast(device_type="cuda", dtype=torch.float16): + if(os.path.exists(item['path'])): + codes = tango.file2code(item['path']) + else: + codes = tango.file2code('/mnt/share/' + item['path']) + writer(str(idx), codes.cpu()) + except: + print(item['path']) + continue + # idx = item['idx'] + # # print(idx) + # with torch.autocast(device_type="cuda", dtype=torch.float16): + # codes = tango.file2code(item['path']) + # writer(str(idx), codes.cpu()) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py new file mode 100644 index 0000000000000000000000000000000000000000..8df3f06af60cfd9aa020dd8ac0e50a0d89898b88 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/extract_codes_stereo_7_1x4_ds.py @@ -0,0 +1,86 @@ +import torch,torchaudio +import os,sys,json +from tqdm import tqdm + +#from codeclm_song_v1.codeclm.semantic_extractor.SpeechDecoder_v01.generate import Tango +from generate_4rvq import Tango +import kaldiio +from kaldiio import WriteHelper +import torch +import subprocess +import time +import sys + +def get_gpu_memory(): + _output_to_list = lambda x: x.decode('ascii').split('\n')[:-1] + + ACCEPTABLE_AVAILABLE_MEMORY = 1024 + COMMAND = "nvidia-smi --query-gpu=memory.free --format=csv" + memory_free_info = _output_to_list(subprocess.check_output(COMMAND.split()))[1:] + memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)] + return memory_free_values + +if __name__ == "__main__": + # Define Model + json_path = sys.argv[1] + outdir = sys.argv[2] + ds = int(sys.argv[3]) + + gpu_idx = int(os.environ['CUDA_VISIBLE_DEVICES']) + while True: + free_mem = get_gpu_memory() + free_mem = free_mem[gpu_idx] + if(free_mem > 25_000): + print("GPU memory {}, run matrix cal".format(free_mem)) + break + else: + print("GPU memory {}, sleep 1min".format(free_mem)) + time.sleep(60) + + mus_infos = [] + with open(json_path) as f: + for line in f: + item = json.loads(line) + mus_infos.append(item) + + tango = Tango(model_path = './saved/model_4rvq/model_2_fixed.safetensors', rvq_num=4) + + + # Feature extraction loop + # for i in tqdm(range(2000)): + with WriteHelper('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir), write_function="pickle") as writer: + print('ark,scp:{}/token.ark,{}/token.scp'.format(outdir, outdir)) + bar = torch.zeros(4, 16384) + for item_idx, item in tqdm(enumerate(mus_infos)): + try: + # if True: + idx = item['idx'] + # print(idx) + with torch.autocast(device_type="cuda", dtype=torch.float16): + if(os.path.exists(item['path'])): + codes = tango.file2code_ds(item['path'], ds) + else: + codes = tango.file2code_ds('/mnt/share/' + item['path'], ds) + codes = codes.cpu() + writer(str(idx), codes) + for i0 in range(codes.shape[-1]): + bar[0, codes[0, 0, i0]] += 1 + bar[1, codes[0, 1, i0]] += 1 + bar[2, codes[0, 2, i0]] += 1 + bar[3, codes[0, 3, i0]] += 1 + except Exception as e: + print(item['path']) + # print(e.message, e.args) + # exit(1) + continue + + if(item_idx % 1000 == 0): + print("=========") + print(1 - (bar[0]==0).sum() / bar.shape[-1]) + print("=========") + + # idx = item['idx'] + # # print(idx) + # with torch.autocast(device_type="cuda", dtype=torch.float16): + # codes = tango.file2code(item['path']) + # writer(str(idx), codes.cpu()) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/generate_1rvq.py b/codeclm/tokenizer/Flow1dVAE/generate_1rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..fe6ff0765580ae05f3f87bd708618e3954173dbf --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/generate_1rvq.py @@ -0,0 +1,283 @@ +import json +import torch +from tqdm import tqdm +from model_1rvq import PromptCondAudioDiffusion +from diffusers import DDIMScheduler, DDPMScheduler +import torchaudio +import librosa +import os +import math +import numpy as np +from tools.get_1dvae_large import get_model +import tools.torch_tools as torch_tools +from safetensors.torch import load_file + +class Tango: + def __init__(self, \ + model_path, \ + vae_config="", + vae_model="", + layer_num=6, \ + device="cuda:0"): + + self.sample_rate = 48000 + scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json" + self.device = device + + self.vae = get_model(vae_config, vae_model) + self.vae = self.vae.to(device) + self.vae=self.vae.eval() + self.layer_num = layer_num + + self.MAX_DURATION = 360 + main_config = { + "num_channels":32, + "unet_model_name":None, + "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json", + "snr_gamma":None, + } + self.model = PromptCondAudioDiffusion(**main_config).to(device) + if model_path.endswith(".safetensors"): + main_weights = load_file(model_path) + else: + main_weights = torch.load(model_path, map_location=device) + self.model.load_state_dict(main_weights, strict=False) + print ("Successfully loaded checkpoint from:", model_path) + + self.model.eval() + self.model.init_device_dtype(torch.device(device), torch.float32) + print("scaling factor: ", self.model.normfeat.std) + + # self.scheduler = DDIMScheduler.from_pretrained( \ + # scheduler_name, subfolder="scheduler") + # self.scheduler = DDPMScheduler.from_pretrained( \ + # scheduler_name, subfolder="scheduler") + # print("Successfully loaded inference scheduler from {}".format(scheduler_name)) + + # def sound2sound(self, orig_samples, lyric, st_et, batch_size=1, duration=40.96, steps=200, disable_progress=False,scenario = "start_seg"): + # """ Genrate audio without condition. """ + # with torch.no_grad(): + # if(orig_samples.shape[-1] 3 B T + codes=codes[:,:,:output_len] + + return codes + + @torch.no_grad() + def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False): + codes = codes.to(self.device) + + min_samples = int(duration * 25) # 40ms per frame + hop_samples = min_samples // 4 * 3 + ovlp_samples = min_samples - hop_samples + hop_frames = hop_samples + ovlp_frames = ovlp_samples + first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device) + first_latent_length = 0 + first_latent_codes_length = 0 + + if(isinstance(prompt, torch.Tensor)): + # prepare prompt + prompt = prompt.to(self.device) + if(prompt.ndim == 3): + assert prompt.shape[0] == 1, prompt.shape + prompt = prompt[0] + elif(prompt.ndim == 1): + prompt = prompt.unsqueeze(0).repeat(2,1) + elif(prompt.ndim == 2): + if(prompt.shape[0] == 1): + prompt = prompt.repeat(2,1) + + if(prompt.shape[-1] < int(30 * self.sample_rate)): + # if less than 30s, just choose the first 10s + prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24 + else: + # else choose from 20.48s which might includes verse or chorus + prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24 + + true_latent = self.vae.encode_audio(prompt).permute(0,2,1) + # print("true_latent.shape", true_latent.shape) + # print("first_latent.shape", first_latent.shape) + #true_latent.shape torch.Size([1, 250, 64]) + # first_latent.shape torch.Size([1, 1000, 64]) + + first_latent[:,0:true_latent.shape[1],:] = true_latent + first_latent_length = true_latent.shape[1] + first_latent_codes = self.sound2code(prompt) + first_latent_codes_length = first_latent_codes.shape[-1] + codes = torch.cat([first_latent_codes, codes], -1) + + + + + codes_len= codes.shape[-1] + target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate) + # target_len = int(codes_len / 100 * 4 * self.sample_rate) + # code repeat + if(codes_len < min_samples): + while(codes.shape[-1] < min_samples): + codes = torch.cat([codes, codes], -1) + codes = codes[:,:,0:min_samples] + codes_len = codes.shape[-1] + if((codes_len - ovlp_samples) % hop_samples > 0): + len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples + while(codes.shape[-1] < len_codes): + codes = torch.cat([codes, codes], -1) + codes = codes[:,:,0:len_codes] + latent_length = min_samples + latent_list = [] + spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device) + with torch.autocast(device_type="cuda", dtype=torch.float16): + for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples): + codes_input=[] + codes_input.append(codes[:,:,sinx:sinx+min_samples]) + if(sinx == 0): + # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate)) + incontext_length = first_latent_length + latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') + latent_list.append(latents) + else: + # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate)) + true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1) + print("true_latent.shape", true_latent.shape) + len_add_to_1000 = min_samples - true_latent.shape[-2] + # print("len_add_to_1000", len_add_to_1000) + # exit() + incontext_length = true_latent.shape[-2] + true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2) + latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') + latent_list.append(latents) + + latent_list = [l.float() for l in latent_list] + latent_list[0] = latent_list[0][:,:,first_latent_length:] + min_samples = int(min_samples * self.sample_rate // 1000 * 40) + hop_samples = int(hop_samples * self.sample_rate // 1000 * 40) + ovlp_samples = min_samples - hop_samples + with torch.no_grad(): + output = None + for i in range(len(latent_list)): + latent = latent_list[i] + cur_output = self.vae.decode_audio(latent)[0].detach().cpu() + + if output is None: + output = cur_output + else: + ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :]) + ov_win = torch.cat([ov_win, 1 - ov_win], -1) + print("output.shape", output.shape) + print("ov_win.shape", ov_win.shape) + output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples] + output = torch.cat([output, cur_output[:, ovlp_samples:]], -1) + output = output[:, 0:target_len] + return output + + @torch.no_grad() + def preprocess_audio(self, input_audios, threshold=0.8): + assert len(input_audios.shape) == 3, input_audios.shape + nchan = input_audios.shape[1] + input_audios = input_audios.reshape(input_audios.shape[0], -1) + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1) + + @torch.no_grad() + def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False): + codes = self.sound2code(sound) + # print(codes.shape) + wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress) + # print(fname, wave.shape) + return wave + + @torch.no_grad() + def sound2sound_vae(self, sound, prompt=None, steps=50, disable_progress=False): + min_samples = int(40 * 25) # 40ms per frame + hop_samples = min_samples // 4 * 3 + ovlp_samples = min_samples - hop_samples + dur = 20 + + latent_list = [] + for i in range(0, sound.shape[-1], dur*48000): + if(i+dur*2*48000 > sound.shape[-1]): + latent = tango.vae.encode_audio(sound.cuda()[None,:,i:]) + break + else: + latent = tango.vae.encode_audio(sound.cuda()[None,:,i:i+dur*48000]) + latent_list.append(latent) + + output = None + for i in range(len(latent_list)): + print(i) + latent = latent_list[i] + cur_output = self.vae.decode_audio(latent)[0].detach().cpu() + if output is None: + output = cur_output + else: + output = torch.cat([output, cur_output], -1) + return output diff --git a/codeclm/tokenizer/Flow1dVAE/generate_2rvq.py b/codeclm/tokenizer/Flow1dVAE/generate_2rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..cc3f21bf1c3db326080701018eb5eee18902ec82 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/generate_2rvq.py @@ -0,0 +1,294 @@ +import json +import torch +from tqdm import tqdm +from model_2rvq import PromptCondAudioDiffusion +from diffusers import DDIMScheduler, DDPMScheduler +import torchaudio +import librosa +import os +import math +import numpy as np +# from tools.get_mulan import get_mulan +from tools.get_1dvae_large import get_model +import tools.torch_tools as torch_tools +from safetensors.torch import load_file +from audio import AudioFile +import kaldiio + +class Tango: + def __init__(self, \ + model_path, \ + layer_num=6, \ + rvq_num=1, \ + device="cuda:0"): + + self.sample_rate = 48000 + scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json" + self.device = device + + self.vae = get_model() + self.vae = self.vae.to(device) + self.vae=self.vae.eval() + self.layer_num = layer_num + + self.MAX_DURATION = 360 + main_config = { + "num_channels":32, + "unet_model_name":None, + "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json", + "snr_gamma":None, + } + self.rvq_num = rvq_num + # print("rvq_num: ", self.rvq_num) + # exit() + self.model = PromptCondAudioDiffusion(**main_config).to(device) + if model_path.endswith(".safetensors"): + main_weights = load_file(model_path) + else: + main_weights = torch.load(model_path, map_location=device) + self.model.load_state_dict(main_weights, strict=False) + print ("Successfully loaded checkpoint from:", model_path) + + self.model.eval() + self.model.init_device_dtype(torch.device(device), torch.float32) + print("scaling factor: ", self.model.normfeat.std) + + # self.scheduler = DDIMScheduler.from_pretrained( \ + # scheduler_name, subfolder="scheduler") + # self.scheduler = DDPMScheduler.from_pretrained( \ + # scheduler_name, subfolder="scheduler") + print("Successfully loaded inference scheduler from {}".format(scheduler_name)) + + + + @torch.no_grad() + @torch.autocast(device_type="cuda", dtype=torch.float32) + def sound2code(self, orig_samples, batch_size=8): + if(orig_samples.ndim == 2): + audios = orig_samples.unsqueeze(0).to(self.device) + elif(orig_samples.ndim == 3): + audios = orig_samples.to(self.device) + else: + assert orig_samples.ndim in (2,3), orig_samples.shape + audios = self.preprocess_audio(audios) + audios = audios.squeeze(0) + orig_length = audios.shape[-1] + min_samples = int(40 * self.sample_rate) + # 40秒对应10个token + output_len = int(orig_length / float(self.sample_rate) * 25) + 1 + # print("output_len: ", output_len) + + while(audios.shape[-1] < min_samples): + audios = torch.cat([audios, audios], -1) + int_max_len=audios.shape[-1]//min_samples+1 + audios = torch.cat([audios, audios], -1) + audios=audios[:,:int(int_max_len*(min_samples))] + codes_list=[] + + audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples) + + for audio_inx in range(0, audio_input.shape[0], batch_size): + # import pdb; pdb.set_trace() + codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num) + # print("codes",codes[0].shape) + + codes_list.append(torch.cat(codes, 1)) + # print("codes_list",codes_list[0].shape) + + codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T + codes=codes[:,:,:output_len] + + return codes + + @torch.no_grad() + @torch.autocast(device_type="cuda", dtype=torch.float32) + def sound2code_ds(self, orig_samples, ds, batch_size=8): + if(orig_samples.ndim == 2): + audios = orig_samples.unsqueeze(0).to(self.device) + elif(orig_samples.ndim == 3): + audios = orig_samples.to(self.device) + else: + assert orig_samples.ndim in (2,3), orig_samples.shape + audios = self.preprocess_audio(audios) + audios = audios.squeeze(0) + orig_length = audios.shape[-1] + min_samples = int(40 * self.sample_rate) + # 40秒对应10个token + output_len = int(orig_length / float(self.sample_rate) * 25) + 1 + # print("output_len: ", output_len) + + while(audios.shape[-1] < min_samples): + audios = torch.cat([audios, audios], -1) + int_max_len=audios.shape[-1]//min_samples+1 + audios = torch.cat([audios, audios], -1) + audios=audios[:,:int(int_max_len*(min_samples))] + codes_list=[] + + audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples) + + for audio_inx in range(0, audio_input.shape[0], batch_size): + # import pdb; pdb.set_trace() + codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds) + # print("codes",codes[0].shape) + + codes_list.append(torch.cat(codes, 1)) + # print("codes_list",codes_list[0].shape) + + codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T + codes=codes[:,:,:output_len] + + return codes + + @torch.no_grad() + def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False): + codes = codes.to(self.device) + + min_samples = duration * 25 # 40ms per frame + hop_samples = min_samples // 4 * 3 + ovlp_samples = min_samples - hop_samples + hop_frames = hop_samples + ovlp_frames = ovlp_samples + first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device) + first_latent_length = 0 + first_latent_codes_length = 0 + + if(isinstance(prompt, torch.Tensor)): + # prepare prompt + prompt = prompt.to(self.device) + if(prompt.ndim == 3): + assert prompt.shape[0] == 1, prompt.shape + prompt = prompt[0] + elif(prompt.ndim == 1): + prompt = prompt.unsqueeze(0).repeat(2,1) + elif(prompt.ndim == 2): + if(prompt.shape[0] == 1): + prompt = prompt.repeat(2,1) + + if(prompt.shape[-1] < int(30 * self.sample_rate)): + # if less than 30s, just choose the first 10s + prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24 + else: + # else choose from 20.48s which might includes verse or chorus + prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24 + + true_latent = self.vae.encode_audio(prompt).permute(0,2,1) + # print("true_latent.shape", true_latent.shape) + # print("first_latent.shape", first_latent.shape) + #true_latent.shape torch.Size([1, 250, 64]) + # first_latent.shape torch.Size([1, 1000, 64]) + + first_latent[:,0:true_latent.shape[1],:] = true_latent + first_latent_length = true_latent.shape[1] + first_latent_codes = self.sound2code(prompt) + first_latent_codes_length = first_latent_codes.shape[-1] + codes = torch.cat([first_latent_codes, codes], -1) + + codes_len= codes.shape[-1] + target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate) + # target_len = int(codes_len / 100 * 4 * self.sample_rate) + # code repeat + if(codes_len < min_samples): + while(codes.shape[-1] < min_samples): + codes = torch.cat([codes, codes], -1) + codes = codes[:,:,0:min_samples] + codes_len = codes.shape[-1] + if((codes_len - ovlp_samples) % hop_samples > 0): + len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples + while(codes.shape[-1] < len_codes): + codes = torch.cat([codes, codes], -1) + codes = codes[:,:,0:len_codes] + latent_length = min_samples + latent_list = [] + spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device) + with torch.autocast(device_type="cuda", dtype=torch.float16): + for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples): + codes_input=[] + codes_input.append(codes[:,:,sinx:sinx+min_samples]) + if(sinx == 0): + # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate)) + incontext_length = first_latent_length + latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') + latent_list.append(latents) + else: + # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate)) + true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1) + print("true_latent.shape", true_latent.shape) + len_add_to_1000 = 1000 - true_latent.shape[-2] + # print("len_add_to_1000", len_add_to_1000) + # exit() + incontext_length = true_latent.shape[-2] + true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2) + latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') + latent_list.append(latents) + + latent_list = [l.float() for l in latent_list] + latent_list[0] = latent_list[0][:,:,first_latent_length:] + min_samples = int(min_samples * self.sample_rate // 1000 * 40) + hop_samples = int(hop_samples * self.sample_rate // 1000 * 40) + ovlp_samples = min_samples - hop_samples + with torch.no_grad(): + output = None + for i in range(len(latent_list)): + latent = latent_list[i] + cur_output = self.vae.decode_audio(latent)[0].detach().cpu() + + if output is None: + output = cur_output + else: + ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :]) + ov_win = torch.cat([ov_win, 1 - ov_win], -1) + print("output.shape", output.shape) + print("ov_win.shape", ov_win.shape) + output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples] + output = torch.cat([output, cur_output[:, ovlp_samples:]], -1) + output = output[:, 0:target_len] + return output + + @torch.no_grad() + def preprocess_audio(self, input_audios, threshold=0.8): + assert len(input_audios.shape) == 3, input_audios.shape + nchan = input_audios.shape[1] + input_audios = input_audios.reshape(input_audios.shape[0], -1) + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1) + + @torch.no_grad() + def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False): + codes = self.sound2code(sound) + # print(codes.shape) + # exit() + wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress) + # print(fname, wave.shape) + return wave + + def file2code(self, fname): + try: + orig_samples, fs = torchaudio.load(fname) + except: + af = AudioFile(fname) + orig_samples = af.read() + fs = af.samplerate() + orig_samples = orig_samples[0] + if(fs!=self.sample_rate): + orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate) + fs = self.sample_rate + if orig_samples.shape[0] == 1: + orig_samples = torch.cat([orig_samples, orig_samples], 0) + return self.sound2code(orig_samples) + + def file2code_ds(self, fname, ds): + try: + orig_samples, fs = torchaudio.load(fname) + except: + af = AudioFile(fname) + orig_samples = af.read() + fs = af.samplerate() + orig_samples = orig_samples[0] + if(fs!=self.sample_rate): + orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate) + fs = self.sample_rate + if orig_samples.shape[0] == 1: + orig_samples = torch.cat([orig_samples, orig_samples], 0) + return self.sound2code_ds(orig_samples, ds) diff --git a/codeclm/tokenizer/Flow1dVAE/generate_4rvq.py b/codeclm/tokenizer/Flow1dVAE/generate_4rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..e87e43ed247fb8e10f1cf11139c7b3ace0ca1493 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/generate_4rvq.py @@ -0,0 +1,293 @@ +import json +import torch +from tqdm import tqdm +from model_4rvq import PromptCondAudioDiffusion +from diffusers import DDIMScheduler, DDPMScheduler +import torchaudio +import librosa +import os +import math +import numpy as np +# from tools.get_mulan import get_mulan +from tools.get_1dvae_large import get_model +import tools.torch_tools as torch_tools +from safetensors.torch import load_file +from audio import AudioFile + +class Tango: + def __init__(self, \ + model_path, \ + layer_num=6, \ + rvq_num=1, \ + device="cuda:0"): + + self.sample_rate = 48000 + scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json" + self.device = device + + self.vae = get_model() + self.vae = self.vae.to(device) + self.vae=self.vae.eval() + self.layer_num = layer_num + + self.MAX_DURATION = 360 + main_config = { + "num_channels":32, + "unet_model_name":None, + "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json", + "snr_gamma":None, + } + self.rvq_num = rvq_num + # print("rvq_num: ", self.rvq_num) + # exit() + self.model = PromptCondAudioDiffusion(**main_config).to(device) + if model_path.endswith(".safetensors"): + main_weights = load_file(model_path) + else: + main_weights = torch.load(model_path, map_location=device) + self.model.load_state_dict(main_weights, strict=False) + print ("Successfully loaded checkpoint from:", model_path) + + self.model.eval() + self.model.init_device_dtype(torch.device(device), torch.float32) + print("scaling factor: ", self.model.normfeat.std) + + # self.scheduler = DDIMScheduler.from_pretrained( \ + # scheduler_name, subfolder="scheduler") + # self.scheduler = DDPMScheduler.from_pretrained( \ + # scheduler_name, subfolder="scheduler") + print("Successfully loaded inference scheduler from {}".format(scheduler_name)) + + + + @torch.no_grad() + @torch.autocast(device_type="cuda", dtype=torch.float32) + def sound2code(self, orig_samples, batch_size=8): + if(orig_samples.ndim == 2): + audios = orig_samples.unsqueeze(0).to(self.device) + elif(orig_samples.ndim == 3): + audios = orig_samples.to(self.device) + else: + assert orig_samples.ndim in (2,3), orig_samples.shape + audios = self.preprocess_audio(audios) + audios = audios.squeeze(0) + orig_length = audios.shape[-1] + min_samples = int(40 * self.sample_rate) + # 40秒对应10个token + output_len = int(orig_length / float(self.sample_rate) * 25) + 1 + # print("output_len: ", output_len) + + while(audios.shape[-1] < min_samples): + audios = torch.cat([audios, audios], -1) + int_max_len=audios.shape[-1]//min_samples+1 + audios = torch.cat([audios, audios], -1) + audios=audios[:,:int(int_max_len*(min_samples))] + codes_list=[] + + audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples) + + for audio_inx in range(0, audio_input.shape[0], batch_size): + # import pdb; pdb.set_trace() + codes, _, spk_embeds = self.model.fetch_codes_batch((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num) + # print("codes",codes[0].shape) + + codes_list.append(torch.cat(codes, 1)) + # print("codes_list",codes_list[0].shape) + + codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T + codes=codes[:,:,:output_len] + + return codes + + @torch.no_grad() + @torch.autocast(device_type="cuda", dtype=torch.float32) + def sound2code_ds(self, orig_samples, ds, batch_size=6): + if(orig_samples.ndim == 2): + audios = orig_samples.unsqueeze(0).to(self.device) + elif(orig_samples.ndim == 3): + audios = orig_samples.to(self.device) + else: + assert orig_samples.ndim in (2,3), orig_samples.shape + audios = self.preprocess_audio(audios) + audios = audios.squeeze(0) + orig_length = audios.shape[-1] + min_samples = int(40 * self.sample_rate) + # 40秒对应10个token + output_len = int(orig_length / float(self.sample_rate) * 25) + 1 + # print("output_len: ", output_len) + + while(audios.shape[-1] < min_samples): + audios = torch.cat([audios, audios], -1) + int_max_len=audios.shape[-1]//min_samples+1 + audios = torch.cat([audios, audios], -1) + audios=audios[:,:int(int_max_len*(min_samples))] + codes_list=[] + + audio_input = audios.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples) + + for audio_inx in range(0, audio_input.shape[0], batch_size): + # import pdb; pdb.set_trace() + codes, _, spk_embeds = self.model.fetch_codes_batch_ds((audio_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer=self.layer_num, rvq_num=self.rvq_num, ds=ds) + # print("codes",codes[0].shape) + + codes_list.append(torch.cat(codes, 1)) + # print("codes_list",codes_list[0].shape) + + codes = torch.cat(codes_list, 0).permute(1,0,2).reshape(self.rvq_num, -1)[None] # B 3 T -> 3 B T + codes=codes[:,:,:output_len] + + return codes + + @torch.no_grad() + def code2sound(self, codes, prompt=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False): + codes = codes.to(self.device) + + min_samples = duration * 25 # 40ms per frame + hop_samples = min_samples // 4 * 3 + ovlp_samples = min_samples - hop_samples + hop_frames = hop_samples + ovlp_frames = ovlp_samples + first_latent = torch.randn(codes.shape[0], min_samples, 64).to(self.device) + first_latent_length = 0 + first_latent_codes_length = 0 + + if(isinstance(prompt, torch.Tensor)): + # prepare prompt + prompt = prompt.to(self.device) + if(prompt.ndim == 3): + assert prompt.shape[0] == 1, prompt.shape + prompt = prompt[0] + elif(prompt.ndim == 1): + prompt = prompt.unsqueeze(0).repeat(2,1) + elif(prompt.ndim == 2): + if(prompt.shape[0] == 1): + prompt = prompt.repeat(2,1) + + if(prompt.shape[-1] < int(30 * self.sample_rate)): + # if less than 30s, just choose the first 10s + prompt = prompt[:,:int(10*self.sample_rate)] # limit max length to 10.24 + else: + # else choose from 20.48s which might includes verse or chorus + prompt = prompt[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24 + + true_latent = self.vae.encode_audio(prompt).permute(0,2,1) + # print("true_latent.shape", true_latent.shape) + # print("first_latent.shape", first_latent.shape) + #true_latent.shape torch.Size([1, 250, 64]) + # first_latent.shape torch.Size([1, 1000, 64]) + + first_latent[:,0:true_latent.shape[1],:] = true_latent + first_latent_length = true_latent.shape[1] + first_latent_codes = self.sound2code(prompt) + first_latent_codes_length = first_latent_codes.shape[-1] + codes = torch.cat([first_latent_codes, codes], -1) + + codes_len= codes.shape[-1] + target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate) + # target_len = int(codes_len / 100 * 4 * self.sample_rate) + # code repeat + if(codes_len < min_samples): + while(codes.shape[-1] < min_samples): + codes = torch.cat([codes, codes], -1) + codes = codes[:,:,0:min_samples] + codes_len = codes.shape[-1] + if((codes_len - ovlp_samples) % hop_samples > 0): + len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples + while(codes.shape[-1] < len_codes): + codes = torch.cat([codes, codes], -1) + codes = codes[:,:,0:len_codes] + latent_length = min_samples + latent_list = [] + spk_embeds = torch.zeros([1, 32, 1, 32], device=codes.device) + with torch.autocast(device_type="cuda", dtype=torch.float16): + for sinx in range(0, codes.shape[-1]-hop_samples, hop_samples): + codes_input=[] + codes_input.append(codes[:,:,sinx:sinx+min_samples]) + if(sinx == 0): + # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate)) + incontext_length = first_latent_length + latents = self.model.inference_codes(codes_input, spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') + latent_list.append(latents) + else: + # print("Processing {} to {}".format(sinx/self.sample_rate, (sinx + min_samples)/self.sample_rate)) + true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1) + print("true_latent.shape", true_latent.shape) + len_add_to_1000 = 1000 - true_latent.shape[-2] + # print("len_add_to_1000", len_add_to_1000) + # exit() + incontext_length = true_latent.shape[-2] + true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2) + latents = self.model.inference_codes(codes_input, spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') + latent_list.append(latents) + + latent_list = [l.float() for l in latent_list] + latent_list[0] = latent_list[0][:,:,first_latent_length:] + min_samples = int(min_samples * self.sample_rate // 1000 * 40) + hop_samples = int(hop_samples * self.sample_rate // 1000 * 40) + ovlp_samples = min_samples - hop_samples + with torch.no_grad(): + output = None + for i in range(len(latent_list)): + latent = latent_list[i] + cur_output = self.vae.decode_audio(latent)[0].detach().cpu() + + if output is None: + output = cur_output + else: + ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :]) + ov_win = torch.cat([ov_win, 1 - ov_win], -1) + print("output.shape", output.shape) + print("ov_win.shape", ov_win.shape) + output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples] + output = torch.cat([output, cur_output[:, ovlp_samples:]], -1) + output = output[:, 0:target_len] + return output + + @torch.no_grad() + def preprocess_audio(self, input_audios, threshold=0.8): + assert len(input_audios.shape) == 3, input_audios.shape + nchan = input_audios.shape[1] + input_audios = input_audios.reshape(input_audios.shape[0], -1) + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios.reshape(input_audios.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1) + + @torch.no_grad() + def sound2sound(self, sound, prompt=None, steps=50, disable_progress=False): + codes = self.sound2code(sound) + # print(codes.shape) + # exit() + wave = self.code2sound(codes, prompt, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress) + # print(fname, wave.shape) + return wave + + def file2code(self, fname): + try: + orig_samples, fs = torchaudio.load(fname) + except: + af = AudioFile(fname) + orig_samples = af.read() + fs = af.samplerate() + orig_samples = orig_samples[0] + if(fs!=self.sample_rate): + orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate) + fs = self.sample_rate + if orig_samples.shape[0] == 1: + orig_samples = torch.cat([orig_samples, orig_samples], 0) + return self.sound2code(orig_samples) + + def file2code_ds(self, fname, ds): + try: + orig_samples, fs = torchaudio.load(fname) + except: + af = AudioFile(fname) + orig_samples = af.read() + fs = af.samplerate() + orig_samples = orig_samples[0] + if(fs!=self.sample_rate): + orig_samples = torchaudio.functional.resample(orig_samples, fs, self.sample_rate) + fs = self.sample_rate + if orig_samples.shape[0] == 1: + orig_samples = torch.cat([orig_samples, orig_samples], 0) + return self.sound2code_ds(orig_samples, ds) diff --git a/codeclm/tokenizer/Flow1dVAE/generate_septoken.py b/codeclm/tokenizer/Flow1dVAE/generate_septoken.py new file mode 100644 index 0000000000000000000000000000000000000000..883e28d0252515321b931bed9114625ce0fbb07a --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/generate_septoken.py @@ -0,0 +1,302 @@ +import json +import torch +from tqdm import tqdm +from model_septoken import PromptCondAudioDiffusion +from diffusers import DDIMScheduler, DDPMScheduler +import torchaudio +import librosa +import os +import math +import numpy as np +# from tools.get_mulan import get_mulan +from tools.get_1dvae_large import get_model +import tools.torch_tools as torch_tools +from safetensors.torch import load_file +from third_party.demucs.models.pretrained import get_model_from_yaml +from filelock import FileLock +import kaldiio +# os.path.join(args.model_dir, "htdemucs.pth"), os.path.join(args.model_dir, "htdemucs.yaml") +class Separator: + def __init__(self, dm_model_path='demucs/ckpt/htdemucs.pth', dm_config_path='demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None: + if torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): + self.device = torch.device(f"cuda:{gpu_id}") + else: + self.device = torch.device("cpu") + self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path) + + def init_demucs_model(self, model_path, config_path): + model = get_model_from_yaml(config_path, model_path) + model.to(self.device) + model.eval() + return model + + def load_audio(self, f): + a, fs = torchaudio.load(f) + if (fs != 48000): + a = torchaudio.functional.resample(a, fs, 48000) + # if a.shape[-1] >= 48000*10: + # a = a[..., :48000*10] + # else: + # a = torch.cat([a, a], -1) + # return a[:, 0:48000*10] + return a + + def run(self, audio_path, output_dir='demucs/test_output', ext=".flac"): + name, _ = os.path.splitext(os.path.split(audio_path)[-1]) + output_paths = [] + # lock_path = os.path.join(output_dir, f"{name}.lock") + # with FileLock(lock_path): # 加一个避免多卡访问时死锁 + for stem in self.demucs_model.sources: + output_path = os.path.join(output_dir, f"{name}_{stem}{ext}") + if os.path.exists(output_path): + output_paths.append(output_path) + if len(output_paths) == 1: # 4 + # drums_path, bass_path, other_path, vocal_path = output_paths + vocal_path = output_paths[0] + else: + lock_path = os.path.join(output_dir, f"{name}_separate.lock") + with FileLock(lock_path): + drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device) + full_audio = self.load_audio(audio_path) + vocal_audio = self.load_audio(vocal_path) + minlen = min(full_audio.shape[-1], vocal_audio.shape[-1]) + # bgm_audio = full_audio[:, 0:minlen] - vocal_audio[:, 0:minlen] + bgm_audio = self.load_audio(drums_path) + self.load_audio(bass_path) + self.load_audio(other_path) + for path in [drums_path, bass_path, other_path, vocal_path]: + os.remove(path) + return full_audio, vocal_audio, bgm_audio + +class Tango: + def __init__(self, \ + model_path, \ + vae_config, + vae_model, + layer_vocal=7,\ + layer_bgm=3,\ + device="cuda:0"): + + self.sample_rate = 48000 + scheduler_name = "configs/scheduler/stable_diffusion_2.1_largenoise_sample.json" + self.device = device + + self.vae = get_model(vae_config, vae_model) + self.vae = self.vae.to(device) + self.vae=self.vae.eval() + self.layer_vocal=layer_vocal + self.layer_bgm=layer_bgm + + self.MAX_DURATION = 360 + main_config = { + "num_channels":32, + "unet_model_name":None, + "unet_model_config_path":"configs/models/transformer2D_wocross_inch112_1x4_multi_large.json", + "snr_gamma":None, + } + self.model = PromptCondAudioDiffusion(**main_config).to(device) + if model_path.endswith(".safetensors"): + main_weights = load_file(model_path) + else: + main_weights = torch.load(model_path, map_location=device) + self.model.load_state_dict(main_weights, strict=False) + print ("Successfully loaded checkpoint from:", model_path) + + self.model.eval() + self.model.init_device_dtype(torch.device(device), torch.float32) + print("scaling factor: ", self.model.normfeat.std) + + # self.scheduler = DDIMScheduler.from_pretrained( \ + # scheduler_name, subfolder="scheduler") + # self.scheduler = DDPMScheduler.from_pretrained( \ + # scheduler_name, subfolder="scheduler") + print("Successfully loaded inference scheduler from {}".format(scheduler_name)) + + + @torch.no_grad() + @torch.autocast(device_type="cuda", dtype=torch.float32) + def sound2code(self, orig_vocal, orig_bgm, batch_size=8): + if(orig_vocal.ndim == 2): + audios_vocal = orig_vocal.unsqueeze(0).to(self.device) + elif(orig_vocal.ndim == 3): + audios_vocal = orig_vocal.to(self.device) + else: + assert orig_vocal.ndim in (2,3), orig_vocal.shape + + if(orig_bgm.ndim == 2): + audios_bgm = orig_bgm.unsqueeze(0).to(self.device) + elif(orig_bgm.ndim == 3): + audios_bgm = orig_bgm.to(self.device) + else: + assert orig_bgm.ndim in (2,3), orig_bgm.shape + + + audios_vocal = self.preprocess_audio(audios_vocal) + audios_vocal = audios_vocal.squeeze(0) + audios_bgm = self.preprocess_audio(audios_bgm) + audios_bgm = audios_bgm.squeeze(0) + if audios_vocal.shape[-1] > audios_bgm.shape[-1]: + audios_vocal = audios_vocal[:,:audios_bgm.shape[-1]] + else: + audios_bgm = audios_bgm[:,:audios_vocal.shape[-1]] + + + orig_length = audios_vocal.shape[-1] + min_samples = int(40 * self.sample_rate) + # 40秒对应10个token + output_len = int(orig_length / float(self.sample_rate) * 25) + 1 + + while(audios_vocal.shape[-1] < min_samples): + audios_vocal = torch.cat([audios_vocal, audios_vocal], -1) + audios_bgm = torch.cat([audios_bgm, audios_bgm], -1) + int_max_len=audios_vocal.shape[-1]//min_samples+1 + audios_vocal = torch.cat([audios_vocal, audios_vocal], -1) + audios_bgm = torch.cat([audios_bgm, audios_bgm], -1) + audios_vocal=audios_vocal[:,:int(int_max_len*(min_samples))] + audios_bgm=audios_bgm[:,:int(int_max_len*(min_samples))] + codes_vocal_list=[] + codes_bgm_list=[] + + + + audio_vocal_input = audios_vocal.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples) + audio_bgm_input = audios_bgm.reshape(2, -1, min_samples).permute(1, 0, 2).reshape(-1, 2, min_samples) + + for audio_inx in range(0, audio_vocal_input.shape[0], batch_size): + [codes_vocal,codes_bgm], _, spk_embeds = self.model.fetch_codes_batch((audio_vocal_input[audio_inx:audio_inx+batch_size]), (audio_bgm_input[audio_inx:audio_inx+batch_size]), additional_feats=[],layer_vocal=self.layer_vocal,layer_bgm=self.layer_bgm) + codes_vocal_list.append(codes_vocal) + codes_bgm_list.append(codes_bgm) + + codes_vocal = torch.cat(codes_vocal_list, 0).permute(1,0,2).reshape(1, -1)[None] + codes_bgm = torch.cat(codes_bgm_list, 0).permute(1,0,2).reshape(1, -1)[None] + codes_vocal=codes_vocal[:,:,:output_len] + codes_bgm=codes_bgm[:,:,:output_len] + + return codes_vocal, codes_bgm + + @torch.no_grad() + def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False): + codes_vocal,codes_bgm = codes + codes_vocal = codes_vocal.to(self.device) + codes_bgm = codes_bgm.to(self.device) + + min_samples = duration * 25 # 40ms per frame + hop_samples = min_samples // 4 * 3 + ovlp_samples = min_samples - hop_samples + hop_frames = hop_samples + ovlp_frames = ovlp_samples + first_latent = torch.randn(codes_vocal.shape[0], min_samples, 64).to(self.device) + first_latent_length = 0 + first_latent_codes_length = 0 + + + if(isinstance(prompt_vocal, torch.Tensor)): + # prepare prompt + prompt_vocal = prompt_vocal.to(self.device) + prompt_bgm = prompt_bgm.to(self.device) + if(prompt_vocal.ndim == 3): + assert prompt_vocal.shape[0] == 1, prompt_vocal.shape + prompt_vocal = prompt_vocal[0] + prompt_bgm = prompt_bgm[0] + elif(prompt_vocal.ndim == 1): + prompt_vocal = prompt_vocal.unsqueeze(0).repeat(2,1) + prompt_bgm = prompt_bgm.unsqueeze(0).repeat(2,1) + elif(prompt_vocal.ndim == 2): + if(prompt_vocal.shape[0] == 1): + prompt_vocal = prompt_vocal.repeat(2,1) + prompt_bgm = prompt_bgm.repeat(2,1) + + if(prompt_vocal.shape[-1] < int(30 * self.sample_rate)): + # if less than 30s, just choose the first 10s + prompt_vocal = prompt_vocal[:,:int(10*self.sample_rate)] # limit max length to 10.24 + prompt_bgm = prompt_bgm[:,:int(10*self.sample_rate)] # limit max length to 10.24 + else: + # else choose from 20.48s which might includes verse or chorus + prompt_vocal = prompt_vocal[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24 + prompt_bgm = prompt_bgm[:,int(20*self.sample_rate):int(30*self.sample_rate)] # limit max length to 10.24 + + true_latent = self.vae.encode_audio(prompt_vocal+prompt_bgm).permute(0,2,1) + + first_latent[:,0:true_latent.shape[1],:] = true_latent + first_latent_length = true_latent.shape[1] + first_latent_codes = self.sound2code(prompt_vocal, prompt_bgm) + first_latent_codes_vocal = first_latent_codes[0] + first_latent_codes_bgm = first_latent_codes[1] + first_latent_codes_length = first_latent_codes_vocal.shape[-1] + codes_vocal = torch.cat([first_latent_codes_vocal, codes_vocal], -1) + codes_bgm = torch.cat([first_latent_codes_bgm, codes_bgm], -1) + + + codes_len= codes_vocal.shape[-1] + target_len = int((codes_len - first_latent_codes_length) / 100 * 4 * self.sample_rate) + # target_len = int(codes_len / 100 * 4 * self.sample_rate) + # code repeat + if(codes_len < min_samples): + while(codes_vocal.shape[-1] < min_samples): + codes_vocal = torch.cat([codes_vocal, codes_vocal], -1) + codes_bgm = torch.cat([codes_bgm, codes_bgm], -1) + + codes_vocal = codes_vocal[:,:,0:min_samples] + codes_bgm = codes_bgm[:,:,0:min_samples] + codes_len = codes_vocal.shape[-1] + if((codes_len - ovlp_samples) % hop_samples > 0): + len_codes=math.ceil((codes_len - ovlp_samples) / float(hop_samples)) * hop_samples + ovlp_samples + while(codes_vocal.shape[-1] < len_codes): + codes_vocal = torch.cat([codes_vocal, codes_vocal], -1) + codes_bgm = torch.cat([codes_bgm, codes_bgm], -1) + codes_vocal = codes_vocal[:,:,0:len_codes] + codes_bgm = codes_bgm[:,:,0:len_codes] + latent_length = min_samples + latent_list = [] + spk_embeds = torch.zeros([1, 32, 1, 32], device=codes_vocal.device) + with torch.autocast(device_type="cuda", dtype=torch.float16): + for sinx in range(0, codes_vocal.shape[-1]-hop_samples, hop_samples): + codes_vocal_input=codes_vocal[:,:,sinx:sinx+min_samples] + codes_bgm_input=codes_bgm[:,:,sinx:sinx+min_samples] + if(sinx == 0): + incontext_length = first_latent_length + latents = self.model.inference_codes([codes_vocal_input,codes_bgm_input], spk_embeds, first_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') + latent_list.append(latents) + else: + true_latent = latent_list[-1][:,:,-ovlp_frames:].permute(0,2,1) + len_add_to_1000 = min_samples - true_latent.shape[-2] + incontext_length = true_latent.shape[-2] + true_latent = torch.cat([true_latent, torch.randn(true_latent.shape[0], len_add_to_1000, true_latent.shape[-1]).to(self.device)], -2) + latents = self.model.inference_codes([codes_vocal_input,codes_bgm_input], spk_embeds, true_latent, latent_length, incontext_length=incontext_length, additional_feats=[], guidance_scale=1.5, num_steps = num_steps, disable_progress=disable_progress, scenario='other_seg') + latent_list.append(latents) + + latent_list = [l.float() for l in latent_list] + latent_list[0] = latent_list[0][:,:,first_latent_length:] + min_samples = int(min_samples * self.sample_rate // 1000 * 40) + hop_samples = int(hop_samples * self.sample_rate // 1000 * 40) + ovlp_samples = min_samples - hop_samples + with torch.no_grad(): + output = None + for i in range(len(latent_list)): + latent = latent_list[i] + cur_output = self.vae.decode_audio(latent)[0].detach().cpu() + + if output is None: + output = cur_output + else: + ov_win = torch.from_numpy(np.linspace(0, 1, ovlp_samples)[None, :]) + ov_win = torch.cat([ov_win, 1 - ov_win], -1) + output[:, -ovlp_samples:] = output[:, -ovlp_samples:] * ov_win[:, -ovlp_samples:] + cur_output[:, 0:ovlp_samples] * ov_win[:, 0:ovlp_samples] + output = torch.cat([output, cur_output[:, ovlp_samples:]], -1) + output = output[:, 0:target_len] + return output + + @torch.no_grad() + def preprocess_audio(self, input_audios_vocal, threshold=0.8): + assert len(input_audios_vocal.shape) == 3, input_audios_vocal.shape + nchan = input_audios_vocal.shape[1] + input_audios_vocal = input_audios_vocal.reshape(input_audios_vocal.shape[0], -1) + norm_value = torch.ones_like(input_audios_vocal[:,0]) + max_volume = input_audios_vocal.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios_vocal.reshape(input_audios_vocal.shape[0], nchan, -1)/norm_value.unsqueeze(-1).unsqueeze(-1) + + @torch.no_grad() + def sound2sound(self, orig_vocal,orig_bgm, prompt_vocal=None,prompt_bgm=None, steps=50, disable_progress=False): + codes_vocal, codes_bgm = self.sound2code(orig_vocal,orig_bgm) + codes=[codes_vocal, codes_bgm] + wave = self.code2sound(codes, prompt_vocal,prompt_bgm, guidance_scale=1.5, num_steps=steps, disable_progress=disable_progress) + return wave diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f0af158226ee61ac9b3f268cc402779d9ae1ea00 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/datasets/MusicSoundMixedDataset.py @@ -0,0 +1,1278 @@ +from torch.utils.data import Dataset +from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List, Union +from beartype import beartype +from beartype.door import is_bearable +import random +import pandas as pd +import os +from torchaudio.functional import resample +import torch +import typing as tp +from pathlib import Path +import torchaudio as ta +import torch.nn.functional as F +import numpy as np +import json +import yaml +import torchaudio +import math +import re +from loguru import logger +import ffmpeg + +class Read_and_PadCrop_Normalized_T(torch.nn.Module): + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: + if self.n_samples < 0: #means not clip + chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) + t_start = 0. + t_end = 1.0 + offset = 0 + else: + if(duration<(float(self.n_samples)/self.sample_rate+1)): + # print(duration,(float(self.n_samples)/self.sample_rate+1)) + chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) + t_start = 0. + t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) + offset = 0 + # print('c1:',chunk.shape) + else: + offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + t_start = offset / float(cur_sample_rate) / duration + t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration + chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + # print('offset:',offset) + # print('c0:',chunk.shape) + # Pad with silence if necessary. + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + if(cur_sample_rate!=self.sample_rate): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) + # print('b:',self.sample_rate,chunk.shape) + + if self.n_samples > 0: + if chunk.shape[-1] < self.n_samples: + chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) + else: + chunk = chunk[:,0:self.n_samples] + seconds_start = math.floor(offset / cur_sample_rate) + seconds_total = math.floor(duration) + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total + ) + +class Read_and_PadCrop_Normalized_T_Avoid_Watermark(torch.nn.Module): + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True, w_start = 0, w_interval = 11.3): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + self.w_start = w_start + self.w_interval = w_interval + + def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: + if self.n_samples < 0: #means not clip + chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) + t_start = 0. + t_end = 1.0 + offset = 0 + else: + if(duration<(float(self.n_samples)/self.sample_rate+1)): + # print(duration,(float(self.n_samples)/self.sample_rate+1)) + chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) + t_start = 0. + t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) + offset = 0 + # print('c1:',chunk.shape) + else: + n_offset_option = (duration - self.w_start) // self.w_interval + if n_offset_option <= 1: + offset = 0 + else: + offset = int((random.randint(0,n_offset_option-1) * self.w_interval + self.w_start) * cur_sample_rate) + # offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + t_start = offset / float(cur_sample_rate) / duration + t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration + chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + # print('offset:',offset) + # print('c0:',chunk.shape) + # Pad with silence if necessary. + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + if(cur_sample_rate!=self.sample_rate): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) + # print('b:',self.sample_rate,chunk.shape) + + if self.n_samples > 0: + if chunk.shape[-1] < self.n_samples: + chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) + else: + chunk = chunk[:,0:self.n_samples] + seconds_start = math.floor(offset / cur_sample_rate) + seconds_total = math.floor(duration) + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total + ) + +USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替 +if USE_DUMMY_AUDIO: + logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!") + +class SafeAudioReader: + """ + This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data. + """ + def __init__(self, + duration: float, # 返回音频长度 + sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample + randomize: bool = True, + use_avoid_watermark_policy = False, + ): + self.n_samples = int(sample_rate * duration) + self.reader = ( + Read_and_PadCrop_Normalized_T_Avoid_Watermark if use_avoid_watermark_policy \ + else Read_and_PadCrop_Normalized_T + )(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize) + + #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数! + def __call__(self, + filepath: os.PathLike, # 音频路径 + origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取 + origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取 + ) -> torch.Tensor: + if USE_DUMMY_AUDIO: + wav = torch.zeros(self.n_samples, dtype=torch.float32) + return wav + try: + if origin_sample_rate is None or origin_duration is None: + # audio_info = torchaudio.info(filepath) + # origin_sample_rate = audio_info.sample_rate + # origin_duration = audio_info.num_frames / origin_sample_rate + info = ffmpeg.probe(filepath) + origin_duration = float(info['format']['duration']) + origin_sample_rate = int(info['streams'][0]['sample_rate']) + wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate) + wav = wav.squeeze_(0) + except Exception as e: + logger.error(f"Error reading {filepath}: {e}") + wav = torch.zeros(self.n_samples, dtype=torch.float32) + return wav + + +class PromptTemplate: + def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'): + self.template_text = template_text + self.tag_map = tag_map + self.lang = lang + + @property + def tags(self): + return tuple(self.tag_map.keys()) + + def apply(self, **kwargs): + for tag in list(kwargs.keys()): + if kwargs[tag] == '': + kwargs.pop(tag) + for tag in self.tags: + if tag in kwargs: + kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]') + else: + kwargs[tag] = '' + prompt = self.template_text.format(**kwargs) + + return self.beautify(prompt) + + def beautify(self, text): + if self.lang == 'en': + return self._beautify_en(text) + elif self.lang == 'zh': + return self._beautify_zh(text) + else: + raise ValueError(f'Unknown language {self.lang}') + + @staticmethod + def _beautify_en(text): + # no continuous commas without content between them + text = re.sub(r'[,\s]*,[,\s]*', r', ', text) + # no continuous whitespace + text = re.sub(r'\s+', ' ', text) + # the comma is NOT followed by whitespace, and should be followed by ONE whitespace + text = re.sub(r'\s+,', r',', text) + text = re.sub(r',\s+', r', ', text) + # no whitespace before the full stop + text = re.sub(r'\s+\.', r'.', text) + # strip whitespace, comma, and replace ',.' + text = text.strip(' ,') + text = text.replace(',.', '.') + return text + + @staticmethod + def _beautify_zh(text): + # no continuous commas without content between them + text = re.sub(r'[,、\s]*,[,、\s]*', r',', text) + text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text) + # assume there should be NO whitespace in Chinese + text = re.sub(r'\s+', r'', text) + # strip whitespace, comma, and replace ',。' + text = text.strip(', 、') + text = text.replace(',。', '。') + return text + + def __repr__(self): + return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})' + + __str__ = __repr__ + +def parse_prompt_template(prompt_template_text, lang='en'): + span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL) + tag_pattern = re.compile(r'{.+?}', re.DOTALL) + + template_text = prompt_template_text.strip() + span_texts = span_pattern.findall(prompt_template_text) + tag_map = {} + for span_text in span_texts: + tag = tag_pattern.findall(span_text)[0].strip('{}') + tag_map[tag] = span_text + template_text = template_text.replace(span_text, '{'+tag+'}') + + return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang) + +def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]: + with open(path, 'r') as f: + lines = f.readlines() + cnt = 0 + pts = [] + for line in lines: + pt = parse_prompt_template(line, lang=lang) + cnt += 1 + if len(pt.tags) < num: + logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}') + pts.append(pt) + + return pts + + +def get_base_dir_file(key: os.PathLike): + base = os.path.basename(key) + dirname = os.path.basename(os.path.dirname(key)) + return os.path.join(dirname, base) + +def read_jsonlike(path: os.PathLike): + #json or jsonl + if str(path).endswith(".json"): + with open(path, 'r', encoding='utf8') as f: + data = json.load(f) + return data + elif str(path).endswith(".jsonl"): + with open(path, 'r', encoding='utf8') as f: + data = [json.loads(line) for line in f.readlines()] + return data + else: + raise ValueError("Unknown file format") + +dist_prob_map = { + 1: (1.0,), + 2: (0.5, 0.5), + 3: (0.3, 0.4, 0.3), + 4: (0.2, 0.3, 0.3, 0.2), + 5: (0.2, 0.2, 0.3, 0.2, 0.1), + 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15), + 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), + 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), + 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), + 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) +} + +''' +#更加偏向短文本的方案 +dist_prob_map = { + 1: (1.0,), + 2: (0.7, 0.3), + 3: (0.7, 0.2, 0.1), + 4: (0.6, 0.2, 0.1, 0.1), + 5: (0.6, 0.2, 0.1, 0.05, 0.05), + 6: (0.6, 0.15, 0.1, 0.05, 0.05, 0.05), + 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), + 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), + 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), + 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) +} +''' + +#全部都用的方案 +# dist_prob_map = { +# 1: (1.0,), +# 2: (0, 1.0), +# 3: (0, 0, 1.0), +# 4: (0, 0, 0, 1.0), +# 5: (0, 0, 0, 0, 1.0), +# 6: (0, 0, 0, 0, 0, 1.0), +# 7: (0, 0, 0, 0, 0, 0, 1.0), +# 8: (0, 0, 0, 0, 0, 0, 0, 1.0), +# 9: (0, 0, 0, 0, 0, 0, 0, 0, 1.0), +# 10: (0, 0, 0, 0, 0, 0, 0, 0, 0, 1.0) +# } + +dist_prob_map_low = { + 1: (1.0,), + 2: (0.8, 0.2), + 3: (0.8, 0.1, 0.1), + 4: (0.7, 0.1, 0.1, 0.1), + 5: (0.7, 0.1, 0.1, 0.05, 0.05), + 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05), +} + +_bpm_range_rights = ( + (40, '20-40'), + (60, '40-60'), + (66, '60-66'), + (76, '66-76'), + (108, '76-108'), + (120, '108-120'), + (168, '120-168'), + (176, '168-176'), + (200, '176-200') +) +_bpm_desc_map = { + '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"), + '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"), + '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'), + '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'), + '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"), + '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'), + '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'), + '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'), + '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'), + '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo') +} +_bpm_desc_map_zh = { + '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"), + '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"), + '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'), + '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'), + '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"), + '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'), + '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'), + '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'), + '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'), + '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板') +} +def get_bpm_range(bpm): + bpm = int(bpm) + for right, tag in _bpm_range_rights: + if bpm <= right: + return tag + return '>200' + +def gen_bpm_descript(bpm, lang='en'): + bpm_range = get_bpm_range(bpm) + if lang == 'en': + return random.choice(_bpm_desc_map[bpm_range]) + elif lang == 'zh': + return random.choice(_bpm_desc_map_zh[bpm_range]) + else: + raise ValueError(f"Unknown language {lang}") + +def read_translate(translate: Union[Dict[str, os.PathLike], os.PathLike, None]): + if translate is None: + return None + if isinstance(translate, str): + return read_jsonlike(translate) + return {k: read_jsonlike(path) for k, path in translate.items()} + + +def gen_plain_prompt(key_list, sep=', '): + if len(key_list) == 0: + return 'none' + + key_list = [k.strip() for k in key_list] + + if len(key_list) > 10: + random.shuffle(key_list) + key_list = key_list[:10] + + probs = dist_prob_map[len(key_list)] + + num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0] + + random.shuffle(key_list) + tags = key_list[:num_tags] + tags_str = sep.join(tags) + return tags_str + + +class MagnaTagATuneDataset(Dataset): + def __init__(self): + pass + + +def tags_to_desc(tag_list, sep=',') -> str: + if not isinstance(tag_list, Sequence): + return str(tag_list) + if isinstance(tag_list, str): + return tag_list + if len(tag_list) <= 0: + return '' + elif len(tag_list) <= 5: + probs = dist_prob_map[len(tag_list)] + tags_num = random.choices(range(1, len(tag_list)+1), probs)[0] + random.shuffle(tag_list) + tag_list = tag_list[:tags_num] + return sep.join(tag_list) + else: + probs = dist_prob_map[5] + tags_num = random.choices(range(1, 6), probs)[0] + random.shuffle(tag_list) + tag_list = tag_list[:tags_num] + return sep.join(tag_list) + +def get_sr_and_duration_info(item): + return item.get('sample_rate', None), item.get('duration', None) + +class MtgJamendoDatasetFromJson(Dataset): + def __init__(self, + data_dir:str, + json_path:str, + duration:float=10, + sr:int = 0, + lang = 'en', + plain_rate = 0, + return_audio = True, + return_path = False, + prompt_template_path: os.PathLike = None, + tag_types = [], + translate:Optional[Dict[str, os.PathLike]] = None, + use_literal_none = True, + ): + self.audio_reader = SafeAudioReader(duration, sr) + + self.data_dir = data_dir + self._load_metadata_json(json_path) + self.sr = sr + self.duration = duration + self.plain_rate = plain_rate + self.return_audio = return_audio + self.return_path = return_path + self.use_literal_none = use_literal_none + self.lang = lang + + self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0 + if self.use_dynamic_prompt: + self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types)) + self.tag_types = tag_types + + self.translate = read_translate(translate) + + #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示 + WEAK_TAG_LIST = ["title", "artist"] + + def _load_metadata_json(self, json_path): + with open(json_path) as fp: + self.data = json.load(fp) + + def convert_key_to_path(self, key): + return os.path.join(self.data_dir, get_base_dir_file(key)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + path = self.convert_key_to_path(item['key']) + description = self.generate_description(item) + + if self.return_audio: + sr, duration = get_sr_and_duration_info(item) + audio = self.audio_reader(path, sr, duration) + else: + audio = None + + if self.return_path: + return audio, description, path + return audio, description + + def tags_to_desc(self, tag_list, tag_type) -> str: + if self.lang == 'en': + return tags_to_desc(tag_list) + elif self.lang == 'zh': + translator = self.translate[tag_type] + translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] + return tags_to_desc(translated_tag_list, sep='、') + + def generate_description(self, item): + if random.random() > self.plain_rate: + # dynamically generate prompt from given prompt template + prompt_template = random.choice(self.prompt_templates) + description = self.generate_description_dynamic(item, prompt_template) + else: + # use plain prompt, i.e. tags sequence separated by comma + description = self.generate_description_plain(item) + return description + + def generate_description_dynamic(self, data, prompt_template: PromptTemplate): + exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] + exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag)) + exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag)) + + if len(exists_strong_tag) > 0: + probs = dist_prob_map[len(exists_strong_tag)] + tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0] + random.shuffle(exists_strong_tag) + tags = exists_strong_tag[:tags_num] + weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1] + weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0] + random.shuffle(exists_weak_tag) + weak_tags = exists_weak_tag[:weak_tags_num] + tags += weak_tags + tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} + prompt = prompt_template.apply(**tags_args) + else: + # no strong tags, use all weak tags instead + tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag} + prompt = prompt_template.apply(**tags_args) + + if self.use_literal_none and len(tags_args) == 0: + return 'none' + + return prompt + + def generate_description_plain(self, item): + keywords = [] + for tag_t in self.tag_types: + this_key = item[tag_t] + if this_key is None: + continue + if isinstance(this_key, str): + this_key = [this_key] + if self.lang != 'en': + this_key = [self.get_translation(tag_t, k) for k in this_key] + keywords += this_key + return gen_plain_prompt(keywords, sep=self.keysep) + + def get_translation(self, tag_t, k): + k = k.strip() + if k in self.translate[tag_t]: + return self.translate[tag_t][k] + else: + return k + + @property + def keysep(self): + if self.lang == 'zh': + return ',' if random.random() > 0.5 else '、' + elif self.lang == 'en': + return ', ' + +class AudioStockDataset(Dataset): + def __init__(self, + metadata_path:str, + duration:float=10, + sr:int = 0, + plain_rate = 0, + return_path = False, + return_audio = True, + prompt_template_path: os.PathLike = None, + tag_types = [], + lang = 'en', + translate:Optional[Dict[str, os.PathLike]] = None, + use_literal_none = True, + ): + self.audio_reader = SafeAudioReader(duration, sr) + + self._load_metadata(metadata_path) + self.sr = sr + self.duration = duration + self.plain_rate = plain_rate + self.return_path = return_path + self.return_audio = return_audio + self.use_literal_none = use_literal_none + + self.use_dynamic_prompt = prompt_template_path is not None and plain_rate < 1.0 + if self.use_dynamic_prompt: + self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang) + self.tag_types = tag_types + + self.lang = lang + self.translate = read_translate(translate) + + def _load_metadata(self, metadata_path): + with open(metadata_path) as fp: + lines = fp.readlines() + self.data = [] + for line in lines: + item = json.loads(line) + self.data.append(item) + self.is_info_recorded = bool('Tags' in self.data[0]) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + path:str = self.data[idx]["path"] + json_path = path[:path.rfind('.')] + ".json" + if self.is_info_recorded: + item = self.data[idx] + else: + try: + with open(json_path) as fp: + item:dict = json.load(fp) + except Exception as e: + print(f"Error loading json file {json_path} :\n{e}") + item = {} + description = self.generate_description(item) + if self.return_audio: + sr, duration = get_sr_and_duration_info(item) + audio = self.audio_reader(path, sr, duration) + else: + audio = None + if self.return_path: + return audio, description, path + return audio, description + + def generate_description(self, item): + if random.random() > self.plain_rate: + # dynamically generate prompt from given prompt template + prompt_template = random.choice(self.prompt_templates) + description = self.generate_description_dynamic(item, prompt_template) + else: + # use plain prompt, i.e. tags sequence separated by comma + description = self.generate_description_plain(item) + return description + + def generate_description_dynamic(self, data, prompt_template: PromptTemplate): + exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] + + if len(exists_tag) > 0: + probs = dist_prob_map[len(exists_tag)] + tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0] + random.shuffle(exists_tag) + tags = exists_tag[:tags_num] + tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} + tags_args = self.handle_BPM_tag(tags_args) + prompt = prompt_template.apply(**tags_args) + else: + return 'none' + + if self.use_literal_none and len(tags_args) == 0: + return 'none' + + return prompt + + def get_translation(self, tag_t, k): + k = k.strip() + if k in self.translate[tag_t]: + return self.translate[tag_t][k] + else: + return k + + def generate_description_plain(self, item): + keywords = [] + for tag_t in self.tag_types: + if tag_t == 'BPMDescript': + bpm = item['BPM'] + if bpm is None or bpm.strip() == '' or bpm.strip() == '0': + continue + this_key = gen_bpm_descript(bpm.strip(), lang=self.lang) + elif tag_t == 'BPM': + bpm = item['BPM'] + if bpm is None or bpm.strip() == '' or bpm.strip() == '0': + continue + this_key = f"{bpm.strip()} bpm" + else: + this_key = item[tag_t] + if this_key is None: + continue + if isinstance(this_key, str): + this_key = [this_key] + if self.lang != 'en': + this_key = [self.get_translation(tag_t, k) for k in this_key] + if this_key is None: + continue + if isinstance(this_key, str): + this_key = [this_key] + keywords += this_key + return gen_plain_prompt(keywords, sep=self.keysep) + + @property + def keysep(self): + if self.lang == 'zh': + return ',' if random.random() > 0.5 else '、' + elif self.lang == 'en': + return ', ' + + def tags_to_desc(self, tag_list, tag_type) -> str: + if self.lang == 'en': + return tags_to_desc(tag_list) + elif self.lang == 'zh': + if tag_type == 'BPM': + return tags_to_desc(tag_list, sep='、') + translator = self.translate[tag_type] + translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] + return tags_to_desc(translated_tag_list, sep='、') + + def handle_BPM_tag(self, tags_args): + if "BPM" in tags_args and 'BPMDescript' in self.tag_types: + bpm = tags_args["BPM"] + del tags_args["BPM"] + tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript'))) + for tag_type in tag_types_used: + tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang) + return tags_args + +def mp3_path_to_id(mp3_path): + return int( + mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.')] + ) + +class TmeDataset(Dataset): + def __init__(self, + data_index:str, + music_info:str = None, + duration:float = 10, + sr:int = 0, + plain_rate = 0, + return_path = False, + return_audio = True, + return_ID = False, + prompt_format_path: os.PathLike = None, + tag_types = ['*'], + lang = 'zh', + translate: Optional[os.PathLike] = None, + prompt_dir: os.PathLike = None, #使用GPT生成的预有的prompt + ): + if plain_rate > 0: + print("Tme Dataset do not support plain rate > 0, use plain_rate = 0 instead.") + plain_rate = 0 + self.audio_reader = SafeAudioReader(duration, sr) + + self.sr = sr + self.duration = duration + self.plain_rate = plain_rate + self.return_path = return_path + self.return_audio = return_audio + self.return_ID = return_ID + self.lang = lang + + self.use_ready_prompt = prompt_dir is not None + + data_index = read_jsonlike(data_index) + self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index} + self.data_ids = list(self.data_index_dict.keys()) + + if not self.use_ready_prompt: + #读取音乐的信息文件 + music_info = read_jsonlike(music_info) + if 'music' in music_info: + music_info = music_info['music'] + self.music_info_dict = {d["歌曲ID"]:d for d in music_info} + self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict} + self.data_ids = list(self.data_index_dict.keys()) + + with open(prompt_format_path) as fp: + self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader) + + #加载tag types,并分成一般的tag_types和关键的key_tag_types + if '*' in tag_types: + self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag'] + else: + self.tag_types = tag_types + + self.key_tag_types = [] + if 'tag' in self.tag_types: + self.tag_types.remove('tag') + self.key_tag_types = list(self.prompt_formats['tag'].keys()) + + #加载translate翻译 + if translate is not None: + self.translator = read_jsonlike(translate) + else: + data_ids_set = set(self.data_ids) + self.prompts_dict = {} + for fname in os.listdir(prompt_dir): + items = read_jsonlike(os.path.join(prompt_dir, fname)) + for item in items: + if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']): + continue + if item['ID'] not in self.prompts_dict: + self.prompts_dict[item['ID']] = [] + self.prompts_dict[item['ID']].append(item['Text']) + self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict} + self.data_ids = list(self.data_index_dict.keys()) + + def tags_to_desc(self, tag_list) -> str: + if is_bearable(tag_list, int): + return str(tag_list) + if self.lang == 'zh': + return tags_to_desc(tag_list, sep=self.sep) + else: + translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ] + return tags_to_desc(translated_tag_list, sep=self.sep) + + def gen_desc_of_tag(self, formats, tags): + fmt = random.choice(formats) + return fmt.format(self.tags_to_desc(tags)) + + @staticmethod + def check_valid(value): + if isinstance(value, int) or isinstance(value, float): + return value > 0 + if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0): + return True + return False + + @staticmethod + def remove_repeat(data): + #若专辑名和歌曲名相同,则只使用后者 + album_name = data.get('专辑名', None) + if album_name is not None and album_name == data.get('歌曲名', None): + del data['专辑名'] + return data + + @property + def comma(self): + if self.lang == 'zh': + return ',' + elif self.lang == 'en': + return ', ' + + @property + def sep(self): + if self.lang == 'zh': + return '、' + elif self.lang == 'en': + return ', ' + + + def generate_description(self, item): + if random.random() > self.plain_rate: + # dynamically generate prompt from given prompt template + description = self.generate_description_dynamic(item) + else: + # use plain prompt, i.e. tags sequence separated by comma + description = self.generate_description_plain(item) + return description + + def generate_description_dynamic(self, data): + data = self.remove_repeat(data) + + weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低 + + key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个 + + prompts = [] + if len(weak_tags) > 0: + probs = dist_prob_map_low[len(weak_tags)] + if len(key_tags) > 0: + tags_num = random.choices(range(0, len(weak_tags)), probs)[0] + else: + tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0] + random.shuffle(weak_tags) + tags = weak_tags[:tags_num] + for tag_type in tags: + tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type]) + prompts.append(tag_desc) + + if len(key_tags) > 0: + probs = dist_prob_map[len(key_tags)] + tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0] + random.shuffle(key_tags) + tags = key_tags[:tags_num] + for tag_type in tags: + tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type]) + prompts.append(tag_desc) + + random.shuffle(prompts) + return self.comma.join(prompts) + + def generate_description_plain(self, item): + keywords = item['tag'] + if self.lang != 'en': + keywords = [self.translator[k.strip()] for k in keywords] + return gen_plain_prompt(keywords, sep=self.keysep) + + @property + def keysep(self): + if self.lang == 'zh': + return ',' if random.random() > 0.5 else '、' + elif self.lang == 'en': + return ', ' + + def is_valid_prompt_text(self, text): + for bad in ('抱歉','sorry', 'Sorry'): + if bad in text: + return False + return True + + def get_ready_prompt(self, path): + sid = mp3_path_to_id(path) + return random.choice(self.prompts_dict[sid]) + + def __len__(self): + return len(self.data_ids) + + def __getitem__(self, idx): + data_id = self.data_ids[idx] + item = self.data_index_dict[data_id] + path = item['path'] + if not self.use_ready_prompt: + info = self.music_info_dict[data_id] + description = self.generate_description(info) + else: + description = self.get_ready_prompt(path) + if self.return_audio: + sr, duration = get_sr_and_duration_info(item) + audio = self.audio_reader(path, sr, duration) + else: + audio = None + if self.return_path: + if self.return_ID: + return audio, description, path, info['歌曲ID'] + return audio, description, path + if self.return_ID: + return audio, description, info['歌曲ID'] + return audio, description + + +class Pond5Dataset(Dataset): + MAX_PROMPT_LEN = 200 + def __init__(self, + metadata_path:str, + index_path:str, + duration:float=10, + sr:int = 0, + plain_rate = 0, + return_path = False, + return_audio = True, + lang = 'en', + translate:Optional[Dict[str, os.PathLike]] = None, + use_literal_none = True, + use_avoid_watermark_policy = None, + ): + + if use_avoid_watermark_policy is None: + raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type") + self.use_avoid_watermark_policy = use_avoid_watermark_policy + self.audio_reader = SafeAudioReader(duration, sr, use_avoid_watermark_policy=use_avoid_watermark_policy) + + self._load_metadata(metadata_path, index_path) + self.sr = sr + self.duration = duration + self.plain_rate = plain_rate + self.return_path = return_path + self.return_audio = return_audio + self.use_literal_none = use_literal_none + + self.lang = lang + self.translate = read_translate(translate) + + def _load_metadata(self, metadata_path, index_path): + data_index = read_jsonlike(index_path) + data_ids = set([item['id'] for item in data_index]) + + with open(metadata_path) as fp: + lines = fp.readlines() + + append_ids = set() + + self.data = [] + for line in lines: + item = json.loads(line) + if item['id'] in data_ids and item['id'] not in append_ids: + self.data.append(item) + append_ids.add(item['id']) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + path:str = item["path"] + description = self.generate_description(item) + if self.return_audio: + sr, duration = get_sr_and_duration_info(item) + audio = self.audio_reader(path, sr, duration) + else: + audio = None + if self.return_path: + return audio, description, path + return audio, description + + @property + def keysep(self): + if self.lang == 'zh': + return ',' if random.random() > 0.5 else '、' + elif self.lang == 'en': + return ', ' + + def generate_description(self, item): + if random.random() > self.plain_rate: + # dynamically generate prompt from given prompt template + description = self.generate_description_dynamic(item) + else: + # use plain prompt, i.e. tags sequence separated by comma + description = self.generate_description_plain(item) + return description + + def get_translation(self, k): + k = k.strip() + if k in self.translate: + return self.translate[k] + else: + return k + + def generate_description_plain(self, item): + keywords = item['keywords'] + if self.lang != 'en': + keywords = [self.get_translation(k) for k in keywords] + return gen_plain_prompt(keywords, sep=self.keysep) + + def generate_description_dynamic(self,item): + desc = item.get('desc', 'none') + if desc is None: + desc = 'none' + desc = desc.strip() + if len(desc) > self.MAX_PROMPT_LEN: + shorter_desc = desc[:self.MAX_PROMPT_LEN] + # find last stop + stop_idx = shorter_desc.rfind('.') + if stop_idx == -1: + stop_idx = shorter_desc.rfind('!') + if stop_idx == -1: + stop_idx = shorter_desc.rfind(',') + if stop_idx == -1: + stop_idx = self.MAX_PROMPT_LEN - 1 + desc = desc[:stop_idx+1] + return desc + +class SoundDataset(Dataset): + def __init__(self, + metadata_index: str, + duration:float = 10, + min_non_silent_duration:float = 3, + sr:int = 0, + return_path = False, + return_audio = True, + ): + self.data = read_jsonlike(metadata_index) + self.sr = sr + self.reader = SafeAudioReader(duration, sr) + self.duration = duration + self.min_non_silent_duration = min_non_silent_duration + self.return_audio = return_audio + self.return_path = return_path + + def __getitem__(self, index): + item = self.data[index] + if self.return_audio: + origin_duration = item['duration'] + if origin_duration < self.min_non_silent_duration: + audio = self.read_and_repeat_and_pad(item) + else: + audio = self.reader(item['path'], item['sample_rate'], origin_duration) + else: + audio = None + desc = item['caption'] + if self.return_path: + return audio, desc, item['path'] + else: + return audio, desc + + def __len__(self): + return len(self.data) + + def read_and_repeat_and_pad(self, item): + path = item['path'] + try: + # read + clip, sr = torchaudio.load(path) + if len(clip.shape) > 1: + clip = torch.mean(clip, dim=0, keepdim=True) + clip = resample(clip, sr, self.sr) + #repeat + n_repeats = math.ceil(self.min_non_silent_duration/item['duration']) + clip = torch.repeat_interleave(clip, n_repeats, dim=0).reshape(-1) + #pad + n_samples = int(self.duration * self.sr) + if clip.shape[0] >= n_samples: + audio = clip[:n_samples] + else: + audio = torch.zeros(int(self.duration * self.sr), dtype=clip.dtype) + start_pos = np.random.randint(0, max(0,(n_samples - clip.shape[0]))) + audio[start_pos:start_pos+clip.shape[0]] = clip + return audio + + except Exception as e: + logger.error(f"Error reading {path}: {e}") + wav = torch.zeros(int(self.duration * self.sr), dtype=torch.float32) + return wav + +class CombinedDataset(Dataset): + @beartype + def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]): + self.datasets = datasets + self.datasets_index = [] + + for i,dataset in enumerate(datasets): + if dataset is None: + continue + for dup in range(ratios[i]): + for j in range(len(dataset)): + self.datasets_index.append((i,j)) + + def __len__(self): + return len(self.datasets_index) + + def __getitem__(self, idx): + index = self.datasets_index[idx] + i,j = index + return self.datasets[i][j] + +class CombinedDataset_random(Dataset): + @beartype + def __init__(self, num_examples:int, datasets: Sequence[Dataset], ratios: Sequence[int]): + self.datasets = datasets + self.datasets_index = [] + + for i,dataset in enumerate(datasets): + if dataset is None: + continue + for dup in range(ratios[i]): + for j in range(len(dataset)): + self.datasets_index.append((i,j)) + + if num_examples > 0: + self.random_choose = True + self.dataset_len = num_examples + else: + self.random_choose = False + self.dataset_len = len(self.datasets_index) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, idx): + first_try = True + try_cnt = 0 + while True: + try: + if(self.random_choose or not first_try): + index2 = [] + index2.append(np.random.randint(0,len(self.datasets))) + index2.append(np.random.randint(0,len(self.datasets[index2[-1]]))) + else: + index2 = self.datasets_index[idx] + first_try = False + out = list(self.datasets[index2[0]][index2[1]]) + return out + except: + print("Error loadding ", index2) + try_cnt += 1 + if(try_cnt>10): + raise ValueError() + +class SoundMixedDataset(Dataset): + @staticmethod + def music_desc(desc): + return f'Music:<{desc}>' + @staticmethod + def sound_desc(desc): + return f'Effect:<{desc}>' + + def __init__(self, + music_dataset: Dataset, + sound_dataset: Dataset, + mixed_ratios: Tuple[float, float, float] = (0.3, 0.3, 0.4) # 只有音乐:只有音效:音乐音效混合 的比例 + ) -> None: + self.music_dataset = music_dataset + self.sound_dataset = sound_dataset + music_r, sound_r, mix_r = [r/sum(mixed_ratios) for r in mixed_ratios] #化为0-1间的比例 + #三个概率区间的左端点 + self.music_anchor = 0 + self.sound_anchor = music_r + self.mix_anchor = music_r + sound_r + + def __len__(self): + return len(self.music_dataset) + + def get_random_sound_data(self): + idx = random.randint(0, len(self.sound_dataset)-1) + return self.sound_dataset[idx] + + def __getitem__(self, idx): + p = random.random() + if p >= self.mix_anchor: + music, m_desc = self.music_dataset[idx] + sound, s_desc = self.get_random_sound_data() + audio = music + sound + if(audio.abs().max()>1.0): + music = music / audio.abs().max() * 0.95 + audio = audio / audio.abs().max() * 0.95 + desc = self.music_desc(m_desc) + self.sound_desc(s_desc) + return audio[None,:], music[None,:], desc + elif p >= self.sound_anchor: + audio, desc = self.get_random_sound_data() + return audio[None,:], torch.zeros_like(audio[None,:]), self.sound_desc(desc) + else: + audio, desc = self.music_dataset[idx] + return audio[None,:], audio[None,:], self.music_desc(desc) + + +class DecoTagDataset(Dataset): + '''这个类把普通的datatset包装成适用于标签解耦学习的dataset''' + + TAG_TYPES = ('genre', 'mood', 'insrument') + + def __init__(self, dataset_class: type, tag_map: Dict[str, str], *args, **kwargs): + self.datasets = [] + for i, tag_t in enumerate(self.TAG_TYPES): + kwargs['tag_types'] = [tag_map[tag_t]] + kwargs['return_audio'] = (i == 0) #只有第0个需要返回音频和文本,其余只需要返回文本 + self.datasets.append(dataset_class(*args, **kwargs)) + + def __len__(self): + return len(self.datasets[0]) + + def __getitem__(self, idx): + audio, text = self.datasets[0][idx] + texts = (text, self.datasets[1][idx][1], self.datasets[2][idx][1]) + return audio, texts + + +class DecoTagWrapper: + '''这是一个包装器,便于选择是否使用标签解耦学习''' + def __init__(self, dataset_class: Dataset, deco_tag_types: List[str] = list(), switch_on: bool = False): + self.dataset_class = dataset_class + self.tag_map = dict(zip(DecoTagDataset.TAG_TYPES, deco_tag_types)) + self.switch_on = switch_on + + def __call__(self, *args, **kwargs): + if self.switch_on: + return DecoTagDataset(self.dataset_class, self.tag_map, *args, **kwargs) + else: + return self.dataset_class(*args, **kwargs) diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad93e9eaf529d30df7db3973d0c9822857df1a2 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_429.py @@ -0,0 +1,372 @@ +import re +import sys +import json +from typing import List, Union + +from torch.utils.data import Dataset +import torchaudio +from torchaudio.functional import resample +import torch +import numpy as np + +from torch.nn.utils.rnn import pad_sequence + +PARAGRAPH_GAP = 6 +MIN_MUSIC_LEN = 3 + +def check_lryics(lyric): + _FILTER_STRING = [ + '作词', '作曲', '编曲', '【', '策划', + '录音', '混音', '母带', ':', '制作', + '版权', '校对', '演奏', '制作', '伴奏' + ] + for item in _FILTER_STRING: + if item in lyric: + return True + + return False + + + +def process_lyrics(lines): + lyric_part = [] + timestamp_part = [] + + timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]') + + for i, line in enumerate(lines): + + # 删除前几行的特定信息 + if i<10 and check_lryics(line): + continue + + # 检查是否包含有效的时间戳和歌词内容 + if timestamp_pattern.match(line): + timestamp_end = line.rfind(']') + lyrics = line[timestamp_end + 1:].strip() + timestamps = line[:timestamp_end + 1] + + if ':' in lyrics: + if len(lyrics.split(":")[0]) <=5: + lyrics = "".join(lyrics.split(":")[1:]) + # if lyrics: # 确保歌词部分不是空的 + # lyric_part.append(lyrics) + # timestamp_part.append(timestamps) + # print(processed_lyrics) + return timestamp_part, lyric_part + +def get_timestamps(timestamp_part): + + # 转换为秒 + + timestamps = [] + + for line in timestamp_part: + match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line) + if match: + minutes = int(match.group(1)) + seconds = float(match.group(2)) + millis = float(match.group(3)) if match.group(3) else 0 + total_seconds = minutes * 60 + seconds + millis + timestamps.append(total_seconds) + + + return timestamps + +def process_lyrics_lrc(lyrics): + timestamp_part, lyric_part = process_lyrics(lyrics) + # print(timestamp_part) + # print(lyric_part) + timestamps = get_timestamps(timestamp_part) + # print(timestamps) + if len(timestamps) == 0: + # print(f'{lyric_path}') + return [] + + slice_start = timestamps[0] + slice_start_idx = 0 + + output_list = [] + for i in range(1, len(timestamps)): + # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉 + if timestamps[i] - slice_start > 30: + output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) + + slice_start = timestamps[i] + slice_start_idx = i + + return output_list + + + +def process_lyrics_yrc(lyrics): + + timestamps, lyric_part = extract_lrc(lyrics) + + # timestamp_part, lyric_part = process_lyrics(lyrics) + # import pdb; pdb.set_trace() + # print(timestamp_part) + # print(lyric_part) + # timestamps = get_timestamps(timestamp_part) + # print(timestamps) + if len(timestamps) == 0: + # print(f'{lyric_path}') + return [] + + slice_start = timestamps[0] + slice_start_idx = 0 + + output_list = [] + for i in range(1, len(timestamps)): + # 如果累积时间超过30秒,则进行切分 + if timestamps[i] - slice_start > 30: + output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) + + slice_start = timestamps[i] + slice_start_idx = i + # import pdb; pdb.set_trace() + return output_list + +def extract_lrc(lyrics): + timestamp_part, lyric_part = [], [] + + for i, text in enumerate(lyrics): + # 提取中括号内的内容 + bracket_content = re.search(r'\[(.*?)\]', text).group(1) + bracket_content = bracket_content.split(',') + # 提取小括号内的内容 + parentheses_content = re.findall(r'\((.*?)\)', text) + # 提取其他内容 + other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip() + + # 数据怎么处理? + if i<10 and check_lryics(other_content): + continue + timestamp_part.append(float(bracket_content[0])/1000) + lyric_part.append(other_content) + return timestamp_part, lyric_part + + + +class WYYSongDataset(Dataset): + def __init__(self, + metadata_path: Union[str, List[str]], + sr:int = 0, + use_lang = ['en', 'zh-cn'], + num_examples = -1, + max_dur = 20, + min_dur=0, + add_music=False, + pad_to_max= True, + ): + + self.sr = sr + self.use_lang = use_lang + self.data = [] + if type(metadata_path) == str: + metadata_path = [metadata_path] + for _meta in metadata_path: + self._load_metadata(_meta) + self.max_dur = max_dur + self.min_dur = min_dur + self.pad_to_max = pad_to_max + self.add_music = add_music + + # buffer + self.lyric_buffer = {} + + if(num_examples<=0): + self.dataset_len = len(self.data) + self.random_slc = False + else: + self.dataset_len = num_examples + self.random_slc = True + + + # 读取jsonl文件 + def _load_metadata(self, metadata_path): + with open(metadata_path) as fp: + lines = fp.readlines() + for line in lines: + item = json.loads(line) + if '伴奏' not in item['path']: + # if "lang_type" in item and item['lang_type'] == 'en': + if "lang_type" in item: + self.data.append(item) + + + def __len__(self): + return self.dataset_len + + + def __getitem__(self, idx): + try_cnt = 0 + while True: + if(self.random_slc): + idx = np.random.randint(0, len(self.data)) + yrc_lyrics = [] + lrc_lyrics = [] + try: + info = self.data[idx] + + # audio path + path = info["path"] + lang_type = info["lang_type"] + lyrics = info['lyrics'] # chinese + # lyrics = info['lyrics_phone'] + + # 随机选取一个lyric段落 + + parsed_lyrics = [] + # st_idx = np.random.randint(0, len(lyrics)) + for ly_id in range(len(lyrics)): + lyric = lyrics[ly_id].strip() + st, et, lyric = self.parse_lyric(lyric) + + if et - st >= self.max_dur: + continue #TODO 前后外沿 [MUSIC] + + if parsed_lyrics != []: + if st - parsed_lyrics[-1][1] >= PARAGRAPH_GAP: # 大gap + parsed_lyrics.append((parsed_lyrics[-1][1], st, '[GAP]')) + elif self.add_music and st - parsed_lyrics[-1][1] >= MIN_MUSIC_LEN: + parsed_lyrics.append((parsed_lyrics[-1][1], st, '[MUSIC]')) + + lyric = lyric.replace("\xa0", " ") + lyric = " ".join(lyric.split()) + parsed_lyrics.append((st, et, lyric)) + + assert parsed_lyrics != [] + # if parsed_lyrics[-1][1] - parsed_lyrics[0][0] > self.max_dur: + # print(f"{parsed_lyrics[0][0]}-{parsed_lyrics[-1][1]} {parsed_lyrics}", file=open('tmp.txt', 'a')) + + parsed_lyrics = [(0, parsed_lyrics[0][0], '[GAP]')] + parsed_lyrics + + possible_starts = [e for e,i in enumerate(parsed_lyrics) if i[2]=='[GAP]'] + st_idx = np.random.choice(possible_starts) + + paraphrase = [] + for i in parsed_lyrics[st_idx+1:]: + if i[2] == '[GAP]': + break + paraphrase.append(i) + # print(paraphrase, lyrics) + + while paraphrase[-1][1] - paraphrase[0][0] > self.max_dur: + if np.random.rand() > 0.2: + paraphrase.pop(-1) # 大概率从后面截断 + else: + paraphrase.pop(0) # 小概率截前面 + + st, et, lyric = paraphrase[0][0], paraphrase[-1][1], ', '.join([i[2] for i in paraphrase]) # [SEP] + # print(st, et, lyric) + # import pdb; pdb.set_trace() + assert self.min_dur < et - st < self.max_dur, f"{st}-{et} {lyric}" + # print(et-st, lyric) + # import pdb; pdb.set_trace() + + if info["lang_type"] == 'en': + # print(len(lyric.split())/(et-st)) + char_num = sum([len(lrc[-1].split()) for lrc in paraphrase]) + assert 6 > char_num / (et-st) > 1 + else: + # print(len(lyric.split())/(et-st)) + char_num = sum([len(lrc[-1]) for lrc in paraphrase]) + assert 6 > char_num / (et-st) > 1 + + # 读取音频文件 + cur_sample_rate = torchaudio.info(path).sample_rate + offset = int(cur_sample_rate*st) + num_frames = int(cur_sample_rate * (et -st)) + chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames) + # chunk = torch.zeros(1, 48000*15) + if abs(chunk.shape[-1] - num_frames) > num_frames * 0.05: # 音频文件长度与歌词不一致 + print(f"fail to load {path} from {st} to {et} !") + raise FileNotFoundError + # 随机选取一个channel + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + + if(cur_sample_rate!=self.sr): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr) + + if self.pad_to_max: + chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0) + + # print(self.sz_cnt) + return chunk, lyric, [st, et], path, lang_type + except (AssertionError, FileNotFoundError, RuntimeError) as e: # 其他Error不ok + # print("Error loadding ", info["path"]) + try_cnt += 1 + idx = np.random.randint(0, len(self.data)) + if(try_cnt>100): + raise e + + def parse_lyric(self, lyric): + pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)' + match = re.search(pattern, lyric) + + start_time = float(match.group(1)) + end_time = float(match.group(2)) + content = match.group(3) + return start_time, end_time, content + + def pad_2d_tensor(self, x, max_len, pad_id): + # 获取输入 tensor 的形状 + batch_size, seq_len = x.size() + max_len = max(max_len, seq_len) + # 计算需要填充的长度 + pad_len = max_len - seq_len + + # 如果需要填充 + if pad_len > 0: + # 创建填充 tensor + pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device) + + # 沿第二个维度(列)连接输入 tensor 和填充 tensor + padded_tensor = torch.cat([x, pad_tensor], dim=1) + else: + # 如果不需要填充,直接返回输入 tensor + padded_tensor = x + + return padded_tensor + +def collect_data(data_list): + audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2) + lyrics = [data[1] for data in data_list] + st_et = [data[2] for data in data_list] + paths = [data[3] for data in data_list] + lang_types = [data[4] for data in data_list] + return audios, lyrics, st_et + # return audios, lyrics, st_et + + +def build_dataset(train_jsonl_list, val_jsonl_list, min_dur=0, max_dur=20, add_music=False): + print(min_dur,max_dur) + print(train_jsonl_list) + # ["exp/wyy3_20240418_v2f.jsonl", + # "exp/tme_lyric_baokuan.jsonl"] + train_dataset = WYYSongDataset( + metadata_path = train_jsonl_list, + sr = 48000, + use_lang = ['zh-cn', 'en'], + num_examples = 10*10000, + min_dur=min_dur, + max_dur=max_dur, + add_music=add_music + ) + + valid_dataset = WYYSongDataset( + metadata_path = val_jsonl_list, + sr = 48000, + use_lang = ['zh-cn', 'en'], + num_examples = 500, + min_dur=min_dur, + max_dur=max_dur, + add_music=add_music + ) + print(train_jsonl_list, "\t total_song = ", len(train_dataset.data)) + return train_dataset, valid_dataset diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ec74a70b8491e7c973ed1dff68d843049c044d --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined.py @@ -0,0 +1,830 @@ +from torch.utils.data import Dataset +from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List +from beartype import beartype +from beartype.door import is_bearable +import random +import pandas as pd +import os +from torchaudio.functional import resample +import torch +import typing as tp +from pathlib import Path +import torchaudio as ta +import torch.nn.functional as F +import numpy as np +import json +import yaml +import torchaudio +import math +import re +from loguru import logger + +class Read_and_PadCrop_Normalized_T(torch.nn.Module): + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: + if(duration<(float(self.n_samples)/self.sample_rate+1)): + # print(duration,(float(self.n_samples)/self.sample_rate+1)) + chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) + t_start = 0. + t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) + offset = 0 + # print('c1:',chunk.shape) + else: + offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + t_start = offset / float(cur_sample_rate) / duration + t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration + chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + # print('offset:',offset) + # print('c0:',chunk.shape) + # Pad with silence if necessary. + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + if(cur_sample_rate!=self.sample_rate): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) + # print('b:',self.sample_rate,chunk.shape) + if chunk.shape[-1] < self.n_samples: + chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) + else: + chunk = chunk[:,0:self.n_samples] + seconds_start = math.floor(offset / cur_sample_rate) + seconds_total = math.floor(duration) + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total + ) + + +USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替 +if USE_DUMMY_AUDIO: + logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!") + +class SafeAudioReader: + """ + This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data. + """ + def __init__(self, + duration: float, # 返回音频长度 + sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample + randomize: bool = True + ): + self.n_samples = int(sample_rate * max(duration, 0)) + self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize) + + #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数! + def __call__(self, + filepath: os.PathLike, # 音频路径 + origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取 + origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取 + ) -> torch.Tensor: + if USE_DUMMY_AUDIO: + wav = torch.zeros(self.n_samples, dtype=torch.float32) + return wav + try: + if origin_sample_rate is None or origin_duration is None: + audio_info = torchaudio.info(filepath) + origin_sample_rate = audio_info.sample_rate + origin_duration = audio_info.num_frames / origin_sample_rate + wav, *ignored = self.reader(filepath, origin_duration, origin_sample_rate) + except Exception as e: + logger.error(f"Error reading {filepath}: {e}") + wav = torch.zeros(self.n_samples, dtype=torch.float32) + return wav + + +class PromptTemplate: + def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'): + self.template_text = template_text + self.tag_map = tag_map + self.lang = lang + + @property + def tags(self): + return tuple(self.tag_map.keys()) + + def apply(self, **kwargs): + for tag in list(kwargs.keys()): + if kwargs[tag] == '': + kwargs.pop(tag) + for tag in self.tags: + if tag in kwargs: + kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]') + else: + kwargs[tag] = '' + prompt = self.template_text.format(**kwargs) + + return self.beautify(prompt) + + def beautify(self, text): + if self.lang == 'en': + return self._beautify_en(text) + elif self.lang == 'zh': + return self._beautify_zh(text) + else: + raise ValueError(f'Unknown language {self.lang}') + + @staticmethod + def _beautify_en(text): + # no continuous commas without content between them + text = re.sub(r'[,\s]*,[,\s]*', r', ', text) + # no continuous whitespace + text = re.sub(r'\s+', ' ', text) + # the comma is NOT followed by whitespace, and should be followed by ONE whitespace + text = re.sub(r'\s+,', r',', text) + text = re.sub(r',\s+', r', ', text) + # no whitespace before the full stop + text = re.sub(r'\s+\.', r'.', text) + # strip whitespace, comma, and replace ',.' + text = text.strip(' ,') + text = text.replace(',.', '.') + return text + + @staticmethod + def _beautify_zh(text): + # no continuous commas without content between them + text = re.sub(r'[,、\s]*,[,、\s]*', r',', text) + text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text) + # assume there should be NO whitespace in Chinese + text = re.sub(r'\s+', r'', text) + # strip whitespace, comma, and replace ',。' + text = text.strip(', 、') + text = text.replace(',。', '。') + return text + + def __repr__(self): + return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})' + + __str__ = __repr__ + +def parse_prompt_template(prompt_template_text, lang='en'): + span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL) + tag_pattern = re.compile(r'{.+?}', re.DOTALL) + + template_text = prompt_template_text.strip() + span_texts = span_pattern.findall(prompt_template_text) + tag_map = {} + for span_text in span_texts: + tag = tag_pattern.findall(span_text)[0].strip('{}') + tag_map[tag] = span_text + template_text = template_text.replace(span_text, '{'+tag+'}') + + return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang) + +def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]: + with open(path, 'r') as f: + lines = f.readlines() + cnt = 0 + pts = [] + for line in lines: + pt = parse_prompt_template(line, lang=lang) + cnt += 1 + if len(pt.tags) < num: + logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}') + pts.append(pt) + + return pts + + +def get_base_dir_file(key: os.PathLike): + base = os.path.basename(key) + dirname = os.path.basename(os.path.dirname(key)) + return os.path.join(dirname, base) + +def read_jsonlike(path: os.PathLike): + #json or jsonl + if str(path).endswith(".json"): + with open(path, 'r', encoding='utf8') as f: + data = json.load(f) + return data + elif str(path).endswith(".jsonl"): + with open(path, 'r', encoding='utf8') as f: + data = [json.loads(line) for line in f.readlines()] + return data + else: + raise ValueError("Unknown file format") + +dist_prob_map = { + 1: (1.0,), + 2: (0.5, 0.5), + 3: (0.3, 0.4, 0.3), + 4: (0.2, 0.3, 0.3, 0.2), + 5: (0.2, 0.2, 0.3, 0.2, 0.1), + 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15), + 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), + 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), + 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), + 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) +} + +dist_prob_map_low = { + 1: (1.0,), + 2: (0.8, 0.2), + 3: (0.8, 0.1, 0.1), + 4: (0.7, 0.1, 0.1, 0.1), + 5: (0.7, 0.1, 0.1, 0.05, 0.05), + 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05), +} + +_bpm_range_rights = ( + (40, '20-40'), + (60, '40-60'), + (66, '60-66'), + (76, '66-76'), + (108, '76-108'), + (120, '108-120'), + (168, '120-168'), + (176, '168-176'), + (200, '176-200') +) +_bpm_desc_map = { + '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"), + '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"), + '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'), + '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'), + '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"), + '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'), + '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'), + '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'), + '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'), + '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo') +} +_bpm_desc_map_zh = { + '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"), + '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"), + '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'), + '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'), + '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"), + '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'), + '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'), + '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'), + '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'), + '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板') +} +def get_bpm_range(bpm): + bpm = int(bpm) + for right, tag in _bpm_range_rights: + if bpm <= right: + return tag + return '>200' + +def gen_bpm_descript(bpm, lang='en'): + bpm_range = get_bpm_range(bpm) + if lang == 'en': + return random.choice(_bpm_desc_map[bpm_range]) + elif lang == 'zh': + return random.choice(_bpm_desc_map_zh[bpm_range]) + else: + raise ValueError(f"Unknown language {lang}") + +def read_translate(translate: Optional[Dict[str, os.PathLike]]): + if translate is None: + return None + return {k: read_jsonlike(path) for k, path in translate.items()} + + +class MagnaTagATuneDataset(Dataset): + def __init__(self): + pass + + +def tags_to_desc(tag_list, sep=',') -> str: + if not isinstance(tag_list, Sequence): + return str(tag_list) + if isinstance(tag_list, str): + return tag_list + if len(tag_list) <= 0: + return '' + elif len(tag_list) <= 5: + probs = dist_prob_map[len(tag_list)] + tags_num = random.choices(range(1, len(tag_list)+1), probs)[0] + random.shuffle(tag_list) + tag_list = tag_list[:tags_num] + return sep.join(tag_list) + else: + probs = dist_prob_map[5] + tags_num = random.choices(range(1, 6), probs)[0] + random.shuffle(tag_list) + tag_list = tag_list[:tags_num] + return sep.join(tag_list) + +def get_sr_and_duration_info(item): + return item.get('sample_rate', None), item.get('duration', None) + +class MtgJamendoDatasetFromJson(Dataset): + def __init__(self, + data_dir:str, + json_path:str, + duration:float=10, + sr:int = 0, + *, + lang = 'en', + return_path = False, + prompt_template_path: os.PathLike = None, + tag_types = [], + translate:Optional[Dict[str, os.PathLike]] = None, + ): + self.audio_reader = SafeAudioReader(duration, sr) + + self.data_dir = data_dir + self._load_metadata_json(json_path) + self.sr = sr + self.duration = duration + self.return_path = return_path + self.lang = lang + + self.use_dynamic_prompt = prompt_template_path is not None + if self.use_dynamic_prompt: + self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types)) + self.tag_types = tag_types + + self.translate = read_translate(translate) + if not self.use_dynamic_prompt and self.lang != 'en': + raise NotImplementedError + + #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示 + WEAK_TAG_LIST = ["title", "artist"] + + def _load_metadata_json(self, json_path): + with open(json_path) as fp: + self.data = json.load(fp) + + def convert_key_to_path(self, key): + return os.path.join(self.data_dir, get_base_dir_file(key)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + path = self.convert_key_to_path(item['key']) + description = self.generate_description(item) + + sr, duration = get_sr_and_duration_info(item) + audio = self.audio_reader(path, sr, duration) + + if self.return_path: + return audio, description, path + return audio, description + + def tags_to_desc(self, tag_list, tag_type) -> str: + if self.lang == 'en': + return tags_to_desc(tag_list) + elif self.lang == 'zh': + translator = self.translate[tag_type] + translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] + return tags_to_desc(translated_tag_list, sep='、') + + def generate_description(self, item): + if self.use_dynamic_prompt: + # dynamically generate prompt from given prompt template + prompt_template = random.choice(self.prompt_templates) + description = self.generate_description_dynamic(item, prompt_template) + + else: + # use ordinary static prompt instead + description = self.generate_description_ordinary(item) + return description + + def generate_description_dynamic(self, data, prompt_template: PromptTemplate): + exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] + exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag)) + exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag)) + + if len(exists_strong_tag) > 0: + probs = dist_prob_map[len(exists_strong_tag)] + tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0] + random.shuffle(exists_strong_tag) + tags = exists_strong_tag[:tags_num] + weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1] + weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0] + random.shuffle(exists_weak_tag) + weak_tags = exists_weak_tag[:weak_tags_num] + tags += weak_tags + tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} + prompt = prompt_template.apply(**tags_args) + else: + # no strong tags, use all weak tags instead + tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag} + prompt = prompt_template.apply(**tags_args) + + return prompt + + def generate_description_ordinary(self, data, thresh = 0.3): + # Initialize the description with title and artist + description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}' + + # Add genre if available + if data["genre"] and random.random() > thresh: + genres = ', '.join(data["genre"]) + description += f', belonging to the {genres} genres' + + # Add moods if available + if data["moods"] and random.random() > thresh: + moods = ', '.join(data["moods"]) + description += f'. This track conveys a {moods} mood' + + # Add instruments if available + if data["instrument"] and random.random() > thresh: + instruments = ', '.join(data["instrument"]) + description += f', and primarily features the following instruments: {instruments}' + + # Add a period to end the description + description += '.' + + return description + +class AudioStockDataset(Dataset): + def __init__(self, + metadata_path:str, + duration:float=10, + sr:int = 0, + return_path = False, + return_audio = True, + prompt_template_path: os.PathLike = None, + tag_types = [], + lang = 'en', + translate:Optional[Dict[str, os.PathLike]] = None + ): + self.audio_reader = SafeAudioReader(duration, sr) + + self._load_metadata(metadata_path) + self.sr = sr + self.duration = duration + self.return_path = return_path + self.return_audio = return_audio + + self.use_dynamic_prompt = prompt_template_path is not None + if self.use_dynamic_prompt: + self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang) + self.tag_types = tag_types + + self.lang = lang + self.translate = read_translate(translate) + + def _load_metadata(self, metadata_path): + with open(metadata_path) as fp: + lines = fp.readlines() + self.data = [] + for line in lines: + item = json.loads(line) + self.data.append(item) + self.is_info_recorded = bool('Tags' in self.data[0]) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + path:str = self.data[idx]["path"] + json_path = path[:path.rfind('.')] + ".json" + if self.is_info_recorded: + item = self.data[idx] + else: + try: + with open(json_path) as fp: + item:dict = json.load(fp) + except Exception as e: + print(f"Error loading json file {json_path} :\n{e}") + item = {} + description = self.generate_description(item) + if self.return_audio: + sr, duration = get_sr_and_duration_info(item) + audio = self.audio_reader(path, sr, duration) + else: + audio = None + if self.return_path: + return audio, description, path + return audio, description + + def generate_description(self, item): + if self.use_dynamic_prompt: + # dynamically generate prompt from given prompt template + prompt_template = random.choice(self.prompt_templates) + description = self.generate_description_dynamic(item, prompt_template) + else: + # use ordinary static prompt instead + description = self.generate_description_ordinary(item) + return description + + def generate_description_dynamic(self, data, prompt_template: PromptTemplate): + exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] + + if len(exists_tag) > 0: + probs = dist_prob_map[len(exists_tag)] + tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0] + random.shuffle(exists_tag) + tags = exists_tag[:tags_num] + tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} + tags_args = self.handle_BPM_tag(tags_args) + prompt = prompt_template.apply(**tags_args) + else: + # no strong tags, use all weak tags instead + prompt = prompt_template.apply() + + return prompt + + def tags_to_desc(self, tag_list, tag_type) -> str: + if self.lang == 'en': + return tags_to_desc(tag_list) + elif self.lang == 'zh': + if tag_type == 'BPM': + return tags_to_desc(tag_list, sep='、') + translator = self.translate[tag_type] + translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] + return tags_to_desc(translated_tag_list, sep='、') + + def handle_BPM_tag(self, tags_args): + if "BPM" in tags_args and 'BPMDescript' in self.tag_types: + bpm = tags_args["BPM"] + del tags_args["BPM"] + tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript'))) + for tag_type in tag_types_used: + tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang) + return tags_args + + def generate_description_ordinary(self, data, thresh = 0.3): + if self.lang != 'en': + raise ValueError(f'Language {self.lang} is not supported for ordinary description generation') + description = f'a piece of music by {data["Artist"]}' + + # Add genre if available + if data["Genre"] and random.random() > thresh: + genres = ', '.join(data["Genre"]) + description += f', belonging to the {genres} genres' + + # Add moods if available + if data["Tags"] and random.random() > thresh: + tags = ', '.join(data["Tags"]) + description += f'. This track contains the tags:{tags}' + + # Add moods if available + if data["Mood"] and random.random() > thresh: + moods = ', '.join(data["Mood"]) + description += f'. This track conveys a {moods} mood.' + + # Add instruments if available + if data["Instrument"] and random.random() > thresh: + instruments = ', '.join(data["Instrument"]) + description += f'. and primarily features the following instruments: {instruments}' + + # Add a period to end the description + description += '.' + + return description + +def mp3_path_to_id(mp3_path): + return int( + mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')] + ) + +class TmeDataset(Dataset): + def __init__(self, + data_index:str, + music_info:str = None, + duration:float = 10, + sr:int = 0, + return_path = False, + return_audio = True, + prompt_format_path: os.PathLike = None, + tag_types = ['*'], + lang = 'zh', + translate: Optional[os.PathLike] = None, + prompt_dir: os.PathLike = None, + ): + self.audio_reader = SafeAudioReader(duration, sr) + + self.sr = sr + self.duration = duration + self.return_path = return_path + self.return_audio = return_audio + self.lang = lang + + self.use_ready_prompt = prompt_dir is not None + + data_index = read_jsonlike(data_index) + self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index} + self.data_ids = list(self.data_index_dict.keys()) + + if not self.use_ready_prompt: + #读取音乐的信息文件 + music_info = read_jsonlike(music_info) + if 'music' in music_info: + music_info = music_info['music'] + self.music_info_dict = {d["歌曲ID"]:d for d in music_info} + self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict} + self.data_ids = list(self.data_index_dict.keys()) + + with open(prompt_format_path) as fp: + self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader) + + #加载tag types,并分成一般的tag_types和关键的key_tag_types + if '*' in tag_types: + self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag'] + else: + self.tag_types = tag_types + + self.key_tag_types = [] + if 'tag' in self.tag_types: + self.tag_types.remove('tag') + self.key_tag_types = list(self.prompt_formats['tag'].keys()) + + #加载translate翻译 + if translate is not None: + self.translator = read_jsonlike(translate) + else: + data_ids_set = set(self.data_ids) + self.prompts_dict = {} + for fname in os.listdir(prompt_dir): + items = read_jsonlike(os.path.join(prompt_dir, fname)) + for item in items: + if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']): + continue + if item['ID'] not in self.prompts_dict: + self.prompts_dict[item['ID']] = [] + self.prompts_dict[item['ID']].append(item['Text']) + self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict} + self.data_ids = list(self.data_index_dict.keys()) + + def tags_to_desc(self, tag_list) -> str: + if is_bearable(tag_list, int): + return str(tag_list) + if self.lang == 'zh': + return tags_to_desc(tag_list, sep=self.sep) + else: + translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ] + return tags_to_desc(translated_tag_list, sep=self.sep) + + def gen_desc_of_tag(self, formats, tags): + fmt = random.choice(formats) + return fmt.format(self.tags_to_desc(tags)) + + @staticmethod + def check_valid(value): + if isinstance(value, int) or isinstance(value, float): + return value > 0 + if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0): + return True + return False + + @staticmethod + def remove_repeat(data): + #若专辑名和歌曲名相同,则只使用后者 + album_name = data.get('专辑名', None) + if album_name is not None and album_name == data.get('歌曲名', None): + del data['专辑名'] + return data + + @property + def comma(self): + if self.lang == 'zh': + return ',' + elif self.lang == 'en': + return ', ' + + @property + def sep(self): + if self.lang == 'zh': + return '、' + elif self.lang == 'en': + return ', ' + + def generate_description(self, data): + data = self.remove_repeat(data) + weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低 + + key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个 + + prompts = [] + if len(weak_tags) > 0: + probs = dist_prob_map_low[len(weak_tags)] + if len(key_tags) > 0: + tags_num = random.choices(range(0, len(weak_tags)), probs)[0] + else: + tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0] + random.shuffle(weak_tags) + tags = weak_tags[:tags_num] + for tag_type in tags: + tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type]) + prompts.append(tag_desc) + + if len(key_tags) > 0: + probs = dist_prob_map[len(key_tags)] + tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0] + random.shuffle(key_tags) + tags = key_tags[:tags_num] + for tag_type in tags: + tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type]) + prompts.append(tag_desc) + + random.shuffle(prompts) + return self.comma.join(prompts) + + def is_valid_prompt_text(self, text): + for bad in ('抱歉','sorry', 'Sorry'): + if bad in text: + return False + return True + + def get_ready_prompt(self, path): + sid = mp3_path_to_id(path) + return random.choice(self.prompts_dict[sid]) + + def __len__(self): + return len(self.data_ids) + + def __getitem__(self, idx): + data_id = self.data_ids[idx] + item = self.data_index_dict[data_id] + path = item['path'] + if not self.use_ready_prompt: + info = self.music_info_dict[data_id] + description = self.generate_description(info) + else: + description = self.get_ready_prompt(path) + if self.return_audio: + sr, duration = get_sr_and_duration_info(item) + audio = self.audio_reader(path, sr, duration) + else: + audio = None + if self.return_path: + return audio, description, path + return audio, description + +class CombinedDataset(Dataset): + @beartype + def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]): + self.datasets = datasets + self.datasets_index = [] + + for i,dataset in enumerate(datasets): + if dataset is None: + continue + for dup in range(ratios[i]): + for j in range(len(dataset)): + self.datasets_index.append((i,j)) + + def __len__(self): + return len(self.datasets_index) + + def __getitem__(self, idx): + index = self.datasets_index[idx] + i,j = index + return self.datasets[i][j] + +class CombinedDataset_random(Dataset): + @beartype + def __init__(self, + num_examples:int, + datasets: Sequence[Dataset], ratios: Sequence[int] + ): + self.datasets = datasets + self.datasets_index = [] + + for i,dataset in enumerate(datasets): + if dataset is None: + continue + for dup in range(ratios[i]): + for j in range(len(dataset)): + self.datasets_index.append((i,j)) + if num_examples > 0: + self.random_choose = True + self.dataset_len = num_examples + else: + self.random_choose = False + self.dataset_len = len(self.datasets_index) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, idx): + first_try = True + try_cnt = 0 + while True: + try: + if(self.random_choose or not first_try): + index2 = [] + index2.append(np.random.randint(0,len(self.datasets))) + index2.append(np.random.randint(0,len(self.datasets[index2[-1]]))) + else: + index2 = self.datasets_index[idx] + first_try = False + out = self.datasets[index2[0]][index2[1]] + if(len(out[0].shape)==1):out[0]=out[0][None,:] + return out + except: + print("Error loadding ", index2) + try_cnt += 1 + if(try_cnt>10): + raise ValueError() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py new file mode 100644 index 0000000000000000000000000000000000000000..39044e6b5a6c945b86ddd0091b6e76775e0573d9 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_combined_withset.py @@ -0,0 +1,994 @@ +from torch.utils.data import Dataset +from beartype.typing import Sequence, Callable, Optional, Dict, Tuple, List +from beartype import beartype +from beartype.door import is_bearable +import random +import pandas as pd +import os +from torchaudio.functional import resample +import torch +import typing as tp +from pathlib import Path +import torchaudio as ta +import torch.nn.functional as F +import numpy as np +import json +import yaml +import torchaudio +import math +import re +from loguru import logger + +def gen_plain_prompt(key_list, sep=', '): + if len(key_list) == 0: + return 'none' + + key_list = [k.strip() for k in key_list] + + if len(key_list) > 10: + random.shuffle(key_list) + key_list = key_list[:10] + + probs = dist_prob_map[len(key_list)] + + num_tags = random.choices(range(1, len(key_list)+1), probs, k=1)[0] + + random.shuffle(key_list) + tags = key_list[:num_tags] + tags_str = sep.join(tags) + return tags_str + +class Read_and_PadCrop_Normalized_T(torch.nn.Module): + + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + self.prob = {"is_start":0.2, "is_end":0.9} + self.shift_secs = 5 + + def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: + if(duration<(float(self.n_samples)/self.sample_rate+1)): + raise ValueError(duration,float(self.n_samples),self.sample_rate) + chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) + t_start = 0. + t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) + offset = 0 + is_start = True + is_end = True + else: + prob = random.uniform(0,1) + if(probself.prob['is_end']): + is_start = False + is_end = True + offset = int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate) + else: + is_start = False + is_end = False + offset = np.random.randint(self.shift_secs*cur_sample_rate, \ + int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)-self.shift_secs*cur_sample_rate) + t_start = offset / float(cur_sample_rate) / duration + t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration + chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + if(cur_sample_rate!=self.sample_rate): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) + # print('b:',self.sample_rate,chunk.shape) + if chunk.shape[-1] != self.n_samples: + raise ValueError(chunk.shape, self.n_samples, offset, int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + # if chunk.shape[-1] < self.n_samples: + # chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) + # else: + # chunk = chunk[:,0:self.n_samples] + seconds_start = math.floor(offset / cur_sample_rate) + seconds_total = math.floor(duration) + + # # In this dataset, we do not introduce zeros + # if(is_start): + # chunk = torch.cat([torch.zeros(1, self.shift_secs*self.sample_rate), chunk],1)[:,0:self.n_samples] + # elif(is_end): + # chunk = torch.cat([chunk, torch.zeros(1, self.shift_secs*self.sample_rate)],1)[:,self.shift_secs*self.sample_rate:] + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total, + is_start, + is_end, + ) + + +USE_DUMMY_AUDIO = False #当测试代码时,可以将其置为True,这样就不会读取实际数据,而是用生成的静默音频代替 +if USE_DUMMY_AUDIO: + logger.warning("USE_DUMMY_AUDIO flag is True, don't use it when train or test!") + +class SafeAudioReader: + """ + This class is an adaptor to Read_and_PadCrop_Normalized_T, make it safe to read audio data. + """ + def __init__(self, + duration: float, # 返回音频长度 + sample_rate: int, # 返回音频的采样率,如与实际音频采样率不同,会作resample + randomize: bool = True + ): + self.n_samples = int(sample_rate * max(duration, 0)) + self.reader = Read_and_PadCrop_Normalized_T(n_samples=self.n_samples, sample_rate=sample_rate, randomize=randomize) + + #NOTE:这个是核心的函数,所有数据集读取音频都是调用的这个函数! + def __call__(self, + filepath: os.PathLike, # 音频路径 + origin_sample_rate: Optional[int] = None, # 从json文件中读取的实际采样率,如果不给定,则会从文件头中读取 + origin_duration: float = None, # 从json文件中读取的实际时长,如果不给定,则会从文件头中读取 + ) -> torch.Tensor: + if USE_DUMMY_AUDIO: + wav = torch.zeros(self.n_samples, dtype=torch.float32) + return wav + try: + # if origin_sample_rate is None or origin_duration is None: + # audio_info = torchaudio.info(filepath) + # origin_sample_rate = audio_info.sample_rate + # origin_duration = audio_info.num_frames / origin_sample_rate + audio_info = torchaudio.info(filepath) + origin_sample_rate = audio_info.sample_rate + origin_duration = audio_info.num_frames / origin_sample_rate + wav, *ignored, is_start, is_end = self.reader(filepath, origin_duration, origin_sample_rate) + except Exception as e: + logger.error(f"Error reading {filepath}: {e}") + raise FileNotFoundError(filepath) + return wav, is_start, is_end + + +class PromptTemplate: + def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'): + self.template_text = template_text + self.tag_map = tag_map + self.lang = lang + + @property + def tags(self): + return tuple(self.tag_map.keys()) + + def apply(self, **kwargs): + for tag in list(kwargs.keys()): + if kwargs[tag] == '': + kwargs.pop(tag) + for tag in self.tags: + if tag in kwargs: + kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]') + else: + kwargs[tag] = '' + prompt = self.template_text.format(**kwargs) + + return self.beautify(prompt) + + def beautify(self, text): + if self.lang == 'en': + return self._beautify_en(text) + elif self.lang == 'zh': + return self._beautify_zh(text) + else: + raise ValueError(f'Unknown language {self.lang}') + + @staticmethod + def _beautify_en(text): + # no continuous commas without content between them + text = re.sub(r'[,\s]*,[,\s]*', r', ', text) + # no continuous whitespace + text = re.sub(r'\s+', ' ', text) + # the comma is NOT followed by whitespace, and should be followed by ONE whitespace + text = re.sub(r'\s+,', r',', text) + text = re.sub(r',\s+', r', ', text) + # no whitespace before the full stop + text = re.sub(r'\s+\.', r'.', text) + # strip whitespace, comma, and replace ',.' + text = text.strip(' ,') + text = text.replace(',.', '.') + return text + + @staticmethod + def _beautify_zh(text): + # no continuous commas without content between them + text = re.sub(r'[,、\s]*,[,、\s]*', r',', text) + text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text) + # assume there should be NO whitespace in Chinese + text = re.sub(r'\s+', r'', text) + # strip whitespace, comma, and replace ',。' + text = text.strip(', 、') + text = text.replace(',。', '。') + return text + + def __repr__(self): + return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})' + + __str__ = __repr__ + +def parse_prompt_template(prompt_template_text, lang='en'): + span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL) + tag_pattern = re.compile(r'{.+?}', re.DOTALL) + + template_text = prompt_template_text.strip() + span_texts = span_pattern.findall(prompt_template_text) + tag_map = {} + for span_text in span_texts: + tag = tag_pattern.findall(span_text)[0].strip('{}') + tag_map[tag] = span_text + template_text = template_text.replace(span_text, '{'+tag+'}') + + return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang) + +def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]: + with open(path, 'r') as f: + lines = f.readlines() + cnt = 0 + pts = [] + for line in lines: + pt = parse_prompt_template(line, lang=lang) + cnt += 1 + if len(pt.tags) < num: + logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}') + pts.append(pt) + + return pts + + +def get_base_dir_file(key: os.PathLike): + base = os.path.basename(key) + dirname = os.path.basename(os.path.dirname(key)) + return os.path.join(dirname, base) + +def read_jsonlike(path: os.PathLike): + #json or jsonl + if str(path).endswith(".json"): + with open(path, 'r', encoding='utf8') as f: + data = json.load(f) + return data + elif str(path).endswith(".jsonl"): + with open(path, 'r', encoding='utf8') as f: + data = [json.loads(line) for line in f.readlines()] + return data + else: + raise ValueError("Unknown file format") + +dist_prob_map = { + 1: (1.0,), + 2: (0.5, 0.5), + 3: (0.3, 0.4, 0.3), + 4: (0.2, 0.3, 0.3, 0.2), + 5: (0.2, 0.2, 0.3, 0.2, 0.1), + 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15), + 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), + 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), + 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), + 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) +} + +dist_prob_map_low = { + 1: (1.0,), + 2: (0.8, 0.2), + 3: (0.8, 0.1, 0.1), + 4: (0.7, 0.1, 0.1, 0.1), + 5: (0.7, 0.1, 0.1, 0.05, 0.05), + 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05), +} + +_bpm_range_rights = ( + (40, '20-40'), + (60, '40-60'), + (66, '60-66'), + (76, '66-76'), + (108, '76-108'), + (120, '108-120'), + (168, '120-168'), + (176, '168-176'), + (200, '176-200') +) +_bpm_desc_map = { + '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"), + '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"), + '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'), + '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'), + '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"), + '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'), + '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'), + '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'), + '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'), + '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo') +} +_bpm_desc_map_zh = { + '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"), + '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"), + '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'), + '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'), + '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"), + '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'), + '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'), + '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'), + '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'), + '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板') +} +def get_bpm_range(bpm): + bpm = int(bpm) + for right, tag in _bpm_range_rights: + if bpm <= right: + return tag + return '>200' + +def gen_bpm_descript(bpm, lang='en'): + bpm_range = get_bpm_range(bpm) + if lang == 'en': + return random.choice(_bpm_desc_map[bpm_range]) + elif lang == 'zh': + return random.choice(_bpm_desc_map_zh[bpm_range]) + else: + raise ValueError(f"Unknown language {lang}") + +def read_translate(translate: Optional[Dict[str, os.PathLike]]): + if translate is None: + return None + if isinstance(translate, str): + return read_jsonlike(translate) + return {k: read_jsonlike(path) for k, path in translate.items()} + + +class MagnaTagATuneDataset(Dataset): + def __init__(self): + pass + + +def tags_to_desc(tag_list, sep=',') -> str: + if not isinstance(tag_list, Sequence): + return str(tag_list) + if isinstance(tag_list, str): + return tag_list + if len(tag_list) <= 0: + return '' + elif len(tag_list) <= 5: + probs = dist_prob_map[len(tag_list)] + tags_num = random.choices(range(1, len(tag_list)+1), probs)[0] + random.shuffle(tag_list) + tag_list = tag_list[:tags_num] + return sep.join(tag_list) + else: + probs = dist_prob_map[5] + tags_num = random.choices(range(1, 6), probs)[0] + random.shuffle(tag_list) + tag_list = tag_list[:tags_num] + return sep.join(tag_list) + +def get_sr_and_duration_info(item): + return item.get('sample_rate', None), item.get('duration', None) + +class MtgJamendoDatasetFromJson(Dataset): + def __init__(self, + data_dir:str, + json_path:str, + duration:float=10, + sr:int = 0, + *, + lang = 'en', + return_path = False, + prompt_template_path: os.PathLike = None, + tag_types = [], + translate:Optional[Dict[str, os.PathLike]] = None, + ): + self.audio_reader = SafeAudioReader(duration, sr) + + self.data_dir = data_dir + self._load_metadata_json(json_path) + self.sr = sr + self.duration = duration + self.return_path = return_path + self.lang = lang + + self.use_dynamic_prompt = prompt_template_path is not None + if self.use_dynamic_prompt: + self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types)) + self.tag_types = tag_types + + self.translate = read_translate(translate) + if not self.use_dynamic_prompt and self.lang != 'en': + raise NotImplementedError + + #这些tag被认为是弱语义的,会避免产生仅包含这些tag的文本提示 + WEAK_TAG_LIST = ["title", "artist"] + + def _load_metadata_json(self, json_path): + with open(json_path) as fp: + self.data = json.load(fp) + + def convert_key_to_path(self, key): + return os.path.join(self.data_dir, get_base_dir_file(key)) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + path = self.convert_key_to_path(item['key']) + description = self.generate_description(item) + + sr, duration = get_sr_and_duration_info(item) + audio, is_start, is_end = self.audio_reader(path, sr, duration) + + if self.return_path: + return audio, description, path + return audio, description, is_start, is_end + + def tags_to_desc(self, tag_list, tag_type) -> str: + if self.lang == 'en': + return tags_to_desc(tag_list) + elif self.lang == 'zh': + translator = self.translate[tag_type] + translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] + return tags_to_desc(translated_tag_list, sep='、') + + def generate_description(self, item): + if self.use_dynamic_prompt: + # dynamically generate prompt from given prompt template + prompt_template = random.choice(self.prompt_templates) + description = self.generate_description_dynamic(item, prompt_template) + + else: + # use ordinary static prompt instead + description = self.generate_description_ordinary(item) + return description + + def generate_description_dynamic(self, data, prompt_template: PromptTemplate): + exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] + exists_weak_tag = list(filter(lambda t: t in self.WEAK_TAG_LIST, exists_tag)) + exists_strong_tag = list(filter(lambda t: t not in self.WEAK_TAG_LIST, exists_tag)) + + if len(exists_strong_tag) > 0: + probs = dist_prob_map[len(exists_strong_tag)] + tags_num = random.choices(range(1, len(exists_strong_tag)+1), probs)[0] + random.shuffle(exists_strong_tag) + tags = exists_strong_tag[:tags_num] + weak_probs = dist_prob_map_low[len(exists_weak_tag) + 1] + weak_tags_num = random.choices(range(0, len(exists_weak_tag) + 1), weak_probs)[0] + random.shuffle(exists_weak_tag) + weak_tags = exists_weak_tag[:weak_tags_num] + tags += weak_tags + tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} + prompt = prompt_template.apply(**tags_args) + else: + # no strong tags, use all weak tags instead + tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in exists_weak_tag} + prompt = prompt_template.apply(**tags_args) + + return prompt + + def generate_description_ordinary(self, data, thresh = 0.3): + # Initialize the description with title and artist + description = f'"{data["title"]+" is " if random.random() > thresh else ""}"a piece of music by {data["artist"]}' + + # Add genre if available + if data["genre"] and random.random() > thresh: + genres = ', '.join(data["genre"]) + description += f', belonging to the {genres} genres' + + # Add moods if available + if data["moods"] and random.random() > thresh: + moods = ', '.join(data["moods"]) + description += f'. This track conveys a {moods} mood' + + # Add instruments if available + if data["instrument"] and random.random() > thresh: + instruments = ', '.join(data["instrument"]) + description += f', and primarily features the following instruments: {instruments}' + + # Add a period to end the description + description += '.' + + return description + +class AudioStockDataset(Dataset): + def __init__(self, + metadata_path:str, + duration:float=10, + sr:int = 0, + return_path = False, + return_audio = True, + prompt_template_path: os.PathLike = None, + tag_types = [], + lang = 'en', + translate:Optional[Dict[str, os.PathLike]] = None + ): + self.audio_reader = SafeAudioReader(duration, sr) + + self.duration = duration + self._load_metadata(metadata_path) + self.sr = sr + self.return_path = return_path + self.return_audio = return_audio + + self.use_dynamic_prompt = prompt_template_path is not None + if self.use_dynamic_prompt: + self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang) + self.tag_types = tag_types + + self.lang = lang + self.translate = read_translate(translate) + + def _load_metadata(self, metadata_path): + with open(metadata_path) as fp: + lines = fp.readlines() + self.data = [] + for line in lines: + item = json.loads(line) + if(item['duration']>self.duration+10): + self.data.append(item) + self.is_info_recorded = bool('Tags' in self.data[0]) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + path:str = self.data[idx]["path"] + json_path = path[:path.rfind('.')] + ".json" + if self.is_info_recorded: + item = self.data[idx] + else: + try: + with open(json_path) as fp: + item:dict = json.load(fp) + except Exception as e: + print(f"Error loading json file {json_path} :\n{e}") + item = {} + description = self.generate_description(item) + if self.return_audio: + sr, duration = get_sr_and_duration_info(item) + audio, is_start, is_end = self.audio_reader(path, sr, duration) + else: + audio = None + if self.return_path: + return audio, description, path, is_start, is_end + else: + return audio, description, is_start, is_end + + def generate_description(self, item): + if self.use_dynamic_prompt: + # dynamically generate prompt from given prompt template + prompt_template = random.choice(self.prompt_templates) + description = self.generate_description_dynamic(item, prompt_template) + else: + # use ordinary static prompt instead + description = self.generate_description_ordinary(item) + return description + + def generate_description_dynamic(self, data, prompt_template: PromptTemplate): + exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] + + if len(exists_tag) > 0: + probs = dist_prob_map[len(exists_tag)] + tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0] + random.shuffle(exists_tag) + tags = exists_tag[:tags_num] + tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} + tags_args = self.handle_BPM_tag(tags_args) + prompt = prompt_template.apply(**tags_args) + else: + # no strong tags, use all weak tags instead + prompt = prompt_template.apply() + + return prompt + + def tags_to_desc(self, tag_list, tag_type) -> str: + if self.lang == 'en': + return tags_to_desc(tag_list) + elif self.lang == 'zh': + if tag_type == 'BPM': + return tags_to_desc(tag_list, sep='、') + translator = self.translate[tag_type] + translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] + return tags_to_desc(translated_tag_list, sep='、') + + def handle_BPM_tag(self, tags_args): + if "BPM" in tags_args and 'BPMDescript' in self.tag_types: + bpm = tags_args["BPM"] + del tags_args["BPM"] + tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript'))) + for tag_type in tag_types_used: + tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang) + return tags_args + + def generate_description_ordinary(self, data, thresh = 0.3): + if self.lang != 'en': + raise ValueError(f'Language {self.lang} is not supported for ordinary description generation') + description = f'a piece of music by {data["Artist"]}' + + # Add genre if available + if data["Genre"] and random.random() > thresh: + genres = ', '.join(data["Genre"]) + description += f', belonging to the {genres} genres' + + # Add moods if available + if data["Tags"] and random.random() > thresh: + tags = ', '.join(data["Tags"]) + description += f'. This track contains the tags:{tags}' + + # Add moods if available + if data["Mood"] and random.random() > thresh: + moods = ', '.join(data["Mood"]) + description += f'. This track conveys a {moods} mood.' + + # Add instruments if available + if data["Instrument"] and random.random() > thresh: + instruments = ', '.join(data["Instrument"]) + description += f'. and primarily features the following instruments: {instruments}' + + # Add a period to end the description + description += '.' + + return description + +def mp3_path_to_id(mp3_path): + return int( + mp3_path[mp3_path.rindex('/') + 1 : mp3_path.rindex('.mp3')] + ) + +class TmeDataset(Dataset): + def __init__(self, + data_index:str, + music_info:str = None, + duration:float = 10, + sr:int = 0, + return_path = False, + return_audio = True, + prompt_format_path: os.PathLike = None, + tag_types = ['*'], + lang = 'zh', + translate: Optional[os.PathLike] = None, + prompt_dir: os.PathLike = None, + ): + self.audio_reader = SafeAudioReader(duration, sr) + + self.sr = sr + self.duration = duration + self.return_path = return_path + self.return_audio = return_audio + self.lang = lang + + self.use_ready_prompt = prompt_dir is not None + + data_index = read_jsonlike(data_index) + data_index = [d for d in data_index if d['duration']>self.duration+10] + self.data_index_dict = {mp3_path_to_id(d['path']) : d for d in data_index} + self.data_ids = list(self.data_index_dict.keys()) + + if not self.use_ready_prompt: + #读取音乐的信息文件 + music_info = read_jsonlike(music_info) + if 'music' in music_info: + music_info = music_info['music'] + self.music_info_dict = {d["歌曲ID"]:d for d in music_info} + self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.music_info_dict} + self.data_ids = list(self.data_index_dict.keys()) + + with open(prompt_format_path) as fp: + self.prompt_formats = yaml.load(fp, Loader=yaml.FullLoader) + + #加载tag types,并分成一般的tag_types和关键的key_tag_types + if '*' in tag_types: + self.tag_types = ['歌曲名', 'bpm', '专辑名', '歌手名', '作曲', 'tag'] + else: + self.tag_types = tag_types + + self.key_tag_types = [] + if 'tag' in self.tag_types: + self.tag_types.remove('tag') + self.key_tag_types = list(self.prompt_formats['tag'].keys()) + + #加载translate翻译 + if translate is not None: + self.translator = read_jsonlike(translate) + else: + data_ids_set = set(self.data_ids) + self.prompts_dict = {} + for fname in os.listdir(prompt_dir): + items = read_jsonlike(os.path.join(prompt_dir, fname)) + for item in items: + if item['ID'] not in data_ids_set or not self.is_valid_prompt_text(item['Text']): + continue + if item['ID'] not in self.prompts_dict: + self.prompts_dict[item['ID']] = [] + self.prompts_dict[item['ID']].append(item['Text']) + self.data_index_dict = {k:v for k,v in self.data_index_dict.items() if k in self.prompts_dict} + self.data_ids = list(self.data_index_dict.keys()) + + def tags_to_desc(self, tag_list) -> str: + if is_bearable(tag_list, int): + return str(tag_list) + if self.lang == 'zh': + return tags_to_desc(tag_list, sep=self.sep) + else: + translated_tag_list = [self.translator[tag] for tag in tag_list if tag in self.translator ] + return tags_to_desc(translated_tag_list, sep=self.sep) + + def gen_desc_of_tag(self, formats, tags): + fmt = random.choice(formats) + return fmt.format(self.tags_to_desc(tags)) + + @staticmethod + def check_valid(value): + if isinstance(value, int) or isinstance(value, float): + return value > 0 + if (value is not None) and (not isinstance(value, Sequence) or len(value) > 0): + return True + return False + + @staticmethod + def remove_repeat(data): + #若专辑名和歌曲名相同,则只使用后者 + album_name = data.get('专辑名', None) + if album_name is not None and album_name == data.get('歌曲名', None): + del data['专辑名'] + return data + + @property + def comma(self): + if self.lang == 'zh': + return ',' + elif self.lang == 'en': + return ', ' + + @property + def sep(self): + if self.lang == 'zh': + return '、' + elif self.lang == 'en': + return ', ' + + def generate_description(self, data): + data = self.remove_repeat(data) + weak_tags = [key for key in data if (key in self.tag_types and self.check_valid(data[key]))] #弱语义的tag,这些tag的出现比例会放低 + + key_tags = [key for key in data['tag'] if (key in self.key_tag_types and self.check_valid(data['tag'][key]))] #关键的tag,这些tag必须出现至少一个 + + prompts = [] + if len(weak_tags) > 0: + probs = dist_prob_map_low[len(weak_tags)] + if len(key_tags) > 0: + tags_num = random.choices(range(0, len(weak_tags)), probs)[0] + else: + tags_num = random.choices(range(1, len(weak_tags) + 1), probs)[0] + random.shuffle(weak_tags) + tags = weak_tags[:tags_num] + for tag_type in tags: + tag_desc = self.gen_desc_of_tag(self.prompt_formats[tag_type], int(data[tag_type]) if tag_type == 'bpm' else data[tag_type]) + prompts.append(tag_desc) + + if len(key_tags) > 0: + probs = dist_prob_map[len(key_tags)] + tags_num = random.choices(range(1, len(key_tags) + 1), probs)[0] + random.shuffle(key_tags) + tags = key_tags[:tags_num] + for tag_type in tags: + tag_desc = self.gen_desc_of_tag(self.prompt_formats['tag'][tag_type], data['tag'][tag_type]) + prompts.append(tag_desc) + + random.shuffle(prompts) + return self.comma.join(prompts) + + def is_valid_prompt_text(self, text): + for bad in ('抱歉','sorry', 'Sorry'): + if bad in text: + return False + return True + + def get_ready_prompt(self, path): + sid = mp3_path_to_id(path) + return random.choice(self.prompts_dict[sid]) + + def __len__(self): + return len(self.data_ids) + + def __getitem__(self, idx): + data_id = self.data_ids[idx] + item = self.data_index_dict[data_id] + path = item['path'] + if not self.use_ready_prompt: + info = self.music_info_dict[data_id] + description = self.generate_description(info) + else: + description = self.get_ready_prompt(path) + if self.return_audio: + sr, duration = get_sr_and_duration_info(item) + audio, is_start, is_end = self.audio_reader(path, sr, duration) + else: + audio = None + if self.return_path: + return audio, description, path, is_start, is_end + else: + return audio, description, is_start, is_end + +class Pond5Dataset(Dataset): + MAX_PROMPT_LEN = 200 + def __init__(self, + metadata_path:str, + index_path:str, + duration:float=10, + sr:int = 0, + plain_rate = 0, + return_path = False, + return_audio = True, + lang = 'en', + translate:Optional[Dict[str, os.PathLike]] = None, + use_literal_none = True, + use_avoid_watermark_policy = None, + ): + + if use_avoid_watermark_policy is None: + raise ValueError("`use_avoid_watermark_policy` is an important param, you need to explicitly specify it with bool type") + self.use_avoid_watermark_policy = use_avoid_watermark_policy + assert self.use_avoid_watermark_policy is False + self.audio_reader = SafeAudioReader(duration, sr) + + self.duration = duration + self._load_metadata(metadata_path, index_path) + self.sr = sr + self.plain_rate = plain_rate + self.return_path = return_path + self.return_audio = return_audio + self.use_literal_none = use_literal_none + + self.lang = lang + self.translate = read_translate(translate) + + def _load_metadata(self, metadata_path, index_path): + data_index = read_jsonlike(index_path) + data_ids = set([item['id'] for item in data_index]) + + with open(metadata_path) as fp: + lines = fp.readlines() + + append_ids = set() + + self.data = [] + for line in lines: + item = json.loads(line) + if item['id'] in data_ids and item['id'] not in append_ids and item["details"]["duration"] is not None and item["details"]["duration"]>self.duration+10: + self.data.append(item) + append_ids.add(item['id']) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + path:str = item["path"] + description = self.generate_description(item) + if self.return_audio: + sr, duration = get_sr_and_duration_info(item) + audio, is_start, is_end = self.audio_reader(path, sr, duration) + else: + audio = None + if self.return_path: + return audio, description, path + return audio, description, is_start, is_end + + @property + def keysep(self): + if self.lang == 'zh': + return ',' if random.random() > 0.5 else '、' + elif self.lang == 'en': + return ', ' + + def generate_description(self, item): + if random.random() > self.plain_rate: + # dynamically generate prompt from given prompt template + description = self.generate_description_dynamic(item) + else: + # use plain prompt, i.e. tags sequence separated by comma + description = self.generate_description_plain(item) + return description + + def get_translation(self, k): + k = k.strip() + if k in self.translate: + return self.translate[k] + else: + return k + + def generate_description_plain(self, item): + keywords = item['keywords'] + if self.lang != 'en': + keywords = [self.get_translation(k) for k in keywords] + return gen_plain_prompt(keywords, sep=self.keysep) + + def generate_description_dynamic(self,item): + desc = item.get('desc', 'none') + if desc is None: + desc = 'none' + desc = desc.strip() + if len(desc) > self.MAX_PROMPT_LEN: + shorter_desc = desc[:self.MAX_PROMPT_LEN] + # find last stop + stop_idx = shorter_desc.rfind('.') + if stop_idx == -1: + stop_idx = shorter_desc.rfind('!') + if stop_idx == -1: + stop_idx = shorter_desc.rfind(',') + if stop_idx == -1: + stop_idx = self.MAX_PROMPT_LEN - 1 + desc = desc[:stop_idx+1] + return desc + +class CombinedDataset(Dataset): + @beartype + def __init__(self, datasets: Sequence[Dataset], ratios: Sequence[int]): + self.datasets = datasets + self.datasets_index = [] + + for i,dataset in enumerate(datasets): + if dataset is None: + continue + for dup in range(ratios[i]): + for j in range(len(dataset)): + self.datasets_index.append((i,j)) + + def __len__(self): + return len(self.datasets_index) + + def __getitem__(self, idx): + index = self.datasets_index[idx] + i,j = index + return self.datasets[i][j] + +class CombinedDataset_random(Dataset): + @beartype + def __init__(self, + num_examples:int, + datasets: Sequence[Dataset], ratios: Sequence[int] + ): + self.datasets = datasets + self.datasets_index = [] + + for i,dataset in enumerate(datasets): + if dataset is None: + continue + for dup in range(ratios[i]): + for j in range(len(dataset)): + self.datasets_index.append((i,j)) + if num_examples > 0: + self.random_choose = True + self.dataset_len = num_examples + else: + self.random_choose = False + self.dataset_len = len(self.datasets_index) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, idx): + first_try = True + try_cnt = 0 + while True: + try: + if(self.random_choose or not first_try): + index2 = [] + index2.append(np.random.randint(0,len(self.datasets))) + index2.append(np.random.randint(0,len(self.datasets[index2[-1]]))) + else: + index2 = self.datasets_index[idx] + first_try = False + out = self.datasets[index2[0]][index2[1]] + if(len(out[0].shape)==1):out[0]=out[0][None,:] + return out + except: + print("Error loadding ", index2) + try_cnt += 1 + if(try_cnt>10): + raise FileNotFoundError() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py new file mode 100644 index 0000000000000000000000000000000000000000..46d619c298718d4869dcdb54a420a1d080fac217 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song.py @@ -0,0 +1,313 @@ +import re +import sys +import json + +from torch.utils.data import Dataset +import torchaudio +from torchaudio.functional import resample +import torch +import numpy as np + +from torch.nn.utils.rnn import pad_sequence + + + +def check_lryics(lyric): + _FILTER_STRING = [ + '作词', '作曲', '编曲', '【', '策划', + '录音', '混音', '母带', ':', '制作', + '版权', '校对', '演奏', '制作', '伴奏' + ] + for item in _FILTER_STRING: + if item in lyric: + return True + + return False + + + +def process_lyrics(lines): + lyric_part = [] + timestamp_part = [] + + timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]') + + for i, line in enumerate(lines): + + # 删除前几行的特定信息 + if i<10 and check_lryics(line): + continue + + # 检查是否包含有效的时间戳和歌词内容 + if timestamp_pattern.match(line): + timestamp_end = line.rfind(']') + lyrics = line[timestamp_end + 1:].strip() + timestamps = line[:timestamp_end + 1] + + if ':' in lyrics: + if len(lyrics.split(":")[0]) <=5: + lyrics = "".join(lyrics.split(":")[1:]) + # if lyrics: # 确保歌词部分不是空的 + # lyric_part.append(lyrics) + # timestamp_part.append(timestamps) + # print(processed_lyrics) + return timestamp_part, lyric_part + +def get_timestamps(timestamp_part): + + # 转换为秒 + + timestamps = [] + + for line in timestamp_part: + match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line) + if match: + minutes = int(match.group(1)) + seconds = float(match.group(2)) + millis = float(match.group(3)) if match.group(3) else 0 + total_seconds = minutes * 60 + seconds + millis + timestamps.append(total_seconds) + + + return timestamps + +def process_lyrics_lrc(lyrics): + timestamp_part, lyric_part = process_lyrics(lyrics) + # print(timestamp_part) + # print(lyric_part) + timestamps = get_timestamps(timestamp_part) + # print(timestamps) + if len(timestamps) == 0: + # print(f'{lyric_path}') + return [] + + slice_start = timestamps[0] + slice_start_idx = 0 + + output_list = [] + for i in range(1, len(timestamps)): + # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉 + if timestamps[i] - slice_start > 30: + output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) + + slice_start = timestamps[i] + slice_start_idx = i + + return output_list + + + +def process_lyrics_yrc(lyrics): + + timestamps, lyric_part = extract_lrc(lyrics) + + # timestamp_part, lyric_part = process_lyrics(lyrics) + # import pdb; pdb.set_trace() + # print(timestamp_part) + # print(lyric_part) + # timestamps = get_timestamps(timestamp_part) + # print(timestamps) + if len(timestamps) == 0: + # print(f'{lyric_path}') + return [] + + slice_start = timestamps[0] + slice_start_idx = 0 + + output_list = [] + for i in range(1, len(timestamps)): + # 如果累积时间超过30秒,则进行切分 + if timestamps[i] - slice_start > 30: + output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) + + slice_start = timestamps[i] + slice_start_idx = i + # import pdb; pdb.set_trace() + return output_list + +def extract_lrc(lyrics): + timestamp_part, lyric_part = [], [] + + for i, text in enumerate(lyrics): + # 提取中括号内的内容 + bracket_content = re.search(r'\[(.*?)\]', text).group(1) + bracket_content = bracket_content.split(',') + # 提取小括号内的内容 + parentheses_content = re.findall(r'\((.*?)\)', text) + # 提取其他内容 + other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip() + + # 数据怎么处理? + # import pdb; pdb.set_trace() + if i<10 and check_lryics(other_content): + continue + + # import pdb; pdb.set_trace() + timestamp_part.append(float(bracket_content[0])/1000) + lyric_part.append(other_content) + # import pdb; pdb.set_trace() + return timestamp_part, lyric_part + + + +class WYYSongDataset(Dataset): + def __init__(self, + metadata_path:str, + sr:int = 0, + use_lang = ['en', 'zh-cn'], + num_examples = -1, + ): + + self.sr = sr + self.use_lang = use_lang + self._load_metadata(metadata_path) + + # buffer + self.lyric_buffer = {} + + if(num_examples<=0): + self.dataset_len = len(self.data) + self.random_slc = False + else: + self.dataset_len = num_examples + self.random_slc = True + + # 读取jsonl文件 + def _load_metadata(self, metadata_path): + with open(metadata_path) as fp: + lines = fp.readlines() + self.data = [] + for line in lines: + item = json.loads(line) + # if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None: + if 'lyrics' in item and 'lang_info' in item: + if len(item['lyrics']) > 0: + for lang in self.use_lang: + if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9: + # if '伴奏' not in item['path'] and "cloud" in item['path']: + if '伴奏' not in item['path']: + self.data.append(item) + + + def __len__(self): + return self.dataset_len + + + def __getitem__(self, idx): + try_cnt = 0 + while True: + if(self.random_slc): + idx = np.random.randint(0, len(self.data)) + yrc_lyrics = [] + lrc_lyrics = [] + try: + info = self.data[idx] + + # audio path + path:str = info["path"] + + # 读取歌词段落 + if 'lyrics' not in info: + if idx not in self.lyric_buffer: + # 字级别align的歌词 + if info['yrc-lyric'] is not None: + with open(info['yrc-lyric']) as f_in: + yrc_lyric = json.load(f_in) + yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1]) + + # 句子级align的歌词 + if info['lrc-lyric'] is not None: + with open(info['lrc-lyric']) as f_in: + lrc_lyric = json.load(f_in) + lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1]) + + # 优先使用字级别align的歌词 + if len(yrc_lyrics) > 0: + lyrics = yrc_lyrics + else: + lyrics = lrc_lyrics + self.lyric_buffer[idx] = lyrics + + # TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲 + else: + lyrics = self.lyric_buffer[idx] + else: + lyrics = info['lyrics'] + + # 随机选取一个lyric段落 + ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item() + # ly_id = 0 + + lyric = lyrics[ly_id] + + + + st, et, lyric = self.parse_lyric(lyric) + + assert et - st < 40 + + # 文本过滤 + + lyric = re.sub(r'【.*?】', '', lyric) + if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8: + assert 200 > len(lyric.replace(" ", "")) > 30 + if ':' in lyrics: + if len(lyrics.split(":")[0]) <=5: + lyrics = "".join(lyrics.split(":")[1:]) + + if ':' in lyrics: + if len(lyrics.split(":")[0]) <=5: + lyrics = "".join(lyrics.split(":")[1:]) + + if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8: + assert 200 > len(lyric.split()) > 20 + + if ':' in lyrics: + if len(lyrics.split(":")[0].split()) <=3: + lyrics = "".join(lyrics.split(":")[1:]) + + if ':' in lyrics: + if len(lyrics.split(":")[0].split()) <=3: + lyrics = "".join(lyrics.split(":")[1:]) + + + + # 读取音频文件 + cur_sample_rate = torchaudio.info(path).sample_rate + offset = int(cur_sample_rate*st) + num_frames = int(cur_sample_rate * (et -st)) + chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames) + + # 随机选取一个channel + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + + if(cur_sample_rate!=self.sr): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr) + + return chunk, lyric, [st, et], path + except: + print("Error loadding ", info["path"]) + try_cnt += 1 + idx = np.random.randint(0, len(self.data)) + if(try_cnt>10): + raise FileNotFoundError() + + def parse_lyric(self, lyric): + pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)' + match = re.search(pattern, lyric) + + start_time = float(match.group(1)) + end_time = float(match.group(2)) + content = match.group(3) + return start_time, end_time, content + +def collect_song(data_list): + audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2) + lyrics = [data[1] for data in data_list] + st_et = [data[2] for data in data_list] + paths = [data[3] for data in data_list] + return audios, lyrics, st_et diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py new file mode 100644 index 0000000000000000000000000000000000000000..991c59786412dc7f2bd22c57c7e4a7e3d30e5776 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_20s.py @@ -0,0 +1,313 @@ +import re +import sys +import json + +from torch.utils.data import Dataset +import torchaudio +from torchaudio.functional import resample +import torch +import numpy as np + +from torch.nn.utils.rnn import pad_sequence + + + +def check_lryics(lyric): + _FILTER_STRING = [ + '作词', '作曲', '编曲', '【', '策划', + '录音', '混音', '母带', ':', '制作', + '版权', '校对', '演奏', '制作', '伴奏' + ] + for item in _FILTER_STRING: + if item in lyric: + return True + + return False + + + +def process_lyrics(lines): + lyric_part = [] + timestamp_part = [] + + timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]') + + for i, line in enumerate(lines): + + # 删除前几行的特定信息 + if i<10 and check_lryics(line): + continue + + # 检查是否包含有效的时间戳和歌词内容 + if timestamp_pattern.match(line): + timestamp_end = line.rfind(']') + lyrics = line[timestamp_end + 1:].strip() + timestamps = line[:timestamp_end + 1] + + if ':' in lyrics: + if len(lyrics.split(":")[0]) <=5: + lyrics = "".join(lyrics.split(":")[1:]) + # if lyrics: # 确保歌词部分不是空的 + # lyric_part.append(lyrics) + # timestamp_part.append(timestamps) + # print(processed_lyrics) + return timestamp_part, lyric_part + +def get_timestamps(timestamp_part): + + # 转换为秒 + + timestamps = [] + + for line in timestamp_part: + match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line) + if match: + minutes = int(match.group(1)) + seconds = float(match.group(2)) + millis = float(match.group(3)) if match.group(3) else 0 + total_seconds = minutes * 60 + seconds + millis + timestamps.append(total_seconds) + + + return timestamps + +def process_lyrics_lrc(lyrics): + timestamp_part, lyric_part = process_lyrics(lyrics) + # print(timestamp_part) + # print(lyric_part) + timestamps = get_timestamps(timestamp_part) + # print(timestamps) + if len(timestamps) == 0: + # print(f'{lyric_path}') + return [] + + slice_start = timestamps[0] + slice_start_idx = 0 + + output_list = [] + for i in range(1, len(timestamps)): + # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉 + if timestamps[i] - slice_start > 30: + output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) + + slice_start = timestamps[i] + slice_start_idx = i + + return output_list + + + +def process_lyrics_yrc(lyrics): + + timestamps, lyric_part = extract_lrc(lyrics) + + # timestamp_part, lyric_part = process_lyrics(lyrics) + # import pdb; pdb.set_trace() + # print(timestamp_part) + # print(lyric_part) + # timestamps = get_timestamps(timestamp_part) + # print(timestamps) + if len(timestamps) == 0: + # print(f'{lyric_path}') + return [] + + slice_start = timestamps[0] + slice_start_idx = 0 + + output_list = [] + for i in range(1, len(timestamps)): + # 如果累积时间超过30秒,则进行切分 + if timestamps[i] - slice_start > 30: + output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) + + slice_start = timestamps[i] + slice_start_idx = i + # import pdb; pdb.set_trace() + return output_list + +def extract_lrc(lyrics): + timestamp_part, lyric_part = [], [] + + for i, text in enumerate(lyrics): + # 提取中括号内的内容 + bracket_content = re.search(r'\[(.*?)\]', text).group(1) + bracket_content = bracket_content.split(',') + # 提取小括号内的内容 + parentheses_content = re.findall(r'\((.*?)\)', text) + # 提取其他内容 + other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip() + + # 数据怎么处理? + # import pdb; pdb.set_trace() + if i<10 and check_lryics(other_content): + continue + + # import pdb; pdb.set_trace() + timestamp_part.append(float(bracket_content[0])/1000) + lyric_part.append(other_content) + # import pdb; pdb.set_trace() + return timestamp_part, lyric_part + + + +class WYYSongDataset(Dataset): + def __init__(self, + metadata_path:str, + sr:int = 0, + use_lang = ['en', 'zh-cn'], + num_examples = -1, + ): + + self.sr = sr + self.use_lang = use_lang + self._load_metadata(metadata_path) + + # buffer + self.lyric_buffer = {} + + if(num_examples<=0): + self.dataset_len = len(self.data) + self.random_slc = False + else: + self.dataset_len = num_examples + self.random_slc = True + + # 读取jsonl文件 + def _load_metadata(self, metadata_path): + with open(metadata_path) as fp: + lines = fp.readlines() + self.data = [] + for line in lines: + item = json.loads(line) + # if item['lrc-lyric'] is not None and item['yrc-lyric'] is not None: + if 'lyrics' in item and 'lang_info' in item: + if len(item['lyrics']) > 0: + for lang in self.use_lang: + if lang in item['lang_info'] and item['lang_info'][lang]['proportion'] > 0.8 and item['lang_info'][lang]['probability'] > 0.9: + # if '伴奏' not in item['path'] and "cloud" in item['path']: + if '伴奏' not in item['path']: + self.data.append(item) + + + def __len__(self): + return self.dataset_len + + + def __getitem__(self, idx): + try_cnt = 0 + while True: + if(self.random_slc): + idx = np.random.randint(0, len(self.data)) + yrc_lyrics = [] + lrc_lyrics = [] + try: + info = self.data[idx] + + # audio path + path:str = info["path"] + + # 读取歌词段落 + if 'lyrics' not in info: + if idx not in self.lyric_buffer: + # 字级别align的歌词 + if info['yrc-lyric'] is not None: + with open(info['yrc-lyric']) as f_in: + yrc_lyric = json.load(f_in) + yrc_lyrics = process_lyrics_yrc(yrc_lyric['lyrics'][:-1]) + + # 句子级align的歌词 + if info['lrc-lyric'] is not None: + with open(info['lrc-lyric']) as f_in: + lrc_lyric = json.load(f_in) + lrc_lyrics = process_lyrics_lrc(lrc_lyric['lyrics'][:-1]) + + # 优先使用字级别align的歌词 + if len(yrc_lyrics) > 0: + lyrics = yrc_lyrics + else: + lyrics = lrc_lyrics + self.lyric_buffer[idx] = lyrics + + # TODO 每段歌词进行长度筛选,过滤掉太长和太短的歌曲 + else: + lyrics = self.lyric_buffer[idx] + else: + lyrics = info['lyrics'] + + # 随机选取一个lyric段落 + ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item() + # ly_id = 0 + + lyric = lyrics[ly_id] + + + + st, et, lyric = self.parse_lyric(lyric) + + assert et - st < 20 + + # 文本过滤 + + lyric = re.sub(r'【.*?】', '', lyric) + if 'zh-cn' in info['lang_info'] and info['lang_info']['zh-cn']['proportion'] > 0.8: + assert 100 > len(lyric.replace(" ", "")) > 5 + if ':' in lyrics: + if len(lyrics.split(":")[0]) <=5: + lyrics = "".join(lyrics.split(":")[1:]) + + if ':' in lyrics: + if len(lyrics.split(":")[0]) <=5: + lyrics = "".join(lyrics.split(":")[1:]) + + if 'en' in info['lang_info'] and info['lang_info']['en']['proportion'] > 0.8: + assert 100 > len(lyric.split()) > 5 + + if ':' in lyrics: + if len(lyrics.split(":")[0].split()) <=3: + lyrics = "".join(lyrics.split(":")[1:]) + + if ':' in lyrics: + if len(lyrics.split(":")[0].split()) <=3: + lyrics = "".join(lyrics.split(":")[1:]) + + + + # 读取音频文件 + cur_sample_rate = torchaudio.info(path).sample_rate + offset = int(cur_sample_rate*st) + num_frames = int(cur_sample_rate * (et -st)) + chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames) + + # 随机选取一个channel + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + + if(cur_sample_rate!=self.sr): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr) + + return chunk, lyric, [st, et], path + except: + print("Error loadding ", info["path"]) + try_cnt += 1 + idx = np.random.randint(0, len(self.data)) + if(try_cnt>10): + raise FileNotFoundError() + + def parse_lyric(self, lyric): + pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)' + match = re.search(pattern, lyric) + + start_time = float(match.group(1)) + end_time = float(match.group(2)) + content = match.group(3) + return start_time, end_time, content + +def collect_song(data_list): + audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2) + lyrics = [data[1] for data in data_list] + st_et = [data[2] for data in data_list] + paths = [data[3] for data in data_list] + return audios, lyrics, st_et diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py new file mode 100644 index 0000000000000000000000000000000000000000..ab395273c6270912f5d84df71c70386f5eeab71b --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_song_new_429.py @@ -0,0 +1,313 @@ +import re +import sys +import json + +from torch.utils.data import Dataset +import torchaudio +from torchaudio.functional import resample +import torch +import numpy as np + +from torch.nn.utils.rnn import pad_sequence + + + +def check_lryics(lyric): + _FILTER_STRING = [ + '作词', '作曲', '编曲', '【', '策划', + '录音', '混音', '母带', ':', '制作', + '版权', '校对', '演奏', '制作', '伴奏' + ] + for item in _FILTER_STRING: + if item in lyric: + return True + + return False + + + +def process_lyrics(lines): + lyric_part = [] + timestamp_part = [] + + timestamp_pattern = re.compile(r'\[\d+:\d+(\.\d+)?\]') + + for i, line in enumerate(lines): + + # 删除前几行的特定信息 + if i<10 and check_lryics(line): + continue + + # 检查是否包含有效的时间戳和歌词内容 + if timestamp_pattern.match(line): + timestamp_end = line.rfind(']') + lyrics = line[timestamp_end + 1:].strip() + timestamps = line[:timestamp_end + 1] + + if ':' in lyrics: + if len(lyrics.split(":")[0]) <=5: + lyrics = "".join(lyrics.split(":")[1:]) + # if lyrics: # 确保歌词部分不是空的 + # lyric_part.append(lyrics) + # timestamp_part.append(timestamps) + # print(processed_lyrics) + return timestamp_part, lyric_part + +def get_timestamps(timestamp_part): + + # 转换为秒 + + timestamps = [] + + for line in timestamp_part: + match = re.match(r'\[(\d+):(\d+)(\.\d+)?\]', line) + if match: + minutes = int(match.group(1)) + seconds = float(match.group(2)) + millis = float(match.group(3)) if match.group(3) else 0 + total_seconds = minutes * 60 + seconds + millis + timestamps.append(total_seconds) + + + return timestamps + +def process_lyrics_lrc(lyrics): + timestamp_part, lyric_part = process_lyrics(lyrics) + # print(timestamp_part) + # print(lyric_part) + timestamps = get_timestamps(timestamp_part) + # print(timestamps) + if len(timestamps) == 0: + # print(f'{lyric_path}') + return [] + + slice_start = timestamps[0] + slice_start_idx = 0 + + output_list = [] + for i in range(1, len(timestamps)): + # 如果累积时间超过30秒,则进行切分, 如果整体小于30s, 整句会被丢掉 + if timestamps[i] - slice_start > 30: + output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) + + slice_start = timestamps[i] + slice_start_idx = i + + return output_list + + + +def process_lyrics_yrc(lyrics): + + timestamps, lyric_part = extract_lrc(lyrics) + + # timestamp_part, lyric_part = process_lyrics(lyrics) + # import pdb; pdb.set_trace() + # print(timestamp_part) + # print(lyric_part) + # timestamps = get_timestamps(timestamp_part) + # print(timestamps) + if len(timestamps) == 0: + # print(f'{lyric_path}') + return [] + + slice_start = timestamps[0] + slice_start_idx = 0 + + output_list = [] + for i in range(1, len(timestamps)): + # 如果累积时间超过30秒,则进行切分 + if timestamps[i] - slice_start > 30: + output_list.append(f'[{str(slice_start)}:{str(timestamps[i])}]' + ", ".join(lyric_part[slice_start_idx:i])) + + slice_start = timestamps[i] + slice_start_idx = i + # import pdb; pdb.set_trace() + return output_list + +def extract_lrc(lyrics): + timestamp_part, lyric_part = [], [] + + for i, text in enumerate(lyrics): + # 提取中括号内的内容 + bracket_content = re.search(r'\[(.*?)\]', text).group(1) + bracket_content = bracket_content.split(',') + # 提取小括号内的内容 + parentheses_content = re.findall(r'\((.*?)\)', text) + # 提取其他内容 + other_content = re.sub(r'\[(.*?)\]|\((.*?)\)', '', text).strip() + + # 数据怎么处理? + if i<10 and check_lryics(other_content): + continue + timestamp_part.append(float(bracket_content[0])/1000) + lyric_part.append(other_content) + return timestamp_part, lyric_part + + + +class WYYSongDataset(Dataset): + def __init__(self, + metadata_path:str, + sr:int = 0, + use_lang = ['en', 'zh-cn'], + num_examples = -1, + max_dur = 20, + pad_to_max= True, + ): + + self.sr = sr + self.use_lang = use_lang + self._load_metadata(metadata_path) + self.max_dur = max_dur + self.pad_to_max = pad_to_max + + # buffer + self.lyric_buffer = {} + + if(num_examples<=0): + self.dataset_len = len(self.data) + self.random_slc = False + else: + self.dataset_len = num_examples + self.random_slc = True + + # 读取jsonl文件 + def _load_metadata(self, metadata_path): + with open(metadata_path) as fp: + lines = fp.readlines() + self.data = [] + for line in lines: + item = json.loads(line) + if '伴奏' not in item['path']: + # if "lang_type" in item and item['lang_type'] == 'en': + if "lang_type" in item: + self.data.append(item) + + + def __len__(self): + return self.dataset_len + + + def __getitem__(self, idx): + try_cnt = 0 + while True: + if(self.random_slc): + idx = np.random.randint(0, len(self.data)) + yrc_lyrics = [] + lrc_lyrics = [] + try: + info = self.data[idx] + + # audio path + path = info["path"] + lang_type = info["lang_type"] + if info["lang_type"] == 'en': + lyrics = info['lyrics'] + else: + lyrics = info['lyrics_phone'] + + # 随机选取一个lyric段落 + ly_id = torch.randint(low=1, high=len(lyrics), size=(1,))[0].item() + lyric = lyrics[ly_id].strip() + + st, et, lyric = self.parse_lyric(lyric) + lyric = lyric.replace("\xa0", " ") + + lyric = " ".join(lyric.split()) + + assert et - st < self.max_dur + + + if info["lang_type"] == 'en': + # print(len(lyric.split())/(et-st)) + assert 6 > len(lyric.split())/(et-st) > 1 + else: + # print(len(lyric.split())/(et-st)) + lyric = lyric.replace("-", "") + assert 6 > len(lyric.split())/(et-st) > 1 + + + # 读取音频文件 + cur_sample_rate = torchaudio.info(path).sample_rate + offset = int(cur_sample_rate*st) + num_frames = int(cur_sample_rate * (et -st)) + chunk, _ = torchaudio.load(path, frame_offset=offset, num_frames=num_frames) + # chunk = torch.zeros(1, 48000*15) + + # 随机选取一个channel + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + + if(cur_sample_rate!=self.sr): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sr) + + if self.pad_to_max: + chunk = self.pad_2d_tensor(chunk, int(self.max_dur * self.sr), 0) + + return chunk, lyric, et-st, path, lang_type + except: + # print("Error loadding ", info["path"]) + try_cnt += 1 + idx = np.random.randint(0, len(self.data)) + if(try_cnt>20): + raise FileNotFoundError() + + def parse_lyric(self, lyric): + pattern = r'\[(\d+\.\d+):(\d+\.\d+)\](.*)' + match = re.search(pattern, lyric) + + start_time = float(match.group(1)) + end_time = float(match.group(2)) + content = match.group(3) + return start_time, end_time, content + + def pad_2d_tensor(self, x, max_len, pad_id): + # 获取输入 tensor 的形状 + batch_size, seq_len = x.size() + max_len = max(max_len, seq_len) + # 计算需要填充的长度 + pad_len = max_len - seq_len + + # 如果需要填充 + if pad_len > 0: + # 创建填充 tensor + pad_tensor = torch.full((batch_size, pad_len), pad_id, dtype=x.dtype, device=x.device) + + # 沿第二个维度(列)连接输入 tensor 和填充 tensor + padded_tensor = torch.cat([x, pad_tensor], dim=1) + else: + # 如果不需要填充,直接返回输入 tensor + padded_tensor = x + + return padded_tensor + +def collect_data(data_list): + audios = pad_sequence([data[0].t() for data in data_list], batch_first=True, padding_value=0).transpose(1,2) + lyrics = [data[1] for data in data_list] + st_et = [data[2] for data in data_list] + paths = [data[3] for data in data_list] + lang_types = [data[4] for data in data_list] + return audios, lyrics, st_et, lang_types + # return audios, lyrics, st_et + + +def build_dataset(): + train_dataset = WYYSongDataset( + metadata_path = "train.jsonl", + sr = 48000, + use_lang = ['zh-cn', 'en'], + num_examples = 10*10000 + ) + + valid_dataset = WYYSongDataset( + metadata_path = "valid.jsonl", + sr = 48000, + use_lang = ['zh-cn', 'en'], + num_examples = 500 + ) + + return train_dataset, valid_dataset diff --git a/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py new file mode 100644 index 0000000000000000000000000000000000000000..693efb42b76cf4c15ed2d045997e92695af30b3a --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/datasets/dataset_stock.py @@ -0,0 +1,461 @@ +from torch.utils.data import Dataset +from beartype.typing import Sequence, Callable, Optional, Dict, List +from beartype.door import is_bearable +import random +import os +from torchaudio.functional import resample +import torch +import typing as tp +from pathlib import Path +import torchaudio as ta +import torch.nn.functional as F +import soundfile +import numpy as np +import json +import yaml +import random +import librosa +from loguru import logger +import re + + +def _av_read(filepath, seek_time=0, duration=None): + if duration is not None: + sr = librosa.get_samplerate(filepath) + offset = seek_time + num_samples = int(duration * sr) + wav, _ = librosa.load(filepath, sr=sr, offset=offset, duration=duration) + else: + wav, sr = librosa.load(filepath, sr=None, offset=seek_time) + + return wav, sr + +def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0., + duration: float = -1., pad: bool = True) -> tp.Tuple[torch.Tensor, int]: + """Read audio by picking the most appropriate backend tool based on the audio format. + + Args: + filepath (str or Path): Path to audio file to read. + seek_time (float): Time at which to start reading in the file. + duration (float): Duration to read from the file. If set to -1, the whole file is read. + pad (bool): Pad output audio if not reaching expected duration. + Returns: + tuple of torch.Tensor, int: Tuple containing audio data and sample rate. + """ + fp = Path(filepath) + if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg + # There is some bug with ffmpeg and reading flac + info = soundfile.info(filepath) + frames = -1 if duration <= 0 else int(duration * info.samplerate) + frame_offset = int(seek_time * info.samplerate) + wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32) + assert info.samplerate == sr, f"Mismatch of sample rates {info.samplerate} {sr}" + wav = torch.from_numpy(wav).t().contiguous() + if len(wav.shape) == 1: + wav = torch.unsqueeze(wav, 0) + elif ( + fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats() + and duration <= 0 and seek_time == 0 + ): + # Torchaudio is faster if we load an entire file at once. + wav, sr = librosa.load(fp, sr=None, mono=True) + else: + wav, sr = _av_read(filepath, seek_time, duration) + if pad and duration > 0: + expected_frames = int(duration * sr) + wav = F.pad(torch.tensor(wav), (0, expected_frames - wav.shape[-1])) + if not isinstance(wav, torch.Tensor): + wav = torch.tensor(wav) + return wav, sr + +def random_seek_read(filepath, duration): + if duration > 0: + total_duration = librosa.get_duration(path=filepath) + acceptable_start = max(0, total_duration - duration) + wav, sr = audio_read(filepath, random.uniform(0, acceptable_start), duration, pad=True) + else: + wav, sr = audio_read(filepath, 0, -1, pad=False) + return wav, sr + +def safe_random_seek_read(filepath, duration, sample_rate): + try: + wav, sr = random_seek_read(filepath, duration) + if sr != sample_rate: + wav = resample(wav, sr, sample_rate) + sr = sample_rate + except Exception as e: + logger.error(f"Error reading {filepath}: {e}") + sr = sample_rate + wav = torch.zeros(sr * max(duration, 0), dtype=torch.float32) + return wav, sr + +def read_jsonlike(path: os.PathLike): + #json or jsonl + if str(path).endswith(".json"): + with open(path, 'r', encoding='utf8') as f: + data = json.load(f) + return data + elif str(path).endswith(".jsonl"): + with open(path, 'r', encoding='utf8') as f: + data = [json.loads(line) for line in f.readlines()] + return data + else: + raise ValueError("Unknown file format") + +dist_prob_map = { + 1: (1.0,), + 2: (0.5, 0.5), + 3: (0.3, 0.4, 0.3), + 4: (0.2, 0.3, 0.3, 0.2), + 5: (0.2, 0.2, 0.3, 0.2, 0.1), + 6: (0.1, 0.15, 0.2, 0.2, 0.2, 0.15), + 7: (0.05, 0.1, 0.1, 0.2, 0.25, 0.2, 0.1), + 8: (0.03, 0.05, 0.1, 0.15, 0.25, 0.2, 0.1, 0.12), + 9: (0.02, 0.1, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.08), + 10: (0.01, 0.1, 0.1, 0.15, 0.2, 0.15, 0.1, 0.05, 0.05, 0.09) +} + +dist_prob_map_low = { + 1: (1.0,), + 2: (0.8, 0.2), + 3: (0.8, 0.1, 0.1), + 4: (0.7, 0.1, 0.1, 0.1), + 5: (0.7, 0.1, 0.1, 0.05, 0.05), + 6: (0.7, 0.1, 0.05, 0.05, 0.05, 0.05), +} + + +_bpm_range_rights = ( + (40, '20-40'), + (60, '40-60'), + (66, '60-66'), + (76, '66-76'), + (108, '76-108'), + (120, '108-120'), + (168, '120-168'), + (176, '168-176'), + (200, '176-200') +) +_bpm_desc_map = { + '20-40': ("glacial pace", "extremely slow tempo", "crawl-like speed", "snail's pace", "almost motionless rhythm", "Larghissimo"), + '40-60': ("broad and slow", "spacious tempo", "unhurried pace", "calm rhythm", "relaxed speed", "Largo"), + '60-66': ("gentle tempo", "leisurely pace", "easy-going rhythm", "unrushed speed", "smooth and slow", 'Larghetto'), + '66-76': ("slow and steady", "deliberate tempo", "unhurried pace", "relaxed rhythm", "easy speed", 'Adagio'), + '76-108': ("walking pace", "moderate tempo", "steady rhythm", "balanced speed", "easy-flowing tempo", "Andante"), + '108-120': ("medium pace", "comfortable tempo", "even rhythm", "measured speed", "controlled tempo", 'Moderato'), + '120-168': ("quick and lively", "brisk pace", "energetic tempo", "upbeat rhythm", "spirited speed", 'Allegro'), + '168-176': ("lively and fast", "bright tempo", "sprightly pace", "vibrant rhythm", "animated speed", 'Vivace'), + '176-200': ("very fast tempo", "rapid pace", "high-speed rhythm", "hurried speed", "accelerated tempo", 'Presto'), + '>200': ("extremely fast", "breakneck speed", "blazing tempo", "lightning-fast rhythm", "supercharged pace", 'Prestissimo') +} +_bpm_desc_map_zh = { + '20-40': ("极度缓慢", "极慢的节奏", "悠长的旋律", "迟缓的节奏", "几乎静止的节奏", "甚缓"), + '40-60': ("宽广而缓慢", "宽敞的节奏", "从容不迫的速度", "平静的节奏", "轻松的速度", "广板"), + '60-66': ("柔和的节奏", "悠闲的速度", "轻松的节奏", "不慌不忙的速度", "平滑而缓慢", '小广板'), + '66-76': ("缓慢而稳定", "沉稳的旋律", "从容不迫的速度", "轻松的节奏", "轻松的速度", '慢板'), + '76-108': ("步行速度", "适中的节奏", "稳定的节奏", "平衡的速度", "流畅的节奏", "行板"), + '108-120': ("中等速度", "舒适的节奏", "均匀的节奏", "有节制的速度", "稳定的氛围", '中板'), + '120-168': ("快速而生动", "轻快的速度", "充满活力的节奏", "欢快的节奏", "富有精神的速度", '快板'), + '168-176': ("生动而快速", "明快的节奏", "活泼的速度", "充满活力的节奏", "生气勃勃的速度", '活泼的'), + '176-200': ("非常快的节奏", "快速的速度", "高速的节奏", "匆忙的速度", "加速的节奏", '急板'), + '>200': ("极快的速度", "极速旋律", "炽热的节奏", "闪电般的节奏", "疾驰的速度", '最急板') +} +def get_bpm_range(bpm): + bpm = int(bpm) + for right, tag in _bpm_range_rights: + if bpm <= right: + return tag + return '>200' + +def gen_bpm_descript(bpm, lang='en'): + bpm_range = get_bpm_range(bpm) + if lang == 'en': + return random.choice(_bpm_desc_map[bpm_range]) + elif lang == 'zh': + return random.choice(_bpm_desc_map_zh[bpm_range]) + else: + raise ValueError(f"Unknown language {lang}") + +def read_translate(translate: Optional[Dict[str, os.PathLike]]): + if translate is None: + return None + return {k: read_jsonlike(path) for k, path in translate.items()} + + +def tags_to_desc(tag_list, sep=',') -> str: + if not isinstance(tag_list, Sequence): + return str(tag_list) + if isinstance(tag_list, str): + return tag_list + if len(tag_list) <= 0: + return '' + elif len(tag_list) <= 5: + probs = dist_prob_map[len(tag_list)] + tags_num = random.choices(range(1, len(tag_list)+1), probs)[0] + random.shuffle(tag_list) + tag_list = tag_list[:tags_num] + return sep.join(tag_list) + else: + probs = dist_prob_map[5] + tags_num = random.choices(range(1, 6), probs)[0] + random.shuffle(tag_list) + tag_list = tag_list[:tags_num] + return sep.join(tag_list) + + +class PromptTemplate: + def __init__(self, template_text: str, tag_map: Dict[str, str], lang:str ='en'): + self.template_text = template_text + self.tag_map = tag_map + self.lang = lang + + @property + def tags(self): + return tuple(self.tag_map.keys()) + + def apply(self, **kwargs): + for tag in list(kwargs.keys()): + if kwargs[tag] == '': + kwargs.pop(tag) + for tag in self.tags: + if tag in kwargs: + kwargs[tag] = self.tag_map[tag].format(**{tag: kwargs[tag]}).strip('[]') + else: + kwargs[tag] = '' + prompt = self.template_text.format(**kwargs) + + return self.beautify(prompt) + + def beautify(self, text): + if self.lang == 'en': + return self._beautify_en(text) + elif self.lang == 'zh': + return self._beautify_zh(text) + else: + raise ValueError(f'Unknown language {self.lang}') + + @staticmethod + def _beautify_en(text): + # no continuous commas without content between them + text = re.sub(r'[,\s]*,[,\s]*', r', ', text) + # no continuous whitespace + text = re.sub(r'\s+', ' ', text) + # the comma is NOT followed by whitespace, and should be followed by ONE whitespace + text = re.sub(r'\s+,', r',', text) + text = re.sub(r',\s+', r', ', text) + # no whitespace before the full stop + text = re.sub(r'\s+\.', r'.', text) + # strip whitespace, comma, and replace ',.' + text = text.strip(' ,') + text = text.replace(',.', '.') + return text + + @staticmethod + def _beautify_zh(text): + # no continuous commas without content between them + text = re.sub(r'[,、\s]*,[,、\s]*', r',', text) + text = re.sub(r'[,、\s]*、[,、\s]*', r'、', text) + # assume there should be NO whitespace in Chinese + text = re.sub(r'\s+', r'', text) + # strip whitespace, comma, and replace ',。' + text = text.strip(', 、') + text = text.replace(',。', '。') + return text + + def __repr__(self): + return f'PromptTemplate({self.template_text!r}, {self.tag_map!r})' + + __str__ = __repr__ + +def parse_prompt_template(prompt_template_text, lang='en'): + span_pattern = re.compile(r'\[.*?{.+?}.*?\]', re.DOTALL) + tag_pattern = re.compile(r'{.+?}', re.DOTALL) + + template_text = prompt_template_text.strip() + span_texts = span_pattern.findall(prompt_template_text) + tag_map = {} + for span_text in span_texts: + tag = tag_pattern.findall(span_text)[0].strip('{}') + tag_map[tag] = span_text + template_text = template_text.replace(span_text, '{'+tag+'}') + + return PromptTemplate(template_text=template_text, tag_map=tag_map, lang=lang) + +def load_prompt_templates(path, num = 5, lang='en') -> List[PromptTemplate]: + with open(path, 'r') as f: + lines = f.readlines() + cnt = 0 + pts = [] + for line in lines: + pt = parse_prompt_template(line, lang=lang) + cnt += 1 + if len(pt.tags) < num: + logger.error(f'Not enough tags on {path} in line {cnt}: {pt.tags}') + pts.append(pt) + + return pts + + +class AudioStockDataset(Dataset): + def __init__(self, + num_examples:int, + metadata_path:str, + duration:float=60, + sr:int = 0, + return_path = False, + return_audio = True, + prompt_template_path: os.PathLike = None, + tag_types = [], + lang = 'en', + translate:Optional[Dict[str, os.PathLike]] = None + ): + self.duration = duration + self.MAX_DURATION = 360 + self._load_metadata(metadata_path) + if num_examples > 0: + self.random_choose = True + self.dataset_len = num_examples + else: + self.random_choose = False + self.dataset_len = len(self.data) + self.sr = sr + self.return_path = return_path + self.return_audio = return_audio + + self.use_dynamic_prompt = prompt_template_path is not None + if self.use_dynamic_prompt: + self.prompt_templates = load_prompt_templates(prompt_template_path, num = len(tag_types), lang = lang) + self.tag_types = tag_types + + self.lang = lang + self.translate = read_translate(translate) + + def _load_metadata(self, metadata_path): + total_len = 0; valid_len = 0 + with open(metadata_path) as fp: + lines = fp.readlines() + self.data = [] + for line in lines: + item = json.loads(line) + total_len += 1 + if(item['duration']>self.duration and item['duration']10): + raise ValueError() + + def getitem_main(self, idx): + path:str = self.data[idx]["path"] + json_path = path[:path.rfind('.')] + ".json" + if self.is_info_recorded: + item = self.data[idx] + else: + with open(json_path) as fp: + item:dict = json.load(fp) + description = self.generate_description(item) + if self.return_audio: + audio, sr = safe_random_seek_read(path, duration=self.duration, sample_rate=self.sr) + else: + audio = None + if self.return_path: + return audio, description, path + return audio, description + + + + def generate_description(self, item): + if self.use_dynamic_prompt: + # dynamically generate prompt from given prompt template + prompt_template = random.choice(self.prompt_templates) + description = self.generate_description_dynamic(item, prompt_template) + else: + # use ordinary static prompt instead + description = self.generate_description_ordinary(item) + return description + + def generate_description_dynamic(self, data, prompt_template: PromptTemplate): + exists_tag = [key for key in data if (key in self.tag_types) and (data[key] is not None) and (len(data[key]) > 0)] + + if len(exists_tag) > 0: + probs = dist_prob_map[len(exists_tag)] + tags_num = random.choices(range(1, len(exists_tag)+1), probs)[0] + random.shuffle(exists_tag) + tags = exists_tag[:tags_num] + tags_args = {tag: self.tags_to_desc(data[tag], tag) for tag in tags} + tags_args = self.handle_BPM_tag(tags_args) + prompt = prompt_template.apply(**tags_args) + else: + # no strong tags, use all weak tags instead + prompt = prompt_template.apply() + + return prompt + + def tags_to_desc(self, tag_list, tag_type) -> str: + if self.lang == 'en': + return tags_to_desc(tag_list) + elif self.lang == 'zh': + if tag_type == 'BPM': + return tags_to_desc(tag_list, sep='、') + translator = self.translate[tag_type] + translated_tag_list = [translator[tag] for tag in tag_list if tag in translator ] + return tags_to_desc(translated_tag_list, sep='、') + + def handle_BPM_tag(self, tags_args): + if "BPM" in tags_args and 'BPMDescript' in self.tag_types: + bpm = tags_args["BPM"] + del tags_args["BPM"] + tag_types_used = random.choice((('BPM',), ('BPMDescript',), ('BPM', 'BPMDescript'))) + for tag_type in tag_types_used: + tags_args[tag_type] = bpm if tag_type == 'BPM' else gen_bpm_descript(bpm, lang=self.lang) + return tags_args + + def generate_description_ordinary(self, data, thresh = 0.3): + if self.lang != 'en': + raise ValueError(f'Language {self.lang} is not supported for ordinary description generation') + description = f'a piece of music by {data["Artist"]}' + + # Add genre if available + if data["Genre"] and random.random() > thresh: + genres = ', '.join(data["Genre"]) + description += f', belonging to the {genres} genres' + + # Add moods if available + if data["Tags"] and random.random() > thresh: + tags = ', '.join(data["Tags"]) + description += f'. This track contains the tags:{tags}' + + # Add moods if available + if data["Mood"] and random.random() > thresh: + moods = ', '.join(data["Mood"]) + description += f'. This track conveys a {moods} mood.' + + # Add instruments if available + if data["Instrument"] and random.random() > thresh: + instruments = ', '.join(data["Instrument"]) + description += f'. and primarily features the following instruments: {instruments}' + + # Add a period to end the description + description += '.' + + return description + diff --git a/codeclm/tokenizer/Flow1dVAE/libs/fsq/fsq.py b/codeclm/tokenizer/Flow1dVAE/libs/fsq/fsq.py new file mode 100644 index 0000000000000000000000000000000000000000..5d603530f061f7bbba2c3d08e1a7021799d73171 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/fsq/fsq.py @@ -0,0 +1,236 @@ +""" +Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 +Code adapted from Jax version in Appendix A.1 +""" + +from __future__ import annotations +from functools import wraps, partial +from contextlib import nullcontext +from typing import List, Tuple + +import torch +import torch.nn as nn +from torch.nn import Module +from torch import Tensor, int32 +from torch.amp import autocast + +from einops import rearrange, pack, unpack + +# helper functions + +def exists(v): + return v is not None + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + +def maybe(fn): + @wraps(fn) + def inner(x, *args, **kwargs): + if not exists(x): + return x + return fn(x, *args, **kwargs) + return inner + +def pack_one(t, pattern): + return pack([t], pattern) + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + +# tensor helpers + +def round_ste(z: Tensor) -> Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + +# main class + +class FSQ(Module): + def __init__( + self, + levels: List[int], + dim: int | None = None, + num_codebooks = 1, + keep_num_codebooks_dim: bool | None = None, + scale: float | None = None, + allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64), + channel_first: bool = False, + projection_has_bias: bool = True, + return_indices = True, + force_quantization_f32 = True + ): + super().__init__() + _levels = torch.tensor(levels, dtype=int32) + self.register_buffer("_levels", _levels, persistent = False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) + self.register_buffer("_basis", _basis, persistent = False) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + self.channel_first = channel_first + + has_projections = self.dim != effective_codebook_dim + self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = projection_has_bias) if has_projections else nn.Identity() + self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = projection_has_bias) if has_projections else nn.Identity() + + self.has_projections = has_projections + + self.return_indices = return_indices + if return_indices: + self.codebook_size = self._levels.prod().item() + implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size)) + self.register_buffer("implicit_codebook", implicit_codebook, persistent = False) + + self.allowed_dtypes = allowed_dtypes + self.force_quantization_f32 = force_quantization_f32 + + def bound(self, z, eps: float = 1e-3): + """ Bound `z`, an array of shape (..., d). """ + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z): + """ Quantizes z, returns quantized zhat, same shape as z. """ + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized): + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat): + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def _indices_to_codes(self, indices): + level_indices = self.indices_to_level_indices(indices) + codes = self._scale_and_shift_inverse(level_indices) + return codes + + def codes_to_indices(self, zhat): + """ Converts a `code` to an index in the codebook. """ + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat) + return (zhat * self._basis).sum(dim=-1).to(int32) + + def indices_to_level_indices(self, indices): + """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """ + indices = rearrange(indices, '... -> ... 1') + codes_non_centered = (indices // self._basis) % self._levels + return codes_non_centered + + def indices_to_codes(self, indices): + """ Inverse of `codes_to_indices`. """ + assert exists(indices) + + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + + codes = self._indices_to_codes(indices) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, '... c d -> ... (c d)') + + codes = self.project_out(codes) + + if is_img_or_video or self.channel_first: + codes = rearrange(codes, 'b ... d -> b d ...') + + return codes + + def forward(self, z): + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension + c - number of codebook dim + """ + + is_img_or_video = z.ndim >= 4 + need_move_channel_last = is_img_or_video or self.channel_first + + # standardize image or video into (batch, seq, dimension) + + if need_move_channel_last: + z = rearrange(z, 'b d ... -> b ... d') + z, ps = pack_one(z, 'b * d') + + assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}' + + z = self.project_in(z) + + z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks) + + # whether to force quantization step to be full precision or not + + force_f32 = self.force_quantization_f32 + quantization_context = partial(autocast, 'cuda', enabled = False) if force_f32 else nullcontext + + with quantization_context(): + orig_dtype = z.dtype + + if force_f32 and orig_dtype not in self.allowed_dtypes: + z = z.float() + + codes = self.quantize(z) + + # returning indices could be optional + + indices = None + + if self.return_indices: + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, 'b n c d -> b n (c d)') + + codes = codes.type(orig_dtype) + + # project out + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if need_move_channel_last: + out = unpack_one(out, ps, 'b * d') + out = rearrange(out, 'b ... d -> b d ...') + + indices = maybe(unpack_one)(indices, ps, 'b * c') + + if not self.keep_num_codebooks_dim and self.return_indices: + indices = maybe(rearrange)(indices, '... 1 -> ...') + + # return quantized output and indices + + return out, indices + + +if __name__ == '__main__': + # test + fsq = FSQ([4, 4, 4],dim=1024) + z = torch.randn(2, 3, 1024) + out, indices = fsq(z) + print(out.shape, indices.shape) + # print(out, indices) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..1993fb6d00854e5ad749e66e88268c34800d777b --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py @@ -0,0 +1,366 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +# from .. import distrib + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange( + means, "c d -> () c d" + ) + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + self.runed_steps = 0 + self.stop_steps = 50_000 + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + distrib.broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + # distrib.broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + # self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + self.runed_steps += 1 + + if self.training and self.runed_steps < self.stop_steps: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1., + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) + self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, + kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, + decay=decay, epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x, do_debug=False): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layerinx, layer in enumerate(self.layers[:n_q]): + print("Layer {} Used ratio {:.1f}".format(layerinx, (layer._codebook.cluster_size > 1.0).sum() / layer._codebook.cluster_size.shape[0] * 100.)) + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..ab004c36be0559c0aab97e35cc2e7cb05d99ac71 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize.py @@ -0,0 +1,268 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + for n in range(encodings.shape[1]): + print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + )) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize2.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize2.py new file mode 100644 index 0000000000000000000000000000000000000000..af345a5b3f71cf6504a8db92e66cec7e6cc8e86f --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize2.py @@ -0,0 +1,290 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + for n in range(encodings.shape[1]): + print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + )) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3.py new file mode 100644 index 0000000000000000000000000000000000000000..946d20ed109fb7fc17f25ce287701db65eab63db --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3.py @@ -0,0 +1,299 @@ +# compared with `descript_quantize2`, we use rvq & random_dropout +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + else: + n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1 + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + # if self.training is False and i >= n_quantizers: + # break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + # for n in range(encodings.shape[1]): + # print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + # (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + # )) + + return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1 + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0) + x = torch.randn(16, 1024, 80) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x) + print(quantized_prompt_embeds.shape) + print(codes.shape) + # w/o reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # w/ reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc1a8a316cd2821fa20b37dc55e872d44283670 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer.py @@ -0,0 +1,303 @@ +# compared with `descript_quantize2`, we use rvq & random_dropout +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm +import random + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + random_num = random.random() + if random_num<0.6: + n_quantizers = torch.ones((z.shape[0],)) * 1 + elif random_num<0.8: + n_quantizers = torch.ones((z.shape[0],)) * 2 + else: + n_quantizers = torch.ones((z.shape[0],)) * 4 + n_quantizers = n_quantizers.to(z.device) + else: + n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + # if self.training is False and i >= n_quantizers: + # break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + for n in range(encodings.shape[1]): + print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + )) + + return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1 + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0) + x = torch.randn(16, 1024, 80) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x) + print(quantized_prompt_embeds.shape) + print(codes.shape) + # w/o reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # w/ reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_freezelayer1.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_freezelayer1.py new file mode 100644 index 0000000000000000000000000000000000000000..c42975d3e3473079333b02398950eef782cfdc9d --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_freezelayer1.py @@ -0,0 +1,301 @@ +# compared with `descript_quantize2`, we use rvq & random_dropout +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm +import random + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + random_num = random.random() + if random_num<0.6: + n_quantizers = torch.ones((z.shape[0],)) * 2 + else: + n_quantizers = torch.ones((z.shape[0],)) * 4 + n_quantizers = n_quantizers.to(z.device) + else: + n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + # if self.training is False and i >= n_quantizers: + # break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + # for n in range(encodings.shape[1]): + # print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + # (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + # )) + + return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1 + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0) + x = torch.randn(16, 1024, 80) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x) + print(quantized_prompt_embeds.shape) + print(codes.shape) + # w/o reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # w/ reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_return_layer.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_return_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..2e535fa033efaec31dad9a43b6e430c19c8b9928 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_return_layer.py @@ -0,0 +1,305 @@ +# compared with `descript_quantize2`, we use rvq & random_dropout +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm +import random + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + layer = self.n_codebooks + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + random_num = random.random() + if random_num<0.6: + n_quantizers = torch.ones((z.shape[0],)) * 1 + elif random_num<0.8: + n_quantizers = torch.ones((z.shape[0],)) * 2 + layer = 2 + else: + n_quantizers = torch.ones((z.shape[0],)) * 4 + layer = 4 + n_quantizers = n_quantizers.to(z.device) + else: + n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + # if self.training is False and i >= n_quantizers: + # break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + for n in range(encodings.shape[1]): + print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + )) + + return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1,layer + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0) + x = torch.randn(16, 1024, 80) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x) + print(quantized_prompt_embeds.shape) + print(codes.shape) + # w/o reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # w/ reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_test.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c4fae8908358fcf10cf07165c32ecdfa69f40145 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_4layer_test.py @@ -0,0 +1,307 @@ +# compared with `descript_quantize2`, we use rvq & random_dropout +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm +import random + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + random_num = random.random() + # random_num = 1.0 + print("Random number: {:.2f}".format(random_num)) + if random_num<0.6: + n_quantizers = torch.ones((z.shape[0],)) * 1 + elif random_num<0.8: + n_quantizers = torch.ones((z.shape[0],)) * 2 + else: + n_quantizers = torch.ones((z.shape[0],)) * 4 + n_quantizers = n_quantizers.to(z.device) + else: + n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + n_quantizers = n_quantizers.to(z.device) + print("Number of quantizers: ", n_quantizers) + + for i, quantizer in enumerate(self.quantizers): + # if self.training is False and i >= n_quantizers: + # break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + print("mask: ", mask) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + for n in range(encodings.shape[1]): + print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + )) + + return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1 + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0) + x = torch.randn(16, 1024, 80) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x) + print(quantized_prompt_embeds.shape) + print(codes.shape) + # w/o reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # w/ reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_nodown.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_nodown.py new file mode 100644 index 0000000000000000000000000000000000000000..fae6dceba85e83795e9cdcd68c924a78d12b6f3f --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_nodown.py @@ -0,0 +1,300 @@ +# compared with `descript_quantize2`, we use rvq & random_dropout +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + # self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.in_proj = nn.Identity() + self.out_proj = nn.Identity() + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + else: + n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1 + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + # if self.training is False and i >= n_quantizers: + # break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + for n in range(encodings.shape[1]): + print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + )) + + return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1 + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0) + x = torch.randn(16, 1024, 80) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x) + print(quantized_prompt_embeds.shape) + print(codes.shape) + # w/o reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # w/ reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_randomdrop.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_randomdrop.py new file mode 100644 index 0000000000000000000000000000000000000000..5be160bebc5a13b1460d7fb017c0c935c0c902f9 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_randomdrop.py @@ -0,0 +1,321 @@ +# compared with `descript_quantize2`, we use rvq & random_dropout +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def dropout_code(self,code,dropout_rate=0.05): + total_elements = code.shape[0] * code.shape[1] + dropout_elements = int(total_elements * dropout_rate) + dropout_indices = np.random.choice(np.arange(total_elements), size=dropout_elements, replace=False) + # 对每个选择的元素位置进行替换 + for idx in dropout_indices: + # 计算二维索引 + i = idx // code.shape[1] + j = idx % code.shape[1] + # 替换元素 + code[i, j] = np.random.randint(0, self.codebook_size) + return code + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + if self.training: + print("train") + indices = self.dropout_code(indices, dropout_rate=0.1) + + # print("indices", indices.shape) + #random replace 10% indices + # if(self.training): + + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + else: + n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1 + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + # if self.training is False and i >= n_quantizers: + # break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + for n in range(encodings.shape[1]): + print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + )) + + return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1 + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0) + # rvq.eval() + x = torch.randn(16, 1024, 80) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x) + print(quantized_prompt_embeds.shape) + print(codes.shape) + # w/o reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # w/ reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_randomdrop_freezervq.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_randomdrop_freezervq.py new file mode 100644 index 0000000000000000000000000000000000000000..f933803c82ff1ef5ec57c8ab0c5006bfc11d162c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_randomdrop_freezervq.py @@ -0,0 +1,323 @@ +# compared with `descript_quantize2`, we use rvq & random_dropout +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100,dropout_rate=0.1): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.dropout_rate = dropout_rate + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def dropout_code(self,code): + total_elements = code.shape[0] * code.shape[1] + dropout_elements = int(total_elements * self.dropout_rate) + dropout_indices = np.random.choice(np.arange(total_elements), size=dropout_elements, replace=False) + # 对每个选择的元素位置进行替换 + for idx in dropout_indices: + # 计算二维索引 + i = idx // code.shape[1] + j = idx % code.shape[1] + # 替换元素 + code[i, j] = np.random.randint(0, self.codebook_size) + return code + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + # if self.training: + # print("train") + indices = self.dropout_code(indices) + + # print("indices", indices.shape) + #random replace 10% indices + # if(self.training): + + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + dropout_rate=0.1 + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance,dropout_rate=dropout_rate) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + else: + n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1 + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + # if self.training is False and i >= n_quantizers: + # break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + for n in range(encodings.shape[1]): + print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + )) + + return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1 + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0) + # rvq.eval() + x = torch.randn(16, 1024, 80) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x) + print(quantized_prompt_embeds.shape) + print(codes.shape) + # w/o reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # w/ reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_simple_vq.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_simple_vq.py new file mode 100644 index 0000000000000000000000000000000000000000..8ccaaa73111366bc9b55c3707b78c3b453fff7df --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/descript_quantize3_simple_vq.py @@ -0,0 +1,299 @@ +# compared with `descript_quantize2`, we use rvq & random_dropout +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + else: + n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1 + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + # if self.training is False and i >= n_quantizers: + # break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + for n in range(encodings.shape[1]): + print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + )) + + return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1 + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 1024, codebook_dim = 32, quantizer_dropout = 0.0) + x = torch.randn(16, 1024, 80) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = rvq(x) + print(quantized_prompt_embeds.shape) + print(codes.shape) + # w/o reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # w/ reconstruction + loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/mert_with_kmeans.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/mert_with_kmeans.py new file mode 100644 index 0000000000000000000000000000000000000000..b4375ee6c7a68e74f4abeb46b70b13e345455aeb --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/mert_with_kmeans.py @@ -0,0 +1,187 @@ +import os, sys +from transformers import AutoModel +import torch +from torch import nn +import torchaudio.transforms as T +import einops +import numpy as np +import joblib +from torch.nn.utils.rnn import pad_sequence + + +def make_pad_mask(lengths: torch.Tensor) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + >>> lengths = torch.tensor([1, 3, 2, 5]) + >>> make_pad_mask(lengths) + tensor([[False, True, True, True, True], + [False, False, False, True, True], + [False, False, True, True, True], + [False, False, False, False, False]]) + """ + + assert lengths.ndim == 1, lengths.ndim + + max_len = lengths.max() + n = lengths.size(0) + expaned_lengths = torch.arange(max_len).expand(n, max_len).to(lengths) + + return expaned_lengths >= lengths.unsqueeze(1) + +class KmeansQuantizer(nn.Module): + def __init__(self, centroids) -> None: + super().__init__() + if type(centroids) == np.ndarray: + centroids = torch.from_numpy(centroids) + # self.clusters = nn.Embedding(n_cluster, feature_dim) + self.clusters = nn.Parameter(centroids) + + @classmethod + def from_pretrained(cls, km_path): + km_model = joblib.load(km_path) + centroids = km_model.cluster_centers_ + return cls(centroids) + + @property + def n_cluster(self) -> int: + return self.clusters.shape[0] + + @property + def feature_dim(self) -> int: + return self.clusters.shape[1] + + + def forward(self, inp: torch.Tensor): + if inp.ndim == 3 and inp.shape[-1] == self.feature_dim: + return self.feat2indice(inp) + elif inp.ndim < 3: + return self.indice2feat(inp) + else: + raise NotImplementedError + + def feat2indice(self, feat): + ''' + feat: B,T,D + ''' + batched_cluster_centers = einops.repeat(self.clusters, 'c d -> b c d', b = feat.shape[0]) + dists = torch.cdist(feat, batched_cluster_centers, p = 2) + indices = dists.argmin(dim = -1) + return indices + + def indice2feat(self, indice): + ''' + indice: B, T + ''' + return nn.functional.embedding(input=indice, weight=self.clusters) + +class MERTwithKmeans(nn.Module): + def __init__(self, pretrained_model_name_or_path, kmeans_path=None, sampling_rate=44100, output_layer=-1, mean_pool=1) -> None: + super().__init__() + + # assert pretrained_model_name_or_path in ["MERT-v1-95M", "MERT-v1-330M"] + assert pretrained_model_name_or_path == "MERT-v1-330M" + # loading our model weights + # self.model = AutoModel.from_pretrained(f"vocal2accmpl/model/.cache/models--m-a-p--MERT-v1-95M/snapshots/8881df140a93e2ea270235b5d7be802245e3d2c6", trust_remote_code=True) + self.model = AutoModel.from_pretrained('pretrained/models--m-a-p--MERT-v1-330M/snapshots/af10da70c94a0c849de9cc94b83e12769c4db499', trust_remote_code=True) + # print(self.model) + if kmeans_path is not None: + centroids = joblib.load(kmeans_path).cluster_centers_ + self.kmeans = KmeansQuantizer(centroids) + else: + self.kmeans = None + + # loading the corresponding preprocessor config + # self.processor = Wav2Vec2FeatureExtractor.from_pretrained(f"m-a-p/{pretrained_model_name_or_path}",trust_remote_code=True) + + # make sure the sample_rate aligned + self.sampling_rate = sampling_rate + self.resampler = T.Resample(sampling_rate, 24000) if sampling_rate != 24000 else lambda x: x + + self.do_normalization = (pretrained_model_name_or_path == "MERT-v1-95M") + self.output_layer = output_layer + self.mean_pool = mean_pool + assert self.mean_pool % 2 == 1 + + @torch.no_grad() + def forward(self, input_audio, seq_len=None, apply_kmeans=True): + ''' + input_audio: B,T + seq_len: B, + ''' + device = input_audio.device + return_seq_len = True + if seq_len is None: + return_seq_len = False + seq_len = [input_audio.shape[1] for _ in input_audio] + + input_audio = [self.resampler(x[:l]) for x, l in zip(input_audio, seq_len)] + new_seq_len = torch.tensor([len(i) for i in input_audio], device=device) + + + # std_inp = self.processor([x.numpy() for x in input_audio], sampling_rate=24000, return_tensors="pt", padding=True) + + if self.do_normalization: + input_audio = self.zero_mean_unit_var_norm(input_audio, new_seq_len) + + padded_input = pad_sequence(input_audio, batch_first=True) + attention_mask = ~ make_pad_mask(new_seq_len) + + # assert (~(attention_mask == std_inp['attention_mask'])).sum() == 0, f"{attention_mask}, {std_inp['attention_mask']}" + # assert (~(padded_input.to(dtype=std_inp['input_values'].dtype) == std_inp['input_values'])).sum() == 0, f"{torch.sum((padded_input - std_inp['input_values']))}" + + outputs = self.model(input_values=padded_input, attention_mask=attention_mask, output_hidden_states=True) + + output = outputs['hidden_states'][self.output_layer] + output_len = torch.round(new_seq_len.float() / 24000 * 75).long() + # print(output_len) + # output_len = output_len.masked_fill(output_len > output.shape[1], output.shape[1]).long() + output = nn.functional.interpolate(output.transpose(-1,-2), output_len.max().item()).transpose(-1,-2) + + if self.mean_pool > 1: + output_len = output_len // 3 + output = nn.functional.avg_pool1d(output.transpose(-1, -2), kernel_size=self.mean_pool, stride=self.mean_pool) + output = output.transpose(-1,-2) + # print(output.shape, output_len) + # print(output.shape, output_len) + + + if apply_kmeans: + output = self.kmeans.feat2indice(output) + + if return_seq_len: + return output, output_len + + return output + + + + # from transformers.models.wav2vec2.feature_extraction_wav2vec2 + # rewrite it by pytorch + @staticmethod + def zero_mean_unit_var_norm( + input_values: torch.Tensor, seq_len: torch.Tensor = None, padding_value: float = 0.0 + ) -> torch.Tensor: + """ + Every array in the list is normalized to have zero mean and unit variance + """ + if seq_len is not None: + normed_input_values = [] + + for vector, length in zip(input_values, seq_len): + normed_slice = (vector - vector[:length].mean()) / torch.sqrt(vector[:length].var() + 1e-7) + if length < normed_slice.shape[0]: + normed_slice[length:] = padding_value + + normed_input_values.append(normed_slice) + # normed_input_values = torch.stack(normed_input_values, dim=0) + else: + normed_input_values = (input_values - input_values.mean(dim=-1, keepdim=True)) / torch.sqrt(input_values.var(dim=-1, keepdim=True) + 1e-7) + + return normed_input_values diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/rvq2.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/rvq2.py new file mode 100644 index 0000000000000000000000000000000000000000..390a550df6700458babe4cb2ec55e8715e5810aa --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/rvq2.py @@ -0,0 +1,149 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +# https://github.com/bshall/VectorQuantizedVAE/blob/master/model.py +class VQEmbeddingEMA(nn.Module): + def __init__(self, nband, num_code, code_dim, decay=0.99, layer=0): + super(VQEmbeddingEMA, self).__init__() + + self.nband = nband + self.num_code = num_code + self.code_dim = code_dim + self.decay = decay + self.layer = layer + self.stale_tolerance = 50 + self.eps = torch.finfo(torch.float32).eps + + if layer == 0: + embedding = torch.empty(nband, num_code, code_dim).normal_() + embedding = embedding / (embedding.pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1) # TODO + else: + embedding = torch.empty(nband, num_code, code_dim).normal_() / code_dim + embedding[:,0] = embedding[:,0] * 0 # TODO + self.register_buffer("embedding", embedding) + self.register_buffer("ema_weight", self.embedding.clone()) + self.register_buffer("ema_count", torch.zeros(self.nband, self.num_code)) + self.register_buffer("stale_counter", torch.zeros(nband, self.num_code)) + + def forward(self, input): + num_valid_bands = 1 + B, C, N, T = input.shape + assert N == self.code_dim + assert C == num_valid_bands + + input_detach = input.detach().permute(0,3,1,2).contiguous().view(B*T, num_valid_bands, self.code_dim) # B*T, nband, dim + embedding = self.embedding[:num_valid_bands,:,:].contiguous() + # distance + eu_dis = input_detach.pow(2).sum(2).unsqueeze(2) + embedding.pow(2).sum(2).unsqueeze(0) # B*T, nband, num_code + eu_dis = eu_dis - 2 * torch.stack([input_detach[:,i].mm(embedding[i].T) for i in range(num_valid_bands)], 1) # B*T, nband, num_code + + # best codes + indices = torch.argmin(eu_dis, dim=-1) # B*T, nband + quantized = [] + for i in range(num_valid_bands): + quantized.append(torch.gather(embedding[i], 0, indices[:,i].unsqueeze(-1).expand(-1, self.code_dim))) # B*T, dim + quantized = torch.stack(quantized, 1) + quantized = quantized.view(B, T, C, N).permute(0,2,3,1).contiguous() # B, C, N, T + + # calculate perplexity + encodings = F.one_hot(indices, self.num_code).float() # B*T, nband, num_code + avg_probs = encodings.mean(0) # nband, num_code + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self.eps), -1)).mean() + + if self.training: + # EMA update for codebook + + self.ema_count[:num_valid_bands] = self.decay * self.ema_count[:num_valid_bands] + (1 - self.decay) * torch.sum(encodings, dim=0) # nband, num_code + + update_direction = encodings.permute(1,2,0).bmm(input_detach.permute(1,0,2)) # nband, num_code, dim + self.ema_weight[:num_valid_bands] = self.decay * self.ema_weight[:num_valid_bands] + (1 - self.decay) * update_direction # nband, num_code, dim + + # Laplace smoothing on the counters + # make sure the denominator will never be zero + n = torch.sum(self.ema_count[:num_valid_bands], dim=-1, keepdim=True) # nband, 1 + self.ema_count[:num_valid_bands] = (self.ema_count[:num_valid_bands] + self.eps) / (n + self.num_code * self.eps) * n # nband, num_code + + self.embedding[:num_valid_bands] = self.ema_weight[:num_valid_bands] / self.ema_count[:num_valid_bands].unsqueeze(-1) + + # calculate code usage + stale_codes = (encodings.sum(0) == 0).float() # nband, num_code + self.stale_counter[:num_valid_bands] = self.stale_counter[:num_valid_bands] * stale_codes + stale_codes + print("Lyaer {}, Ratio of unused vector : {}, {:.1f}, {:.1f}".format(self.layer, encodings.sum(), stale_codes.sum()/torch.numel(stale_codes)*100., (self.stale_counter > self.stale_tolerance //2).sum() /torch.numel(self.stale_counter)*100.)) + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter[:num_valid_bands] == self.stale_tolerance).float() # nband, num_code + if replace_code.sum(-1).max() > 0: + random_input_idx = torch.randperm(input_detach.shape[0]) + random_input = input_detach[random_input_idx].view(input_detach.shape) + if random_input.shape[0] < self.num_code: + random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0) + random_input = random_input[:self.num_code,:].contiguous().transpose(0,1) # nband, num_code, dim + + self.embedding[:num_valid_bands] = self.embedding[:num_valid_bands] * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.ema_weight[:num_valid_bands] = self.ema_weight[:num_valid_bands] * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.ema_count[:num_valid_bands] = self.ema_count[:num_valid_bands] * (1 - replace_code) + self.stale_counter[:num_valid_bands] = self.stale_counter[:num_valid_bands] * (1 - replace_code) + + # TODO: + # code constraints + if self.layer == 0: + self.embedding[:num_valid_bands] = self.embedding[:num_valid_bands] / (self.embedding[:num_valid_bands].pow(2).sum(-1) + self.eps).sqrt().unsqueeze(-1) + # else: + # # make sure there is always a zero code + # self.embedding[:,0] = self.embedding[:,0] * 0 + # self.ema_weight[:,0] = self.ema_weight[:,0] * 0 + + return quantized, indices.reshape(B, T, -1), perplexity + +class RVQEmbedding(nn.Module): + def __init__(self, nband, code_dim, decay=0.99, num_codes=[1024, 1024]): + super(RVQEmbedding, self).__init__() + + self.nband = nband + self.code_dim = code_dim + self.decay = decay + self.eps = torch.finfo(torch.float32).eps + self.min_max = [10000, -10000] + self.bins = [256+i*8 for i in range(32)] + + self.VQEmbedding = nn.ModuleList([]) + for i in range(len(num_codes)): + self.VQEmbedding.append(VQEmbeddingEMA(nband, num_codes[i], code_dim, decay, layer=i)) + + def forward(self, input): + norm_value = torch.norm(input, p=2, dim=-2) # b c t + if(norm_value.min()self.min_max[-1]):self.min_max[-1]=norm_value.max().cpu().item() + print("Min-max : {}".format(self.min_max)) + norm_value = (((norm_value - 256) / 20).clamp(min=0, max=7).int() * 20 + 256 + 10).float() + print("Min-max : {}, {}".format(norm_value.min(), norm_value.max())) + input = torch.nn.functional.normalize(input, p = 2, dim = -2) + + quantized_list = [] + perplexity_list = [] + indices_list = [] + c = [] + + residual_input = input + for i in range(len(self.VQEmbedding)): + this_quantized, this_indices, this_perplexity = self.VQEmbedding[i](residual_input) + perplexity_list.append(this_perplexity) + indices_list.append(this_indices) + residual_input = residual_input - this_quantized + if i == 0: + quantized_list.append(this_quantized) + else: + quantized_list.append(quantized_list[-1] + this_quantized) + + quantized_list = torch.stack(quantized_list, -1) # b,1,1024,768,1 + perplexity_list = torch.stack(perplexity_list, -1) + indices_list = torch.stack(indices_list, -1) # B T 1 codebooknum + latent_loss = 0 + for i in range(quantized_list.shape[-1]): + latent_loss = latent_loss + F.mse_loss(input, quantized_list.detach()[:,:,:,:,i]) + # TODO: remove unit norm + quantized_list = quantized_list / (quantized_list.pow(2).sum(2) + self.eps).sqrt().unsqueeze(2) # unit norm + + return quantized_list, norm_value, indices_list, latent_loss diff --git a/codeclm/tokenizer/Flow1dVAE/model_1rvq.py b/codeclm/tokenizer/Flow1dVAE/model_1rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..7687447204ec7337870411fa8889a113a49e6265 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/model_1rvq.py @@ -0,0 +1,710 @@ +import yaml +import random +import inspect +import numpy as np +from tqdm import tqdm +import typing as tp +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from tools.torch_tools import wav_to_fbank + +from diffusers.utils.torch_utils import randn_tensor +from transformers import HubertModel +from libs.rvq.descript_quantize3 import ResidualVectorQuantize + +from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model +from models_gpt.models.gpt2_config import GPT2Config + +from torch.cuda.amp import autocast + + +from our_MERT_BESTRQ.test import load_model + +class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + + # The final projection layer is only used for backward compatibility. + # Following https://github.com/auspicious3000/contentvec/issues/6 + # Remove this layer is necessary to achieve the desired outcome. + print("hidden_size:",config.hidden_size) + print("classifier_proj_size:",config.classifier_proj_size) + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + """Project back from diffusion space to the actual sample space.""" + return z + +class Feature1DProcessor(SampleProcessor): + def __init__(self, dim: int = 100, power_std = 1., \ + num_samples: int = 100_000, cal_num_frames: int = 600): + super().__init__() + + self.num_samples = num_samples + self.dim = dim + self.power_std = power_std + self.cal_num_frames = cal_num_frames + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(dim)) + self.register_buffer('sum_x2', torch.zeros(dim)) + self.register_buffer('sum_target_x2', torch.zeros(dim)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + if(self.counts < 10): + mean = torch.zeros_like(mean) + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + if(self.counts < 10): + std = torch.ones_like(std) + return std + + @property + def target_std(self): + return 1 + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + if self.counts.item() < self.num_samples: + self.counts += len(x) + self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) + self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) + return x + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + rescale = (self.std / self.target_std) ** self.power_std + # print(rescale, self.mean) + x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) + return x + +def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): + if(prior_text_encoder_hidden_states.shape[1] 1.0): + + model_input = torch.cat([ \ + torch.cat([latent_mask_input, latent_mask_input], 0), \ + torch.cat([incontext_x, incontext_x], 0), \ + torch.cat([torch.zeros_like(mu), mu], 0), \ + torch.cat([x, x], 0), \ + ], 2) + timestep=t.unsqueeze(-1).repeat(2) + + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) + dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) + else: + model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) + timestep=t.unsqueeze(-1) + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + + dphi_dt = dphi_dt[: ,:, -x.shape[2]:] + # print("dphi_dt.shape:",dphi_dt.shape) + # print("x.shape:",x.shape) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def projection_loss(self,hidden_proj, bestrq_emb): + bsz = hidden_proj.shape[0] + + hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) + bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) + + proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) + proj_loss = 1+proj_loss.mean() + + return proj_loss + + def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_channels, mel_timesteps, n_feats) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_channels, mel_timesteps, n_feats) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_channels, mel_timesteps, n_feats) + """ + b = mu[0].shape[0] + len_x = x1.shape[2] + # random timestep + if(validation_mode): + t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 + else: + t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + # print("y.shape:",y.shape) + #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state + model_input = torch.cat([*mu,y], 2) + t=t.squeeze(-1).squeeze(-1) + # print("model_input.shape:",model_input.shape) + # print("attention_mask.shape:",attention_mask.shape) + out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) + hidden_layer = out.hidden_states[self.ssl_layer] + hidden_proj = self.mlp(hidden_layer) + # print("hidden_proj.shape:",hidden_proj.shape) + # print("mert_emb.shape:",mert_emb.shape) + # exit() + + + out = out.last_hidden_state + + out=out[:,:,-len_x:] + # out=self.proj_out(out) + + weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 + # print("out.shape",out.shape) + # print("u.shape",u.shape) + loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() + # print("hidden_proj.shape:",hidden_proj.shape) + # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) + loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) + loss = loss_re + loss_cos * 0.5 + # print("loss_cos:",loss_cos,loss_cos.device) + print("loss:",loss,loss.device) + # exit() + return loss, loss_re, loss_cos + +class PromptCondAudioDiffusion(nn.Module): + def __init__( + self, + num_channels, + unet_model_name=None, + unet_model_config_path=None, + snr_gamma=None, + hubert_layer=None, + ssl_layer=None, + uncondition=True, + out_paint=False, + ): + super().__init__() + + assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" + + self.unet_model_name = unet_model_name + self.unet_model_config_path = unet_model_config_path + self.snr_gamma = snr_gamma + self.uncondition = uncondition + self.num_channels = num_channels + self.hubert_layer = hubert_layer + self.ssl_layer = ssl_layer + + # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview + self.normfeat = Feature1DProcessor(dim=64) + + self.sample_rate = 48000 + self.num_samples_perseg = self.sample_rate * 20 // 1000 + self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) + self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) + # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + self.bestrq = load_model( + model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq', + checkpoint_dir='ckpt/encode-s12k.pt', + ) + self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) + self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) + for v in self.bestrq.parameters():v.requires_grad = False + self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) + for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False + self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") + for v in self.hubert.parameters():v.requires_grad = False + self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) + # self.xvecmodel = XVECModel() + config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) + unet = GPT2Model(config) + mlp = nn.Sequential( + nn.Linear(1200, 1024), + nn.SiLU(), + nn.Linear(1024, 1024), + nn.SiLU(), + nn.Linear(1024, 768) + ) + self.set_from = "random" + self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) + self.mask_emb = torch.nn.Embedding(3, 48) + print("Transformer initialized from pretrain.") + torch.cuda.empty_cache() + # self.unet.set_attn_processor(AttnProcessor2_0()) + # self.unet.set_use_memory_efficient_attention_xformers(True) + + # self.start_embedding = nn.Parameter(torch.randn(1,1024)) + # self.end_embedding = nn.Parameter(torch.randn(1,1024)) + + def compute_snr(self, timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = self.noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + def preprocess_audio(self, input_audios, threshold=0.9): + assert len(input_audios.shape) == 2, input_audios.shape + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios/norm_value.unsqueeze(-1) + + def extract_wav2vec_embeds(self, input_audios,output_len): + wav2vec_stride = 2 + + wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 + # print(wav2vec_embeds) + # print("audio.shape:",input_audios.shape) + wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] + # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) + wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) + return wav2vec_embeds_last + + def extract_mert_embeds(self, input_audios): + prompt_stride = 3 + inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") + input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) + prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 + mert_emb= prompt_embeds[-1] + mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) + + return mert_emb + + def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): + self.bestrq.eval() + # print("audio shape:",input_audio_0.shape) + input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 + # print("input_wav_mean.shape:",input_wav_mean.shape) + # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) + input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) + layer_results = input_wav_mean['layer_results'] + # print("layer_results.shape:",layer_results[layer].shape) + bestrq_emb = layer_results[layer] + bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() + #[b,t,1024] t=t/960 + #35.84s->batch,896,1024 + return bestrq_emb + + + def extract_spk_embeds(self, input_audios): + spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) + spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) + return spk_embeds + + def extract_lyric_feats(self, lyric): + with torch.no_grad(): + try: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) + except: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) + text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) + text_mask = text_mask.to(self.device) + text_encoder_hidden_states, text_mask, text_prompt_embeds = \ + pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() + return text_encoder_hidden_states, text_mask + + def extract_energy_bar(self, input_audios): + if(input_audios.shape[-1] % self.num_samples_perseg > 0): + energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) + else: + energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) + energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T + energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() + energy_embedding = self.energy_embedding(energy_bar) + energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t + return energy_embedding + + def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ + additional_feats = ['spk', 'lyric'], \ + train_rvq=True, train_ssl=False,layer=5): + if not hasattr(self,"device"): + self.device = input_audios.device + if not hasattr(self,"dtype"): + self.dtype = input_audios.dtype + device = self.device + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 + # energy_embedding = self.extract_energy_bar(input_audios) + # print("energy_embedding.shape:",energy_embedding.shape) + # with autocast(enabled=False): + if(train_ssl): + self.wav2vec.train() + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) + self.clap_embd_extractor.train() + prompt_embeds = self.extract_mert_embeds(input_audios) + if('spk' in additional_feats): + self.xvecmodel.train() + spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) + else: + with torch.no_grad(): + with autocast(enabled=False): + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + # mert_emb = self.extract_mert_embeds(input_audios_mert) + + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) + + bestrq_emb = bestrq_emb.detach() + if('lyric' in additional_feats): + text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) + else: + text_encoder_hidden_states, text_mask = None, None + + # prompt_embeds_13 = torch.cat([mert_emb_13, energy_embedding], 1) + # print("prompt_embes.shape:",prompt_embeds.shape) + #prompt_embes.shape: torch.Size([3, 1088, 896]) + # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) + #wav2vec_embeds.shape:torch.Size([3, 1024, 896]) + if(train_rvq): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + else: + bestrq_emb = bestrq_emb.float() + self.rvq_bestrq_emb.eval() + # with autocast(enabled=False): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() + codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() + quantized_bestrq_emb = quantized_bestrq_emb.detach() + + commitment_loss = commitment_loss_bestrq_emb + codebook_loss = codebook_loss_bestrq_emb + + + alpha=1 + quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) + + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + scenario = np.random.choice(['start_seg', 'other_seg']) + if(scenario == 'other_seg'): + for binx in range(input_audios.shape[0]): + # latent_masks[binx,0:64] = 1 + latent_masks[binx,0:random.randint(64,128)] = 1 + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + + + + + if self.uncondition: + mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] + if len(mask_indices) > 0: + quantized_bestrq_emb[mask_indices] = 0 + # print("latents.shape:",latents.shape) + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.project_sample(latents) + latents = latents.permute(0,2,1).contiguous() + incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + # print("incontext_latents.shape:",incontext_latents.shape) + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + latent_mask_input = self.mask_emb(latent_masks) + #64+48+64+1024 + loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) + return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() + + def init_device_dtype(self, device, dtype): + self.device = device + self.dtype = dtype + + @torch.no_grad() + def fetch_codes(self, input_audios, additional_feats,layer): + input_audio_0 = input_audios[[0],:] + input_audio_1 = input_audios[[1],:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + + @torch.no_grad() + def fetch_codes_batch(self, input_audios, additional_feats,layer): + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + + @torch.no_grad() + def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, + guidance_scale=2, num_steps=20, + disable_progress=True, scenario='start_seg'): + classifier_free_guidance = guidance_scale > 1.0 + device = self.device + dtype = self.dtype + # codes_bestrq_middle, codes_bestrq_last = codes + codes_bestrq_emb = codes[0] + + + batch_size = codes_bestrq_emb.shape[0] + + + quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + + if('spk' in additional_feats): + spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() + + num_frames = quantized_bestrq_emb.shape[1] + + num_channels_latents = self.num_channels + shape = (batch_size, num_frames, 64) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + + + + latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) + latent_masks[:,0:latent_length] = 2 + if(scenario=='other_seg'): + latent_masks[:,0:incontext_length] = 1 + + + + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + true_latents = true_latents.permute(0,2,1).contiguous() + true_latents = self.normfeat.project_sample(true_latents) + true_latents = true_latents.permute(0,2,1).contiguous() + incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] + + + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + latent_mask_input = self.mask_emb(latent_masks) + + if('spk' in additional_feats): + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) + additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) + else: + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) + additional_model_input = torch.cat([quantized_bestrq_emb],1) + + temperature = 1.0 + t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) + latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) + + latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.return_sample(latents) + # latents = latents.permute(0,2,1).contiguous() + return latents + + @torch.no_grad() + def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg'): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) + + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents + + @torch.no_grad() + def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg'): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) + import time + start = time.time() + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents,time.time()-start + + def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): + divisor = 4 + shape = (batch_size, num_channels_latents, num_frames, 32) + if(num_frames%divisor>0): + num_frames = round(num_frames/float(divisor))*divisor + shape = (batch_size, num_channels_latents, num_frames, 32) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + return latents + + diff --git a/codeclm/tokenizer/Flow1dVAE/model_2rvq.py b/codeclm/tokenizer/Flow1dVAE/model_2rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..f1ac8c206dd6579af434b67ba0d0aa73c671dc5c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/model_2rvq.py @@ -0,0 +1,774 @@ +import yaml +import random +import inspect +import numpy as np +from tqdm import tqdm +import typing as tp +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from einops import repeat +from tools.torch_tools import wav_to_fbank + +import diffusers +from diffusers.utils.torch_utils import randn_tensor +from diffusers import DDPMScheduler +from models.transformer_2d_flow import Transformer2DModel +from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel +# from tools.get_mulan import get_mulan +from third_party.wespeaker.extract_embd import XVECModel +# from libs.rvq2 import RVQEmbedding +from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize + +from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model +from models_gpt.models.gpt2_config import GPT2Config + +from torch.cuda.amp import autocast + + +from our_MERT_BESTRQ.test import load_model + +class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + + # The final projection layer is only used for backward compatibility. + # Following https://github.com/auspicious3000/contentvec/issues/6 + # Remove this layer is necessary to achieve the desired outcome. + print("hidden_size:",config.hidden_size) + print("classifier_proj_size:",config.classifier_proj_size) + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + """Project back from diffusion space to the actual sample space.""" + return z + +class Feature1DProcessor(SampleProcessor): + def __init__(self, dim: int = 100, power_std = 1., \ + num_samples: int = 100_000, cal_num_frames: int = 600): + super().__init__() + + self.num_samples = num_samples + self.dim = dim + self.power_std = power_std + self.cal_num_frames = cal_num_frames + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(dim)) + self.register_buffer('sum_x2', torch.zeros(dim)) + self.register_buffer('sum_target_x2', torch.zeros(dim)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + if(self.counts < 10): + mean = torch.zeros_like(mean) + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + if(self.counts < 10): + std = torch.ones_like(std) + return std + + @property + def target_std(self): + return 1 + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + if self.counts.item() < self.num_samples: + self.counts += len(x) + self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) + self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) + return x + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + rescale = (self.std / self.target_std) ** self.power_std + # print(rescale, self.mean) + x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) + return x + +def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): + if(prior_text_encoder_hidden_states.shape[1] 1.0): + + model_input = torch.cat([ \ + torch.cat([latent_mask_input, latent_mask_input], 0), \ + torch.cat([incontext_x, incontext_x], 0), \ + torch.cat([torch.zeros_like(mu), mu], 0), \ + torch.cat([x, x], 0), \ + ], 2) + timestep=t.unsqueeze(-1).repeat(2) + + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) + dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) + else: + model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) + timestep=t.unsqueeze(-1) + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + + dphi_dt = dphi_dt[: ,:, -x.shape[2]:] + # print("dphi_dt.shape:",dphi_dt.shape) + # print("x.shape:",x.shape) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def projection_loss(self,hidden_proj, bestrq_emb): + bsz = hidden_proj.shape[0] + + hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) + bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) + + proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) + proj_loss = 1+proj_loss.mean() + + return proj_loss + + def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_channels, mel_timesteps, n_feats) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_channels, mel_timesteps, n_feats) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_channels, mel_timesteps, n_feats) + """ + b = mu[0].shape[0] + len_x = x1.shape[2] + # random timestep + if(validation_mode): + t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 + else: + t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + # print("y.shape:",y.shape) + #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state + model_input = torch.cat([*mu,y], 2) + t=t.squeeze(-1).squeeze(-1) + # print("model_input.shape:",model_input.shape) + # print("attention_mask.shape:",attention_mask.shape) + out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) + hidden_layer = out.hidden_states[self.ssl_layer] + hidden_proj = self.mlp(hidden_layer) + # print("hidden_proj.shape:",hidden_proj.shape) + # print("mert_emb.shape:",mert_emb.shape) + # exit() + + + out = out.last_hidden_state + + out=out[:,:,-len_x:] + # out=self.proj_out(out) + + weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 + # print("out.shape",out.shape) + # print("u.shape",u.shape) + loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() + # print("hidden_proj.shape:",hidden_proj.shape) + # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) + loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) + loss = loss_re + loss_cos * 0.5 + # print("loss_cos:",loss_cos,loss_cos.device) + print("loss:",loss,loss.device) + # exit() + return loss, loss_re, loss_cos + +class PromptCondAudioDiffusion(nn.Module): + def __init__( + self, + num_channels, + unet_model_name=None, + unet_model_config_path=None, + snr_gamma=None, + hubert_layer=None, + ssl_layer=None, + uncondition=True, + out_paint=False, + ): + super().__init__() + + assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" + + self.unet_model_name = unet_model_name + self.unet_model_config_path = unet_model_config_path + self.snr_gamma = snr_gamma + self.uncondition = uncondition + self.num_channels = num_channels + self.hubert_layer = hubert_layer + self.ssl_layer = ssl_layer + + # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview + self.normfeat = Feature1DProcessor(dim=64) + + self.sample_rate = 48000 + self.num_samples_perseg = self.sample_rate * 20 // 1000 + self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) + self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) + # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + self.bestrq = load_model( + model_dir='path/to/our-MERT/mert_fairseq', + checkpoint_dir='checkpoint-120000.pt', + ) + self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) + self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) + for v in self.bestrq.parameters():v.requires_grad = False + self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 2, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) + # for v in self.rvq_bestrq_emb.parameters(): + # print(v) + freeze_parameters='quantizers.0' + for name, param in self.rvq_bestrq_emb.named_parameters(): + if freeze_parameters in name: + param.requires_grad = False + print("Freezing RVQ parameters:", name) + self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") + for v in self.hubert.parameters():v.requires_grad = False + self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) + # self.xvecmodel = XVECModel() + config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) + unet = GPT2Model(config) + mlp = nn.Sequential( + nn.Linear(1200, 1024), + nn.SiLU(), + nn.Linear(1024, 1024), + nn.SiLU(), + nn.Linear(1024, 768) + ) + self.set_from = "random" + self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) + self.mask_emb = torch.nn.Embedding(3, 48) + print("Transformer initialized from pretrain.") + torch.cuda.empty_cache() + # self.unet.set_attn_processor(AttnProcessor2_0()) + # self.unet.set_use_memory_efficient_attention_xformers(True) + + # self.start_embedding = nn.Parameter(torch.randn(1,1024)) + # self.end_embedding = nn.Parameter(torch.randn(1,1024)) + + def compute_snr(self, timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = self.noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + def preprocess_audio(self, input_audios, threshold=0.9): + assert len(input_audios.shape) == 2, input_audios.shape + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios/norm_value.unsqueeze(-1) + + def extract_wav2vec_embeds(self, input_audios,output_len): + wav2vec_stride = 2 + + wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 + # print(wav2vec_embeds) + # print("audio.shape:",input_audios.shape) + wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] + # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) + wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) + return wav2vec_embeds_last + + def extract_mert_embeds(self, input_audios): + prompt_stride = 3 + inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") + input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) + prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 + mert_emb= prompt_embeds[-1] + mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) + + return mert_emb + + def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): + self.bestrq.eval() + # print("audio shape:",input_audio_0.shape) + input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 + # print("input_wav_mean.shape:",input_wav_mean.shape) + # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) + input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) + layer_results = input_wav_mean['layer_results'] + # print("layer_results.shape:",layer_results[layer].shape) + bestrq_emb = layer_results[layer] + bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() + #[b,t,1024] t=t/960 + #35.84s->batch,896,1024 + return bestrq_emb + + + def extract_spk_embeds(self, input_audios): + spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) + spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) + return spk_embeds + + def extract_lyric_feats(self, lyric): + with torch.no_grad(): + try: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) + except: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) + text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) + text_mask = text_mask.to(self.device) + text_encoder_hidden_states, text_mask, text_prompt_embeds = \ + pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() + return text_encoder_hidden_states, text_mask + + def extract_energy_bar(self, input_audios): + if(input_audios.shape[-1] % self.num_samples_perseg > 0): + energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) + else: + energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) + energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T + energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() + energy_embedding = self.energy_embedding(energy_bar) + energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t + return energy_embedding + + def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ + additional_feats = ['spk', 'lyric'], \ + train_rvq=True, train_ssl=False,layer=5): + if not hasattr(self,"device"): + self.device = input_audios.device + if not hasattr(self,"dtype"): + self.dtype = input_audios.dtype + device = self.device + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 + # energy_embedding = self.extract_energy_bar(input_audios) + # print("energy_embedding.shape:",energy_embedding.shape) + # with autocast(enabled=False): + if(train_ssl): + self.wav2vec.train() + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) + self.clap_embd_extractor.train() + prompt_embeds = self.extract_mert_embeds(input_audios) + if('spk' in additional_feats): + self.xvecmodel.train() + spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) + else: + with torch.no_grad(): + with autocast(enabled=False): + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + # mert_emb = self.extract_mert_embeds(input_audios_mert) + + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) + + bestrq_emb = bestrq_emb.detach() + if('lyric' in additional_feats): + text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) + else: + text_encoder_hidden_states, text_mask = None, None + + + if(train_rvq): + random_num=random.random() + if(random_num<0.6): + rvq_layer = 1 + elif(random_num<0.8): + rvq_layer = 2 + else: + rvq_layer = 4 + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t + else: + bestrq_emb = bestrq_emb.float() + self.rvq_bestrq_emb.eval() + # with autocast(enabled=False): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() + codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() + quantized_bestrq_emb = quantized_bestrq_emb.detach() + + commitment_loss = commitment_loss_bestrq_emb + codebook_loss = codebook_loss_bestrq_emb + + + alpha=1 + quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) + + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + scenario = np.random.choice(['start_seg', 'other_seg']) + if(scenario == 'other_seg'): + for binx in range(input_audios.shape[0]): + # latent_masks[binx,0:64] = 1 + latent_masks[binx,0:random.randint(64,128)] = 1 + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + + + + + if self.uncondition: + mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] + if len(mask_indices) > 0: + quantized_bestrq_emb[mask_indices] = 0 + # print("latents.shape:",latents.shape) + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.project_sample(latents) + latents = latents.permute(0,2,1).contiguous() + incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + # print("incontext_latents.shape:",incontext_latents.shape) + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + latent_mask_input = self.mask_emb(latent_masks) + #64+48+64+1024 + loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) + return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() + + def init_device_dtype(self, device, dtype): + self.device = device + self.dtype = dtype + + @torch.no_grad() + def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1): + input_audio_0 = input_audios[[0],:] + input_audio_1 = input_audios[[1],:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1): + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250): + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds) + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, + guidance_scale=2, num_steps=20, + disable_progress=True, scenario='start_seg'): + classifier_free_guidance = guidance_scale > 1.0 + device = self.device + dtype = self.dtype + # codes_bestrq_middle, codes_bestrq_last = codes + codes_bestrq_emb = codes[0] + + + batch_size = codes_bestrq_emb.shape[0] + + + quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + + if('spk' in additional_feats): + spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() + + num_frames = quantized_bestrq_emb.shape[1] + + num_channels_latents = self.num_channels + shape = (batch_size, num_frames, 64) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + + + + latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) + latent_masks[:,0:latent_length] = 2 + if(scenario=='other_seg'): + latent_masks[:,0:incontext_length] = 1 + + + + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + true_latents = true_latents.permute(0,2,1).contiguous() + true_latents = self.normfeat.project_sample(true_latents) + true_latents = true_latents.permute(0,2,1).contiguous() + incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] + + + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + latent_mask_input = self.mask_emb(latent_masks) + + if('spk' in additional_feats): + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) + additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) + else: + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) + additional_model_input = torch.cat([quantized_bestrq_emb],1) + + temperature = 1.0 + t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) + latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) + + latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.return_sample(latents) + # latents = latents.permute(0,2,1).contiguous() + return latents + + @torch.no_grad() + def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg',rvq_num=1): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num) + + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents + + @torch.no_grad() + def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg'): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) + import time + start = time.time() + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents,time.time()-start + + def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): + divisor = 4 + shape = (batch_size, num_channels_latents, num_frames, 32) + if(num_frames%divisor>0): + num_frames = round(num_frames/float(divisor))*divisor + shape = (batch_size, num_channels_latents, num_frames, 32) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + return latents + + diff --git a/codeclm/tokenizer/Flow1dVAE/model_4rvq.py b/codeclm/tokenizer/Flow1dVAE/model_4rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..1cb3ea89f0b3f91b56e734d0f42257d8b48199ad --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/model_4rvq.py @@ -0,0 +1,774 @@ +import yaml +import random +import inspect +import numpy as np +from tqdm import tqdm +import typing as tp +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from einops import repeat +from tools.torch_tools import wav_to_fbank + +import diffusers +from diffusers.utils.torch_utils import randn_tensor +from diffusers import DDPMScheduler +from models.transformer_2d_flow import Transformer2DModel +from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel +# from tools.get_mulan import get_mulan +from third_party.wespeaker.extract_embd import XVECModel +# from libs.rvq2 import RVQEmbedding +from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize + +from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model +from models_gpt.models.gpt2_config import GPT2Config + +from torch.cuda.amp import autocast + + +from our_MERT_BESTRQ.test import load_model + +class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + + # The final projection layer is only used for backward compatibility. + # Following https://github.com/auspicious3000/contentvec/issues/6 + # Remove this layer is necessary to achieve the desired outcome. + print("hidden_size:",config.hidden_size) + print("classifier_proj_size:",config.classifier_proj_size) + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + """Project back from diffusion space to the actual sample space.""" + return z + +class Feature1DProcessor(SampleProcessor): + def __init__(self, dim: int = 100, power_std = 1., \ + num_samples: int = 100_000, cal_num_frames: int = 600): + super().__init__() + + self.num_samples = num_samples + self.dim = dim + self.power_std = power_std + self.cal_num_frames = cal_num_frames + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(dim)) + self.register_buffer('sum_x2', torch.zeros(dim)) + self.register_buffer('sum_target_x2', torch.zeros(dim)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + if(self.counts < 10): + mean = torch.zeros_like(mean) + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + if(self.counts < 10): + std = torch.ones_like(std) + return std + + @property + def target_std(self): + return 1 + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + if self.counts.item() < self.num_samples: + self.counts += len(x) + self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) + self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) + return x + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + rescale = (self.std / self.target_std) ** self.power_std + # print(rescale, self.mean) + x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) + return x + +def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): + if(prior_text_encoder_hidden_states.shape[1] 1.0): + + model_input = torch.cat([ \ + torch.cat([latent_mask_input, latent_mask_input], 0), \ + torch.cat([incontext_x, incontext_x], 0), \ + torch.cat([torch.zeros_like(mu), mu], 0), \ + torch.cat([x, x], 0), \ + ], 2) + timestep=t.unsqueeze(-1).repeat(2) + + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) + dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) + else: + model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) + timestep=t.unsqueeze(-1) + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + + dphi_dt = dphi_dt[: ,:, -x.shape[2]:] + print("dphi_dt.shape:",dphi_dt.shape) + print("x.shape:",x.shape) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def projection_loss(self,hidden_proj, bestrq_emb): + bsz = hidden_proj.shape[0] + + hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) + bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) + + proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) + proj_loss = 1+proj_loss.mean() + + return proj_loss + + def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_channels, mel_timesteps, n_feats) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_channels, mel_timesteps, n_feats) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_channels, mel_timesteps, n_feats) + """ + b = mu[0].shape[0] + len_x = x1.shape[2] + # random timestep + if(validation_mode): + t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 + else: + t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + # print("y.shape:",y.shape) + #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state + model_input = torch.cat([*mu,y], 2) + t=t.squeeze(-1).squeeze(-1) + # print("model_input.shape:",model_input.shape) + # print("attention_mask.shape:",attention_mask.shape) + out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) + hidden_layer = out.hidden_states[self.ssl_layer] + hidden_proj = self.mlp(hidden_layer) + # print("hidden_proj.shape:",hidden_proj.shape) + # print("mert_emb.shape:",mert_emb.shape) + # exit() + + + out = out.last_hidden_state + + out=out[:,:,-len_x:] + # out=self.proj_out(out) + + weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 + # print("out.shape",out.shape) + # print("u.shape",u.shape) + loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() + # print("hidden_proj.shape:",hidden_proj.shape) + # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) + loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) + loss = loss_re + loss_cos * 0.5 + # print("loss_cos:",loss_cos,loss_cos.device) + print("loss:",loss,loss.device) + # exit() + return loss, loss_re, loss_cos + +class PromptCondAudioDiffusion(nn.Module): + def __init__( + self, + num_channels, + unet_model_name=None, + unet_model_config_path=None, + snr_gamma=None, + hubert_layer=None, + ssl_layer=None, + uncondition=True, + out_paint=False, + ): + super().__init__() + + assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" + + self.unet_model_name = unet_model_name + self.unet_model_config_path = unet_model_config_path + self.snr_gamma = snr_gamma + self.uncondition = uncondition + self.num_channels = num_channels + self.hubert_layer = hubert_layer + self.ssl_layer = ssl_layer + + # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview + self.normfeat = Feature1DProcessor(dim=64) + + self.sample_rate = 48000 + self.num_samples_perseg = self.sample_rate * 20 // 1000 + self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) + self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) + # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + self.bestrq = load_model( + model_dir='path/to/our-MERT/mert_fairseq', + checkpoint_dir='checkpoint-120000.pt', + ) + self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) + self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) + for v in self.bestrq.parameters():v.requires_grad = False + self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) + # for v in self.rvq_bestrq_emb.parameters(): + # print(v) + freeze_parameters='quantizers.0' + for name, param in self.rvq_bestrq_emb.named_parameters(): + if freeze_parameters in name: + param.requires_grad = False + print("Freezing RVQ parameters:", name) + self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") + for v in self.hubert.parameters():v.requires_grad = False + self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) + # self.xvecmodel = XVECModel() + config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) + unet = GPT2Model(config) + mlp = nn.Sequential( + nn.Linear(1200, 1024), + nn.SiLU(), + nn.Linear(1024, 1024), + nn.SiLU(), + nn.Linear(1024, 768) + ) + self.set_from = "random" + self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) + self.mask_emb = torch.nn.Embedding(3, 48) + print("Transformer initialized from pretrain.") + torch.cuda.empty_cache() + # self.unet.set_attn_processor(AttnProcessor2_0()) + # self.unet.set_use_memory_efficient_attention_xformers(True) + + # self.start_embedding = nn.Parameter(torch.randn(1,1024)) + # self.end_embedding = nn.Parameter(torch.randn(1,1024)) + + def compute_snr(self, timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = self.noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + def preprocess_audio(self, input_audios, threshold=0.9): + assert len(input_audios.shape) == 2, input_audios.shape + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios/norm_value.unsqueeze(-1) + + def extract_wav2vec_embeds(self, input_audios,output_len): + wav2vec_stride = 2 + + wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 + # print(wav2vec_embeds) + # print("audio.shape:",input_audios.shape) + wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] + # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) + wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) + return wav2vec_embeds_last + + def extract_mert_embeds(self, input_audios): + prompt_stride = 3 + inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") + input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) + prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 + mert_emb= prompt_embeds[-1] + mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) + + return mert_emb + + def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): + self.bestrq.eval() + # print("audio shape:",input_audio_0.shape) + input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 + # print("input_wav_mean.shape:",input_wav_mean.shape) + # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) + input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) + layer_results = input_wav_mean['layer_results'] + # print("layer_results.shape:",layer_results[layer].shape) + bestrq_emb = layer_results[layer] + bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() + #[b,t,1024] t=t/960 + #35.84s->batch,896,1024 + return bestrq_emb + + + def extract_spk_embeds(self, input_audios): + spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) + spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) + return spk_embeds + + def extract_lyric_feats(self, lyric): + with torch.no_grad(): + try: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) + except: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) + text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) + text_mask = text_mask.to(self.device) + text_encoder_hidden_states, text_mask, text_prompt_embeds = \ + pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() + return text_encoder_hidden_states, text_mask + + def extract_energy_bar(self, input_audios): + if(input_audios.shape[-1] % self.num_samples_perseg > 0): + energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) + else: + energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) + energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T + energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() + energy_embedding = self.energy_embedding(energy_bar) + energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t + return energy_embedding + + def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ + additional_feats = ['spk', 'lyric'], \ + train_rvq=True, train_ssl=False,layer=5): + if not hasattr(self,"device"): + self.device = input_audios.device + if not hasattr(self,"dtype"): + self.dtype = input_audios.dtype + device = self.device + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 + # energy_embedding = self.extract_energy_bar(input_audios) + # print("energy_embedding.shape:",energy_embedding.shape) + # with autocast(enabled=False): + if(train_ssl): + self.wav2vec.train() + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) + self.clap_embd_extractor.train() + prompt_embeds = self.extract_mert_embeds(input_audios) + if('spk' in additional_feats): + self.xvecmodel.train() + spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) + else: + with torch.no_grad(): + with autocast(enabled=False): + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + # mert_emb = self.extract_mert_embeds(input_audios_mert) + + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) + + bestrq_emb = bestrq_emb.detach() + if('lyric' in additional_feats): + text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) + else: + text_encoder_hidden_states, text_mask = None, None + + + if(train_rvq): + random_num=random.random() + if(random_num<0.6): + rvq_layer = 1 + elif(random_num<0.8): + rvq_layer = 2 + else: + rvq_layer = 4 + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t + else: + bestrq_emb = bestrq_emb.float() + self.rvq_bestrq_emb.eval() + # with autocast(enabled=False): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() + codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() + quantized_bestrq_emb = quantized_bestrq_emb.detach() + + commitment_loss = commitment_loss_bestrq_emb + codebook_loss = codebook_loss_bestrq_emb + + + alpha=1 + quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) + + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + scenario = np.random.choice(['start_seg', 'other_seg']) + if(scenario == 'other_seg'): + for binx in range(input_audios.shape[0]): + # latent_masks[binx,0:64] = 1 + latent_masks[binx,0:random.randint(64,128)] = 1 + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + + + + + if self.uncondition: + mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] + if len(mask_indices) > 0: + quantized_bestrq_emb[mask_indices] = 0 + # print("latents.shape:",latents.shape) + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.project_sample(latents) + latents = latents.permute(0,2,1).contiguous() + incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + # print("incontext_latents.shape:",incontext_latents.shape) + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + latent_mask_input = self.mask_emb(latent_masks) + #64+48+64+1024 + loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) + return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() + + def init_device_dtype(self, device, dtype): + self.device = device + self.dtype = dtype + + @torch.no_grad() + def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1): + input_audio_0 = input_audios[[0],:] + input_audio_1 = input_audios[[1],:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1): + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250): + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds) + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, + guidance_scale=2, num_steps=20, + disable_progress=True, scenario='start_seg'): + classifier_free_guidance = guidance_scale > 1.0 + device = self.device + dtype = self.dtype + # codes_bestrq_middle, codes_bestrq_last = codes + codes_bestrq_emb = codes[0] + + + batch_size = codes_bestrq_emb.shape[0] + + + quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + + if('spk' in additional_feats): + spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() + + num_frames = quantized_bestrq_emb.shape[1] + + num_channels_latents = self.num_channels + shape = (batch_size, num_frames, 64) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + + + + latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) + latent_masks[:,0:latent_length] = 2 + if(scenario=='other_seg'): + latent_masks[:,0:incontext_length] = 1 + + + + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + true_latents = true_latents.permute(0,2,1).contiguous() + true_latents = self.normfeat.project_sample(true_latents) + true_latents = true_latents.permute(0,2,1).contiguous() + incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] + + + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + latent_mask_input = self.mask_emb(latent_masks) + + if('spk' in additional_feats): + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) + additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) + else: + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) + additional_model_input = torch.cat([quantized_bestrq_emb],1) + + temperature = 1.0 + t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) + latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) + + latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.return_sample(latents) + # latents = latents.permute(0,2,1).contiguous() + return latents + + @torch.no_grad() + def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg',rvq_num=1): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num) + + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents + + @torch.no_grad() + def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg'): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) + import time + start = time.time() + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents,time.time()-start + + def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): + divisor = 4 + shape = (batch_size, num_channels_latents, num_frames, 32) + if(num_frames%divisor>0): + num_frames = round(num_frames/float(divisor))*divisor + shape = (batch_size, num_channels_latents, num_frames, 32) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + return latents + + diff --git a/codeclm/tokenizer/Flow1dVAE/model_septoken.py b/codeclm/tokenizer/Flow1dVAE/model_septoken.py new file mode 100644 index 0000000000000000000000000000000000000000..331f2a3fa23e5f8873c62b532ef41d0519e93d11 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/model_septoken.py @@ -0,0 +1,670 @@ +import yaml +import random +import inspect +import numpy as np +from tqdm import tqdm +import typing as tp +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from einops import repeat +from tools.torch_tools import wav_to_fbank + +from diffusers.utils.torch_utils import randn_tensor +from transformers import HubertModel +from libs.rvq.descript_quantize3 import ResidualVectorQuantize + +from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model +from models_gpt.models.gpt2_config import GPT2Config + +from torch.cuda.amp import autocast +from our_MERT_BESTRQ.test import load_model + +class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + + # The final projection layer is only used for backward compatibility. + # Following https://github.com/auspicious3000/contentvec/issues/6 + # Remove this layer is necessary to achieve the desired outcome. + print("hidden_size:",config.hidden_size) + print("classifier_proj_size:",config.classifier_proj_size) + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + """Project back from diffusion space to the actual sample space.""" + return z + +class Feature1DProcessor(SampleProcessor): + def __init__(self, dim: int = 100, power_std = 1., \ + num_samples: int = 100_000, cal_num_frames: int = 600): + super().__init__() + + self.num_samples = num_samples + self.dim = dim + self.power_std = power_std + self.cal_num_frames = cal_num_frames + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(dim)) + self.register_buffer('sum_x2', torch.zeros(dim)) + self.register_buffer('sum_target_x2', torch.zeros(dim)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + if(self.counts < 10): + mean = torch.zeros_like(mean) + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + if(self.counts < 10): + std = torch.ones_like(std) + return std + + @property + def target_std(self): + return 1 + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + if self.counts.item() < self.num_samples: + self.counts += len(x) + self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) + self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) + return x + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + rescale = (self.std / self.target_std) ** self.power_std + x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) + return x + +def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): + if(prior_text_encoder_hidden_states.shape[1] 1.0): + + model_input = torch.cat([ \ + torch.cat([latent_mask_input, latent_mask_input], 0), \ + torch.cat([incontext_x, incontext_x], 0), \ + torch.cat([torch.zeros_like(mu), mu], 0), \ + torch.cat([x, x], 0), \ + ], 2) + timestep=t.unsqueeze(-1).repeat(2) + + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) + dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) + else: + model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) + timestep=t.unsqueeze(-1) + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + + dphi_dt = dphi_dt[: ,:, -x.shape[2]:] + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def projection_loss(self,hidden_proj, bestrq_emb): + bsz = hidden_proj.shape[0] + + hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) + bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) + + proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) + proj_loss = 1+proj_loss.mean() + + return proj_loss + + def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_channels, mel_timesteps, n_feats) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_channels, mel_timesteps, n_feats) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_channels, mel_timesteps, n_feats) + """ + b = mu[0].shape[0] + len_x = x1.shape[2] + # random timestep + if(validation_mode): + t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 + else: + t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + model_input = torch.cat([*mu,y], 2) + t=t.squeeze(-1).squeeze(-1) + out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) + hidden_layer_7 = out.hidden_states[7] + hidden_proj = self.mlp(hidden_layer_7) + out = out.last_hidden_state + out=out[:,:,-len_x:] + + weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 + loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() + loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) + loss = loss_re + loss_cos * 0.5 + return loss, loss_re, loss_cos + +class PromptCondAudioDiffusion(nn.Module): + def __init__( + self, + num_channels, + unet_model_name=None, + unet_model_config_path=None, + snr_gamma=None, + uncondition=True, + out_paint=False, + ): + super().__init__() + + assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" + + self.unet_model_name = unet_model_name + self.unet_model_config_path = unet_model_config_path + self.snr_gamma = snr_gamma + self.uncondition = uncondition + self.num_channels = num_channels + + # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview + self.normfeat = Feature1DProcessor(dim=64) + + self.sample_rate = 48000 + self.num_samples_perseg = self.sample_rate * 20 // 1000 + self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) + self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) + # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + self.bestrq = load_model( + model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq', + checkpoint_dir='ckpt/encode-s12k.pt', + ) + self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) + self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) + for v in self.bestrq.parameters():v.requires_grad = False + self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) + self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) + self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") + for v in self.hubert.parameters():v.requires_grad = False + self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) + # self.xvecmodel = XVECModel() + config = GPT2Config(n_positions=1000,n_layer=16,n_head=20,n_embd=2200,n_inner=4400) + unet = GPT2Model(config) + mlp = nn.Sequential( + nn.Linear(2200, 1024), + nn.SiLU(), + nn.Linear(1024, 1024), + nn.SiLU(), + nn.Linear(1024, 768) + ) + self.set_from = "random" + self.cfm_wrapper = BASECFM(unet, mlp) + self.mask_emb = torch.nn.Embedding(3, 24) + print("Transformer initialized from pretrain.") + torch.cuda.empty_cache() + + def compute_snr(self, timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = self.noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + def preprocess_audio(self, input_audios, threshold=0.9): + assert len(input_audios.shape) == 2, input_audios.shape + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios/norm_value.unsqueeze(-1) + + def extract_wav2vec_embeds(self, input_audios,output_len): + wav2vec_stride = 2 + + wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 + wav2vec_embeds_last=wav2vec_embeds[-1] + wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) + return wav2vec_embeds_last + + def extract_mert_embeds(self, input_audios): + prompt_stride = 3 + inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") + input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) + prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 + mert_emb= prompt_embeds[-1] + mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=375, mode='linear', align_corners=False).permute(0, 2, 1) + + return mert_emb + + def extract_bestrq_embeds(self, input_audio_vocal_0,input_audio_vocal_1,layer): + input_wav_mean = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 + input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) + layer_results = input_wav_mean['layer_results'] + bestrq_emb = layer_results[layer] + bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() + return bestrq_emb + + + def extract_spk_embeds(self, input_audios): + spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) + spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) + return spk_embeds + + def extract_lyric_feats(self, lyric): + with torch.no_grad(): + try: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) + except: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) + text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) + text_mask = text_mask.to(self.device) + text_encoder_hidden_states, text_mask, text_prompt_embeds = \ + pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() + return text_encoder_hidden_states, text_mask + + def extract_energy_bar(self, input_audios): + if(input_audios.shape[-1] % self.num_samples_perseg > 0): + energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) + else: + energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) + energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T + energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() + energy_embedding = self.energy_embedding(energy_bar) + energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t + return energy_embedding + + def forward(self, input_audios_vocal,input_audios_bgm, lyric, latents, latent_masks, validation_mode=False, \ + additional_feats = ['spk', 'lyric'], \ + train_rvq=True, train_ssl=False,layer_vocal=7,layer_bgm=7): + if not hasattr(self,"device"): + self.device = input_audios_vocal.device + if not hasattr(self,"dtype"): + self.dtype = input_audios_vocal.dtype + device = self.device + input_audio_vocal_0 = input_audios_vocal[:,0,:] + input_audio_vocal_1 = input_audios_vocal[:,1,:] + input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) + input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) + input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 + + input_audio_bgm_0 = input_audios_bgm[:,0,:] + input_audio_bgm_1 = input_audios_bgm[:,1,:] + input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) + input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) + input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 + + if(train_ssl): + self.wav2vec.train() + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) + self.clap_embd_extractor.train() + prompt_embeds = self.extract_mert_embeds(input_audios) + if('spk' in additional_feats): + self.xvecmodel.train() + spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) + else: + with torch.no_grad(): + with autocast(enabled=False): + bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) + bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) + # mert_emb = self.extract_mert_embeds(input_audios_mert) + output_len = bestrq_emb.shape[2] + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_vocal_wav2vec+input_audios_bgm_wav2vec,output_len) + + + bestrq_emb = bestrq_emb.detach() + bestrq_emb_bgm = bestrq_emb_bgm.detach() + + if('lyric' in additional_feats): + text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) + else: + text_encoder_hidden_states, text_mask = None, None + + + if(train_rvq): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + quantized_bestrq_emb_bgm, _, _, commitment_loss_bestrq_emb_bgm, codebook_loss_bestrq_emb_bgm,_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t + else: + bestrq_emb = bestrq_emb.float() + self.rvq_bestrq_emb.eval() + # with autocast(enabled=False): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() + codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() + quantized_bestrq_emb = quantized_bestrq_emb.detach() + + commitment_loss = commitment_loss_bestrq_emb+commitment_loss_bestrq_emb_bgm + codebook_loss = codebook_loss_bestrq_emb+codebook_loss_bestrq_emb_bgm + + + alpha=1 + quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) + quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm * alpha + bestrq_emb_bgm * (1-alpha) + + + + + scenario = np.random.choice(['start_seg', 'other_seg']) + if(scenario == 'other_seg'): + for binx in range(input_audios_vocal.shape[0]): + # latent_masks[binx,0:64] = 1 + latent_masks[binx,0:random.randint(64,128)] = 1 + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous() + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + + + + + if self.uncondition: + mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] + if len(mask_indices) > 0: + quantized_bestrq_emb[mask_indices] = 0 + quantized_bestrq_emb_bgm[mask_indices] = 0 + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.project_sample(latents) + latents = latents.permute(0,2,1).contiguous() + incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + latent_mask_input = self.mask_emb(latent_masks) + loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb,quantized_bestrq_emb_bgm], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) + return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() + + def init_device_dtype(self, device, dtype): + self.device = device + self.dtype = dtype + + @torch.no_grad() + def fetch_codes(self, input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7): + input_audio_vocal_0 = input_audios_vocal[[0],:] + input_audio_vocal_1 = input_audios_vocal[[1],:] + input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) + input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) + input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 + + input_audio_bgm_0 = input_audios_bgm[[0],:] + input_audio_bgm_1 = input_audios_bgm[[1],:] + input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) + input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) + input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) + bestrq_emb = bestrq_emb.detach() + + bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) + bestrq_emb_bgm = bestrq_emb_bgm.detach() + + + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + + self.rvq_bestrq_bgm_emb.eval() + quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def fetch_codes_batch(self, input_audios_vocal, input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7): + input_audio_vocal_0 = input_audios_vocal[:,0,:] + input_audio_vocal_1 = input_audios_vocal[:,1,:] + input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) + input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) + input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 + + input_audio_bgm_0 = input_audios_bgm[:,0,:] + input_audio_bgm_1 = input_audios_bgm[:,1,:] + input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) + input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) + input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) + bestrq_emb = bestrq_emb.detach() + + bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) + bestrq_emb_bgm = bestrq_emb_bgm.detach() + + + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + + self.rvq_bestrq_bgm_emb.eval() + quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + + @torch.no_grad() + def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats,incontext_length=127, + guidance_scale=2, num_steps=20, + disable_progress=True, scenario='start_seg'): + classifier_free_guidance = guidance_scale > 1.0 + device = self.device + dtype = self.dtype + # codes_bestrq_middle, codes_bestrq_last = codes + codes_bestrq_emb,codes_bestrq_emb_bgm = codes + + + batch_size = codes_bestrq_emb.shape[0] + + + quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) + quantized_bestrq_emb_bgm,_,_=self.rvq_bestrq_bgm_emb.from_codes(codes_bestrq_emb_bgm) + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous() + if('spk' in additional_feats): + spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() + + num_frames = quantized_bestrq_emb.shape[1] + + num_channels_latents = self.num_channels + shape = (batch_size, num_frames, 64) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + + + + latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) + latent_masks[:,0:latent_length] = 2 + if(scenario=='other_seg'): + latent_masks[:,0:incontext_length] = 1 + + + + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + true_latents = true_latents.permute(0,2,1).contiguous() + true_latents = self.normfeat.project_sample(true_latents) + true_latents = true_latents.permute(0,2,1).contiguous() + incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] + + + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + latent_mask_input = self.mask_emb(latent_masks) + + if('spk' in additional_feats): + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) + additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm, spk_embeds],2) + else: + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) + additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm],2) + + temperature = 1.0 + t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) + latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) + + latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.return_sample(latents) + # latents = latents.permute(0,2,1).contiguous() + return latents + + @torch.no_grad() + def inference(self, input_audios_vocal,input_audios_bgm, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer_vocal=7,layer_bgm=3,scenario='start_seg'): + codes, embeds, spk_embeds = self.fetch_codes(input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal,layer_bgm) + + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents + + def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): + divisor = 4 + shape = (batch_size, num_channels_latents, num_frames, 32) + if(num_frames%divisor>0): + num_frames = round(num_frames/float(divisor))*divisor + shape = (batch_size, num_channels_latents, num_frames, 32) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + return latents + + diff --git a/codeclm/tokenizer/Flow1dVAE/models/__init__.py b/codeclm/tokenizer/Flow1dVAE/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/codeclm/tokenizer/Flow1dVAE/models/attention.py b/codeclm/tokenizer/Flow1dVAE/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..a565aa34bbf03b79c5130a8538b0b1e34d3fa5e2 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models/attention.py @@ -0,0 +1,682 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import USE_PEFT_BACKEND +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import SinusoidalPositionalEmbedding +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm + + +def _chunked_feed_forward( + ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None +): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + if lora_scale is None: + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + else: + # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete + ff_output = torch.cat( + [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + + return ff_output + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if self.use_ada_layer_norm: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.attn2 = None + + # 3. Feed-forward + if self.use_ada_layer_norm_continuous: + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + elif not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.use_ada_layer_norm_single: + # print("Using PixArt-Alpha norm") + # print("time step: ", timestep.shape) + # print("self.scale_shift_table: ", self.scale_shift_table.shape) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + # print("scale_msa: ", scale_msa.shape) + # print("shift_msa: ", shift_msa.shape) + #scale_msa: torch.Size([5, 1, 1152]) + #shift_msa: torch.Size([5, 1, 1152]) + # exit() + # print("before: ", norm_hidden_states.shape) + #before: torch.Size([5, 3584, 1152]) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + # print("after: ", norm_hidden_states.shape) + #before: torch.Size([5, 3584, 1152]) + # exit() + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_zero or self.use_layer_norm: + norm_hidden_states = self.norm2(hidden_states) + elif self.use_ada_layer_norm_single: + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +@maybe_allow_in_graph +class TemporalBasicTransformerBlock(nn.Module): + r""" + A basic Transformer block for video like data. + + Parameters: + dim (`int`): The number of channels in the input and output. + time_mix_inner_dim (`int`): The number of channels for temporal attention. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + """ + + def __init__( + self, + dim: int, + time_mix_inner_dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.is_res = dim == time_mix_inner_dim + + self.norm_in = nn.LayerNorm(dim) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward( + dim, + dim_out=time_mix_inner_dim, + activation_fn="geglu", + ) + + self.norm1 = nn.LayerNorm(time_mix_inner_dim) + self.attn1 = Attention( + query_dim=time_mix_inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + cross_attention_dim=None, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(time_mix_inner_dim) + self.attn2 = Attention( + query_dim=time_mix_inner_dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(time_mix_inner_dim) + self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): + # Sets chunk feed-forward + self._chunk_size = chunk_size + # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off + self._chunk_dim = 1 + + def forward( + self, + hidden_states: torch.FloatTensor, + num_frames: int, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states + + +class SkipFFTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + kv_input_dim: int, + kv_input_dim_proj_use_bias: bool, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + attention_out_bias: bool = True, + ): + super().__init__() + if kv_input_dim != dim: + self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) + else: + self.kv_mapper = None + + self.norm1 = RMSNorm(dim, 1e-06) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + ) + + self.norm2 = RMSNorm(dim, 1e-06) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + if self.kv_mapper is not None: + encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states diff --git a/codeclm/tokenizer/Flow1dVAE/models/attention_add_rot_emb.py b/codeclm/tokenizer/Flow1dVAE/models/attention_add_rot_emb.py new file mode 100644 index 0000000000000000000000000000000000000000..743144907fcf312604912ea254b6a5bca8e9572f --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models/attention_add_rot_emb.py @@ -0,0 +1,734 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import USE_PEFT_BACKEND +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import SinusoidalPositionalEmbedding +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm + + +def _chunked_feed_forward( + ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None +): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + if lora_scale is None: + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + else: + # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete + ff_output = torch.cat( + [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + + return ff_output + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if self.use_ada_layer_norm: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.attn2 = None + + # 3. Feed-forward + if self.use_ada_layer_norm_continuous: + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + elif not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.use_ada_layer_norm_single: + # print("Using PixArt-Alpha norm") + # print("time step: ", timestep.shape) + # print("self.scale_shift_table: ", self.scale_shift_table.shape) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + # print("scale_msa: ", scale_msa.shape) + # print("shift_msa: ", shift_msa.shape) + #scale_msa: torch.Size([5, 1, 1152]) + #shift_msa: torch.Size([5, 1, 1152]) + # exit() + # print("before: ", norm_hidden_states.shape) + #before: torch.Size([5, 3584, 1152]) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + # print("after: ", norm_hidden_states.shape) + #before: torch.Size([5, 3584, 1152]) + # exit() + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + freqs_cis = precompute_freqs_cis(norm_hidden_states.shape[-1], norm_hidden_states.shape[1]).to(norm_hidden_states.device) + print("norm_hidden_states1: ", norm_hidden_states.shape) + norm_hidden_states = apply_rotary_emb(norm_hidden_states, freqs_cis) + print("norm_hidden_states2: ", norm_hidden_states.shape) + + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_zero or self.use_layer_norm: + norm_hidden_states = self.norm2(hidden_states) + elif self.use_ada_layer_norm_single: + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +@maybe_allow_in_graph +class TemporalBasicTransformerBlock(nn.Module): + r""" + A basic Transformer block for video like data. + + Parameters: + dim (`int`): The number of channels in the input and output. + time_mix_inner_dim (`int`): The number of channels for temporal attention. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + """ + + def __init__( + self, + dim: int, + time_mix_inner_dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.is_res = dim == time_mix_inner_dim + + self.norm_in = nn.LayerNorm(dim) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward( + dim, + dim_out=time_mix_inner_dim, + activation_fn="geglu", + ) + + self.norm1 = nn.LayerNorm(time_mix_inner_dim) + self.attn1 = Attention( + query_dim=time_mix_inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + cross_attention_dim=None, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(time_mix_inner_dim) + self.attn2 = Attention( + query_dim=time_mix_inner_dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(time_mix_inner_dim) + self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): + # Sets chunk feed-forward + self._chunk_size = chunk_size + # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off + self._chunk_dim = 1 + + def forward( + self, + hidden_states: torch.FloatTensor, + num_frames: int, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states + + +class SkipFFTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + kv_input_dim: int, + kv_input_dim_proj_use_bias: bool, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + attention_out_bias: bool = True, + ): + super().__init__() + if kv_input_dim != dim: + self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) + else: + self.kv_mapper = None + + self.norm1 = RMSNorm(dim, 1e-06) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + ) + + self.norm2 = RMSNorm(dim, 1e-06) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + if self.kv_mapper is not None: + encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states diff --git a/codeclm/tokenizer/Flow1dVAE/models/test_model.py b/codeclm/tokenizer/Flow1dVAE/models/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea8626df7ee4cced7bcda40f3e5cbb7a6752704 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models/test_model.py @@ -0,0 +1,92 @@ +from thop import profile +from thop import clever_format +import torch +from tqdm import tqdm +import time +import sys +sys.path.append('./') + + +def analyze_model(model, inputs): + # model size + num_trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + print("Num trainable parameters: {} M".format(num_trainable_parameters/1000./1000.)) + + # computation cost + with torch.no_grad(): + model.eval() + macs, params = profile(model, inputs=inputs) + macs, params = clever_format([macs, params], "%.3f") + print("Macs: {}, Params: {}".format(macs, params)) + + run_times = 50 + # eval forward 100 times + with torch.no_grad(): + model = model.eval().to('cuda') + inputs = [i.to('cuda') if isinstance(i, torch.Tensor) else i for i in inputs] + model.init_device_dtype(inputs[0].device, inputs[0].dtype) + st = time.time() + for i in tqdm(range(run_times)): + _ = model(*inputs) + et = time.time() + print("Eval forward : {:.03f} secs/per iter".format((et-st)/float(run_times))) + + # train backward 100 times + model = model.train().to('cuda') + inputs = [i.to('cuda') if isinstance(i, torch.Tensor) else i for i in inputs] + model.init_device_dtype(inputs[0].device, inputs[0].dtype) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + optimizer.zero_grad() + st = time.time() + for i in tqdm(range(run_times)): + inputs = [torch.rand_like(i) if isinstance(i, torch.cuda.FloatTensor) else i for i in inputs] + out = model(*inputs) + optimizer.zero_grad() + out.mean().backward() + optimizer.step() + et = time.time() + print("Train forward : {:.03f} secs/per iter".format((et-st)/float(run_times))) + +def fetch_model_v3_transformer(): + # num params: 326M + # macs (uncorrect): 261G/iter + # infer: 0.32s/iter + # train: 2.54s/iter + from models_transformercond_winorm_ch16_everything_512 import PromptCondAudioDiffusion + model = PromptCondAudioDiffusion( \ + "configs/scheduler/stable_diffusion_2.1_largenoise.json", \ + None, \ + "configs/models/transformer2D.json" + ) + inputs = [ + torch.rand(1,16,1024*3//8,32), + torch.rand(1,7,512), + torch.tensor([1,]), + torch.tensor([0,]), + False, + ] + return model, inputs + +def fetch_model_v3_unet(): + # num params: 310M + # infer: 0.10s/iter + # train: 0.70s/iter + from models_musicldm_winorm_ch16_everything_sepnorm import PromptCondAudioDiffusion + model = PromptCondAudioDiffusion( \ + "configs/scheduler/stable_diffusion_2.1_largenoise.json", \ + None, \ + "configs/diffusion_clapcond_model_config_ch16_everything.json" + ) + inputs = [ + torch.rand(1,16,1024*3//8,32), + torch.rand(1,7,512), + torch.tensor([1,]), + torch.tensor([0,]), + False, + ] + return model, inputs + +if __name__=="__main__": + model, inputs = fetch_model_v3_transformer() + # model, inputs = fetch_model_v3_unet() + analyze_model(model, inputs) diff --git a/codeclm/tokenizer/Flow1dVAE/models/transformer_2d.py b/codeclm/tokenizer/Flow1dVAE/models/transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..7891daaac6ceebeb6b540b280f1dbfeab8ee8960 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models/transformer_2d.py @@ -0,0 +1,487 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from models.attention import BasicTransformerBlock +from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + print("before input hidden_states.shape", hidden_states.shape) + # print("is_input_patches", self.is_input_patches) + #true + # print("is_input_vectorized", self.is_input_vectorized) + # print("is_input_continuous", self.is_input_continuous) + #false + # print("use_linear_projection", self.use_linear_projection) + #false + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + # print("input 1 hidden_states.shape", hidden_states.shape) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + # print("input 2 hidden_states.shape", hidden_states.shape) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + # print("before patches hidden_states.shape", hidden_states.shape) + hidden_states = self.pos_embed(hidden_states) + # print("after patches hidden_states.shape", hidden_states.shape) + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + batch_size = hidden_states.shape[0] + print("time step before adaln_single", timestep) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + print("time step after adaln_single", timestep.shape) + print("embedded_timestep after adaln_single", embedded_timestep.shape) + + + # 2. Blocks + print("before blocks hidden_states.shape", hidden_states.shape) + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + + # 3. Output + print("after blocks hidden_states.shape", hidden_states.shape) + if self.is_input_continuous: + print("1") + if not self.use_linear_projection: + print("1 hidden_states.shape", hidden_states.shape) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + print("2 hidden_states.shape", hidden_states.shape) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + print("2") + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if self.is_input_patches: + print("3") + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + print("before proj_out hidden_states.shape", hidden_states.shape) + hidden_states = self.proj_out(hidden_states) + print("after proj_out hidden_states.shape", hidden_states.shape) + + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + print("4") + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/codeclm/tokenizer/Flow1dVAE/models/transformer_2d_additionalemb.py b/codeclm/tokenizer/Flow1dVAE/models/transformer_2d_additionalemb.py new file mode 100644 index 0000000000000000000000000000000000000000..06b803337054093783e33a19f6645fcdd6327c9b --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models/transformer_2d_additionalemb.py @@ -0,0 +1,468 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from models.attention import BasicTransformerBlock +from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + self.timestep_linear = nn.Sequential( + nn.Linear(inner_dim * 2, inner_dim), + nn.SiLU(), + nn.Linear(inner_dim, 6 * inner_dim), + ) + self.embedded_timestep_linear = nn.Linear(inner_dim * 2, inner_dim) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + additional_embd: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states = self.pos_embed(hidden_states) + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + timestep = self.timestep_linear(torch.cat([embedded_timestep, additional_embd],1)) + embedded_timestep = self.embedded_timestep_linear(torch.cat([embedded_timestep, additional_embd],1)) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/codeclm/tokenizer/Flow1dVAE/models/transformer_2d_flow.py b/codeclm/tokenizer/Flow1dVAE/models/transformer_2d_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf331555534d63fcacc05ee5163530a41cfb910 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models/transformer_2d_flow.py @@ -0,0 +1,545 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +import math +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from models.attention import BasicTransformerBlock +from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.embeddings import TimestepEmbedding + +class PixArtAlphaCombinedFlowEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.flow_t_size = 512 + self.outdim = size_emb_dim + self.timestep_embedder = TimestepEmbedding(in_channels=self.flow_t_size, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + # https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87 + def timestep_embedding(self, timesteps, max_period=10000, scale=1000): + """Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = self.flow_t_size // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type()) + args = timesteps[:, None] * freqs[None] * scale + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if self.flow_t_size % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.timestep_embedding(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + +class AdaLayerNormSingleFlow(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = PixArtAlphaCombinedFlowEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingleFlow(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states = self.pos_embed(hidden_states) + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/codeclm/tokenizer/Flow1dVAE/models/transformer_2d_rot_emb.py b/codeclm/tokenizer/Flow1dVAE/models/transformer_2d_rot_emb.py new file mode 100644 index 0000000000000000000000000000000000000000..734fcc706e54bf594813b79940734b2cbfb42970 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models/transformer_2d_rot_emb.py @@ -0,0 +1,487 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from models.attention_add_rot_emb import BasicTransformerBlock +from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + print("before input hidden_states.shape", hidden_states.shape) + # print("is_input_patches", self.is_input_patches) + #true + # print("is_input_vectorized", self.is_input_vectorized) + # print("is_input_continuous", self.is_input_continuous) + #false + # print("use_linear_projection", self.use_linear_projection) + #false + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + # print("input 1 hidden_states.shape", hidden_states.shape) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + # print("input 2 hidden_states.shape", hidden_states.shape) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + # print("before patches hidden_states.shape", hidden_states.shape) + hidden_states = self.pos_embed(hidden_states) + # print("after patches hidden_states.shape", hidden_states.shape) + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + batch_size = hidden_states.shape[0] + print("time step before adaln_single", timestep) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + print("time step after adaln_single", timestep.shape) + print("embedded_timestep after adaln_single", embedded_timestep.shape) + + + # 2. Blocks + print("before blocks hidden_states.shape", hidden_states.shape) + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + + # 3. Output + print("after blocks hidden_states.shape", hidden_states.shape) + if self.is_input_continuous: + print("1") + if not self.use_linear_projection: + print("1 hidden_states.shape", hidden_states.shape) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + print("2 hidden_states.shape", hidden_states.shape) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + print("2") + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if self.is_input_patches: + print("3") + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + print("before proj_out hidden_states.shape", hidden_states.shape) + hidden_states = self.proj_out(hidden_states) + print("after proj_out hidden_states.shape", hidden_states.shape) + + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + print("4") + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py b/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py new file mode 100644 index 0000000000000000000000000000000000000000..1c733ab9ee1092da1ebc0330a65892d68b608a34 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py @@ -0,0 +1,996 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.activations import get_activation +from diffusers.models.embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import ( + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + additional_embd: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + emb = emb + additional_embd + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py b/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..71eb958b9188df5f6fef76bf3d7d19dd329bbc16 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py @@ -0,0 +1,934 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.activations import get_activation +from diffusers.models.embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import ( + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + self.block_out_channels = block_out_channels + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + # https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87 + def timestep_embedding(self, timesteps, max_period=10000, scale=1000): + """Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + dim = self.block_out_channels[-1] + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type()) + args = timesteps[:, None] * freqs[None] * scale + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + emb = self.timestep_embedding(timesteps) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/attention.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..33a060345e3163667c2ad29d2900244db8fdacd8 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/attention.py @@ -0,0 +1,668 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import USE_PEFT_BACKEND +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import SinusoidalPositionalEmbedding +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm + + +def _chunked_feed_forward( + ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None +): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + if lora_scale is None: + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + else: + # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete + ff_output = torch.cat( + [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + + return ff_output + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if self.use_ada_layer_norm: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.attn2 = None + + # 3. Feed-forward + if self.use_ada_layer_norm_continuous: + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + elif not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.use_ada_layer_norm_single: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_zero or self.use_layer_norm: + norm_hidden_states = self.norm2(hidden_states) + elif self.use_ada_layer_norm_single: + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +@maybe_allow_in_graph +class TemporalBasicTransformerBlock(nn.Module): + r""" + A basic Transformer block for video like data. + + Parameters: + dim (`int`): The number of channels in the input and output. + time_mix_inner_dim (`int`): The number of channels for temporal attention. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + """ + + def __init__( + self, + dim: int, + time_mix_inner_dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.is_res = dim == time_mix_inner_dim + + self.norm_in = nn.LayerNorm(dim) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward( + dim, + dim_out=time_mix_inner_dim, + activation_fn="geglu", + ) + + self.norm1 = nn.LayerNorm(time_mix_inner_dim) + self.attn1 = Attention( + query_dim=time_mix_inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + cross_attention_dim=None, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(time_mix_inner_dim) + self.attn2 = Attention( + query_dim=time_mix_inner_dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(time_mix_inner_dim) + self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): + # Sets chunk feed-forward + self._chunk_size = chunk_size + # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off + self._chunk_dim = 1 + + def forward( + self, + hidden_states: torch.FloatTensor, + num_frames: int, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states + + +class SkipFFTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + kv_input_dim: int, + kv_input_dim_proj_use_bias: bool, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + attention_out_bias: bool = True, + ): + super().__init__() + if kv_input_dim != dim: + self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) + else: + self.kv_mapper = None + + self.norm1 = RMSNorm(dim, 1e-06) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + ) + + self.norm2 = RMSNorm(dim, 1e-06) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + if self.kv_mapper is not None: + encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/attention_add_rot_emb.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/attention_add_rot_emb.py new file mode 100644 index 0000000000000000000000000000000000000000..f4341c729867600a41fdd38245dd8f161e5035e5 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/attention_add_rot_emb.py @@ -0,0 +1,724 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import USE_PEFT_BACKEND +from diffusers.utils.torch_utils import maybe_allow_in_graph +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention_processor import Attention +from diffusers.models.embeddings import SinusoidalPositionalEmbedding +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm + + +def _chunked_feed_forward( + ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None +): + # "feed_forward_chunk_size" can be used to save memory + if hidden_states.shape[chunk_dim] % chunk_size != 0: + raise ValueError( + f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." + ) + + num_chunks = hidden_states.shape[chunk_dim] // chunk_size + if lora_scale is None: + ff_output = torch.cat( + [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + else: + # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete + ff_output = torch.cat( + [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], + dim=chunk_dim, + ) + + return ff_output + + +@maybe_allow_in_graph +class GatedSelfAttentionDense(nn.Module): + r""" + A gated self-attention dense layer that combines visual features and object features. + + Parameters: + query_dim (`int`): The number of channels in the query. + context_dim (`int`): The number of channels in the context. + n_heads (`int`): The number of heads to use for attention. + d_head (`int`): The number of channels in each head. + """ + + def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): + super().__init__() + + # we need a linear projection since we need cat visual feature and obj feature + self.linear = nn.Linear(context_dim, query_dim) + + self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) + self.ff = FeedForward(query_dim, activation_fn="geglu") + + self.norm1 = nn.LayerNorm(query_dim) + self.norm2 = nn.LayerNorm(query_dim) + + self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) + self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) + + self.enabled = True + + def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: + if not self.enabled: + return x + + n_visual = x.shape[1] + objs = self.linear(objs) + + x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] + x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) + + return x +def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0): + # 计算词向量元素两两分组之后,每组元素对应的旋转角度 + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + # 生成 token 序列索引 t = [0, 1,..., seq_len-1] + t = torch.arange(seq_len, device=freqs.device) + # freqs.shape = [seq_len, dim // 2] + freqs = torch.outer(t, freqs).float() + # torch.polar 的文档 + # https://pytorch.org/docs/stable/generated/torch.polar.html + # 计算结果是个复数向量 + # 假设 freqs = [x, y] + # 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i] + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + +def apply_rotary_emb( + xq: torch.Tensor, + freqs_cis: torch.Tensor, +): + # xq.shape = [batch_size, seq_len, dim] + # xq_.shape = [batch_size, seq_len, dim // 2, 2] + xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2).to(freqs_cis.device) + + # 转为复数域 + xq_ = torch.view_as_complex(xq_) + + # 应用旋转操作,然后将结果转回实数域 + # xq_out.shape = [batch_size, seq_len, dim] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2) + return xq_out.type_as(xq) + + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + norm_type (`str`, *optional*, defaults to `"layer_norm"`): + The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' + norm_eps: float = 1e-5, + final_dropout: bool = False, + attention_type: str = "default", + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + + self.freqs_cis = precompute_freqs_cis(dim, 2000) + + self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" + self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" + self.use_ada_layer_norm_single = norm_type == "ada_norm_single" + self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" + + if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: + raise ValueError( + f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" + f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." + ) + + if positional_embeddings and (num_positional_embeddings is None): + raise ValueError( + "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." + ) + + if positional_embeddings == "sinusoidal": + self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) + else: + self.pos_embed = None + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + if self.use_ada_layer_norm: + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_zero: + self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + if self.use_ada_layer_norm: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim if not double_self_attention else None, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.attn2 = None + + # 3. Feed-forward + if self.use_ada_layer_norm_continuous: + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + elif not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 4. Fuser + if attention_type == "gated" or attention_type == "gated-text-image": + self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) + + # 5. Scale-shift for PixArt-Alpha. + if self.use_ada_layer_norm_single: + self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + if self.use_ada_layer_norm: + norm_hidden_states = self.norm1(hidden_states, timestep) + elif self.use_ada_layer_norm_zero: + norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( + hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + elif self.use_layer_norm: + norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif self.use_ada_layer_norm_single: + # print("Using PixArt-Alpha norm") + # print("time step: ", timestep.shape) + # print("self.scale_shift_table: ", self.scale_shift_table.shape) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + norm_hidden_states = self.norm1(hidden_states) + # print("scale_msa: ", scale_msa.shape) + # print("shift_msa: ", shift_msa.shape) + #scale_msa: torch.Size([5, 1, 1152]) + #shift_msa: torch.Size([5, 1, 1152]) + # exit() + # print("before: ", norm_hidden_states.shape) + #before: torch.Size([5, 3584, 1152]) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + # print("after: ", norm_hidden_states.shape) + #before: torch.Size([5, 3584, 1152]) + # exit() + norm_hidden_states = norm_hidden_states.squeeze(1) + else: + raise ValueError("Incorrect norm used") + + if self.pos_embed is not None: + norm_hidden_states = self.pos_embed(norm_hidden_states) + + # print("norm_hidden_states1: ", norm_hidden_states.shape) + norm_hidden_states = apply_rotary_emb(norm_hidden_states, self.freqs_cis) + # print("norm_hidden_states2: ", norm_hidden_states.shape) + + # 1. Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 2. Prepare GLIGEN inputs + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + gligen_kwargs = cross_attention_kwargs.pop("gligen", None) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + if self.use_ada_layer_norm_zero: + attn_output = gate_msa.unsqueeze(1) * attn_output + elif self.use_ada_layer_norm_single: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 2.5 GLIGEN Control + if gligen_kwargs is not None: + hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.use_ada_layer_norm: + norm_hidden_states = self.norm2(hidden_states, timestep) + elif self.use_ada_layer_norm_zero or self.use_layer_norm: + norm_hidden_states = self.norm2(hidden_states) + elif self.use_ada_layer_norm_single: + # For PixArt norm2 isn't applied here: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 + norm_hidden_states = hidden_states + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) + else: + raise ValueError("Incorrect norm") + + if self.pos_embed is not None and self.use_ada_layer_norm_single is False: + norm_hidden_states = self.pos_embed(norm_hidden_states) + # print("1") + # print("norm_hidden_states:",norm_hidden_states.shape) + norm_hidden_states = apply_rotary_emb(norm_hidden_states, self.freqs_cis) + # print("norm_hidden_states:",norm_hidden_states.shape) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + # print("attn_output: ", attn_output.shape) + # print("hidden_states: ", hidden_states.shape) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + if self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.use_ada_layer_norm_single: + norm_hidden_states = self.norm3(hidden_states) + + if self.use_ada_layer_norm_zero: + norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] + + if self.use_ada_layer_norm_single: + norm_hidden_states = self.norm2(hidden_states) + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale + ) + else: + ff_output = self.ff(norm_hidden_states, scale=lora_scale) + + if self.use_ada_layer_norm_zero: + ff_output = gate_mlp.unsqueeze(1) * ff_output + elif self.use_ada_layer_norm_single: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +@maybe_allow_in_graph +class TemporalBasicTransformerBlock(nn.Module): + r""" + A basic Transformer block for video like data. + + Parameters: + dim (`int`): The number of channels in the input and output. + time_mix_inner_dim (`int`): The number of channels for temporal attention. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + """ + + def __init__( + self, + dim: int, + time_mix_inner_dim: int, + num_attention_heads: int, + attention_head_dim: int, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.is_res = dim == time_mix_inner_dim + + self.norm_in = nn.LayerNorm(dim) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm_in = nn.LayerNorm(dim) + self.ff_in = FeedForward( + dim, + dim_out=time_mix_inner_dim, + activation_fn="geglu", + ) + + self.norm1 = nn.LayerNorm(time_mix_inner_dim) + self.attn1 = Attention( + query_dim=time_mix_inner_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + cross_attention_dim=None, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None: + # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. + # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during + # the second cross attention block. + self.norm2 = nn.LayerNorm(time_mix_inner_dim) + self.attn2 = Attention( + query_dim=time_mix_inner_dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + ) # is self-attn if encoder_hidden_states is none + else: + self.norm2 = None + self.attn2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(time_mix_inner_dim) + self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = None + + def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): + # Sets chunk feed-forward + self._chunk_size = chunk_size + # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off + self._chunk_dim = 1 + + def forward( + self, + hidden_states: torch.FloatTensor, + num_frames: int, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + batch_frames, seq_length, channels = hidden_states.shape + batch_size = batch_frames // num_frames + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) + + residual = hidden_states + hidden_states = self.norm_in(hidden_states) + + if self._chunk_size is not None: + hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) + else: + hidden_states = self.ff_in(hidden_states) + + if self.is_res: + hidden_states = hidden_states + residual + + norm_hidden_states = self.norm1(hidden_states) + attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) + hidden_states = attn_output + hidden_states + + # 3. Cross-Attention + if self.attn2 is not None: + norm_hidden_states = self.norm2(hidden_states) + attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + + if self._chunk_size is not None: + ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) + else: + ff_output = self.ff(norm_hidden_states) + + if self.is_res: + hidden_states = ff_output + hidden_states + else: + hidden_states = ff_output + + hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) + hidden_states = hidden_states.permute(0, 2, 1, 3) + hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) + + return hidden_states + + +class SkipFFTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + kv_input_dim: int, + kv_input_dim_proj_use_bias: bool, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + attention_out_bias: bool = True, + ): + super().__init__() + if kv_input_dim != dim: + self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) + else: + self.kv_mapper = None + + self.norm1 = RMSNorm(dim, 1e-06) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + ) + + self.norm2 = RMSNorm(dim, 1e-06) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + if self.kv_mapper is not None: + encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + if activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2.py new file mode 100644 index 0000000000000000000000000000000000000000..8199aa18c3d617c359e59e5d71f561862f5ed4a1 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2.py @@ -0,0 +1,1954 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + + + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config.py new file mode 100644 index 0000000000000000000000000000000000000000..da09b49949b3cb676feacd6d814bdd72430b6968 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT-2 configuration""" + +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GPT2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-2 + [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in for the multiple choice head in + [`GPT2DoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + bos_token_id (`int`, *optional*, defaults to 50256): + Id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + Id of the end of sentence token in the vocabulary. + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + + Example: + + ```python + >>> from transformers import GPT2Config, GPT2Model + + >>> # Initializing a GPT2 configuration + >>> configuration = GPT2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPT2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1280, + n_embd=768, + n_layer=12, + n_head=8, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config1.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config1.py new file mode 100644 index 0000000000000000000000000000000000000000..da09b49949b3cb676feacd6d814bdd72430b6968 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config1.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT-2 configuration""" + +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GPT2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-2 + [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in for the multiple choice head in + [`GPT2DoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + bos_token_id (`int`, *optional*, defaults to 50256): + Id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + Id of the end of sentence token in the vocabulary. + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + + Example: + + ```python + >>> from transformers import GPT2Config, GPT2Model + + >>> # Initializing a GPT2 configuration + >>> configuration = GPT2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPT2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1280, + n_embd=768, + n_layer=12, + n_head=8, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config_flashattn.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config_flashattn.py new file mode 100644 index 0000000000000000000000000000000000000000..c9028f4c512995339f5f4946c123fe7b7d12f727 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_config_flashattn.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT-2 configuration""" + +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GPT2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-2 + [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in for the multiple choice head in + [`GPT2DoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + bos_token_id (`int`, *optional*, defaults to 50256): + Id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + Id of the end of sentence token in the vocabulary. + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + + Example: + + ```python + >>> from transformers import GPT2Config, GPT2Model + + >>> # Initializing a GPT2 configuration + >>> configuration = GPT2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPT2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1280, + n_embd=768, + n_layer=12, + n_head=8, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id,_attn_implementation = "flash_attention_2", **kwargs) + + +class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_crossattn_config.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_crossattn_config.py new file mode 100644 index 0000000000000000000000000000000000000000..e3a8c446ce836547dada997ef8a79c7e6834eb34 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_crossattn_config.py @@ -0,0 +1,271 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT-2 configuration""" + +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GPT2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-2 + [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in for the multiple choice head in + [`GPT2DoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + bos_token_id (`int`, *optional*, defaults to 50256): + Id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + Id of the end of sentence token in the vocabulary. + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + + Example: + + ```python + >>> from transformers import GPT2Config, GPT2Model + + >>> # Initializing a GPT2 configuration + >>> configuration = GPT2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPT2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1280, + n_embd=768, + n_layer=12, + n_head=8, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.add_cross_attention = True + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id,add_cross_attention = True, **kwargs) + + +class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_crossattn_config1.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_crossattn_config1.py new file mode 100644 index 0000000000000000000000000000000000000000..2f818027fbdf0d30a7f4feea8ea32a8edb705378 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_crossattn_config1.py @@ -0,0 +1,271 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT-2 configuration""" + +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GPT2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-2 + [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in for the multiple choice head in + [`GPT2DoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + bos_token_id (`int`, *optional*, defaults to 50256): + Id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + Id of the end of sentence token in the vocabulary. + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + + Example: + + ```python + >>> from transformers import GPT2Config, GPT2Model + + >>> # Initializing a GPT2 configuration + >>> configuration = GPT2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPT2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=500, + n_embd=768, + n_layer=12, + n_head=8, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.add_cross_attention = True + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id,add_cross_attention = True, **kwargs) + + +class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_crossattn_config_flashattn.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_crossattn_config_flashattn.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4b03e7efbdf5ab9faf0da23f72ad972f6748b0 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_crossattn_config_flashattn.py @@ -0,0 +1,271 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT-2 configuration""" + +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GPT2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-2 + [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in for the multiple choice head in + [`GPT2DoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + bos_token_id (`int`, *optional*, defaults to 50256): + Id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + Id of the end of sentence token in the vocabulary. + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + + Example: + + ```python + >>> from transformers import GPT2Config, GPT2Model + + >>> # Initializing a GPT2 configuration + >>> configuration = GPT2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPT2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1280, + n_embd=768, + n_layer=12, + n_head=8, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.add_cross_attention = True + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id,add_cross_attention = True,_attn_implementation = "flash_attention_2", **kwargs) + + +class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2.py new file mode 100644 index 0000000000000000000000000000000000000000..7a37aa11d94cb0a042d495c2b21c2df27ab7100f --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2.py @@ -0,0 +1,2017 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + + query = query.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + # print("1") + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_casual_mask.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_casual_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..e95948db15476c62993e50b3473a63a8eb0f607c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_casual_mask.py @@ -0,0 +1,2017 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + + query = query.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.tril() + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time.py new file mode 100644 index 0000000000000000000000000000000000000000..32ac7dfc3d3e77254d76fe2b7911e2bf660b2d4f --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time.py @@ -0,0 +1,2038 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + time_step: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + time_step.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + # print("shift_msa:",shift_msa.shape) + # print("scale_msa:",scale_msa.shape) + #shift_msa: torch.Size([5, 1, 768]) + #scale_msa: torch.Size([5, 1, 768]) + # print("before hidden:",hidden_states.shape) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + # print("after hidden:",hidden_states.shape) + hidden_states = hidden_states.squeeze(1) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + attn_output = attn_output * gate_msa + # print("attn_output:",attn_output.shape) + hidden_states = attn_output + residual + # print("hidden_states:",hidden_states.shape) + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + feed_forward_hidden_states = self.mlp(hidden_states) + feed_forward_hidden_states = feed_forward_hidden_states * gate_mlp + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + time_step: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + time_step=time_step, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new copy.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new copy.py new file mode 100644 index 0000000000000000000000000000000000000000..03e1cb0fb0cc1e54cfef579fdbbe0d010043b34a --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new copy.py @@ -0,0 +1,2043 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + + query = query.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + # print("1") + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + time_step: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + time_step.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + hidden_states = hidden_states.squeeze(1) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + attn_output = attn_output * gate_msa + + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + feed_forward_hidden_states = self.mlp(hidden_states) + feed_forward_hidden_states = feed_forward_hidden_states * gate_mlp + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.proj_out=torch.nn.Linear(self.embed_dim, self.embed_dim) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + self.scale_shift_table = nn.Parameter(torch.randn(2, config.hidden_size) / config.hidden_size**0.5) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + time_step: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + # print("time_step", time_step) + time_step, embedded_timestep = self.adaln_single( + time_step, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + time_step=time_step, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new.py new file mode 100644 index 0000000000000000000000000000000000000000..03e1cb0fb0cc1e54cfef579fdbbe0d010043b34a --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new.py @@ -0,0 +1,2043 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + + query = query.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + # print("1") + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + time_step: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + time_step.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + hidden_states = hidden_states.squeeze(1) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + attn_output = attn_output * gate_msa + + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + feed_forward_hidden_states = self.mlp(hidden_states) + feed_forward_hidden_states = feed_forward_hidden_states * gate_mlp + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.proj_out=torch.nn.Linear(self.embed_dim, self.embed_dim) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + self.scale_shift_table = nn.Parameter(torch.randn(2, config.hidden_size) / config.hidden_size**0.5) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + time_step: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + # print("time_step", time_step) + time_step, embedded_timestep = self.adaln_single( + time_step, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + time_step=time_step, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..06049398f167a868211c70e24eb3d090fcee78ca --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask.py @@ -0,0 +1,2051 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + + query = query.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + # print("1") + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + time_step: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + time_step.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + hidden_states = hidden_states.squeeze(1) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + attn_output = attn_output * gate_msa + + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + feed_forward_hidden_states = self.mlp(hidden_states) + feed_forward_hidden_states = feed_forward_hidden_states * gate_mlp + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.proj_out=torch.nn.Linear(self.embed_dim, self.embed_dim) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + self.scale_shift_table = nn.Parameter(torch.randn(2, config.hidden_size) / config.hidden_size**0.5) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + time_step: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + if attention_mask.dim() == 4: + attention_mask = attention_mask.to(dtype=self.dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + else: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + if encoder_attention_mask.dim() == 4: + encoder_attention_mask = encoder_attention_mask.to(dtype=self.dtype) + encoder_attention_mask = (1.0 - encoder_attention_mask) * torch.finfo(self.dtype).min + else: + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + # print("time_step", time_step) + time_step, embedded_timestep = self.adaln_single( + time_step, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + time_step=time_step, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual.py new file mode 100644 index 0000000000000000000000000000000000000000..a59b14b42ba7251c1b275ac67c1f71c6313ec196 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual.py @@ -0,0 +1,2051 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models_gpt.models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + + query = query.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + # if not self.is_cross_attention: + # # if only "normal" attention layer implements causal mask + # query_length, key_length = query.size(-2), key.size(-2) + # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + # mask_value = torch.finfo(attn_weights.dtype).min + # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + # print("1") + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + time_step: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + time_step.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + hidden_states = hidden_states.squeeze(1) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + attn_output = attn_output * gate_msa + + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + feed_forward_hidden_states = self.mlp(hidden_states) + feed_forward_hidden_states = feed_forward_hidden_states * gate_mlp + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.proj_out=torch.nn.Linear(self.embed_dim, self.embed_dim) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + self.scale_shift_table = nn.Parameter(torch.randn(2, config.hidden_size) / config.hidden_size**0.5) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + time_step: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + if attention_mask.dim() == 4: + attention_mask = attention_mask.to(dtype=self.dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + else: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + if encoder_attention_mask.dim() == 4: + encoder_attention_mask = encoder_attention_mask.to(dtype=self.dtype) + encoder_attention_mask = (1.0 - encoder_attention_mask) * torch.finfo(self.dtype).min + else: + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + # print("time_step", time_step) + time_step, embedded_timestep = self.adaln_single( + time_step, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + time_step=time_step, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual_flashattn.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual_flashattn.py new file mode 100644 index 0000000000000000000000000000000000000000..fbf05c19a03742e3b8617bdec39349da712bd149 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual_flashattn.py @@ -0,0 +1,2061 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + + query = query.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + # if not self.is_cross_attention: + # # if only "normal" attention layer implements causal mask + # query_length, key_length = query.size(-2), key.size(-2) + # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + # mask_value = torch.finfo(attn_weights.dtype).min + # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Add ropes to the query + query = query.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + # print("1") + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + time_step: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + time_step.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + hidden_states = hidden_states.squeeze(1) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + attn_output = attn_output * gate_msa + + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + feed_forward_hidden_states = self.mlp(hidden_states) + feed_forward_hidden_states = feed_forward_hidden_states * gate_mlp + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.proj_out=torch.nn.Linear(self.embed_dim, self.embed_dim) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + self.scale_shift_table = nn.Parameter(torch.randn(2, config.hidden_size) / config.hidden_size**0.5) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + time_step: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + if attention_mask.dim() == 4: + attention_mask = attention_mask.to(dtype=self.dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + else: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + if encoder_attention_mask.dim() == 4: + encoder_attention_mask = encoder_attention_mask.to(dtype=self.dtype) + encoder_attention_mask = (1.0 - encoder_attention_mask) * torch.finfo(self.dtype).min + else: + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + # print("time_step", time_step) + time_step, embedded_timestep = self.adaln_single( + time_step, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + time_step=time_step, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual_reflow.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual_reflow.py new file mode 100644 index 0000000000000000000000000000000000000000..f6fec37d0d46048a46ecf089e2fb890d952b2a0d --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_correct_mask_noncasual_reflow.py @@ -0,0 +1,2142 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models_gpt.models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +from diffusers.models.normalization import AdaLayerNormSingle +from diffusers.models.embeddings import TimestepEmbedding + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +class AdaLayerNormSingleFlow(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = PixArtAlphaCombinedFlowEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + +class PixArtAlphaCombinedFlowEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.flow_t_size = 512 + self.outdim = size_emb_dim + self.timestep_embedder = TimestepEmbedding(in_channels=self.flow_t_size, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + # https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87 + def timestep_embedding(self, timesteps, max_period=10000, scale=1000): + """Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = self.flow_t_size // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type()) + args = timesteps[:, None] * freqs[None] * scale + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if self.flow_t_size % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.timestep_embedding(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + + + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + + query = query.transpose(1, 2) + + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + # if not self.is_cross_attention: + # # if only "normal" attention layer implements causal mask + # query_length, key_length = query.size(-2), key.size(-2) + # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + # mask_value = torch.finfo(attn_weights.dtype).min + # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + # print("1") + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + time_step: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + time_step.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + hidden_states = hidden_states.squeeze(1) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + attn_output = attn_output * gate_msa + + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + feed_forward_hidden_states = self.mlp(hidden_states) + feed_forward_hidden_states = feed_forward_hidden_states * gate_mlp + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.proj_out=torch.nn.Linear(self.embed_dim, self.embed_dim) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + self.adaln_single = AdaLayerNormSingleFlow(config.hidden_size, use_additional_conditions=False) + self.scale_shift_table = nn.Parameter(torch.randn(2, config.hidden_size) / config.hidden_size**0.5) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + time_step: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + if attention_mask.dim() == 4: + attention_mask = attention_mask.to(dtype=self.dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + else: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + if encoder_attention_mask.dim() == 4: + encoder_attention_mask = encoder_attention_mask.to(dtype=self.dtype) + encoder_attention_mask = (1.0 - encoder_attention_mask) * torch.finfo(self.dtype).min + else: + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + # print("time_step", time_step) + time_step, embedded_timestep = self.adaln_single( + time_step, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + time_step=time_step, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_flashattn.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_flashattn.py new file mode 100644 index 0000000000000000000000000000000000000000..beba8398dcfac848213d78ec8a775892e9ee5ced --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope2_time_new_flashattn.py @@ -0,0 +1,2056 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + + query = query.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + # print("GPT2Attention") + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + # print("GPT2FlashAttention2") + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Add ropes to the query + query = query.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + query = query.transpose(1, 2) + if query.shape == key.shape: + key = key.transpose(1, 2) + key = apply_rotary_emb(key, freqs_cis) + key = key.transpose(1, 2) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + causal = False + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + # print("1") + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size**0.5) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + time_step: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + time_step.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + hidden_states = hidden_states.squeeze(1) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + attn_output = attn_output * gate_msa + + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + feed_forward_hidden_states = self.mlp(hidden_states) + feed_forward_hidden_states = feed_forward_hidden_states * gate_mlp + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.proj_out=torch.nn.Linear(self.embed_dim, self.embed_dim) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + self.scale_shift_table = nn.Parameter(torch.randn(2, config.hidden_size) / config.hidden_size**0.5) + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + time_step: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + # print("time_step", time_step) + time_step, embedded_timestep = self.adaln_single( + time_step, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + time_step=time_step, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.ln_f(hidden_states) + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope3.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope3.py new file mode 100644 index 0000000000000000000000000000000000000000..06bb425843f2b8b64afca89c1d54795864bcfcff --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_rope3.py @@ -0,0 +1,2015 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch OpenAI GPT-2 model.""" + +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.cuda.amp import autocast +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel, SequenceSummary +from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from models.gpt2_config import GPT2Config + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input + + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "openai-community/gpt2" +_CONFIG_FOR_DOC = "GPT2Config" + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): + """Load tf checkpoints in a pytorch model""" + try: + import re + + import tensorflow as tf + except ImportError: + logger.error( + "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " + "https://www.tensorflow.org/install/ for installation instructions." + ) + raise + tf_path = os.path.abspath(gpt2_checkpoint_path) + logger.info(f"Converting TensorFlow checkpoint from {tf_path}") + # Load weights from TF model + init_vars = tf.train.list_variables(tf_path) + names = [] + arrays = [] + for name, shape in init_vars: + logger.info(f"Loading TF weight {name} with shape {shape}") + array = tf.train.load_variable(tf_path, name) + names.append(name) + arrays.append(array.squeeze()) + + for name, array in zip(names, arrays): + name = name[6:] # skip "model/" + name = name.split("/") + pointer = model + for m_name in name: + if re.fullmatch(r"[A-Za-z]+\d+", m_name): + scope_names = re.split(r"(\d+)", m_name) + else: + scope_names = [m_name] + if scope_names[0] == "w" or scope_names[0] == "g": + pointer = getattr(pointer, "weight") + elif scope_names[0] == "b": + pointer = getattr(pointer, "bias") + elif scope_names[0] == "wpe" or scope_names[0] == "wte": + pointer = getattr(pointer, scope_names[0]) + pointer = getattr(pointer, "weight") + else: + pointer = getattr(pointer, scope_names[0]) + if len(scope_names) >= 2: + num = int(scope_names[1]) + pointer = pointer[num] + try: + if pointer.shape != array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") + except ValueError as e: + e.args += (pointer.shape, array.shape) + raise + logger.info(f"Initialize PyTorch weight {name}") + pointer.data = torch.from_numpy(array) + return model + + +class GPT2Attention(nn.Module): + def __init__(self, config, is_cross_attention=False, layer_idx=None): + super().__init__() + self.config = config + max_positions = config.max_position_embeddings + self.register_buffer( + "bias", + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( + 1, 1, max_positions, max_positions + ), + persistent=False, + ) + self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) + + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.split_size = self.embed_dim + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + + self.scale_attn_weights = config.scale_attn_weights + self.is_cross_attention = is_cross_attention + + # Layer-wise attention scaling, reordering, and upcasting + self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx + self.layer_idx = layer_idx + self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.is_cross_attention: + self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) + self.q_attn = Conv1D(self.embed_dim, self.embed_dim) + else: + self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + self.is_causal = True + + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads) + index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)]) + + # Prune conv1d layers + self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) + self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) + + # Update hyper params + self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads)) + self.num_heads = self.num_heads - len(heads) + self.pruned_heads = self.pruned_heads.union(heads) + + def _attn(self, query, key, value, attention_mask=None, head_mask=None): + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + freqs_cis= precompute_freqs_cis(dim=query.size(-1), end=query.size(1)).to(query.device) + query = apply_rotary_emb(query, freqs_cis) + key = apply_rotary_emb(key, freqs_cis) + value = apply_rotary_emb(value, freqs_cis) + value = value.transpose(1, 2) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + + + attn_weights = torch.matmul(query, key.transpose(-1, -2)) + + if self.scale_attn_weights: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + # Layer-wise attention scaling + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = query.size() + _, _, k_seq_len, _ = key.size() + + # Preallocate attn_weights for `baddbmm` + attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) + + # Compute Scale Factor + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + + if self.scale_attn_by_inverse_layer_idx: + scale_factor /= float(self.layer_idx + 1) + + # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) + with autocast(enabled=False): + q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) + attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) + attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) + + if not self.is_cross_attention: + # if only "normal" attention layer implements causal mask + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. + # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` + mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights, mask_value) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise + if attn_weights.dtype != torch.float32: + raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + tensor = tensor.permute(0, 2, 1, 3).contiguous() + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) + return tensor.view(new_shape) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key, past_value = layer_past + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + if self.reorder_and_upcast_attn: + attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) + else: + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.c_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + +def precompute_freqs_cis(dim: int, end: int, constant: float = 10000.0): + ''' + 计算cos和sin的值,cos值在实部,sin值在虚部,类似于 cosx+j*sinx + :param dim: q,k,v的最后一维,一般为emb_dim/head_num + :param end: 句长length + :param constant: 这里指10000 + :return: + 复数计算 torch.polar(a, t)输出, a*(cos(t)+j*sin(t)) + ''' + # freqs: 计算 1/(10000^(2i/d) ),将结果作为参数theta + # 形式化为 [theta_0, theta_1, ..., theta_(d/2-1)] + freqs = 1.0 / (constant ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [d/2] + + # 计算m + t = torch.arange(end, device=freqs.device) # [length] + # 计算m*theta + freqs = torch.outer(t, freqs).float() # [length, d/2] + # freqs形式化为 [m*theta_0, m*theta_1, ..., m*theta_(d/2-1)],其中 m=0,1,...,length-1 + + # 计算cos(m*theta)+j*sin(m*theta) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + # freqs_cis: [cos(m*theta_0)+j*sin(m*theta_0), cos(m*theta_1)+j*sin(m*theta_1),), ..., cos(m*theta_(d/2-1))+j*sin(m*theta_(d/2-1))] + # 其中j为虚数单位, m=0,1,...,length-1 + return freqs_cis # [length, d/2] + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # (1, length, 1, d/2) + return freqs_cis.view(*shape) # [1, length, 1, d/2] + +def apply_rotary_emb(xq: torch.Tensor, freqs_cis: torch.Tensor,): + # 先将xq维度变为[bs, length, head, d/2, 2], 利用torch.view_as_complex转变为复数 + # xq:[q0, q1, .., q(d-1)] 转变为 xq_: [q0+j*q1, q2+j*q3, ..., q(d-2)+j*q(d-1)] + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [bs, length, head, d/2] + + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # [1, length, 1, d/2] + # 下式xq_ * freqs_cis形式化输出,以第一个为例, 如下 + # (q0+j*q1)(cos(m*theta_0)+j*sin(m*theta_0)) = q0*cos(m*theta_0)-q1*sin(m*theta_0) + j*(q1*cos(m*theta_0)+q0*sin(m*theta_0)) + # 上式的实部为q0*cos(m*theta_0)-q1*sin(m*theta_0),虚部为q1*cos(m*theta_0)+q0*sin(m*theta_0) + # 然后通过torch.view_as_real函数,取出实部和虚部,维度由[bs, length, head, d/2]变为[bs, length, head, d/2, 2],最后一维放实部与虚部 + # 最后经flatten函数将维度拉平,即[bs, length, head, d] + # 此时xq_out形式化为 [实部0,虚部0,实部1,虚部1,..., 实部(d/2-1), 虚部(d/2-1)] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # [bs, length, head, d] + + return xq_out.type_as(xq) + + + +class GPT2FlashAttention2(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + + + attn_dropout = self.attn_dropout.p if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = self._flash_attention_forward( + query, key, value, attention_mask, query_length, dropout=attn_dropout + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class GPT2MLP(nn.Module): + def __init__(self, intermediate_size, config): + super().__init__() + embed_dim = config.hidden_size + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] + self.dropout = nn.Dropout(config.resid_pdrop) + + def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.c_proj(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, +} + + +class GPT2Block(nn.Module): + def __init__(self, config, layer_idx=None): + super().__init__() + hidden_size = config.hidden_size + inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + attention_class = GPT2_ATTENTION_CLASSES[config._attn_implementation] + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = attention_class(config=config, layer_idx=layer_idx) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + if config.add_cross_attention: + self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx) + self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = GPT2MLP(inner_dim, config) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + # residual connection + hidden_states = attn_output + residual + + if encoder_hidden_states is not None: + # add one self-attention block for cross-attention + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with " + "cross-attention layers by setting `config.add_cross_attention=True`" + ) + residual = hidden_states + hidden_states = self.ln_cross_attn(hidden_states) + cross_attn_outputs = self.crossattention( + hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + ) + attn_output = cross_attn_outputs[0] + # residual connection + hidden_states = residual + attn_output + outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions, cross_attentions) + + +class GPT2PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = GPT2Config + load_tf_weights = load_tf_weights_in_gpt2 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["GPT2Block"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (nn.Linear, Conv1D)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name == "c_proj.weight": + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) + + +@dataclass +class GPT2DoubleHeadsModelOutput(ModelOutput): + """ + Base class for outputs of models predicting if two sentences are consecutive or not. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + mc_loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `mc_labels` is provided): + Multiple choice classification loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + mc_logits (`torch.FloatTensor` of shape `(batch_size, num_choices)`): + Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). + past_key_values (`Tuple[Tuple[torch.Tensor]]`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of length `config.n_layers`, containing tuples of tensors of shape `(batch_size, num_heads, + sequence_length, embed_size_per_head)`). + + Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + GPT2Attentions weights after the attention softmax, used to compute the weighted average in the + self-attention heads. + """ + + loss: Optional[torch.FloatTensor] = None + mc_loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + mc_logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +GPT2_START_DOCSTRING = r""" + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`GPT2Config`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +GPT2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`): + `input_ids_length` = `sequence_length` if `past_key_values` is `None` else + `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input + sequence tokens in the vocabulary. + + If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as + `input_ids`. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): + Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see + `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for + `past_key_values`. In other words, the `attention_mask` always has to have the length: + `len(past_key_values) + len(input_ids)` + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + + [What are position IDs?](../glossary#position-ids) + head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + + If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see + `past_key_values`). + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" +PARALLELIZE_DOCSTRING = r""" + This is an experimental feature and is a subject to change at a moment's notice. + + Uses a device map to distribute attention modules of the model across several devices. If no device map is given, + it will evenly distribute blocks across all devices. + + Args: + device_map (`Dict[int, list]`, optional, defaults to None): + A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always + automatically mapped to the first device (for esoteric reasons). That means that the first device should + have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the + following number of attention modules: + + - openai-community/gpt2: 12 + - openai-community/gpt2-medium: 24 + - openai-community/gpt2-large: 36 + - openai-community/gpt2-xl: 48 + + Example: + + ```python + # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-xl") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7, 8], + 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21], + 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34], + 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], + } + model.parallelize(device_map) + ``` +""" +DEPARALLELIZE_DOCSTRING = r""" + Moves the model to cpu from a model parallel state. + + Example: + + ```python + # On a 4 GPU machine with openai-community/gpt2-large: + model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2-large") + device_map = { + 0: [0, 1, 2, 3, 4, 5, 6, 7], + 1: [8, 9, 10, 11, 12, 13, 14, 15], + 2: [16, 17, 18, 19, 20, 21, 22, 23], + 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35], + } + model.parallelize(device_map) # Splits the model across several devices + model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache() + ``` +""" + + +@add_start_docstrings( + "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.", + GPT2_START_DOCSTRING, +) +class GPT2Model(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.hidden_size + + self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + self._attn_implementation = config._attn_implementation + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + # Check validity of device_map + warnings.warn( + "`GPT2Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your" + " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1," + " ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map + ) + assert_device_map(self.device_map, len(self.h)) + self.model_parallel = True + self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys())) + self.last_device = "cuda:" + str(max(self.device_map.keys())) + self.wte = self.wte.to(self.first_device) + self.wpe = self.wpe.to(self.first_device) + # Load onto devices + for k, v in self.device_map.items(): + for block in v: + cuda_device = "cuda:" + str(k) + self.h[block] = self.h[block].to(cuda_device) + # ln_f to last + self.ln_f = self.ln_f.to(self.last_device) + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.model_parallel = False + self.device_map = None + self.first_device = "cpu" + self.last_device = "cpu" + self.wte = self.wte.to("cpu") + self.wpe = self.wpe.to("cpu") + for index in range(len(self.h)): + self.h[index] = self.h[index].to("cpu") + self.ln_f = self.ln_f.to("cpu") + torch.cuda.empty_cache() + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + """ + for layer, heads in heads_to_prune.items(): + self.h[layer].attn.prune_heads(heads) + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=BaseModelOutputWithPastAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].size(-2) + if position_ids is None: + position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + if self._attn_implementation == "flash_attention_2": + attention_mask = attention_mask if 0 in attention_mask else None + else: + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and the dtype's smallest value for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.config.add_cross_attention and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + if self._attn_implementation != "flash_attention_2": + encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # head_mask has shape n_layer x batch x n_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + all_hidden_states = () if output_hidden_states else None + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure layer_past is on same device as hidden_states (might not be correct) + if layer_past is not None: + layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if isinstance(head_mask, torch.Tensor): + head_mask = head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + outputs = self._gradient_checkpointing_func( + block.__call__, + hidden_states, + None, + attention_mask, + head_mask[i], + encoder_hidden_states, + encoder_attention_mask, + use_cache, + output_attentions, + ) + else: + outputs = block( + hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + if self.config.add_cross_attention: + all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + if v is not None + ) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input + embeddings). + """, + GPT2_START_DOCSTRING, +) +class GPT2LMHeadModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2LMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load" + " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own" + " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':" + " 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(lm_logits.device) + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + cross_attentions=transformer_outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ +The GPT2 Model transformer with a language modeling and a multiple-choice classification head on top e.g. for +RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the +input embeddings, the classification head takes as input the input of a specified classification token index in the +input sequence). +""", + GPT2_START_DOCSTRING, +) +class GPT2DoubleHeadsModel(GPT2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + config.num_labels = 1 + self.transformer = GPT2Model(config) + self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + self.multiple_choice_head = SequenceSummary(config) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings(PARALLELIZE_DOCSTRING) + def parallelize(self, device_map=None): + warnings.warn( + "`GPT2DoubleHeadsModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should" + " load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your" + " own `device_map` but it needs to be a dictionary module_name to device, so for instance" + " {'transformer.h.0': 0, 'transformer.h.1': 1, ...}", + FutureWarning, + ) + self.device_map = ( + get_device_map(len(self.transformer.h), range(torch.cuda.device_count())) + if device_map is None + else device_map + ) + assert_device_map(self.device_map, len(self.transformer.h)) + self.transformer.parallelize(self.device_map) + self.lm_head = self.lm_head.to(self.transformer.first_device) + self.multiple_choice_head = self.multiple_choice_head.to(self.transformer.first_device) + self.model_parallel = True + + @add_start_docstrings(DEPARALLELIZE_DOCSTRING) + def deparallelize(self): + warnings.warn( + "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.", + FutureWarning, + ) + self.transformer.deparallelize() + self.transformer = self.transformer.to("cpu") + self.lm_head = self.lm_head.to("cpu") + self.multiple_choice_head = self.multiple_choice_head.to("cpu") + self.model_parallel = False + torch.cuda.empty_cache() + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def prepare_inputs_for_generation(self, input_ids, inputs_embeds=None, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # Omit tokens covered by past_key_values + if past_key_values: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + else: + position_ids = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids.contiguous()} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + ) + return model_inputs + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mc_token_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + mc_labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, GPT2DoubleHeadsModelOutput]: + r""" + mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input): + Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) - + 1]`. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids`. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. All labels set to + `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size - 1]` + mc_labels (`torch.LongTensor` of shape `(batch_size)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where *num_choices* is the size of the second dimension of the input tensors. (see *input_ids* above) + + Return: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoTokenizer, GPT2DoubleHeadsModel + + >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") + >>> model = GPT2DoubleHeadsModel.from_pretrained("openai-community/gpt2") + + >>> # Add a [CLS] to the vocabulary (we should train it also!) + >>> num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"}) + >>> # Update the model embeddings with the new vocabulary size + >>> embedding_layer = model.resize_token_embeddings(len(tokenizer)) + + >>> choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"] + >>> encoded_choices = [tokenizer.encode(s) for s in choices] + >>> cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices] + + >>> input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2 + >>> mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1 + + >>> outputs = model(input_ids, mc_token_ids=mc_token_ids) + >>> lm_logits = outputs.logits + >>> mc_logits = outputs.mc_logits + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.transformer.first_device) + hidden_states = hidden_states.to(self.lm_head.weight.device) + + lm_logits = self.lm_head(hidden_states) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids).squeeze(-1) + + mc_loss = None + if mc_labels is not None: + loss_fct = CrossEntropyLoss() + mc_loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)) + lm_loss = None + if labels is not None: + labels = labels.to(lm_logits.device) + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + + if not return_dict: + output = (lm_logits, mc_logits) + transformer_outputs[1:] + if mc_loss is not None: + output = (mc_loss,) + output + return ((lm_loss,) + output) if lm_loss is not None else output + + return GPT2DoubleHeadsModelOutput( + loss=lm_loss, + mc_loss=mc_loss, + logits=lm_logits, + mc_logits=mc_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + @staticmethod + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct + beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past_key_values + ) + + +@add_start_docstrings( + """ + The GPT2 Model transformer with a sequence classification head on top (linear layer). + + [`GPT2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-1) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForSequenceClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.score = nn.Linear(config.n_embd, self.num_labels, bias=False) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint="microsoft/DialogRPT-updown", + output_type=SequenceClassifierOutputWithPast, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + assert ( + self.config.pad_token_id is not None or batch_size == 1 + ), "Cannot handle batch sizes > 1 if no padding token is defined." + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + logger.warning( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + GPT2 Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + GPT2_START_DOCSTRING, +) +class GPT2ForTokenClassification(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.transformer = GPT2Model(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) + # fmt: off + @add_code_sample_docstrings( + checkpoint="brad1141/gpt2-finetuned-comp2", + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_loss=0.25, + expected_output=[ + "Lead", + "Lead", + "Lead", + "Position", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + "Lead", + ], + ) + # fmt: on + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ + The GPT-2 Model transformer with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + GPT2_START_DOCSTRING, +) +class GPT2ForQuestionAnswering(GPT2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.transformer = GPT2Model(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Model parallel + self.model_parallel = False + self.device_map = None + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=QuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + real_checkpoint=_CHECKPOINT_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_textemb_config.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_textemb_config.py new file mode 100644 index 0000000000000000000000000000000000000000..06085dfead5a91a5e8cfbe9001c5b58c5d35fa8d --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_textemb_config.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT-2 configuration""" + +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GPT2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-2 + [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in for the multiple choice head in + [`GPT2DoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + bos_token_id (`int`, *optional*, defaults to 50256): + Id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + Id of the end of sentence token in the vocabulary. + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + + Example: + + ```python + >>> from transformers import GPT2Config, GPT2Model + + >>> # Initializing a GPT2 configuration + >>> configuration = GPT2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPT2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1280, + n_embd=768, + n_layer=3, + n_head=4, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_textemb_config_128.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_textemb_config_128.py new file mode 100644 index 0000000000000000000000000000000000000000..45adb2d7fba1cb9a2b15b30766ab931b74cc6a84 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_textemb_config_128.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT-2 configuration""" + +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GPT2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-2 + [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in for the multiple choice head in + [`GPT2DoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + bos_token_id (`int`, *optional*, defaults to 50256): + Id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + Id of the end of sentence token in the vocabulary. + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + + Example: + + ```python + >>> from transformers import GPT2Config, GPT2Model + + >>> # Initializing a GPT2 configuration + >>> configuration = GPT2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPT2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1280, + n_embd=128, + n_layer=3, + n_head=4, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_textemb_config_769.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_textemb_config_769.py new file mode 100644 index 0000000000000000000000000000000000000000..63dba5f3a7f97a4537320d5c2e6f10a220716b9d --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/gpt2_textemb_config_769.py @@ -0,0 +1,270 @@ +# coding=utf-8 +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""OpenAI GPT-2 configuration""" + +from collections import OrderedDict +from typing import Any, List, Mapping, Optional + +from transformers import PreTrainedTokenizer, TensorType, is_torch_available +from transformers.configuration_utils import PretrainedConfig +from transformers.onnx import OnnxConfigWithPast, PatchingSpec +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class GPT2Config(PretrainedConfig): + """ + This is the configuration class to store the configuration of a [`GPT2Model`] or a [`TFGPT2Model`]. It is used to + instantiate a GPT-2 model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the GPT-2 + [openai-community/gpt2](https://huggingface.co/openai-community/gpt2) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 50257): + Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. + n_positions (`int`, *optional*, defaults to 1024): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + n_embd (`int`, *optional*, defaults to 768): + Dimensionality of the embeddings and hidden states. + n_layer (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + n_head (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + n_inner (`int`, *optional*): + Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd + activation_function (`str`, *optional*, defaults to `"gelu_new"`): + Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. + resid_pdrop (`float`, *optional*, defaults to 0.1): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + embd_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the embeddings. + attn_pdrop (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention. + layer_norm_epsilon (`float`, *optional*, defaults to 1e-05): + The epsilon to use in the layer normalization layers. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + summary_type (`string`, *optional*, defaults to `"cls_index"`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Has to be one of the following options: + + - `"last"`: Take the last token hidden state (like XLNet). + - `"first"`: Take the first token hidden state (like BERT). + - `"mean"`: Take the mean of all tokens hidden states. + - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). + - `"attn"`: Not implemented now, use multi-head attention. + summary_use_proj (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether or not to add a projection after the vector extraction. + summary_activation (`str`, *optional*): + Argument used when doing sequence summary. Used in for the multiple choice head in + [`GPT2DoubleHeadsModel`]. + + Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. + summary_proj_to_labels (`bool`, *optional*, defaults to `True`): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. + summary_first_dropout (`float`, *optional*, defaults to 0.1): + Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and + [`TFGPT2DoubleHeadsModel`]. + + The dropout ratio to be used after the projection and activation. + scale_attn_weights (`bool`, *optional*, defaults to `True`): + Scale attention weights by dividing by sqrt(hidden_size).. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). + bos_token_id (`int`, *optional*, defaults to 50256): + Id of the beginning of sentence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 50256): + Id of the end of sentence token in the vocabulary. + scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): + Whether to additionally scale attention weights by `1 / layer_idx + 1`. + reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): + Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention + dot-product/softmax to float() when training with mixed precision. + + Example: + + ```python + >>> from transformers import GPT2Config, GPT2Model + + >>> # Initializing a GPT2 configuration + >>> configuration = GPT2Config() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = GPT2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "gpt2" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "hidden_size": "n_embd", + "max_position_embeddings": "n_positions", + "num_attention_heads": "n_head", + "num_hidden_layers": "n_layer", + } + + def __init__( + self, + vocab_size=50257, + n_positions=1280, + n_embd=769, + n_layer=3, + n_head=4, + n_inner=None, + activation_function="gelu_new", + resid_pdrop=0.1, + embd_pdrop=0.1, + attn_pdrop=0.1, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + summary_type="cls_index", + summary_use_proj=True, + summary_activation=None, + summary_proj_to_labels=True, + summary_first_dropout=0.1, + scale_attn_weights=True, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + scale_attn_by_inverse_layer_idx=False, + reorder_and_upcast_attn=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.summary_type = summary_type + self.summary_use_proj = summary_use_proj + self.summary_activation = summary_activation + self.summary_first_dropout = summary_first_dropout + self.summary_proj_to_labels = summary_proj_to_labels + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx + self.reorder_and_upcast_attn = reorder_and_upcast_attn + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + +class GPT2OnnxConfig(OnnxConfigWithPast): + def __init__( + self, + config: PretrainedConfig, + task: str = "default", + patching_specs: List[PatchingSpec] = None, + use_past: bool = False, + ): + super().__init__(config, task=task, patching_specs=patching_specs, use_past=use_past) + if not getattr(self._config, "pad_token_id", None): + # TODO: how to do that better? + self._config.pad_token_id = 0 + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + common_inputs = OrderedDict({"input_ids": {0: "batch", 1: "sequence"}}) + if self.use_past: + self.fill_with_past_key_values_(common_inputs, direction="inputs") + common_inputs["attention_mask"] = {0: "batch", 1: "past_sequence + sequence"} + else: + common_inputs["attention_mask"] = {0: "batch", 1: "sequence"} + + return common_inputs + + @property + def num_layers(self) -> int: + return self._config.n_layer + + @property + def num_attention_heads(self) -> int: + return self._config.n_head + + def generate_dummy_inputs( + self, + tokenizer: PreTrainedTokenizer, + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + common_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs( + tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + + # We need to order the input in the way they appears in the forward() + ordered_inputs = OrderedDict({"input_ids": common_inputs["input_ids"]}) + + # Need to add the past_keys + if self.use_past: + if not is_torch_available(): + raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.") + else: + import torch + + batch, seqlen = common_inputs["input_ids"].shape + # Not using the same length for past_key_values + past_key_values_length = seqlen + 2 + past_shape = ( + batch, + self.num_attention_heads, + past_key_values_length, + self._config.hidden_size // self.num_attention_heads, + ) + ordered_inputs["past_key_values"] = [ + (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(self.num_layers) + ] + + ordered_inputs["attention_mask"] = common_inputs["attention_mask"] + if self.use_past: + mask_dtype = ordered_inputs["attention_mask"].dtype + ordered_inputs["attention_mask"] = torch.cat( + [ordered_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1 + ) + + return ordered_inputs + + @property + def default_onnx_opset(self) -> int: + return 13 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama.py new file mode 100644 index 0000000000000000000000000000000000000000..9e093239b28c1fc471c514f6b336470eb0b39745 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama.py @@ -0,0 +1,1624 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from models.llama_config import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_one(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + # print(causal_mask) + # exit() + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + # value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + return causal_mask + + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_ca_config.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_ca_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d291ffb04373d7080e512e021b3052b420469f90 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_ca_config.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LLaMA model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class LlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=11008, + num_hidden_layers=12, + num_attention_heads=8, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=1280, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_config.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d291ffb04373d7080e512e021b3052b420469f90 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_config.py @@ -0,0 +1,192 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LLaMA model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + + +class LlamaConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the LLaMA-7B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`LlamaModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens, + Llama 2 up to 4096, CodeLlama up to 16384. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/main/perf_train_gpu_many#tensor-parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. See the following thread for more information on how + these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an + experimental feature, subject to breaking API changes in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers in the MLP layers. + + ```python + >>> from transformers import LlamaModel, LlamaConfig + + >>> # Initializing a LLaMA llama-7b style configuration + >>> configuration = LlamaConfig() + + >>> # Initializing a model from the llama-7b style configuration + >>> model = LlamaModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "llama" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=11008, + num_hidden_layers=12, + num_attention_heads=8, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=1280, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with two fields, `type` and `factor`, " f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}") \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn.py new file mode 100644 index 0000000000000000000000000000000000000000..e56925267d9b81c210b3a991394e3bd55f017249 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn.py @@ -0,0 +1,2135 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from models.llama_config import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_one(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def invert_attention_mask(encoder_attention_mask): + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + if encoder_attention_mask.dim() == 4: + encoder_extended_attention_mask = encoder_attention_mask + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(torch.float32).min + + return encoder_extended_attention_mask + +class LlamaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + bsz, k_len, _ = encoder_hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + B,X = encoder_attention_mask.size() + _,T = attention_mask.size() + + encoder_attention_mask = encoder_attention_mask.view(B, 1, 1, X) + attention_mask = attention_mask.view(B, 1, T, 1) + cross_attention_mask = attention_mask * encoder_attention_mask + # print("attention_mask", attention_mask.shape) + # exit + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # cos, sin = self.rotary_emb(value_states, position_ids) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + # key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + # if past_key_value is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = invert_attention_mask(cross_attention_mask) + causal_mask = attention_mask + # print("attention_mask", attention_mask.shape) + # print("key_states", key_states.shape) + # exit() + # print("attn_weights", attn_weights.shape) + # print("causal_mask", causal_mask.shape) + # exit() + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + # value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LlamaDecoderLayer_CrossAttn(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.cross_attn = LlamaCrossAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pre_cross_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + #Cross Attention + residual = hidden_states + + hidden_states = self.pre_cross_layernorm(hidden_states) + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + hidden_states + + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaModel_crossattn(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer_CrossAttn(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + attention_mask_pre=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb.py new file mode 100644 index 0000000000000000000000000000000000000000..c09d9770f788f102e30dd40e5369ba3717cb3673 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb.py @@ -0,0 +1,2172 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from models.llama_config import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + # print("x.shape",x.shape) + # print("position_ids.shape",position_ids.shape) + # print("position_ids",position_ids) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_one(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + # print("q.shape",q.shape) + # print("cos.shape",cos.shape) + # print("sin.shape",sin.shape) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def invert_attention_mask(encoder_attention_mask): + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + if encoder_attention_mask.dim() == 4: + encoder_extended_attention_mask = encoder_attention_mask + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(torch.float32).min + + return encoder_extended_attention_mask + +class LlamaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + bsz, k_len, _ = encoder_hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + # cos, sin = self.rotary_emb(hidden_states, position_ids) + # print("hidden_states.shape,",hidden_states.shape) + # print("self.head",self.head_dim) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # hidden_states = apply_rotary_pos_emb_one(hidden_states, cos, sin) + # cos_e, sin_e = self.rotary_emb(encoder_hidden_states, position_ids) + # encoder_hidden_states = apply_rotary_pos_emb_one(encoder_hidden_states, cos, sin) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + B,X = encoder_attention_mask.size() + _,T = attention_mask.size() + + encoder_attention_mask = encoder_attention_mask.view(B, 1, 1, X) + attention_mask = attention_mask.view(B, 1, T, 1) + cross_attention_mask = attention_mask * encoder_attention_mask + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(query_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + position_ids_k = torch.arange(0, k_len, device=hidden_states.device).unsqueeze(0) + cos_k, sin_k = self.rotary_emb(key_states, position_ids_k) + # print("cos_k",cos_k.shape) + # print("sin_k",sin_k.shape) + # print("key_states",key_states.shape) + key_states = apply_rotary_pos_emb_one(key_states, cos_k, sin_k) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + # if past_key_value is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = invert_attention_mask(cross_attention_mask) + causal_mask = attention_mask + # print("attention_mask", attention_mask.shape) + # print("key_states", key_states.shape) + # exit() + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + # value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LlamaDecoderLayer_CrossAttn(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.cross_attn = LlamaCrossAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + self.norm1 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.norm2 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.scale_shift_table = nn.Parameter(torch.randn(6, self.hidden_size) / self.hidden_size**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + + hidden_states = hidden_states.squeeze(1) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = hidden_states * gate_msa + + hidden_states = residual + hidden_states + + #Cross Attention + residual = hidden_states + + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + hidden_states + + + # Fully Connected + residual = hidden_states + + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + hidden_states = self.mlp(hidden_states) + hidden_states = hidden_states * gate_mlp + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaModel_crossattn(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer_CrossAttn(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=inputs_embeds.shape[0], hidden_dtype=inputs_embeds.dtype + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + attention_mask_pre=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_3rope.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_3rope.py new file mode 100644 index 0000000000000000000000000000000000000000..f93a8bc54e59cb618a4228403670a30c8f130908 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_3rope.py @@ -0,0 +1,2162 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from models.llama_config import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_one(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def invert_attention_mask(encoder_attention_mask): + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + if encoder_attention_mask.dim() == 4: + encoder_extended_attention_mask = encoder_attention_mask + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(torch.float32).min + + return encoder_extended_attention_mask + +class LlamaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + bsz, k_len, _ = encoder_hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + # cos, sin = self.rotary_emb(hidden_states, position_ids) + # print("hidden_states.shape,",hidden_states.shape) + # print("self.head",self.head_dim) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # hidden_states = apply_rotary_pos_emb_one(hidden_states, cos, sin) + # cos_e, sin_e = self.rotary_emb(encoder_hidden_states, position_ids) + # encoder_hidden_states = apply_rotary_pos_emb_one(encoder_hidden_states, cos, sin) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + B,X = encoder_attention_mask.size() + _,T = attention_mask.size() + + encoder_attention_mask = encoder_attention_mask.view(B, 1, 1, X) + attention_mask = attention_mask.view(B, 1, T, 1) + cross_attention_mask = attention_mask * encoder_attention_mask + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(query_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + position_ids_k = torch.arange(0, k_len, device=hidden_states.device).unsqueeze(0) + cos_k, sin_k = self.rotary_emb(key_states, position_ids_k) + key_states = apply_rotary_pos_emb_one(key_states, cos_k, sin_k) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + # if past_key_value is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = invert_attention_mask(cross_attention_mask) + causal_mask = attention_mask + # print("attention_mask", attention_mask.shape) + # print("key_states", key_states.shape) + # exit() + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + # value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LlamaDecoderLayer_CrossAttn(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.cross_attn = LlamaCrossAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + self.norm1 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.norm2 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.scale_shift_table = nn.Parameter(torch.randn(6, self.hidden_size) / self.hidden_size**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + + hidden_states = hidden_states.squeeze(1) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = hidden_states * gate_msa + + hidden_states = residual + hidden_states + + #Cross Attention + residual = hidden_states + + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + hidden_states + + + # Fully Connected + residual = hidden_states + + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + hidden_states = self.mlp(hidden_states) + hidden_states = hidden_states * gate_mlp + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaModel_crossattn(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer_CrossAttn(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=inputs_embeds.shape[0], hidden_dtype=inputs_embeds.dtype + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + attention_mask_pre=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_ori.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_ori.py new file mode 100644 index 0000000000000000000000000000000000000000..51b371c84bf0648a7e0847f18b4d298b7df80b9e --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_ori.py @@ -0,0 +1,2152 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from models.llama_config import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_one(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def invert_attention_mask(encoder_attention_mask): + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + if encoder_attention_mask.dim() == 4: + encoder_extended_attention_mask = encoder_attention_mask + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(torch.float32).min + + return encoder_extended_attention_mask + +class LlamaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + bsz, k_len, _ = encoder_hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + B,X = encoder_attention_mask.size() + _,T = attention_mask.size() + + encoder_attention_mask = encoder_attention_mask.view(B, 1, 1, X) + attention_mask = attention_mask.view(B, 1, T, 1) + cross_attention_mask = attention_mask * encoder_attention_mask + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # cos, sin = self.rotary_emb(value_states, position_ids) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + # key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + # if past_key_value is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = invert_attention_mask(cross_attention_mask) + causal_mask = attention_mask + # print("attention_mask", attention_mask.shape) + # print("key_states", key_states.shape) + # exit() + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + # value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LlamaDecoderLayer_CrossAttn(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.cross_attn = LlamaCrossAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + self.norm1 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.norm2 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.scale_shift_table = nn.Parameter(torch.randn(6, self.hidden_size) / self.hidden_size**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + + hidden_states = hidden_states.squeeze(1) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = hidden_states * gate_msa + + hidden_states = residual + hidden_states + + #Cross Attention + residual = hidden_states + + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + hidden_states + + + # Fully Connected + residual = hidden_states + + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + hidden_states = self.mlp(hidden_states) + hidden_states = hidden_states * gate_mlp + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaModel_crossattn(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer_CrossAttn(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=inputs_embeds.shape[0], hidden_dtype=inputs_embeds.dtype + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + attention_mask_pre=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_all.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_all.py new file mode 100644 index 0000000000000000000000000000000000000000..8b70139a897ce393d37598211255b8112a497ae9 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_all.py @@ -0,0 +1,2172 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from models.llama_config import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + # print("x.shape",x.shape) + # print("position_ids.shape",position_ids.shape) + # print("position_ids",position_ids) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_one(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + # print("q.shape",q.shape) + # print("cos.shape",cos.shape) + # print("sin.shape",sin.shape) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def invert_attention_mask(encoder_attention_mask): + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + if encoder_attention_mask.dim() == 4: + encoder_extended_attention_mask = encoder_attention_mask + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(torch.float32).min + + return encoder_extended_attention_mask + +class LlamaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + bsz, k_len, _ = encoder_hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + # cos, sin = self.rotary_emb(hidden_states, position_ids) + # print("hidden_states.shape,",hidden_states.shape) + # print("self.head",self.head_dim) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # hidden_states = apply_rotary_pos_emb_one(hidden_states, cos, sin) + # cos_e, sin_e = self.rotary_emb(encoder_hidden_states, position_ids) + # encoder_hidden_states = apply_rotary_pos_emb_one(encoder_hidden_states, cos, sin) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + B,X = encoder_attention_mask.size() + _,T = attention_mask.size() + + encoder_attention_mask = encoder_attention_mask.view(B, 1, 1, X) + # attention_mask = attention_mask.view(B, 1, T, 1) + cross_attention_mask = encoder_attention_mask.repeat(1, 1, T, 1) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # cos, sin = self.rotary_emb(query_states, position_ids) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + # position_ids_k = torch.arange(0, k_len, device=hidden_states.device).unsqueeze(0) + # cos_k, sin_k = self.rotary_emb(key_states, position_ids_k) + # # print("cos_k",cos_k.shape) + # # print("sin_k",sin_k.shape) + # # print("key_states",key_states.shape) + # key_states = apply_rotary_pos_emb_one(key_states, cos_k, sin_k) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + # if past_key_value is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = invert_attention_mask(cross_attention_mask) + causal_mask = attention_mask + # print("attention_mask", attention_mask.shape) + # print("key_states", key_states.shape) + # exit() + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + # value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LlamaDecoderLayer_CrossAttn(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.cross_attn = LlamaCrossAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + self.norm1 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.norm2 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.scale_shift_table = nn.Parameter(torch.randn(6, self.hidden_size) / self.hidden_size**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + + hidden_states = hidden_states.squeeze(1) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = hidden_states * gate_msa + + hidden_states = residual + hidden_states + + #Cross Attention + residual = hidden_states + + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + hidden_states + + + # Fully Connected + residual = hidden_states + + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + hidden_states = self.mlp(hidden_states) + hidden_states = hidden_states * gate_mlp + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaModel_crossattn(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer_CrossAttn(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=inputs_embeds.shape[0], hidden_dtype=inputs_embeds.dtype + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + attention_mask_pre=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_all_valuerope.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_all_valuerope.py new file mode 100644 index 0000000000000000000000000000000000000000..786456e7b51b506f8b5d5fac9ed55aeba93878cd --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_all_valuerope.py @@ -0,0 +1,2172 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from models.llama_config import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + # print("x.shape",x.shape) + # print("position_ids.shape",position_ids.shape) + # print("position_ids",position_ids) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_one(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + # print("q.shape",q.shape) + # print("cos.shape",cos.shape) + # print("sin.shape",sin.shape) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def invert_attention_mask(encoder_attention_mask): + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + if encoder_attention_mask.dim() == 4: + encoder_extended_attention_mask = encoder_attention_mask + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(torch.float32).min + + return encoder_extended_attention_mask + +class LlamaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + bsz, k_len, _ = encoder_hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + # cos, sin = self.rotary_emb(hidden_states, position_ids) + # print("hidden_states.shape,",hidden_states.shape) + # print("self.head",self.head_dim) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # hidden_states = apply_rotary_pos_emb_one(hidden_states, cos, sin) + # cos_e, sin_e = self.rotary_emb(encoder_hidden_states, position_ids) + # encoder_hidden_states = apply_rotary_pos_emb_one(encoder_hidden_states, cos, sin) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + B,X = encoder_attention_mask.size() + _,T = attention_mask.size() + + encoder_attention_mask = encoder_attention_mask.view(B, 1, 1, X) + # attention_mask = attention_mask.view(B, 1, T, 1) + cross_attention_mask = encoder_attention_mask.repeat(1, 1, T, 1) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # cos, sin = self.rotary_emb(query_states, position_ids) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + # position_ids_k = torch.arange(0, k_len, device=hidden_states.device).unsqueeze(0) + # cos_k, sin_k = self.rotary_emb(key_states, position_ids_k) + # # print("cos_k",cos_k.shape) + # # print("sin_k",sin_k.shape) + # # print("key_states",key_states.shape) + # key_states = apply_rotary_pos_emb_one(key_states, cos_k, sin_k) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + # if past_key_value is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = invert_attention_mask(cross_attention_mask) + causal_mask = attention_mask + # print("attention_mask", attention_mask.shape) + # print("key_states", key_states.shape) + # exit() + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + # value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LlamaDecoderLayer_CrossAttn(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.cross_attn = LlamaCrossAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + self.norm1 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.norm2 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.scale_shift_table = nn.Parameter(torch.randn(6, self.hidden_size) / self.hidden_size**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + + hidden_states = hidden_states.squeeze(1) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = hidden_states * gate_msa + + hidden_states = residual + hidden_states + + #Cross Attention + residual = hidden_states + + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + hidden_states + + + # Fully Connected + residual = hidden_states + + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + hidden_states = self.mlp(hidden_states) + hidden_states = hidden_states * gate_mlp + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaModel_crossattn(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer_CrossAttn(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=inputs_embeds.shape[0], hidden_dtype=inputs_embeds.dtype + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + attention_mask_pre=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_mask.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..d7d494b1bdf33df0cab30bbfa2a9eff114616e15 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_mask.py @@ -0,0 +1,2172 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from models.llama_config import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + # print("x.shape",x.shape) + # print("position_ids.shape",position_ids.shape) + # print("position_ids",position_ids) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_one(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + # print("q.shape",q.shape) + # print("cos.shape",cos.shape) + # print("sin.shape",sin.shape) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def invert_attention_mask(encoder_attention_mask): + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + if encoder_attention_mask.dim() == 4: + encoder_extended_attention_mask = encoder_attention_mask + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(torch.float32).min + + return encoder_extended_attention_mask + +class LlamaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + bsz, k_len, _ = encoder_hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + # cos, sin = self.rotary_emb(hidden_states, position_ids) + # print("hidden_states.shape,",hidden_states.shape) + # print("self.head",self.head_dim) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # hidden_states = apply_rotary_pos_emb_one(hidden_states, cos, sin) + # cos_e, sin_e = self.rotary_emb(encoder_hidden_states, position_ids) + # encoder_hidden_states = apply_rotary_pos_emb_one(encoder_hidden_states, cos, sin) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + B,X = encoder_attention_mask.size() + _,T = attention_mask.size() + + encoder_attention_mask = encoder_attention_mask.view(B, 1, 1, X) + # attention_mask = attention_mask.view(B, 1, T, 1) + cross_attention_mask = encoder_attention_mask.repeat(1, 1, T, 1) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(query_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + position_ids_k = torch.arange(0, k_len, device=hidden_states.device).unsqueeze(0) + cos_k, sin_k = self.rotary_emb(key_states, position_ids_k) + # print("cos_k",cos_k.shape) + # print("sin_k",sin_k.shape) + # print("key_states",key_states.shape) + key_states = apply_rotary_pos_emb_one(key_states, cos_k, sin_k) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + # if past_key_value is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = invert_attention_mask(cross_attention_mask) + causal_mask = attention_mask + # print("attention_mask", attention_mask.shape) + # print("key_states", key_states.shape) + # exit() + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + # value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LlamaDecoderLayer_CrossAttn(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.cross_attn = LlamaCrossAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + self.norm1 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.norm2 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.scale_shift_table = nn.Parameter(torch.randn(6, self.hidden_size) / self.hidden_size**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + + hidden_states = hidden_states.squeeze(1) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = hidden_states * gate_msa + + hidden_states = residual + hidden_states + + #Cross Attention + residual = hidden_states + + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + hidden_states + + + # Fully Connected + residual = hidden_states + + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + hidden_states = self.mlp(hidden_states) + hidden_states = hidden_states * gate_mlp + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaModel_crossattn(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer_CrossAttn(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=inputs_embeds.shape[0], hidden_dtype=inputs_embeds.dtype + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + attention_mask_pre=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_rope.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..0a7dd65a5f63dfe28030ddfd6aff47eaef2945ab --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/llama_crossattn_timeemb_wo_rope.py @@ -0,0 +1,2172 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, StaticCache +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + QuestionAnsweringModelOutput, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from models.llama_config import LlamaConfig + + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + +from diffusers.models.normalization import AdaLayerNormSingle + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "LlamaConfig" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(LlamaRMSNorm) + + +class LlamaRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): + super().__init__() + self.scaling_factor = scaling_factor + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + # For BC we register cos and sin cached + self.max_seq_len_cached = max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids): + # x: [bs, num_attention_heads, seq_len, head_size] + # print("x.shape",x.shape) + # print("position_ids.shape",position_ids.shape) + # print("position_ids",position_ids) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin + + +class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding): + """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation + + cos, sin = super().forward(x, position_ids) + return cos, sin + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + +def apply_rotary_pos_emb_one(q, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + # print("q.shape",q.shape) + # print("cos.shape",cos.shape) + # print("sin.shape",sin.shape) + q_embed = (q * cos) + (rotate_half(q) * sin) + return q_embed + + +class LlamaMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + if self.config.pretraining_tp > 1: + slice = self.intermediate_size // self.config.pretraining_tp + gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) + up_proj_slices = self.up_proj.weight.split(slice, dim=0) + down_proj_slices = self.down_proj.weight.split(slice, dim=1) + + gate_proj = torch.cat( + [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 + ) + up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) + + intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) + down_proj = [ + F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) + ] + down_proj = sum(down_proj) + else: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + return down_proj + + +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def invert_attention_mask(encoder_attention_mask): + """ + Invert an attention mask (e.g., switches 0. and 1.). + + Args: + encoder_attention_mask (`torch.Tensor`): An attention mask. + + Returns: + `torch.Tensor`: The inverted attention mask. + """ + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + if encoder_attention_mask.dim() == 4: + encoder_extended_attention_mask = encoder_attention_mask + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow + # /transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = (encoder_extended_attention_mask == + # encoder_extended_attention_mask.transpose(-1, -2)) + encoder_extended_attention_mask = encoder_extended_attention_mask + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(torch.float32).min + + return encoder_extended_attention_mask + +class LlamaCrossAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + bsz, k_len, _ = encoder_hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + # cos, sin = self.rotary_emb(hidden_states, position_ids) + # print("hidden_states.shape,",hidden_states.shape) + # print("self.head",self.head_dim) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # hidden_states = apply_rotary_pos_emb_one(hidden_states, cos, sin) + # cos_e, sin_e = self.rotary_emb(encoder_hidden_states, position_ids) + # encoder_hidden_states = apply_rotary_pos_emb_one(encoder_hidden_states, cos, sin) + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(encoder_hidden_states) + value_states = self.v_proj(encoder_hidden_states) + B,X = encoder_attention_mask.size() + _,T = attention_mask.size() + + encoder_attention_mask = encoder_attention_mask.view(B, 1, 1, X) + attention_mask = attention_mask.view(B, 1, T, 1) + cross_attention_mask = attention_mask * encoder_attention_mask + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, k_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + # cos, sin = self.rotary_emb(query_states, position_ids) + # # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + # query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + # position_ids_k = torch.arange(0, k_len, device=hidden_states.device).unsqueeze(0) + # cos_k, sin_k = self.rotary_emb(key_states, position_ids_k) + # # print("cos_k",cos_k.shape) + # # print("sin_k",sin_k.shape) + # # print("key_states",key_states.shape) + # key_states = apply_rotary_pos_emb_one(key_states, cos_k, sin_k) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + # if past_key_value is not None: + # # sin and cos are specific to RoPE models; cache_position needed for the static cache + # cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = invert_attention_mask(cross_attention_mask) + causal_mask = attention_mask + # print("attention_mask", attention_mask.shape) + # print("key_states", key_states.shape) + # exit() + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + +class LlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + self._init_rope() + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + query_states = apply_rotary_pos_emb_one(query_states, cos, sin) + key_states = apply_rotary_pos_emb_one(key_states, cos, sin) + # value_states = apply_rotary_pos_emb_one(value_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, -1) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaFlashAttention2(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +class LlamaSdpaAttention(LlamaAttention): + """ + Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from LlamaAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + # value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + is_causal = True if causal_mask is None and q_len > 1 else False + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, -1) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +LLAMA_ATTENTION_CLASSES = { + "eager": LlamaAttention, + "flash_attention_2": LlamaFlashAttention2, + "sdpa": LlamaSdpaAttention, +} + + +class LlamaDecoderLayer(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class LlamaDecoderLayer_CrossAttn(nn.Module): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) + self.cross_attn = LlamaCrossAttention(config=config, layer_idx=layer_idx) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + + self.norm1 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.norm2 = nn.LayerNorm(self.hidden_size, elementwise_affine=True,eps=1e-5) + self.scale_shift_table = nn.Parameter(torch.randn(6, self.hidden_size) / self.hidden_size**0.5) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + residual = hidden_states + batch_size = hidden_states.shape[0] + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) + ).chunk(6, dim=1) + + hidden_states = self.norm1(hidden_states) + hidden_states = hidden_states * (1 + scale_msa) + shift_msa + + hidden_states = hidden_states.squeeze(1) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = hidden_states * gate_msa + + hidden_states = residual + hidden_states + + #Cross Attention + residual = hidden_states + + hidden_states, self_attn_weights, present_key_value = self.cross_attn( + hidden_states=hidden_states, + attention_mask=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + hidden_states + + + # Fully Connected + residual = hidden_states + + hidden_states = self.norm2(hidden_states) + hidden_states = hidden_states * (1 + scale_mlp) + shift_mlp + hidden_states = self.mlp(hidden_states) + hidden_states = hidden_states * gate_mlp + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + +LLAMA_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`LlamaConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaPreTrainedModel(PreTrainedModel): + config_class = LlamaConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["LlamaDecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + _supports_quantized_cache = True + _supports_static_cache = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +LLAMA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare LLaMA Model outputting raw hidden-states without any specific head on top.", + LLAMA_START_DOCSTRING, +) +class LlamaModel(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + +class LlamaModel_crossattn(LlamaPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] + + Args: + config: LlamaConfig + """ + + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [LlamaDecoderLayer_CrossAttn(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + self.adaln_single = AdaLayerNormSingle(config.hidden_size, use_additional_conditions=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + attention_mask_pre: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." + ) + use_cache = False + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + added_cond_kwargs = {"resolution": None, "aspect_ratio": None} + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=inputs_embeds.shape[0], hidden_dtype=inputs_embeds.dtype + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + attention_mask_pre=attention_mask_pre, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + +class LlamaForCausalLM(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): + past_length = 0 + if past_key_values is not None: + # Past key values are always initialized with a `Cache` object -> no need for if-else anymore + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_length == 0: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + + model_inputs.update( + { + "position_ids": position_ids, + "cache_position": cache_position, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@add_start_docstrings( + """ + The LLaMa Model transformer with a sequence classification head on top (linear layer). + + [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForSequenceClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@add_start_docstrings( + """ +The Llama Model transformer with a span classification head on top for extractive question-answering tasks like +SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`). + """, + LLAMA_START_DOCSTRING, +) +class LlamaForQuestionAnswering(LlamaPreTrainedModel): + base_model_prefix = "transformer" + + # Copied from transformers.models.bloom.modeling_bloom.BloomForQuestionAnswering.__init__ with Bloom->Llama + def __init__(self, config): + super().__init__(config) + self.transformer = LlamaModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, 2) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.transformer.embed_tokens + + def set_input_embeddings(self, value): + self.transformer.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + start_positions: Optional[torch.LongTensor] = None, + end_positions: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, QuestionAnsweringModelOutput]: + r""" + start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1).contiguous() + end_logits = end_logits.squeeze(-1).contiguous() + + total_loss = None + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1).to(start_logits.device) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1).to(end_logits.device) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions = start_positions.clamp(0, ignored_index) + end_positions = end_positions.clamp(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((total_loss,) + output) if total_loss is not None else output + + return QuestionAnsweringModelOutput( + loss=total_loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + The Llama Model transformer with a token classification head on top (a linear layer on top of the hidden-states + output) e.g. for Named-Entity-Recognition (NER) tasks. + """, + LLAMA_START_DOCSTRING, +) +class LlamaForTokenClassification(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = LlamaModel(config) + if getattr(config, "classifier_dropout", None) is not None: + classifier_dropout = config.classifier_dropout + elif getattr(config, "hidden_dropout", None) is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.score = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + sequence_output = outputs[0] + sequence_output = self.dropout(sequence_output) + logits = self.score(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/__init__.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f22b4edd5d07f14186cde6ee16c23590319982a0 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/__init__.py @@ -0,0 +1,430 @@ +from .symbols import symbols + + + + +pinyin_dict = { + "a": ("^", "a"), + "ai": ("^", "ai"), + "an": ("^", "an"), + "ang": ("^", "ang"), + "ao": ("^", "ao"), + "ba": ("b", "a"), + "bai": ("b", "ai"), + "ban": ("b", "an"), + "bang": ("b", "ang"), + "bao": ("b", "ao"), + "be": ("b", "e"), + "bei": ("b", "ei"), + "ben": ("b", "en"), + "beng": ("b", "eng"), + "bi": ("b", "i"), + "bian": ("b", "ian"), + "biao": ("b", "iao"), + "bie": ("b", "ie"), + "bin": ("b", "in"), + "bing": ("b", "ing"), + "bo": ("b", "o"), + "bu": ("b", "u"), + "ca": ("c", "a"), + "cai": ("c", "ai"), + "can": ("c", "an"), + "cang": ("c", "ang"), + "cao": ("c", "ao"), + "ce": ("c", "e"), + "cen": ("c", "en"), + "ceng": ("c", "eng"), + "cha": ("ch", "a"), + "chai": ("ch", "ai"), + "chan": ("ch", "an"), + "chang": ("ch", "ang"), + "chao": ("ch", "ao"), + "che": ("ch", "e"), + "chen": ("ch", "en"), + "cheng": ("ch", "eng"), + "chi": ("ch", "iii"), + "chong": ("ch", "ong"), + "chou": ("ch", "ou"), + "chu": ("ch", "u"), + "chua": ("ch", "ua"), + "chuai": ("ch", "uai"), + "chuan": ("ch", "uan"), + "chuang": ("ch", "uang"), + "chui": ("ch", "uei"), + "chun": ("ch", "uen"), + "chuo": ("ch", "uo"), + "ci": ("c", "ii"), + "cong": ("c", "ong"), + "cou": ("c", "ou"), + "cu": ("c", "u"), + "cuan": ("c", "uan"), + "cui": ("c", "uei"), + "cun": ("c", "uen"), + "cuo": ("c", "uo"), + "da": ("d", "a"), + "dai": ("d", "ai"), + "dan": ("d", "an"), + "dang": ("d", "ang"), + "dao": ("d", "ao"), + "de": ("d", "e"), + "dei": ("d", "ei"), + "den": ("d", "en"), + "deng": ("d", "eng"), + "di": ("d", "i"), + "dia": ("d", "ia"), + "dian": ("d", "ian"), + "diao": ("d", "iao"), + "die": ("d", "ie"), + "ding": ("d", "ing"), + "diu": ("d", "iou"), + "dong": ("d", "ong"), + "dou": ("d", "ou"), + "du": ("d", "u"), + "duan": ("d", "uan"), + "dui": ("d", "uei"), + "dun": ("d", "uen"), + "duo": ("d", "uo"), + "e": ("^", "e"), + "ei": ("^", "ei"), + "en": ("^", "en"), + "ng": ("^", "en"), + "eng": ("^", "eng"), + "er": ("^", "er"), + "fa": ("f", "a"), + "fan": ("f", "an"), + "fang": ("f", "ang"), + "fei": ("f", "ei"), + "fen": ("f", "en"), + "feng": ("f", "eng"), + "fo": ("f", "o"), + "fou": ("f", "ou"), + "fu": ("f", "u"), + "ga": ("g", "a"), + "gai": ("g", "ai"), + "gan": ("g", "an"), + "gang": ("g", "ang"), + "gao": ("g", "ao"), + "ge": ("g", "e"), + "gei": ("g", "ei"), + "gen": ("g", "en"), + "geng": ("g", "eng"), + "gong": ("g", "ong"), + "gou": ("g", "ou"), + "gu": ("g", "u"), + "gua": ("g", "ua"), + "guai": ("g", "uai"), + "guan": ("g", "uan"), + "guang": ("g", "uang"), + "gui": ("g", "uei"), + "gun": ("g", "uen"), + "guo": ("g", "uo"), + "ha": ("h", "a"), + "hai": ("h", "ai"), + "han": ("h", "an"), + "hang": ("h", "ang"), + "hao": ("h", "ao"), + "he": ("h", "e"), + "hei": ("h", "ei"), + "hen": ("h", "en"), + "heng": ("h", "eng"), + "hong": ("h", "ong"), + "hou": ("h", "ou"), + "hu": ("h", "u"), + "hua": ("h", "ua"), + "huai": ("h", "uai"), + "huan": ("h", "uan"), + "huang": ("h", "uang"), + "hui": ("h", "uei"), + "hun": ("h", "uen"), + "huo": ("h", "uo"), + "ji": ("j", "i"), + "jia": ("j", "ia"), + "jian": ("j", "ian"), + "jiang": ("j", "iang"), + "jiao": ("j", "iao"), + "jie": ("j", "ie"), + "jin": ("j", "in"), + "jing": ("j", "ing"), + "jiong": ("j", "iong"), + "jiu": ("j", "iou"), + "ju": ("j", "v"), + "juan": ("j", "van"), + "jue": ("j", "ve"), + "jun": ("j", "vn"), + "ka": ("k", "a"), + "kai": ("k", "ai"), + "kan": ("k", "an"), + "kang": ("k", "ang"), + "kao": ("k", "ao"), + "ke": ("k", "e"), + "kei": ("k", "ei"), + "ken": ("k", "en"), + "keng": ("k", "eng"), + "kong": ("k", "ong"), + "kou": ("k", "ou"), + "ku": ("k", "u"), + "kua": ("k", "ua"), + "kuai": ("k", "uai"), + "kuan": ("k", "uan"), + "kuang": ("k", "uang"), + "kui": ("k", "uei"), + "kun": ("k", "uen"), + "kuo": ("k", "uo"), + "la": ("l", "a"), + "lai": ("l", "ai"), + "lan": ("l", "an"), + "lang": ("l", "ang"), + "lao": ("l", "ao"), + "le": ("l", "e"), + "lei": ("l", "ei"), + "leng": ("l", "eng"), + "li": ("l", "i"), + "lia": ("l", "ia"), + "lian": ("l", "ian"), + "liang": ("l", "iang"), + "liao": ("l", "iao"), + "lie": ("l", "ie"), + "lin": ("l", "in"), + "ling": ("l", "ing"), + "liu": ("l", "iou"), + "lo": ("l", "o"), + "long": ("l", "ong"), + "lou": ("l", "ou"), + "lu": ("l", "u"), + "lv": ("l", "v"), + "luan": ("l", "uan"), + "lve": ("l", "ve"), + "lue": ("l", "ve"), + "lun": ("l", "uen"), + "luo": ("l", "uo"), + "ma": ("m", "a"), + "mai": ("m", "ai"), + "man": ("m", "an"), + "mang": ("m", "ang"), + "mao": ("m", "ao"), + "me": ("m", "e"), + "mei": ("m", "ei"), + "men": ("m", "en"), + "meng": ("m", "eng"), + "mi": ("m", "i"), + "mian": ("m", "ian"), + "miao": ("m", "iao"), + "mie": ("m", "ie"), + "min": ("m", "in"), + "ming": ("m", "ing"), + "miu": ("m", "iou"), + "mo": ("m", "o"), + "mou": ("m", "ou"), + "mu": ("m", "u"), + "na": ("n", "a"), + "nai": ("n", "ai"), + "nan": ("n", "an"), + "nang": ("n", "ang"), + "nao": ("n", "ao"), + "ne": ("n", "e"), + "nei": ("n", "ei"), + "nen": ("n", "en"), + "neng": ("n", "eng"), + "ni": ("n", "i"), + "nia": ("n", "ia"), + "nian": ("n", "ian"), + "niang": ("n", "iang"), + "niao": ("n", "iao"), + "nie": ("n", "ie"), + "nin": ("n", "in"), + "ning": ("n", "ing"), + "niu": ("n", "iou"), + "nong": ("n", "ong"), + "nou": ("n", "ou"), + "nu": ("n", "u"), + "nv": ("n", "v"), + "nuan": ("n", "uan"), + "nve": ("n", "ve"), + "nue": ("n", "ve"), + "nuo": ("n", "uo"), + "o": ("^", "o"), + "ou": ("^", "ou"), + "pa": ("p", "a"), + "pai": ("p", "ai"), + "pan": ("p", "an"), + "pang": ("p", "ang"), + "pao": ("p", "ao"), + "pe": ("p", "e"), + "pei": ("p", "ei"), + "pen": ("p", "en"), + "peng": ("p", "eng"), + "pi": ("p", "i"), + "pian": ("p", "ian"), + "piao": ("p", "iao"), + "pie": ("p", "ie"), + "pin": ("p", "in"), + "ping": ("p", "ing"), + "po": ("p", "o"), + "pou": ("p", "ou"), + "pu": ("p", "u"), + "qi": ("q", "i"), + "qia": ("q", "ia"), + "qian": ("q", "ian"), + "qiang": ("q", "iang"), + "qiao": ("q", "iao"), + "qie": ("q", "ie"), + "qin": ("q", "in"), + "qing": ("q", "ing"), + "qiong": ("q", "iong"), + "qiu": ("q", "iou"), + "qu": ("q", "v"), + "quan": ("q", "van"), + "que": ("q", "ve"), + "qun": ("q", "vn"), + "ran": ("r", "an"), + "rang": ("r", "ang"), + "rao": ("r", "ao"), + "re": ("r", "e"), + "ren": ("r", "en"), + "reng": ("r", "eng"), + "ri": ("r", "iii"), + "rong": ("r", "ong"), + "rou": ("r", "ou"), + "ru": ("r", "u"), + "rua": ("r", "ua"), + "ruan": ("r", "uan"), + "rui": ("r", "uei"), + "run": ("r", "uen"), + "ruo": ("r", "uo"), + "sa": ("s", "a"), + "sai": ("s", "ai"), + "san": ("s", "an"), + "sang": ("s", "ang"), + "sao": ("s", "ao"), + "se": ("s", "e"), + "sen": ("s", "en"), + "seng": ("s", "eng"), + "sha": ("sh", "a"), + "shai": ("sh", "ai"), + "shan": ("sh", "an"), + "shang": ("sh", "ang"), + "shao": ("sh", "ao"), + "she": ("sh", "e"), + "shei": ("sh", "ei"), + "shen": ("sh", "en"), + "sheng": ("sh", "eng"), + "shi": ("sh", "iii"), + "shou": ("sh", "ou"), + "shu": ("sh", "u"), + "shua": ("sh", "ua"), + "shuai": ("sh", "uai"), + "shuan": ("sh", "uan"), + "shuang": ("sh", "uang"), + "shui": ("sh", "uei"), + "shun": ("sh", "uen"), + "shuo": ("sh", "uo"), + "si": ("s", "ii"), + "song": ("s", "ong"), + "sou": ("s", "ou"), + "su": ("s", "u"), + "suan": ("s", "uan"), + "sui": ("s", "uei"), + "sun": ("s", "uen"), + "suo": ("s", "uo"), + "ta": ("t", "a"), + "tai": ("t", "ai"), + "tan": ("t", "an"), + "tang": ("t", "ang"), + "tao": ("t", "ao"), + "te": ("t", "e"), + "tei": ("t", "ei"), + "teng": ("t", "eng"), + "ti": ("t", "i"), + "tian": ("t", "ian"), + "tiao": ("t", "iao"), + "tie": ("t", "ie"), + "ting": ("t", "ing"), + "tong": ("t", "ong"), + "tou": ("t", "ou"), + "tu": ("t", "u"), + "tuan": ("t", "uan"), + "tui": ("t", "uei"), + "tun": ("t", "uen"), + "tuo": ("t", "uo"), + "wa": ("^", "ua"), + "wai": ("^", "uai"), + "wan": ("^", "uan"), + "wang": ("^", "uang"), + "wei": ("^", "uei"), + "wen": ("^", "uen"), + "weng": ("^", "ueng"), + "wo": ("^", "uo"), + "wu": ("^", "u"), + "xi": ("x", "i"), + "xia": ("x", "ia"), + "xian": ("x", "ian"), + "xiang": ("x", "iang"), + "xiao": ("x", "iao"), + "xie": ("x", "ie"), + "xin": ("x", "in"), + "xing": ("x", "ing"), + "xiong": ("x", "iong"), + "xiu": ("x", "iou"), + "xu": ("x", "v"), + "xuan": ("x", "van"), + "xue": ("x", "ve"), + "xun": ("x", "vn"), + "ya": ("^", "ia"), + "yan": ("^", "ian"), + "yang": ("^", "iang"), + "yao": ("^", "iao"), + "ye": ("^", "ie"), + "yi": ("^", "i"), + "yin": ("^", "in"), + "ying": ("^", "ing"), + "yo": ("^", "iou"), + "yong": ("^", "iong"), + "you": ("^", "iou"), + "yu": ("^", "v"), + "yuan": ("^", "van"), + "yue": ("^", "ve"), + "yun": ("^", "vn"), + "za": ("z", "a"), + "zai": ("z", "ai"), + "zan": ("z", "an"), + "zang": ("z", "ang"), + "zao": ("z", "ao"), + "ze": ("z", "e"), + "zei": ("z", "ei"), + "zen": ("z", "en"), + "zeng": ("z", "eng"), + "zha": ("zh", "a"), + "zhai": ("zh", "ai"), + "zhan": ("zh", "an"), + "zhang": ("zh", "ang"), + "zhao": ("zh", "ao"), + "zhe": ("zh", "e"), + "zhei": ("zh", "ei"), + "zhen": ("zh", "en"), + "zheng": ("zh", "eng"), + "zhi": ("zh", "iii"), + "zhong": ("zh", "ong"), + "zhou": ("zh", "ou"), + "zhu": ("zh", "u"), + "zhua": ("zh", "ua"), + "zhuai": ("zh", "uai"), + "zhuan": ("zh", "uan"), + "zhuang": ("zh", "uang"), + "zhui": ("zh", "uei"), + "zhun": ("zh", "uen"), + "zhuo": ("zh", "uo"), + "zi": ("z", "ii"), + "zong": ("z", "ong"), + "zou": ("z", "ou"), + "zu": ("z", "u"), + "zuan": ("z", "uan"), + "zui": ("z", "uei"), + "zun": ("z", "uen"), + "zuo": ("z", "uo"), +} + + +def gen_vocabs(): + import yaml + vocab = [f"<{c}{i}>" for c in list(pinyin_dict.keys()) for i in range(1,6)] + yaml.dump(vocab, open('./vocab.yaml', 'w')) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/pinyin.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/pinyin.py new file mode 100644 index 0000000000000000000000000000000000000000..8d15692e80ac5290ed254814482bf263f8964ea1 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/pinyin.py @@ -0,0 +1,137 @@ +import re + +from pypinyin import Style +from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin +from pypinyin.converter import DefaultConverter +from pypinyin.core import Pinyin + +from . import pinyin_dict +import torch + + +class MyConverter(NeutralToneWith5Mixin, DefaultConverter): + pass + + +def is_chinese(uchar): + if uchar >= u'\u4e00' and uchar <= u'\u9fa5': + return True + else: + return False + + +def clean_chinese(text: str): + text = text.strip() + text_clean = [] + for char in text: + if (is_chinese(char)): + text_clean.append(char) + else: + if len(text_clean) > 1 and is_chinese(text_clean[-1]): + text_clean.append(',') + text_clean = ''.join(text_clean).strip(',') + return text_clean + + +class G2P_PinYin(): + + def __init__(self): + super(G2P_PinYin, self).__init__() + self.pinyin_parser = Pinyin(MyConverter()) + + def get_phoneme4pinyin(self, pinyins): + result = [] + count_phone = [] + for pinyin in pinyins: + if pinyin[:-1] in pinyin_dict: + tone = pinyin[-1] + a = pinyin[:-1] + a1, a2 = pinyin_dict[a] + result += [a1, a2 + tone] + count_phone.append(2) + return result, count_phone + + # def chinese_to_phonemes(self, text): + # text = clean_chinese(text) + # phonemes = ["sil"] + # chars = ['[PAD]'] + # all_pinyins = [] + # count_phone = [] + # count_phone.append(1) + # for subtext in text.split(","): + # if (len(subtext) == 0): + # continue + # pinyins = self.correct_pinyin_tone3(subtext) + # all_pinyins.append(' '.join(pinyins)) + # sub_p, sub_c = self.get_phoneme4pinyin(pinyins) + # phonemes.extend(sub_p) + # phonemes.append(",") + # count_phone.extend(sub_c) + # count_phone.append(1) + # chars.append(subtext) + # chars.append(',') + # phonemes.append("sil") + # count_phone.append(1) + # chars.append('[PAD]') + # # char_embeds = self.prosody.expand_for_phone(char_embeds, count_phone) + # return " ".join(phonemes), " ".join(chars), ' , '.join(all_pinyins) + + def chinese_to_phonemes(self, text): + all_pinyins = [] + subtext = [] + for chr in text: + if is_chinese(chr): + subtext.append(chr) + else: + if subtext != []: + subtext = ''.join(subtext) + pinyins = self.correct_pinyin_tone3(subtext) + pinyins = [f"<{i}>" for i in pinyins] + all_pinyins.append(' '+ ' '.join(pinyins)+ ' ') + all_pinyins.append(chr) + subtext = [] + if subtext != []: + subtext = ''.join(subtext) + pinyins = self.correct_pinyin_tone3(subtext) + pinyins = [f"<{i}>" for i in pinyins] + all_pinyins.append(' '+ ' '.join(pinyins)+ ' ') + # char_embeds = self.prosody.expand_for_phone(char_embeds, count_phone) + return ''.join(all_pinyins) + + def correct_pinyin_tone3(self, text): + pinyin_list = [ + p[0] + for p in self.pinyin_parser.pinyin(text, + style=Style.TONE3, + strict=False, + neutral_tone_with_five=True) + ] + if len(pinyin_list) >= 2: + for i in range(1, len(pinyin_list)): + try: + if re.findall(r'\d', + pinyin_list[i - 1])[0] == '3' and re.findall( + r'\d', pinyin_list[i])[0] == '3': + pinyin_list[i - 1] = pinyin_list[i - 1].replace( + '3', '2') + except IndexError: + pass + return pinyin_list + + # def expand_for_phone(self, char_embeds, length): # length of phones for char + # if(char_embeds.size(0) > len(length)): + # print(char_embeds.shape, len(length)) + # char_embeds = char_embeds[0:len(length),:] + # elif(char_embeds.size(0) < len(length)): + # print(char_embeds.shape, len(length)) + # length = length[0:char_embeds.size(0)] + # expand_vecs = list() + # for vec, leng in zip(char_embeds, length): + # vec = vec.expand(leng, -1) + # expand_vecs.append(vec) + # expand_embeds = torch.cat(expand_vecs, 0) + # assert expand_embeds.size(0) == sum(length) + # return expand_embeds.numpy() + + def __call__(self, text): + return self.chinese_to_phonemes(text) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..c387dfe285fa0a10624e9bb1d4aebc248bea8d78 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py @@ -0,0 +1,71 @@ +_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"] + +_initials = [ + "^", + "b", + "c", + "ch", + "d", + "f", + "g", + "h", + "j", + "k", + "l", + "m", + "n", + "p", + "q", + "r", + "s", + "sh", + "t", + "x", + "z", + "zh", +] + +_tones = ["1", "2", "3", "4", "5"] + +_finals = [ + "a", + "ai", + "an", + "ang", + "ao", + "e", + "ei", + "en", + "eng", + "er", + "i", + "ia", + "ian", + "iang", + "iao", + "ie", + "ii", + "iii", + "in", + "ing", + "iong", + "iou", + "o", + "ong", + "ou", + "u", + "ua", + "uai", + "uan", + "uang", + "uei", + "uen", + "ueng", + "uo", + "v", + "van", + "ve", + "vn", +] + +symbols = _pause + _initials + [i + j for i in _finals for j in _tones] diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/structure.yaml b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/structure.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cf82be7beeec2e1a637847805ebdcd9a820e6b2a --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/structure.yaml @@ -0,0 +1,10 @@ +- '[start]' +- '[verse]' +- '[chorus]' +- '[outro]' +- '[end]' +- '[intro]' +- '[solo]' +- '[inst]' +- '[bridge]' +- '[break]' diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/tokenizer1.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/tokenizer1.py new file mode 100644 index 0000000000000000000000000000000000000000..231fc870eb3358f15992a8d6df26540b71e1ad33 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/tokenizer1.py @@ -0,0 +1,72 @@ +import torch.nn as nn +from transformers import LlamaTokenizer +import os +import typing as tp +import torch +import sys +from pinyin.pinyin import G2P_PinYin + + +ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] # condition, mask + +def process_line(line): + line = line.strip()[2:] + if(line[0]=='\'' and line[-1]=='\''): + line = line[1:-1] + return line + +class LlamaTokenizerConditioner(nn.Module): + def __init__(self, device: str = 'cpu', max_len = 3000, padding_idx='', tokenizer_type=None, + pretrained="hfl/chinese-llama-2-13b"): #"hfl/chinese-llama-2-13b" + super().__init__() + print(f"text tokenizer from {pretrained}") + self.text_tokenizer = LlamaTokenizer.from_pretrained(pretrained,cache_dir="huggingface_cache") + print(f"tokenizer vocab size: {self.text_tokenizer.vocab_size}") + self.g2p = G2P_PinYin() + add_token_list = [] + with open(os.path.dirname(os.path.abspath(__file__))+'/vocab.yaml', 'r') as f: + for line in f: + if(line): + add_token_list.append(process_line(line)) + type_tokens = [] + with open(os.path.dirname(os.path.abspath(__file__))+'/structure.yaml', 'r') as f: + for line in f: + if(line): + type_tokens.append(process_line(line)) + if add_token_list != []: + self.text_tokenizer.add_tokens(add_token_list, special_tokens=True) + # voc_size = self.text_tokenizer.vocab_size + voc_size = len(self.text_tokenizer.get_vocab()) # 加了额外token之后vocab_size似乎不会额外增加 ——cyy + print( voc_size) + # import pdb; pdb.set_trace() + padding_idx = str(padding_idx) + + self.text_tokenizer.pad_token = padding_idx + self.max_len = max_len + self.padding_idx = padding_idx + + vocab = self.text_tokenizer.get_vocab() + self.type_token_ids = [vocab[i] for i in type_tokens if i in vocab] + struct_tokens = [padding_idx] + [i for i in add_token_list if i[0]=='[' and i[-1]==']'] + self.struct_token_ids = [vocab[i] for i in struct_tokens] + print("type tokens: ",{self.text_tokenizer.convert_ids_to_tokens(i):i for i in self.type_token_ids}, + "\t all structure tokens: ", {self.text_tokenizer.convert_ids_to_tokens(i):i for i in self.struct_token_ids}) + + def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: + x = [self.g2p(xi) if xi is not None else "" for xi in x] + inputs = self.text_tokenizer(x, return_tensors="pt", padding=True) + # print(x, [self.text_tokenizer.convert_ids_to_tokens(i.tolist()) for i in inputs['input_ids']]) + # import pdb; pdb.set_trace() + if inputs['input_ids'].shape[-1] > self.max_len: + warnings.warn(f"Max len limit ({self.max_len}) Exceed! {x}") + + # print(x, inputs['input_ids'].shape) + return inputs + + +if __name__ == "__main__": + tokenizer = LlamaTokenizerConditioner() + out = tokenizer.tokenize(["im ok today, and im happy now", "今天我很开心"]) + print(out) + print(tokenizer.text_tokenizer.decode(out['input_ids'][0][:4])) + print(tokenizer.text_tokenizer.convert_ids_to_tokens(out['input_ids'][0])) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/vocab.yaml b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/vocab.yaml new file mode 100644 index 0000000000000000000000000000000000000000..01c7fad863f85e4081bee46c54281d4c9589eb90 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/vocab.yaml @@ -0,0 +1,2086 @@ +- '[MUSIC]' +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- +- diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d.py new file mode 100644 index 0000000000000000000000000000000000000000..b754c557841aa157bc1f1cc4ccd561b6ad8b5d2c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d.py @@ -0,0 +1,488 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from models.attention import BasicTransformerBlock +from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + # print("before input hidden_states.shape", hidden_states.shape) + # print("is_input_patches", self.is_input_patches) + #true + # print("is_input_vectorized", self.is_input_vectorized) + # print("is_input_continuous", self.is_input_continuous) + #false + # print("use_linear_projection", self.use_linear_projection) + #false + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + # print("input 1 hidden_states.shape", hidden_states.shape) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + # print("input 2 hidden_states.shape", hidden_states.shape) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + # print("before patches hidden_states.shape", hidden_states.shape) + hidden_states = self.pos_embed(hidden_states) + # print("after patches hidden_states.shape", hidden_states.shape) + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + batch_size = hidden_states.shape[0] + # print("time step before adaln_single", timestep) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + # print("time step after adaln_single", timestep.shape) + # print("embedded_timestep after adaln_single", embedded_timestep.shape) + + + # 2. Blocks + # print("before blocks hidden_states.shape", hidden_states.shape) + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + + # 3. Output + # print("after blocks hidden_states.shape", hidden_states.shape) + if self.is_input_continuous: + # print("1") + if not self.use_linear_projection: + # print("1 hidden_states.shape", hidden_states.shape) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + # print("2 hidden_states.shape", hidden_states.shape) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + # print("2") + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if self.is_input_patches: + # print("3") + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + # print("before proj_out hidden_states.shape", hidden_states.shape) + hidden_states = self.proj_out(hidden_states) + # print("after proj_out hidden_states.shape", hidden_states.shape) + + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + # print("4") + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + # print("hidden_states.shape",hidden_states.shape) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d_flow.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf331555534d63fcacc05ee5163530a41cfb910 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d_flow.py @@ -0,0 +1,545 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +import math +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from models.attention import BasicTransformerBlock +from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.embeddings import TimestepEmbedding + +class PixArtAlphaCombinedFlowEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.flow_t_size = 512 + self.outdim = size_emb_dim + self.timestep_embedder = TimestepEmbedding(in_channels=self.flow_t_size, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + # https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87 + def timestep_embedding(self, timesteps, max_period=10000, scale=1000): + """Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = self.flow_t_size // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type()) + args = timesteps[:, None] * freqs[None] * scale + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if self.flow_t_size % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.timestep_embedding(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + +class AdaLayerNormSingleFlow(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = PixArtAlphaCombinedFlowEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingleFlow(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + hidden_states = self.pos_embed(hidden_states) + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d_flow_diff_height_width.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d_flow_diff_height_width.py new file mode 100644 index 0000000000000000000000000000000000000000..5a407c58aa81bcb048d537534f7e791c7fe3fd95 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d_flow_diff_height_width.py @@ -0,0 +1,610 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +import math +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from models.attention import BasicTransformerBlock +from diffusers.models.embeddings import get_2d_sincos_pos_embed, PixArtAlphaTextProjection +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.embeddings import TimestepEmbedding + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + height=224, + width=224, + patch_size=[4,4], + in_channels=3, + embed_dim=768, + layer_norm=False, + flatten=True, + bias=True, + interpolation_scale=1, + ): + super().__init__() + + num_patches = (height // patch_size[0]) * (width // patch_size[1]) + self.flatten = flatten + self.layer_norm = layer_norm + + self.proj = nn.Conv2d( + in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias + ) + if layer_norm: + self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) + else: + self.norm = None + + self.patch_size = patch_size + # See: + # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161 + self.height, self.width = height // patch_size[0], width // patch_size[1] + self.base_size = patch_size[0] * patch_size[1] + self.interpolation_scale = interpolation_scale + pos_embed = get_2d_sincos_pos_embed( + embed_dim, (height // patch_size[0], width // patch_size[1]), base_size=1, interpolation_scale=self.interpolation_scale + ) + self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) + + def forward(self, latent): + height, width = latent.shape[-2] // self.patch_size[0], latent.shape[-1] // self.patch_size[1] + + latent = self.proj(latent) + if self.flatten: + latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC + if self.layer_norm: + latent = self.norm(latent) + + # Interpolate positional embeddings if needed. + # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160) + if self.height != height or self.width != width: + pos_embed = get_2d_sincos_pos_embed( + embed_dim=self.pos_embed.shape[-1], + grid_size=(height, width), + base_size=self.base_size, + interpolation_scale=self.interpolation_scale, + ) + pos_embed = torch.from_numpy(pos_embed) + pos_embed = pos_embed.float().unsqueeze(0).to(latent.device) + else: + pos_embed = self.pos_embed + + return (latent + pos_embed).to(latent.dtype) + +class PixArtAlphaCombinedFlowEmbeddings(nn.Module): + """ + For PixArt-Alpha. + + Reference: + https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 + """ + + def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): + super().__init__() + + self.flow_t_size = 512 + self.outdim = size_emb_dim + self.timestep_embedder = TimestepEmbedding(in_channels=self.flow_t_size, time_embed_dim=embedding_dim) + + self.use_additional_conditions = use_additional_conditions + if use_additional_conditions: + self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) + self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) + + # https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87 + def timestep_embedding(self, timesteps, max_period=10000, scale=1000): + """Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = self.flow_t_size // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type()) + args = timesteps[:, None] * freqs[None] * scale + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if self.flow_t_size % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): + timesteps_proj = self.timestep_embedding(timestep) + timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) + + if self.use_additional_conditions: + resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) + resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) + aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) + aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) + conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) + else: + conditioning = timesteps_emb + + return conditioning + +class AdaLayerNormSingleFlow(nn.Module): + r""" + Norm layer adaptive layer norm single (adaLN-single). + + As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3). + + Parameters: + embedding_dim (`int`): The size of each embedding vector. + use_additional_conditions (`bool`): To use additional conditions for normalization or not. + """ + + def __init__(self, embedding_dim: int, use_additional_conditions: bool = False): + super().__init__() + + self.emb = PixArtAlphaCombinedFlowEmbeddings( + embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions + ) + + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) + + def forward( + self, + timestep: torch.Tensor, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + batch_size: Optional[int] = None, + hidden_dtype: Optional[torch.dtype] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + # No modulation happening here. + embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + return self.linear(self.silu(embedded_timestep)), embedded_timestep + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`list`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[list] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size[0] + self.width = sample_size[1] + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size[0] + self.width = sample_size[1] + + self.patch_size = patch_size + interpolation_scale = 1 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size[0], + width=sample_size[1], + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size[0] * patch_size[1] * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size[0] * patch_size[1] * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingleFlow(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size[0], hidden_states.shape[-1] // self.patch_size[1] + hidden_states = self.pos_embed(hidden_states) + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if self.is_input_patches: + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size[0], self.patch_size[1], self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size[0], width * self.patch_size[1]) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d_rot_emb.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d_rot_emb.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ea15dad35f26f63fa7b31acfeb4f1907d8ea5b --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/transformer_2d_rot_emb.py @@ -0,0 +1,488 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version +from models.attention_add_rot_emb import BasicTransformerBlock +from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = linear_cls(in_channels, inner_dim) + else: + self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 + interpolation_scale = max(interpolation_scale, 1) + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + interpolation_scale=interpolation_scale, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = linear_cls(inner_dim, in_channels) + else: + self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches and norm_type != "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + elif self.is_input_patches and norm_type == "ada_norm_single": + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) + self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + # 5. PixArt-Alpha blocks. + self.adaln_single = None + self.use_additional_conditions = False + if norm_type == "ada_norm_single": + self.use_additional_conditions = self.config.sample_size == 128 + # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use + # additional conditions until we find better name + self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + added_cond_kwargs: Dict[str, torch.Tensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # Retrieve lora scale. + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + # 1. Input + # print("before input hidden_states.shape", hidden_states.shape) + # print("is_input_patches", self.is_input_patches) + #true + # print("is_input_vectorized", self.is_input_vectorized) + # print("is_input_continuous", self.is_input_continuous) + #false + # print("use_linear_projection", self.use_linear_projection) + #false + if self.is_input_continuous: + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + # print("input 1 hidden_states.shape", hidden_states.shape) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + # print("input 2 hidden_states.shape", hidden_states.shape) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = ( + self.proj_in(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_in(hidden_states) + ) + + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size + # print("before patches hidden_states.shape", hidden_states.shape) + hidden_states = self.pos_embed(hidden_states) + # print("after patches hidden_states.shape", hidden_states.shape) + + if self.adaln_single is not None: + if self.use_additional_conditions and added_cond_kwargs is None: + raise ValueError( + "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." + ) + batch_size = hidden_states.shape[0] + # print("time step before adaln_single", timestep) + timestep, embedded_timestep = self.adaln_single( + timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + # print("time step after adaln_single", timestep.shape) + # print("embedded_timestep after adaln_single", embedded_timestep.shape) + + + # 2. Blocks + # print("before blocks hidden_states.shape", hidden_states.shape) + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + + # 3. Output + # print("after blocks hidden_states.shape", hidden_states.shape) + if self.is_input_continuous: + # print("1") + if not self.use_linear_projection: + # print("1 hidden_states.shape", hidden_states.shape) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + # print("2 hidden_states.shape", hidden_states.shape) + else: + hidden_states = ( + self.proj_out(hidden_states, scale=lora_scale) + if not USE_PEFT_BACKEND + else self.proj_out(hidden_states) + ) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + # print("2") + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if self.is_input_patches: + # print("3") + if self.config.norm_type != "ada_norm_single": + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + elif self.config.norm_type == "ada_norm_single": + shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + # print("before proj_out hidden_states.shape", hidden_states.shape) + hidden_states = self.proj_out(hidden_states) + # print("after proj_out hidden_states.shape", hidden_states.shape) + + hidden_states = hidden_states.squeeze(1) + + # unpatchify + if self.adaln_single is None: + # print("4") + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + # print("hidden_states.shape",hidden_states.shape) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/README.md b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0f8b11e0432914396f6200821ebf9f30d88d34e6 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/README.md @@ -0,0 +1,65 @@ +# Our MERT & BEST-RQ +Our implementation on MERT model. Files modified: +- mert_fairseq/models/mert/mert_model.py +- mert_fairseq/data/mert_dataset.py +- run_training_mulNodes_wotorchdist_womodelparsize.sh + +# Prepare + +The MERT training is implemented with [fairseq](https://github.com/pytorch/fairseq). You need to clone the fairseq repo inside our repo at ./src/fairseq and MERT implementation codes as a fairseq example projcet. + +You can do that by following the steps: +``` +mkdir -c ./src/fairseq +cd ./src +git clone https://github.com/pytorch/fairseq +``` + + +# Docker +``` +mirrors.tencent.com/cloudezhou/mert:v3 +``` + +# Start + +### 1-node training + +``` +bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes_debug1node +``` + + +### 1-node training (BEST-RQ) + +``` +bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_95M_bestrq +``` + +### 4-node training +``` +bash run_training_mulNodes_wotorchdist_womodelparsize.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes +bash run_training_mulNodes_wotorchdist_womodelparsize.sh 1 dummy MERT_RVQ-VAE_CQT_330M_multinodes +bash run_training_mulNodes_wotorchdist_womodelparsize.sh 2 dummy MERT_RVQ-VAE_CQT_330M_multinodes +bash run_training_mulNodes_wotorchdist_womodelparsize.sh 3 dummy MERT_RVQ-VAE_CQT_330M_multinodes +``` + + +### 4-node training (BEST-RQ) +``` +bash run_training_mulNodes_wotorchdist_womodelparsize.sh $INDEX dummy MERT_RVQ-VAE_CQT_95M_bestrq_multinodes BEST_RQ $CHIEF_IP +``` + +### 4-node training (MusicFM) +``` +bash run_training_mulNodes_wotorchdist_womodelparsize.sh $INDEX dummy MusicFM_95M_multinodes MUSICFM $CHIEF_IP +``` + +### 4-node training (EAT) +``` +bash run_training_eat.sh $INDEX dummy EAT_pretraining_music_multinodes EAT $CHIEF_IP +``` + +You could set the parameters in [mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml](mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml) + +Our latest checkpoints is loaded at [data/fairseq_savedir/ckpt_MERT_RVQ-VAE_CQT/MERT_RVQ-VAE_CQT_330M/checkpoint_last.pt](data/fairseq_savedir/ckpt_MERT_RVQ-VAE_CQT/MERT_RVQ-VAE_CQT_330M/checkpoint_last.pt) diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b2d31ab63754090138a7c097ca142d3582bf988 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_AS2M.yaml @@ -0,0 +1,122 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + fp16_no_flatten_grads: true + user_dir: ${env:PWD} + seed: 1 + +checkpoint: + save_interval: 1 + save_interval_updates: 10000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: mae_image_pretraining + data: /hpc_stor03/sjtu_home/wenxi.chen/mydata/audio/unbalanced_train + rebuild_batches: true + key: source + precompute_mask_config: {} + downsr_16hz: true + audio_mae: true + h5_format: false + target_length: 1024 + flexible_mask: false + +dataset: + num_workers: 10 + batch_size: 12 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 4 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +optimization: + max_update: 400000 + lr: [ 0.0005 ] + debug_param_names: true + clip_norm: 4 + +optimizer: + _name: composite + dynamic_groups: true + groups: + default: + lr_float: 0.0005 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0.05 + lr_scheduler: + _name: cosine + warmup_updates: 53333 + +lr_scheduler: pass_through + +model: + _name: data2vec_multi + + ema_decay: 0.9998 + ema_end_decay: 0.99999 + ema_anneal_end_step: 100000 + instance_norm_target_layer: true + layer_norm_target_layer: false + layer_norm_targets: true + end_of_block_targets: false + + depth: 12 + average_top_k_layers: 12 + clone_batch: 16 + + norm_eps: 1e-6 + + min_target_var: 0 + min_pred_var: 0 + + encoder_dropout: 0 + post_mlp_drop: 0 + attention_dropout: 0 + activation_dropout: 0 + + supported_modality: IMAGE + cls_loss: 1 + + ema_encoder_only: false + + modalities: + image: + in_chans: 1 + inverse_mask: true + mask_prob: 0.8 + mask_prob_adjust: 0.07 + mask_length: 5 + mask_noise_std: 0.01 + prenet_depth: 0 + ema_local_encoder: true + num_extra_tokens: 1 + init_extra_token_zero: false + use_alibi_encoder: false + decoder: + decoder_dim: 768 + decoder_groups: 16 + decoder_kernel: 3 + decoder_layers: 6 + input_dropout: 0 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_music_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_music_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..92318155dcb8ab93a395dbb49d9b99144b534ded --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/EAT_pretraining_music_multinodes.yaml @@ -0,0 +1,125 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + tensorboard_logdir: tb + min_loss_scale: 1e-6 + fp16_no_flatten_grads: true + user_dir: ${env:PWD} + seed: 1 + +checkpoint: + save_interval: 1 + save_interval_updates: 10000 + keep_interval_updates: 1000 + no_epoch_checkpoints: true + +task: + _name: mae_image_pretraining + data: music4all_sh/ + rebuild_batches: true + key: source + precompute_mask_config: {} + downsr_16hz: false + audio_mae: true + h5_format: false + target_length: 752 + flexible_mask: false + sample_rate: 24000 + fixed_duration: 30 + +dataset: + num_workers: 10 + batch_size: 12 + skip_invalid_size_inputs_valid_test: true + required_batch_size_multiple: 1 + disable_validation: true + +distributed_training: + distributed_world_size: 4 + ddp_backend: c10d + +criterion: + _name: model + log_keys: + - ema_decay + - target_var + - pred_var + - model_norm + - ema_norm + - masked_pct + +optimization: + max_update: 400000 + lr: [ 0.0001 ] + # debug_param_names: true + clip_norm: 4 + +optimizer: + _name: composite + # dynamic_groups: true + groups: + default: + lr_float: 0.0005 + optimizer: + _name: adam + adam_betas: [0.9,0.95] + weight_decay: 0.05 + lr_scheduler: + _name: cosine + warmup_updates: 10000 # 53333 + +lr_scheduler: pass_through + +model: + _name: data2vec_multi + + ema_decay: 0.9998 + ema_end_decay: 0.99999 + ema_anneal_end_step: 100000 + instance_norm_target_layer: true + layer_norm_target_layer: false + layer_norm_targets: true + end_of_block_targets: false + + depth: 12 + average_top_k_layers: 12 + clone_batch: 16 + + norm_eps: 1e-6 + + min_target_var: 0 + min_pred_var: 0 + + encoder_dropout: 0 + post_mlp_drop: 0 + attention_dropout: 0 + activation_dropout: 0 + + supported_modality: IMAGE + cls_loss: 1 + + ema_encoder_only: false + + modalities: + image: + in_chans: 1 + inverse_mask: true + mask_prob: 0.8 + mask_prob_adjust: 0.07 + mask_length: 5 + mask_noise_std: 0.01 + prenet_depth: 0 + ema_local_encoder: true + num_extra_tokens: 1 + init_extra_token_zero: false + use_alibi_encoder: false + decoder: + decoder_dim: 768 + decoder_groups: 16 + decoder_kernel: 3 + decoder_layers: 6 + input_dropout: 0 + target_length: 752 \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml new file mode 100644 index 0000000000000000000000000000000000000000..900a60b1ce41b671db2ac77a6d0cd4290dc8ff2f --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M.yaml @@ -0,0 +1,137 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 100 + seed: 1337 + + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 5000 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + # reset-dataloader: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sharding_data: -1 #数据分块 + load_random_data_shard: false + sample_rate: 24000 + # crop to 5s + # max_sample_size: 120000 + # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. + max_sample_size: 122880 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + # normalize: true # must be consistent with extractor_mode: layer_norm + normalize: false # must be consistent with extractor_mode: default (groupnorm) + + +dataset: + num_workers: 6 + max_tokens: 900000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 1000000 + lr: [0.0015] + clip_norm: 1.0 + update_freq: [8] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 36000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + + final_dim: 128 + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + # default refers to group norm + extractor_mode: default + # extractor_mode: layer_norm + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + encoder_layerdrop: 0.0 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + + layer_norm_first: true + feature_grad_mult: 1.0 + + untie_final_proj: true + activation_dropout: 0.0 + + deepnorm: false + attention_relax: 32.0 + + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6ddce0e95f235300e69375bed680234a015112e --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes.yaml @@ -0,0 +1,139 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 100 + seed: 1337 + # model_parallel_size: 8 + # amp: true + + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 5000 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + # reset-dataloader: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sharding_data: -1 #数据分块 + load_random_data_shard: false + sample_rate: 24000 + # crop to 5s + # max_sample_size: 120000 + # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. + max_sample_size: 122880 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + # normalize: true # must be consistent with extractor_mode: layer_norm + normalize: false # must be consistent with extractor_mode: default (groupnorm) + + +dataset: + num_workers: 6 + max_tokens: 900000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 1000000 + lr: [0.0015] + clip_norm: 1.0 + update_freq: [8] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 36000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + + final_dim: 128 + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + # default refers to group norm + extractor_mode: default + # extractor_mode: layer_norm + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + encoder_layerdrop: 0.0 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + + layer_norm_first: true + feature_grad_mult: 1.0 + + untie_final_proj: true + activation_dropout: 0.0 + + deepnorm: false + attention_relax: 32.0 + + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: run + sweep: + dir: sweep + subdir: subdir diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug1node.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug1node.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e65613fd792771729436d20f48578fc88a5e7ad1 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug1node.yaml @@ -0,0 +1,138 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 100 + seed: 1337 + # amp: true + + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 5000 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + # reset-dataloader: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sharding_data: -1 #数据分块 + load_random_data_shard: false + sample_rate: 24000 + # crop to 5s + # max_sample_size: 120000 + # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. + max_sample_size: 122880 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + # normalize: true # must be consistent with extractor_mode: layer_norm + normalize: false # must be consistent with extractor_mode: default (groupnorm) + + +dataset: + num_workers: 6 + max_tokens: 900000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 1000000 + lr: [0.0015] + clip_norm: 1.0 + update_freq: [8] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 36000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + + final_dim: 128 + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + # default refers to group norm + extractor_mode: default + # extractor_mode: layer_norm + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + encoder_layerdrop: 0.0 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + + layer_norm_first: true + feature_grad_mult: 1.0 + + untie_final_proj: true + activation_dropout: 0.0 + + deepnorm: false + attention_relax: 32.0 + + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: run + sweep: + dir: sweep + subdir: subdir diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug2node.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug2node.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de0ca018b11daebbee3a6b82c5356bf69d0bd839 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_multinodes_debug2node.yaml @@ -0,0 +1,139 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 100 + seed: 1337 + model_parallel_size: 8 + # amp: true + + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 5000 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + # reset-dataloader: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sharding_data: -1 #数据分块 + load_random_data_shard: false + sample_rate: 24000 + # crop to 5s + # max_sample_size: 120000 + # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. + max_sample_size: 122880 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + # normalize: true # must be consistent with extractor_mode: layer_norm + normalize: false # must be consistent with extractor_mode: default (groupnorm) + + +dataset: + num_workers: 6 + max_tokens: null + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 1000000 + lr: [0.0015] + clip_norm: 1.0 + update_freq: [8] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 36000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + + final_dim: 128 + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + # default refers to group norm + extractor_mode: default + # extractor_mode: layer_norm + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + encoder_layerdrop: 0.0 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + + layer_norm_first: true + feature_grad_mult: 1.0 + + untie_final_proj: true + activation_dropout: 0.0 + + deepnorm: false + attention_relax: 32.0 + + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: run + sweep: + dir: sweep + subdir: subdir diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_orig.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_orig.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f1beb97fd08917202a5574535ad8badabbeae72c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_orig.yaml @@ -0,0 +1,135 @@ +# @package _group_ +common: + fp16: true + log_format: json + log_interval: 100 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 5000 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sharding_data: 6 + load_random_data_shard: false + sample_rate: 24000 + # crop to 5s + # max_sample_size: 120000 + # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. + max_sample_size: 122880 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + # normalize: true # must be consistent with extractor_mode: layer_norm + normalize: false # must be consistent with extractor_mode: default (groupnorm) + + +dataset: + num_workers: 6 + max_tokens: 900000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0015] + clip_norm: 1.0 + update_freq: [8] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 36000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + + final_dim: 128 + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + # default refers to group norm + extractor_mode: default + # extractor_mode: layer_norm + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + encoder_layerdrop: 0.0 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + + layer_norm_first: true + feature_grad_mult: 1.0 + + untie_final_proj: true + activation_dropout: 0.0 + + deepnorm: false + attention_relax: 32.0 + + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_tune.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_tune.yaml new file mode 100644 index 0000000000000000000000000000000000000000..82adc82cd824b65e0423221de7565307c893cd75 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_330M_tune.yaml @@ -0,0 +1,137 @@ +# @package _group_ +common: + fp16: true + log_format: json + log_interval: 100 + seed: 1337 + + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 5000 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + # reset-dataloader: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sharding_data: -1 #数据分块 + load_random_data_shard: false + sample_rate: 24000 + # crop to 5s + # max_sample_size: 120000 + # crop to 5.12s, refers to 384 token per audio, which can be devided by 8. + max_sample_size: 122880 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + # normalize: true # must be consistent with extractor_mode: layer_norm + normalize: false # must be consistent with extractor_mode: default (groupnorm) + + +dataset: + num_workers: 6 + max_tokens: 900000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0015] + clip_norm: 1.0 + update_freq: [8] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + # freeze_parameters:true + logit_temp: 0.1 + + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 36000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + + final_dim: 128 + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + # default refers to group norm + extractor_mode: default + # extractor_mode: layer_norm + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + encoder_layerdrop: 0.0 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + + layer_norm_first: true + feature_grad_mult: 1.0 + + untie_final_proj: true + activation_dropout: 0.0 + + deepnorm: false + attention_relax: 32.0 + + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1f5cea6050278a6af45f9e4d8f2b4c20476ab5a6 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M.yaml @@ -0,0 +1,116 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # crop to 5s + max_sample_size: 120000 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 24000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + extractor_mode: default + audio_extract_type: w2v_conv + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + # ----------- + final_dim: 64 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bdb39618e657a1f2f2663943e94b4d5bc176cc75 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq.yaml @@ -0,0 +1,125 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 8 # 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # crop to 5s + max_sample_size: 120000 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 24000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + extractor_mode: default + audio_extract_type: melspec # use melspec (instead of `w2v_conv`) + melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + + # best-rq loss + audio_rq_loss_m: true + audio_rq_loss_embed_dim: 16 + audio_rq_loss_num_codebooks: 1 + audio_rq_loss_num_embeds: 8192 + audio_rq_loss_seed: 42 + audio_rq_loss_use_norm: true + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + # ----------- + final_dim: 64 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3ea092f03084c911848080d58b470fe5dcb00f8d --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes.yaml @@ -0,0 +1,128 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 12500 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # crop to 5s + max_sample_size: 120000 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 24000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + extractor_mode: default + audio_extract_type: melspec # use melspec (instead of `w2v_conv`) + melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + + # best-rq loss + audio_rq_loss_m: true + audio_rq_loss_embed_dim: 16 + audio_rq_loss_num_codebooks: 1 + audio_rq_loss_num_embeds: 8192 + audio_rq_loss_seed: 42 + audio_rq_loss_use_norm: true + audio_rq_loss_use_chroma: true + audio_rq_loss_seed_chroma: 123 + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + # ----------- + final_dim: 32 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c7471c8e8482ad82bcadcdc7909513a6b17efd88 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_multinodes.yaml @@ -0,0 +1,126 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 12500 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # crop to 5s + max_sample_size: 120000 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 24000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + extractor_mode: default + audio_extract_type: melspec # use melspec (instead of `w2v_conv`) + melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + + # best-rq loss + audio_rq_loss_m: true + audio_rq_loss_embed_dim: 16 + audio_rq_loss_num_codebooks: 1 + audio_rq_loss_num_embeds: 8192 + audio_rq_loss_seed: 42 + audio_rq_loss_use_norm: true + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + # ----------- + final_dim: 64 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ce6f750f1cc7e4bccb947713fa6eef280673e53a --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes.yaml @@ -0,0 +1,128 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 12500 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # crop to 5s + max_sample_size: 120000 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 24000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + extractor_mode: default + audio_extract_type: melspec # use melspec (instead of `w2v_conv`) + melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + + # best-rq loss + audio_rq_loss_m: true + audio_rq_loss_embed_dim: 16 + audio_rq_loss_num_codebooks: 1 + audio_rq_loss_num_embeds: 8192 + audio_rq_loss_seed: 42 + audio_rq_loss_use_norm: true + audio_rq_loss_use_chroma: false + audio_rq_loss_seed_chroma: 123 + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + # ----------- + final_dim: 64 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_speech_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_speech_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b296cdc55d0c69e4e6630e29a12ba7acb0bb6727 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrq_norm_speech_multinodes.yaml @@ -0,0 +1,128 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 12500 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # crop to 5s + max_sample_size: 120000 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + # ----- mixture ------ + mixture_prob: 0 # 0.5 + inbatch_noise_augment_len_range: "[12000, 24000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + extractor_mode: default + audio_extract_type: melspec # use melspec (instead of `w2v_conv`) + melspec_n_bins: 80 # 120 # for melspec we use 120, means 12 bins per octave + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + + # best-rq loss + audio_rq_loss_m: true + audio_rq_loss_embed_dim: 16 + audio_rq_loss_num_codebooks: 1 + audio_rq_loss_num_embeds: 8192 + audio_rq_loss_seed: 42 + audio_rq_loss_use_norm: true + audio_rq_loss_use_chroma: false + audio_rq_loss_seed_chroma: 123 + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: false + audio_cqt_bins: 336 + # ----------- + final_dim: 64 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrvq_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrvq_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4c898b49f42845afeb8df7823e76d69f66a53ce5 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_bestrvq_multinodes.yaml @@ -0,0 +1,121 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 12500 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # crop to 5s + max_sample_size: 120000 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 24000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + extractor_mode: default + audio_extract_type: w2v_conv + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + + # ---- codec target + audio_codec_type: rvq + audio_codec_ckpt_path: RVQ_3000.pth + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + # ----------- + final_dim: 64 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..52d274734f088016f071c5c13b9d5c2add7af536 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_dac_multinodes.yaml @@ -0,0 +1,121 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 12500 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # crop to 5s + max_sample_size: 120000 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 24000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + extractor_mode: default + audio_extract_type: w2v_conv + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + + # ---- codec target + audio_codec_type: dac + audio_codec_dac_model_path: weights_24khz_8kbps_0.0.4.pth #nj + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + # ----------- + final_dim: 64 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35d50add44d5354045a8272d18fba4e151724bee --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes.yaml @@ -0,0 +1,125 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 12500 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # crop to 5s + max_sample_size: 120000 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 24000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + extractor_mode: default + audio_extract_type: melspec # use melspec (instead of `w2v_conv`) + melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + + # best-rq loss + audio_rq_loss_m: true + audio_rq_loss_embed_dim: 16 + audio_rq_loss_num_codebooks: 64 # 32 + audio_rq_loss_num_embeds: 1024 + audio_rq_loss_seed: 42 + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + # ----------- + final_dim: 16 # 64 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_mel_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_mel_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..83e0d8b62c7afa92a7f9931c8dd686cff2fff737 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MERT_RVQ-VAE_CQT_95M_mel_multinodes.yaml @@ -0,0 +1,124 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 12500 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # crop to 5s + max_sample_size: 120000 + min_sample_size: 72000 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: hubert + pred_masked_weight: 1.0 + pred_nomask_weight: 0.0 + loss_weights: [10, 1] + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: mert + label_rate: ??? + skip_masked: false + skip_nomask: true + mask_prob: 0.8 + mask_length: 5 + + logit_temp: 0.1 + + # ----- mixture ------ + mixture_prob: 0.5 + inbatch_noise_augment_len_range: "[12000, 24000]" + inbatch_noise_augment_number_range: "[1, 3]" + inbatch_noise_augment_volume: 1.0 + # ------------------------ + extractor_mode: default + audio_extract_type: melspec # use melspec (instead of `w2v_conv`) + melspec_n_bins: 120 # for melspec we use 120, means 12 bins per octave + conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2' + + # best-rq loss + audio_rq_loss_m: false + audio_rq_loss_embed_dim: 16 + audio_rq_loss_num_codebooks: 1 + audio_rq_loss_num_embeds: 8192 + + # ---- cqt reconstruction, need to add loss weight --- + audio_cqt_loss_m: true + audio_cqt_bins: 336 + # ----------- + final_dim: 64 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + dropout: 0.1 + attention_dropout: 0.1 + feature_grad_mult: 0.1 + untie_final_proj: true + activation_dropout: 0.0 + + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_bestrvq_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_bestrvq_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9a5fa87c965ce2aab465f0df5cd563a99a4ae20c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_bestrvq_multinodes.yaml @@ -0,0 +1,108 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 12500 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # # crop to 5s + # max_sample_size: 120000 + # min_sample_size: 72000 + + # crop to 30s + max_sample_size: 720000 + min_sample_size: 432000 + clip_secs: 30 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: model + # log_keys: + # - accuracies + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: musicfm + label_rate: 25 + num_codebooks: 1 + codebook_dim: 16 + codebook_size: 8192 # 4096 + features: ["melspec_2048"] + hop_length: 240 + n_mels: 128 + conv_dim: 512 + encoder_dim: 1024 + encoder_depth: 12 + mask_hop: 0.4 + mask_prob: 0.6 + is_flash: false + + stat_path: msd_stats.json + model_path: null + w2v2_config_path: our-MERT/data/models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa + use_rvq_target: true + rvq_ckpt_path: RVQ_4000.pth + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd38c04f61129d7866b231df2bb6ec2cbf606d78 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_multinodes.yaml @@ -0,0 +1,105 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 12500 + keep_interval_updates: -1 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # # crop to 5s + # max_sample_size: 120000 + # min_sample_size: 72000 + + # crop to 30s + max_sample_size: 720000 + min_sample_size: 432000 + clip_secs: 30 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + +criterion: + _name: model + # log_keys: + # - accuracies + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: musicfm + label_rate: 25 + num_codebooks: 1 + codebook_dim: 16 + codebook_size: 4096 + features: ["melspec_2048"] + hop_length: 240 + n_mels: 128 + conv_dim: 512 + encoder_dim: 1024 + encoder_depth: 12 + mask_hop: 0.4 + mask_prob: 0.6 + is_flash: false + stat_path: msd_stats.json + model_path: pretrained_msd.pt + w2v2_config_path: models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_speech_multinodes.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_speech_multinodes.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0cc0d2a03c6f937ddbdf39fc46eb532eef98fd2f --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/MusicFM_95M_speech_multinodes.yaml @@ -0,0 +1,106 @@ +# @package _group_ +common: + fp16: false + log_format: json + log_interval: 200 + seed: 1337 + # tensorboard_logdir: tblog_proj_name + # wandb_project: wandb_proj_name + +checkpoint: + save_interval_updates: 2500 + keep_interval_updates: 10000 + no_epoch_checkpoints: true + + +distributed_training: + ddp_backend: no_c10d + distributed_backend: 'nccl' + distributed_world_size: 64 + nprocs_per_node: 8 + find_unused_parameters: true + +task: + _name: mert_pretraining + data: ??? + label_dir: ??? + labels: ??? + label_rate: ${model.label_rate} + sample_rate: 24000 + # # crop to 5s + # max_sample_size: 120000 + # min_sample_size: 72000 + + # crop to 30s + max_sample_size: 720000 + min_sample_size: 12000 + # clip_secs: 30 + + pad_audio: false + random_crop: true + normalize: false # must be consistent with extractor + + +dataset: + num_workers: 6 + max_tokens: 2000000 + skip_invalid_size_inputs_valid_test: true + validate_interval: 1 + validate_interval_updates: 10000 + disable_validation: true + +criterion: + _name: model + # log_keys: + # - accuracies + +optimization: + max_update: 400000 + lr: [0.0005] + clip_norm: 10.0 + update_freq: [1] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: musicfm + label_rate: 25 + num_codebooks: 1 + codebook_dim: 16 + codebook_size: 4096 + features: ["melspec_2048"] + hop_length: 240 + n_mels: 128 + conv_dim: 512 + encoder_dim: 1024 + encoder_depth: 12 + mask_hop: 0.4 + mask_prob: 0.6 + is_flash: false + stat_path: msd_stats.json + model_path: null + w2v2_config_path: models--facebook--wav2vec2-conformer-rope-large-960h-ft/snapshots/6b36ef01c6443c67ae7ed0822876d091ab50e4aa + +hydra: + job: + config: + override_dirname: + kv_sep: '-' + item_sep: '__' + exclude_keys: + - run + - task.data + - task.label_dir + run: + dir: ??? + sweep: + dir: ??? + subdir: ${hydra.job.config_name}__${hydra.job.override_dirname} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/run/submitit_reg.yaml b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/run/submitit_reg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..46c979cd2835fe026b0a532a54533904d1001e54 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/config/pretrain/run/submitit_reg.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +hydra: + launcher: + cpus_per_task: 8 + gpus_per_node: 8 + tasks_per_node: ${hydra.launcher.gpus_per_node} + nodes: 4 + comment: null + mem_gb: 384 + timeout_min: 4320 + max_num_timeout: 100 + constraint: volta32gb + name: ${hydra.job.config_name}/${hydra.job.override_dirname} + submitit_folder: ${hydra.sweep.dir}/submitit/%j + +distributed_training: + distributed_world_size: 32 + distributed_port: 29671 + nprocs_per_node: 8 diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/__init__.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..17079090f01ca01bf4bab30fa298fd9762752fc8 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/__init__.py @@ -0,0 +1,2 @@ +from .mert_dataset import MERTDataset +from .eat_data import * \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/ark_dataset.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/ark_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..47e8b8507d823305152a2f296432498afc8946c1 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/ark_dataset.py @@ -0,0 +1,115 @@ +import logging +import torch +import torch.nn.functional as F +from fairseq.data.audio.raw_audio_dataset import RawAudioDataset +from typing import Tuple +try: + import kaldiio +except: + kaldiio = None +import warnings + +logger = logging.getLogger(__name__) + + +class ArkDataset(RawAudioDataset): + def __init__( + self, + wav_scp, + dur_scp, + sr = 24000, + max_dur = 20, + num_buckets=0, + normalize=False, + ): + super().__init__( + sample_rate=sr, + max_sample_size=max_dur*sr, + min_sample_size=1200, + shuffle=True, + pad=True, + normalize=normalize, + compute_mask=False, + ) + self.sr = sr + self.max_dur = max_dur + self.normalize = normalize + + logger.info("Loading Kaldi scp files from {}".format(wav_scp)) + + self.wav_data = kaldiio.load_scp(wav_scp) + self.keys = list(self.wav_data.keys()) + dur_data = {} + keys_set = set(self.keys) + + with open(dur_scp, 'r') as f: + for line in f: + line = line.strip().split() + if line[0] in keys_set: + dur_data[line[0]] = float(line[-1]) + self.sizes = [int(dur_data[k]*self.sr/100) for k in self.keys] + + logger.info("Loading Kaldi scp files done") + + self.dataset_len = len(self.keys) + self.set_bucket_info(num_buckets) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, idx): + # print("getitem idx: ", idx) + try_cnt = 0 + while True: + idx = idx + try_cnt + try: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + key = self.keys[idx] + # print(self.wav_data[key].keys()) + wav = self.wav_data[key]['wav'] + + wav = torch.from_numpy(wav).float() + wav = self.postprocess(wav) + # print("success load", idx, " shape =", wav.shape) + return {"id": idx, "source": wav} + except Exception as e: + # from traceback import print_exc + # print_exc() + # print("Error loadding ", idx) + # return {"id": idx, "source": None} + try_cnt += 1 + if try_cnt > 50: + return {"id": idx, "source": None} + continue + + def size(self, idx): + return self.sizes[idx] + + def postprocess(self, wav): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav + + def collater(self, samples): + # print("collate from:", [s['source'].shape for s in samples if s['source'] is not None]) + return super().collater(samples) + +if __name__ == '__main__': + import torch + raw_tensor_str = torch.Tensor.__repr__ + torch.Tensor.__str__ = torch.Tensor.__repr__ = lambda self: f'Tensor{{Size({[*self.shape]}) {self.device} {str(self.dtype)[6]}{str(self.dtype)[-2:]}}}' if self.numel() > 10 else raw_tensor_str(self) + + ds = ArkDataset( + wav_scp='data/ark_demo/wav_ark.scp', + dur_scp='data/ark_demo/dur_ark.scp', + sr=24000, + ) + + for i in range(len(ds)): + print(ds[i]) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/__init__.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92eae486fee1ed52c3550a8c47b17f760fe0933e --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +try: + from .mae_image_dataset import MaeImageDataset + from .raw_audio_dataset import FileAudioDataset +except: + import sys, os + sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '.')) + from mae_image_dataset import MaeImageDataset + from raw_audio_dataset import FileAudioDataset + +__all__ = [ + "MaeImageDataset", + "FileAudioDataset", +] \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/add_class_target_dataset.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/add_class_target_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea93918d00fe2c5d4582bbc3d5b9e6035ccdc94 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/add_class_target_dataset.py @@ -0,0 +1,63 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from fairseq.data import BaseWrapperDataset + +# add labels for audio clips in fine-tuning +class AddClassTargetDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + labels, + multi_class, + num_classes=None, + label_indices=None, + add_to_input=True, + ): + super().__init__(dataset) + + self.label_indices = label_indices + self.labels = labels + self.multi_class = multi_class + self.add_to_input = add_to_input + if num_classes is None and multi_class: + assert self.label_indices is not None + num_classes = len(self.label_indices) + + self.num_classes = num_classes + + def __getitem__(self, index): + item = self.dataset[index] + item_labels = self.labels[index] + if self.multi_class: + item["label"] = torch.zeros(self.num_classes) + for il in item_labels: + if self.label_indices is not None: + il = self.label_indices[il] + item["label"][int(il)] = 1.0 + else: + item["label"] = torch.tensor( + self.labels[index] + if self.label_indices is None + else self.label_indices[self.labels[index]] + ) + + return item + + def collater(self, samples): + collated = self.dataset.collater(samples) + if len(collated) == 0: + return collated + + indices = set(collated["id"].tolist()) + target = [s["label"] for s in samples if s["id"] in indices] + collated["label"] = torch.stack(target, dim=0) + + if self.add_to_input: + collated["net_input"]["label"] = collated["label"] + + return collated diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/mae_image_dataset.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/mae_image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cce5086ca31263e8eb5d66c084b0edc4a57f63fe --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/mae_image_dataset.py @@ -0,0 +1,296 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +from functools import partial +import logging +import random +import time +import numpy as np +import os +import torch + +from fairseq.data import FairseqDataset +try: + from ..utils.data_utils import compute_block_mask_1d, compute_block_mask_2d +except: + import sys, os + sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) + from utils.data_utils import compute_block_mask_1d, compute_block_mask_2d +try: + from .raw_audio_dataset import FileAudioDataset +except: + import sys, os + sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '.')) + from raw_audio_dataset import FileAudioDataset + +from shutil import copyfile + +logger = logging.getLogger(__name__) + + +def load(path, loader, cache): + if hasattr(caching_loader, "cache_root"): + cache = caching_loader.cache_root + + cached_path = cache + path + + num_tries = 3 + for curr_try in range(num_tries): + try: + if curr_try == 2: + return loader(path) + if not os.path.exists(cached_path) or curr_try > 0: + os.makedirs(os.path.dirname(cached_path), exist_ok=True) + copyfile(path, cached_path) + os.chmod(cached_path, 0o777) + return loader(cached_path) + except Exception as e: + logger.warning(str(e)) + if "Errno 13" in str(e): + caching_loader.cache_root = f"/scratch/{random.randint(0, 69420)}" + logger.warning(f"setting cache root to {caching_loader.cache_root}") + cached_path = caching_loader.cache_root + path + if curr_try == (num_tries - 1): + raise + time.sleep(2) + + +def caching_loader(cache_root: str, loader): + if cache_root is None: + return loader + + if cache_root == "slurm_tmpdir": + cache_root = os.environ["SLURM_TMPDIR"] + assert len(cache_root) > 0 + + if not cache_root.endswith("/"): + cache_root += "/" + + return partial(load, loader=loader, cache=cache_root) + + +class MaeImageDataset(FairseqDataset): + def __init__( + self, + root: str, + split: str, + input_size, + shuffle=True, + key="imgs", + compute_mask=False, + patch_size: int = 16, + mask_prob: float = 0.75, + mask_prob_adjust: float = 0, + mask_length: int = 1, + inverse_mask: bool = False, + expand_adjacent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, + require_same_masks: bool = True, + clone_batch: int = 1, + audio_mae:bool = False, + h5_format:bool = False, + downsr_16hz:bool = False, + target_length:int = 1024, + esc50_eval:bool = False, + spcv2_eval:bool = False, + roll_aug: bool = False, + noise: bool = False, + dataset_type: str = "imagefolder", + num_samples: int = 200000, + replacement: bool = False, + AS2M_finetune: bool = False, + spcv1_finetune: bool =False, + weights_file: str="", + flexible_mask: bool = False, + sample_rate=24000, + fixed_duration=10, + ): + FairseqDataset.__init__(self) + + self.shuffle = shuffle + self.key = key + self.audio_mae = audio_mae + if self.audio_mae: + self.h5_format = h5_format + self.downsr_16hz = downsr_16hz + self.target_length = target_length + self.esc50_eval = esc50_eval + self.spcv2_eval = spcv2_eval + self.noise = noise + self.num_samples = num_samples + self.replacement = replacement + self.split = split + self.AS2M_finetune = AS2M_finetune + self.spcv1_finetune= spcv1_finetune + self.weights_file = weights_file + self.flexible_mask = flexible_mask + + self.transform_source = None + self.transform_target = None + self.img_shape = None + self.roll_aug = roll_aug + + # load wav files + mask_args = {} + if self.audio_mae: + min_sample_size = 10000 + + input_size = (self.target_length,128) + manifest_path = os.path.join(root, "{}.jsonl".format(split)) + self.dataset = FileAudioDataset( + manifest_path=manifest_path, + sample_rate=sample_rate, + fixed_duration=fixed_duration, + max_sample_size=sample_rate*fixed_duration, + min_sample_size=min_sample_size, + pad=False, + normalize=True, + num_buckets=0, + compute_mask=False, + h5_format=self.h5_format, + downsr_16hz=self.downsr_16hz, + wav2fbank=True, + target_length=self.target_length, + esc50_eval=self.esc50_eval, + spcv2_eval=self.spcv2_eval, + roll_mag_aug=self.roll_aug, + train_mode=split, + noise=self.noise, + **mask_args, + ) + self.skipped_indices = self.dataset.skipped_indices + + else: + raise Exception(f"invalid dataset type {dataset_type}") + + + logger.info(f"loaded {len(self.dataset)} examples") + + self.is_compute_mask = compute_mask + + if type(input_size) == tuple: + self.patches = (input_size[0] // patch_size ) * ( input_size[1] // patch_size ) + self.img_shape = (input_size[0] // patch_size,input_size[1] // patch_size ) + + else: + self.patches = (input_size // patch_size) ** 2 + self.mask_prob = mask_prob + self.mask_prob_adjust = mask_prob_adjust + self.mask_length = mask_length + self.inverse_mask = inverse_mask + self.expand_adjacent = expand_adjacent + self.mask_dropout = mask_dropout + self.non_overlapping = non_overlapping + self.require_same_masks = require_same_masks + self.clone_batch = clone_batch + + def __getitem__(self, index): + if self.audio_mae: + img = self.dataset[index]['source'] + else: + img, _ = self.dataset[index] + + source = None + target = None + + v = {"id": index, self.key: source if source is not None else img} + if target is not None: + v["target"] = target + + # inverse block mask on audio patches + if self.is_compute_mask: + if self.mask_length == 1: + mask = compute_block_mask_1d( + shape=(self.clone_batch, self.patches), + mask_prob=self.mask_prob, + mask_length=self.mask_length, + mask_prob_adjust=self.mask_prob_adjust, + inverse_mask=self.inverse_mask, + require_same_masks=True, + ) + else: # mask_length==5 + mask = compute_block_mask_2d( + shape=(self.clone_batch, self.patches), + mask_prob=self.mask_prob, + mask_length=self.mask_length, + mask_prob_adjust=self.mask_prob_adjust, + inverse_mask=self.inverse_mask, + require_same_masks=True, + expand_adjcent=self.expand_adjacent, + mask_dropout=self.mask_dropout, + non_overlapping=self.non_overlapping, + img_shape=self.img_shape, + flexible_mask=self.flexible_mask + ) + + if mask.shape[1] < self.patches: + padding = torch.zeros((mask.shape[0], self.patches - mask.shape[1])) + mask = torch.cat((mask, padding), dim=1) + + v["precomputed_mask"] = mask + + return v + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if len(samples) == 0: + return {} + + collated_img = torch.stack([s[self.key] for s in samples], dim=0) + + res = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": { + self.key: collated_img, + }, + } + + if "target" in samples[0]: + collated_target = torch.stack([s["target"] for s in samples], dim=0) + res["net_input"]["target"] = collated_target + + if "precomputed_mask" in samples[0]: + collated_mask = torch.cat([s["precomputed_mask"] for s in samples], dim=0) + res["net_input"]["precomputed_mask"] = collated_mask + + return res + + def num_tokens(self, index): + return 1 + + def size(self, index): + return 1 + + @property + def sizes(self): + return np.full((len(self),), 1) + + # shuffle data (for pre-training and fine-tuning) + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle and (self.AS2M_finetune or self.spcv1_finetune) and self.split == "train" : + weights = np.loadtxt(self.weights_file) + normalized_weights = weights / np.sum(weights) + weights_tensor = torch.from_numpy(normalized_weights) + + subsample_balanced_indicies = torch.multinomial(weights_tensor, self.num_samples, self.replacement) + order = subsample_balanced_indicies.numpy() + + # order = [np.random.choice(order[0], size=len(self), replace=True, p=weights)] + return order + + elif self.shuffle and self.split == "train": + order = [np.random.permutation(len(self))] + return order[0] + + + else: + order = [np.arange(len(self))] + return order[0] diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/raw_audio_dataset.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/raw_audio_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d8475ec02fcc3c0de9cebf16c8232c685037d6c1 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/eat_data/raw_audio_dataset.py @@ -0,0 +1,545 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import os +import sys +import time +import io +try: + import h5py +except: + h5py = None + +import numpy as np +import torch +import torch.nn.functional as F +import torchaudio + +from fairseq.data import FairseqDataset +try: + from ..utils.data_utils import compute_block_mask_1d, get_buckets, get_bucketed_sizes +except: + import sys, os + sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) + from utils.data_utils import compute_block_mask_1d, get_buckets, get_bucketed_sizes +from fairseq.data.audio.audio_utils import ( + parse_path, + read_from_stored_zip, + is_sf_audio_data, +) +from fairseq.data.text_compressor import TextCompressor, TextCompressionLevel + +import math +from typing import Tuple +import json + +logger = logging.getLogger(__name__) + +def load_audio_by_json(json_path, max_keep, min_keep): + # read json file + n_long, n_short = 0, 0 + datas = [] + inds = [] + sizes = [] + with open(json_path) as fp: + for ind,line in enumerate(fp): + data = json.loads(line) + sz = int(data['duration'] * data['sample_rate']) + if min_keep is not None and sz < min_keep: + n_short += 1 + elif max_keep is not None and sz > max_keep: + n_long += 1 + else: + datas.append(data) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"json_path={json_path}, " + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(datas)}, skipped {n_short} short and {n_long} long, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return datas, inds, tot, sizes + +class RawAudioDataset(FairseqDataset): + def __init__( + self, + sample_rate, + max_sample_size=None, + min_sample_size=0, + shuffle=True, + pad=False, + normalize=False, + compute_mask=False, + feature_encoder_spec: str = "None", + mask_prob: float = 0.75, + mask_prob_adjust: float = 0, + mask_length: int = 1, + inverse_mask: bool = False, + require_same_masks: bool = True, + clone_batch: int = 1, + expand_adjacent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, + corpus_key=None, + ): + super().__init__() + + self.sample_rate = sample_rate + self.sizes = [] + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + self.min_sample_size = min_sample_size + self.pad = pad + self.shuffle = shuffle + self.normalize = normalize + + self.is_compute_mask = compute_mask + self.feature_encoder_spec = eval(feature_encoder_spec) + self._features_size_map = {} + self.mask_prob = mask_prob + self.mask_prob_adjust = mask_prob_adjust + self.mask_length = mask_length + self.inverse_mask = inverse_mask + self.require_same_masks = require_same_masks + self.clone_batch = clone_batch + self.expand_adjacent = expand_adjacent + self.mask_dropout = mask_dropout + self.non_overlapping = non_overlapping + self.corpus_key = corpus_key + + def __getitem__(self, index): + raise NotImplementedError() + + def __len__(self): + return len(self.sizes) + + def _roll_mag_aug(self, waveform): + waveform=waveform.numpy() + idx=np.random.randint(len(waveform)) + rolled_waveform=np.roll(waveform,idx) + mag = np.random.beta(10, 10) + 0.5 + return torch.Tensor(rolled_waveform*mag) + + + def postprocess(self, feats, curr_sample_rate, roll_aug = False): + if feats.dim() == 2: + feats = feats.mean(-1) + + if curr_sample_rate != self.sample_rate: + raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}") + + assert feats.dim() == 1, feats.dim() + # if self.normalize: + # with torch.no_grad(): + # feats = F.layer_norm(feats, feats.shape) + feats = feats - feats.mean() + + if roll_aug: + feats = self._roll_mag_aug(feats) + + return feats + + def crop_to_max_size(self, t, target_size, dim=0): + size = t.size(dim) + diff = size - target_size + if diff <= 0: + return t + + start = np.random.randint(0, diff + 1) + end = size - diff + start + + slices = [] + for d in range(dim): + slices.append(slice(None)) + slices.append(slice(start, end)) + + return t[slices] + + @staticmethod + def _bucket_tensor(tensor, num_pad, value): + return F.pad(tensor, (0, num_pad), value=value) + + def collater(self, samples): + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + sources = [s["source"] for s in samples] + sizes = [len(s) for s in sources] + + if self.pad: + target_size = min(max(sizes), self.max_sample_size) + else: + target_size = min(min(sizes), self.max_sample_size) + + collated_sources = sources[0].new_zeros(len(sources), target_size) + padding_mask = ( + torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None + ) + for i, (source, size) in enumerate(zip(sources, sizes)): + diff = size - target_size + if diff == 0: + collated_sources[i] = source + elif diff < 0: + assert self.pad + collated_sources[i] = torch.cat( + [source, source.new_full((-diff,), 0.0)] + ) + padding_mask[i, diff:] = True + else: + collated_sources[i] = self.crop_to_max_size(source, target_size) + + input = {"source": collated_sources} + if self.corpus_key is not None: + input["corpus_key"] = [self.corpus_key] * len(sources) + out = {"id": torch.LongTensor([s["id"] for s in samples])} + if self.pad: + input["padding_mask"] = padding_mask + + if hasattr(self, "num_buckets") and self.num_buckets > 0: + assert self.pad, "Cannot bucket without padding first." + bucket = max(self._bucketed_sizes[s["id"]] for s in samples) + num_pad = bucket - collated_sources.size(-1) + if num_pad: + input["source"] = self._bucket_tensor(collated_sources, num_pad, 0) + input["padding_mask"] = self._bucket_tensor(padding_mask, num_pad, True) + + if "precomputed_mask" in samples[0]: + target_size = self._get_mask_indices_dims(target_size) + collated_mask = torch.cat( + [ + self.crop_to_max_size(s["precomputed_mask"], target_size, dim=1) + for s in samples + ], + dim=0, + ) + input["precomputed_mask"] = collated_mask + + out["net_input"] = input + return out + + def _get_mask_indices_dims(self, size, padding=0, dilation=1): + if size not in self.feature_encoder_spec: + L_in = size + for (_, kernel_size, stride) in self.feature_encoder_spec: + L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1 + L_out = 1 + L_out // stride + L_in = L_out + self._features_size_map[size] = L_out + return self._features_size_map[size] + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + if self.pad: + return self.sizes[index] + return min(self.sizes[index], self.max_sample_size) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + + if self.shuffle: + order = [np.random.permutation(len(self))] + order.append( + np.minimum( + np.array(self.sizes), + self.max_sample_size, + ) + ) + return np.lexsort(order)[::-1] + else: + return np.arange(len(self)) + + def set_bucket_info(self, num_buckets): + self.num_buckets = num_buckets + if self.num_buckets > 0: + self._collated_sizes = np.minimum( + np.array(self.sizes), + self.max_sample_size, + ) + self.buckets = get_buckets( + self._collated_sizes, + self.num_buckets, + ) + self._bucketed_sizes = get_bucketed_sizes( + self._collated_sizes, self.buckets + ) + logger.info( + f"{len(self.buckets)} bucket(s) for the audio dataset: " + f"{self.buckets}" + ) + + def filter_indices_by_size(self, indices, max_sizes): + return indices, [] + +class Read_and_PadCrop_Normalized_T(torch.nn.Module): + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + + def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: + if(duration<(float(self.n_samples)/self.sample_rate+1)): + # print(duration,(float(self.n_samples)/self.sample_rate+1)) + chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) + t_start = 0. + t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) + offset = 0 + # print('c1:',chunk.shape) + else: + offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + t_start = offset / float(cur_sample_rate) / duration + t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration + chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + # print('offset:',offset) + # print('c0:',chunk.shape) + # Pad with silence if necessary. + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + if(cur_sample_rate!=self.sample_rate): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) + # print('b:',self.sample_rate,chunk.shape) + if chunk.shape[-1] < self.n_samples: + chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) + else: + chunk = chunk[:,0:self.n_samples] + seconds_start = math.floor(offset / cur_sample_rate) + seconds_total = math.floor(duration) + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total + ) + +class FileAudioDataset(RawAudioDataset): + def __init__( + self, + manifest_path, + sample_rate, + fixed_duration=None, + max_sample_size=None, + min_sample_size=0, + shuffle=True, + pad=False, + normalize=False, + num_buckets=0, + compute_mask=False, + text_compression_level=TextCompressionLevel.none, + h5_format=False, + downsr_16hz=False, + wav2fbank=False, + target_length=1024, + esc50_eval=False, + spcv2_eval=False, + roll_mag_aug=False, + noise=False, + train_mode='train', + **mask_compute_kwargs, + ): + super().__init__( + sample_rate=sample_rate, + max_sample_size=max_sample_size, + min_sample_size=min_sample_size, + shuffle=shuffle, + pad=pad, + normalize=normalize, + compute_mask=compute_mask, + **mask_compute_kwargs, + ) + + self.text_compressor = TextCompressor(level=text_compression_level) + self.h5_format = h5_format + self.downsr_16hz = downsr_16hz + self.wav2fbank = wav2fbank + self.target_length = target_length + self.esc50_eval = esc50_eval + self.spcv2_eval = spcv2_eval + self.roll_mag_aug = roll_mag_aug + self.noise = noise + self.train_mode = train_mode + self.reader = Read_and_PadCrop_Normalized_T(n_samples = int(fixed_duration*sample_rate), sample_rate = sample_rate) + + skipped = 0 + self.fnames = [] + sizes = [] + self.skipped_indices = set() + + # exclude data not in sample rate range 10.h5/****.wav 320000 + self.durations = [] + self.raw_srs = [] + datas, inds, tot, sizes = load_audio_by_json(manifest_path, max_keep=None, min_keep=None) + for data in datas: + self.fnames.append(self.text_compressor.compress(data['path'])) + self.durations.append(data['duration']) + self.raw_srs.append(data['sample_rate']) + + logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") + + if self.esc50_eval: + task_dataset = "ESC-50" + elif self.spcv2_eval: + task_dataset = "SPC-2" + else: + task_dataset = "AS" + + logger.info( + f"sample rate: {sample_rate}\t" + f"target length: {self.target_length}\t" + f"current task: {task_dataset}\t" + ) + + # self.sizes = np.array(sizes, dtype=np.int64) + self.sizes = np.array(sizes, dtype=np.int64) + self.durations = np.array(self.durations) + self.raw_srs = np.array(self.raw_srs) + + try: + import pyarrow + + self.fnames = pyarrow.array(self.fnames) + except: + logger.debug( + "Could not create a pyarrow array. Please install pyarrow for better performance" + ) + pass + + self.set_bucket_info(num_buckets) + # print("skipped_index: {}".format(self.skipped_indices)) + # print(len(self.skipped_indices)) + + # two file format. h5_format = true -> .h5(.hdf5) ; h5_format = false -> .wav + def __getitem__(self, index): + import soundfile as sf + + fn = self.fnames[index] + fn = fn if isinstance(self.fnames, list) else fn.as_py() + fn = self.text_compressor.decompress(fn) + path_or_fp = fn # os.path.join(self.root_dir, fn) + _path, slice_ptr = parse_path(path_or_fp) + if len(slice_ptr) == 2: + byte_data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) # root/10.h5/***.wav + assert is_sf_audio_data(byte_data) + path_or_fp = io.BytesIO(byte_data) + + retry = 3 + wav = None + for i in range(retry): + try: + # if self.h5_format and self.train_mode == 'train': + # parts = path_or_fp.split("/") + # path_or_fp = "/".join(parts[:-1]) + # path_or_fp = h5py.File(path_or_fp,'r') + # wav = path_or_fp[parts[-1]][:] + # curr_sample_rate = 32000 + # break + # else: + # wav, curr_sample_rate = sf.read(path_or_fp, dtype="float32") + # break + wav, *ignored = self.reader(fn, self.durations[index], self.raw_srs[index]) + curr_sample_rate = self.reader.sample_rate + except Exception as e: + logger.warning( + f"Failed to read {path_or_fp}: {e}. Sleeping for {1 * i}" + ) + time.sleep(1 * i) + + if wav is None: + raise Exception(f"Failed to load {path_or_fp}") + + if self.h5_format: + feats = torch.tensor(wav).float() + else: + if not isinstance(wav, torch.Tensor): + feats = torch.from_numpy(wav).float() + else: + feats = wav + if len(feats.shape) == 2: + feats = feats.squeeze(dim=0) + + if self.downsr_16hz: + feats = torchaudio.functional.resample(feats, orig_freq=curr_sample_rate, new_freq=16000) + curr_sample_rate = 16000 + self.sample_rate = curr_sample_rate + + # whether to use roll augmentation on waveform + use_roll = self.roll_mag_aug and self.train_mode == 'train' + + feats = self.postprocess(feats, curr_sample_rate, use_roll) + + # convert waveform to spectrogram + if self.wav2fbank: #这里,将wav转换为mel谱 + feats = feats.unsqueeze(dim=0) + feats = torchaudio.compliance.kaldi.fbank(feats, htk_compat=True, sample_frequency=curr_sample_rate, use_energy=False, + window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=40).unsqueeze(dim=0) + + # padding + n_frames = feats.shape[1] + diff = self.target_length - n_frames #时间维度上补齐至1024 + if diff > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, diff)) + feats = m(feats) + + elif diff < 0: + feats = feats[0:self.target_length, :] + + # global normalization for AS + self.norm_mean = -4.268 + self.norm_std = 4.569 + + # global normalization for ESC-50 + if self.esc50_eval: + self.norm_mean = -6.627 + self.norm_std = 5.359 + + # global normalization for spcv2 + if self.spcv2_eval: + self.norm_mean = -6.846 + self.norm_std = 5.565 + + feats = (feats - self.norm_mean) / (self.norm_std * 2) + + if self.noise and self.train_mode == 'train': + feats = feats + torch.rand(feats.shape[1], feats.shape[2]) * np.random.rand() / 10 # 这个加noise的方式视情况可能要换成inbatch noise + feats = torch.roll(feats, np.random.randint(-10, 10), 1) + + v = {"id": index, "source": feats} + + if self.is_compute_mask: + T = self._get_mask_indices_dims(feats.size(-1)) + mask = compute_block_mask_1d( + shape=(self.clone_batch, T), + mask_prob=self.mask_prob, + mask_length=self.mask_length, + mask_prob_adjust=self.mask_prob_adjust, + inverse_mask=self.inverse_mask, + require_same_masks=True, + expand_adjcent=self.expand_adjacent, + mask_dropout=self.mask_dropout, + non_overlapping=self.non_overlapping, + ) + + v["precomputed_mask"] = mask + + return v diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/mert_dataset.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/mert_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8d4852b0cf07390f8ae5bb55044888a0d00969ee --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/mert_dataset.py @@ -0,0 +1,682 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os +import sys +from typing import Any, List, Optional, Union + +import numpy as np +from typing import Tuple +import torch +import torch.nn.functional as F +from fairseq.data import data_utils +from fairseq.data.fairseq_dataset import FairseqDataset +from fairseq.data.audio.audio_utils import ( + parse_path, + read_from_stored_zip, +) + +import math +import io +import torchaudio +# this is in the user_dir +from nnAudio import features as nnAudioFeatures + +# from tqdm import tqdm +import tqdm +import json +import random +import traceback +# from scripts.prepare_codecs_from_manifest import * + +logger = logging.getLogger(__name__) + +class model_cqt_pred(torch.nn.Module): + def __init__(self, n_bins=84, sr=16000, freq=50): + super().__init__() + self.epsilon=1e-10 + # Getting Mel Spectrogram on the fly + self.spec_layer = nnAudioFeatures.cqt.CQT(sr=sr, hop_length=sr//freq, fmin=32.7, + fmax=None, n_bins=n_bins, bins_per_octave=n_bins//7, + filter_scale=1, norm=1, window='hann', center=True, + pad_mode='constant', trainable=False, + output_format='Magnitude', verbose=True) + + # self.fc = nn.Linear(input_dim, n_bins) + + # self.criterion = nn.MSELoss() + self.forward_dict = { + # 'masked_transformer_output': self.plain_forward + 'compute_cqt': self.compute_cqt + } + def compute_cqt(self, x): + ''' + convert waveform to CQT -> [batch, bins, len] -> transpose + ''' + # align with the padding of HuBERT model, + # the truncation is calculated by bruteforce search since the nnAudio padding strategy and fairseq models are different + # x = x[..., :-560] + return torch.transpose(self.spec_layer(x), -1, -2) + + def forward(self, x, forward_type='masked_transformer_output'): + ''' + take input from transformer hidden states: [batch, len_seq, channel] + output: [batch, len_seq, n_bins] + ''' + + return self.forward_dict[forward_type](x) +# def audio2label(wav,sr): +# wav = convert_audio(wav, sr, model.sample_rate, model.channels) +# wav = wav.unsqueeze(0) +# wav = wav.to(device) +# with torch.no_grad(): +# encoded_frames = model.encode(wav) +# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T] +# codes = codes.to('cpu')[0] + +# # for i in range(args.n_codebook): +# # f_codecs[i].write(' '.join([str(x) for x in codes[i].numpy()]) + '\n') +def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate, clip_secs=5): + # read json file + print(json_path) + datas = [] + inds = [] + sizes = [] + with open(json_path) as fp: + for ind,line in enumerate(fp): + data = json.loads(line) + datas.append(data) + inds.append(ind) + # sz = int(data['duration'] * data['sample_rate']) + sz = int(tgt_sample_rate * clip_secs) + sizes.append(sz) + tot = ind + 1 + return datas,inds,tot,sizes +def load_audio(manifest_path, max_keep, min_keep): #读取tsv文件(原本) + print(manifest_path) + + n_long, n_short = 0, 0 + names, inds, sizes = [], [], [] + with open(manifest_path) as f: + root = f.readline().strip() + for ind, line in enumerate(f): + items = line.strip().split("\t") + assert len(items) == 2, line + sz = int(items[1]) + if min_keep is not None and sz < min_keep: + n_short += 1 + elif max_keep is not None and sz > max_keep: + n_long += 1 + else: + names.append(items[0]) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(names)}, skipped {n_short} short and {n_long} long, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return root, names, inds, tot, sizes + + +def load_label(label_path, inds, tot): + with open(label_path) as f: + labels = [] + for line in tqdm.tqdm(f): + labels.append(line.rstrip()) + # labels = [line.rstrip() ] + assert ( + len(labels) == tot + ), f"number of labels does not match ({len(labels)} != {tot})" + labels = [labels[i] for i in inds] + return labels + +def load_numpy_label(label_path, inds, tot): + labels = np.load(label_path, mmap_mode='r') + assert (labels.shape[0] == tot), f"number of labels does not match ({labels.shape[0]} != {tot})" + return labels + + +# def load_label_offset(label_path, inds, tot): +# with open(label_path) as f: +# code_lengths = [len(line.encode("utf-8")) for line in f] +# assert ( +# len(code_lengths) == tot +# ), f"number of labels does not match ({len(code_lengths)} != {tot})" +# offsets = list(itertools.accumulate([0] + code_lengths)) +# offsets = [(offsets[i], offsets[i + 1]) for i in inds] +# return offsets + + +def verify_label_lengths( + audio_sizes, + audio_rate, + label_path, + label_rate, + inds, + tot, + tol=0.1, # tolerance in seconds +): + if label_rate < 0: + logger.info(f"{label_path} is sequence label. skipped") + return + + with open(label_path) as f: + lengths = [] + for line in tqdm.tqdm(f): + lengths.append(len(line.rstrip().split())) + assert len(lengths) == tot + lengths = [lengths[i] for i in inds] + num_invalid = 0 + for i, ind in enumerate(inds): + dur_from_audio = audio_sizes[i] / audio_rate + dur_from_label = lengths[i] / label_rate + if abs(dur_from_audio - dur_from_label) > tol: + logger.warning( + ( + f"audio and label duration differ too much " + f"(|{dur_from_audio} - {dur_from_label}| > {tol}) " + f"in line {ind+1} of {label_path}. Check if `label_rate` " + f"is correctly set (currently {label_rate}). " + f"num. of samples = {audio_sizes[i]}; " + f"label length = {lengths[i]}" + ) + ) + num_invalid += 1 + if num_invalid > 0: + logger.warning( + f"total {num_invalid} (audio, label) pairs with mismatched lengths" + ) + +class Read_and_PadCrop_Normalized_T(torch.nn.Module): + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + + def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: + if(duration<(float(self.n_samples)/self.sample_rate+1)): + # print(duration,(float(self.n_samples)/self.sample_rate+1)) + chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) + t_start = 0. + t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) + offset = 0 + # print('c1:',chunk.shape) + else: + offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + t_start = offset / float(cur_sample_rate) / duration + t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration + chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + # print('offset:',offset) + # print('c0:',chunk.shape) + # Pad with silence if necessary. + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + if(cur_sample_rate!=self.sample_rate): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) + # print('b:',self.sample_rate,chunk.shape) + if chunk.shape[-1] < self.n_samples: + chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) + else: + chunk = chunk[:,0:self.n_samples] + seconds_start = math.floor(offset / cur_sample_rate) + seconds_total = math.floor(duration) + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total + ) + + +class MERTDataset(FairseqDataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + label_paths: List[str], + label_rates: Union[List[float], float], # -1 for sequence labels + pad_list: List[str], + eos_list: List[str], + label_processors: Optional[List[Any]] = None, + max_keep_sample_size: Optional[int] = None, + min_keep_sample_size: Optional[int] = None, + max_sample_size: Optional[int] = None, + shuffle: bool = True, + pad_audio: bool = False, + normalize: bool = False, + store_labels: bool = True, + npmemmap: bool = False, + random_crop: bool = False, + single_target: bool = False, + augmentation_effects: List[str] = [], + augmentation_probs: List[float] = [], + inbatch_noise_augment_len_range: List[int] = [8000, 24000], + inbatch_noise_augment_number_range: List[int] = [1, 3], + inbatch_noise_augment_volume: float = 1.0, + cqt_prediction_bin: int = -1, + dataset_len:int = 128*3000, + clip_secs = 5, + ): + # self.audio_root, self.audio_names, inds, tot, self.sizes = load_audio( + # manifest_path, max_keep_sample_size, min_keep_sample_size + # ) + + # manifest_path = '/apdcephfs_cq2/share_1297902/speech_user/erichtchen/shixisheng/zhouyz/MERT/music_data/all_v4/train.json' + self.sample_rate = sample_rate + self.shuffle = shuffle + self.random_crop = random_crop + self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path,max_keep_sample_size,min_keep_sample_size, self.sample_rate, clip_secs) + + self.num_labels = len(label_paths) + self.pad_list = pad_list + self.eos_list = eos_list + self.label_processors = label_processors + self.single_target = single_target + self.label_rates = ( + [label_rates for _ in range(len(label_paths))] + if isinstance(label_rates, float) + else label_rates + ) + self.store_labels = store_labels + self.npmemmap = npmemmap + + # self.dataset_len = dataset_len + self.dataset_len = len(self.datas) + logger.info('preparing labels') + logger.info('========dataset len: {}=========='.format(self.dataset_len)) + if store_labels: + if self.npmemmap: + self.label_list = [load_numpy_label(p+'.npy', inds, tot) for p in label_paths] + else: + self.label_list = [load_label(p, inds, tot) for p in label_paths] + else: + self.label_paths = label_paths + # self.label_offsets_list = [ + # load_label_offset(p, inds, tot) for p in label_paths + # ] + assert label_processors is None or len(label_processors) == self.num_labels + # logger.info('skip verify labels and audio lengths') + # for label_path, label_rate in zip(label_paths, self.label_rates): + # verify_label_lengths( + # self.sizes, sample_rate, label_path, label_rate, inds, tot + # ) + + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + self.pad_audio = pad_audio + self.normalize = normalize + logger.info( + f"pad_audio={pad_audio}, random_crop={random_crop}, " + f"normalize={normalize}, max_sample_size={self.max_sample_size}" + ) + + self.augmentation_effects = augmentation_effects + self.augmentation_probs = augmentation_probs + # if len(self.augmentation_effects) > 0: + # self.augmentor_init() + # self.apply_augmentation = self.augmentation_factry(sample_rate) + + self.inbatch_noise_augment_len_range = inbatch_noise_augment_len_range + self.inbatch_noise_augment_number_range = inbatch_noise_augment_number_range + self.inbatch_noise_augment_volume = inbatch_noise_augment_volume + + + self.cqt_prediction_bin = cqt_prediction_bin + if self.cqt_prediction_bin > 0: + self.encoder_cqt_model = model_cqt_pred(n_bins=self.cqt_prediction_bin) + logger.info('preparing cqt loss objective in dataloader with cpu') + + self.epoch = -1 + + self.reader = Read_and_PadCrop_Normalized_T(n_samples=clip_secs*sample_rate,sample_rate = self.sample_rate) + + + + @property + def can_reuse_epoch_itr_across_epochs(self): + """ + Whether we can reuse the :class:`fairseq.data.EpochBatchIterator` for + this dataset across epochs. + + This needs to return ``False`` if the sample sizes can change across + epochs, in which case we may need to regenerate batches at each epoch. + If your dataset relies in ``set_epoch`` then you should consider setting + this to ``False``. + """ + return False + def set_epoch(self, epoch): + """Will receive the updated epoch number at the beginning of the epoch.""" + self.epoch = epoch + + def inbatch_noise_augment(self, + target_audio: torch.Tensor, target_audio_idx: int , + batch_audios: torch.Tensor, # [bsz, audio_lengths] + noise_len_min: int, noise_len_max: int, + n_noise_min: int, n_noise_max: int, + noise_vol: float = 1.0): + ''' + augmenation that leverages in-batch noise audios. + noise_len_min and noise_len_max are the range of the lengths of noises (counted as samples) + n_noise_min and n_noise_max are the range of number of noises, + ''' + # assert noise_len_max <= target_audio.shape[0] and noise_len_min >= 1 # should assert this outside? + + augmented_audio = torch.clone(target_audio) + + # exclude the target audio and use the rest as noise candidates + noise_pool = torch.cat( batch_audios[:target_audio_idx] + batch_audios[target_audio_idx+1:], dim=0).view(-1) + + n_noise = np.random.randint(n_noise_min, n_noise_max) + # n_noise + random_start_idxs = np.random.randint(0, noise_pool.shape[0] - noise_len_max, size=(n_noise,)) + random_durations = np.random.randint(noise_len_min, noise_len_max, size=(n_noise,)) + + for noise_idx in range(n_noise): + augmentation_position = np.random.randint(0, target_audio.shape[0] - random_durations[noise_idx], size=None) + # assign noise to the original audio + augmented_audio[augmentation_position:augmentation_position+random_durations[noise_idx]] += \ + noise_vol * noise_pool[random_start_idxs[noise_idx]: random_start_idxs[noise_idx]+random_durations[noise_idx]] + + return augmented_audio + def get_audio_by_slice(self,index): + + # wav_path = os.path.join('/apdcephfs/share_1316500/cloudezhou/MERT/MERT/converted', self.audio_names[index]) + wav_path = self.datas[index]['path'] + # print(wav_path) + audio_info = torchaudio.info(wav_path) + origin_sample_rate = audio_info.sample_rate + origin_duration = audio_info.num_frames / origin_sample_rate + + wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate) + wav = wav.float() + + # _path, slice_ptr = parse_path(wav_path) #这个应该也要改 + # original way + # if len(slice_ptr) == 0: + # wav, cur_sample_rate = sf.read(_path) + # else: + # assert _path.endswith(".zip") + # data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) + # f = io.BytesIO(data) + # wav, cur_sample_rate = sf.read(f) + # wav = torch.from_numpy(wav).float() + # print(wav.shape) + wav = wav.permute(1,0) + wav = self.postprocess(wav, self.sample_rate) #降至单个声道,确认采样率,归一化 + # print(wav.shape) + + # wav = wav.squeeze(0) + return wav + def get_audio(self, index): + import soundfile as sf + + # wav_path = os.path.join(self.audio_root, self.audio_names[index]) + wav_path = os.path.join('/apdcephfs/share_1316500/cloudezhou/MERT/MERT/converted', self.audio_names[index]) + # print(wav_path) + # self.reader() + _path, slice_ptr = parse_path(wav_path) #这个应该也要改 + # original way + if len(slice_ptr) == 0: + wav, cur_sample_rate = sf.read(_path) + else: + assert _path.endswith(".zip") + data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) + f = io.BytesIO(data) + wav, cur_sample_rate = sf.read(f) + wav = torch.from_numpy(wav).float() + + wav = self.postprocess(wav, cur_sample_rate) #降至单个声道,确认采样率,归一化 + # print(wav.shape) + return wav + + def get_label(self, index, label_idx): + #label_idx 表示第label_idx个字典,默认8个 + + if self.store_labels and (not self.npmemmap): + label = self.label_list[label_idx][index] + elif self.store_labels and self.npmemmap: + label = self.label_list[label_idx][index] + else: + with open(self.label_paths[label_idx]) as f: + offset_s, offset_e = self.label_offsets_list[label_idx][index] + f.seek(offset_s) + label = f.read(offset_e - offset_s) + + if self.label_processors is not None: + label = self.label_processors[label_idx](label) + return 0 + + def get_labels(self, index): + return [self.get_label(index, i) for i in range(self.num_labels)] + + #在这里修改,将raw_data直接处理完放在里面;如果已经处理过则直接读取 + def __getitem__(self, i): + # WORLD_SIZE = int(torch.distributed.get_world_size()) + # WORLD_RANK = int(torch.distributed.get_rank()) + # np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i) + # index = random.randint(0,len(self.sizes) - 1) + index = i + item = None + while item is None: + try: + wav = self.get_audio_by_slice(index) + # labels = self.get_labels(index) #这个得改 + # labels = None + # item = {"id": index, "source": wav, "label_list": labels} + item = {"id": index, "source": wav} + except Exception as e: + # print(e) + traceback.print_exc() + print(f'skip damaged data {index}') + index = np.random.randint(0,len(self.sizes)-1) + return item + + def __len__(self): + return self.dataset_len + + def crop_to_max_size(self, wav, target_size): + size = len(wav) + diff = size - target_size + if diff <= 0: + return wav, 0 + + start, end = 0, target_size + if self.random_crop: + start = np.random.randint(0, diff + 1) + end = size - diff + start + return wav[start:end], start + + def collater(self, samples): + #这个方法类似collate_fn + samples = [s for s in samples if s["source"] is not None] + if len(samples) == 0: + return {} + + audios = [s["source"] for s in samples] + audio_sizes = [len(s) for s in audios] + if self.pad_audio: + audio_size = min(max(audio_sizes), self.max_sample_size) + else: + audio_size = min(min(audio_sizes), self.max_sample_size) + collated_audios, padding_mask, audio_starts, collated_cqt_labels = self.collater_audio( + audios, audio_size + ) + + # targets_by_label = [ + # [s["label_list"][i] for s in samples] for i in range(self.num_labels) + # ] + # targets_list, lengths_list, ntokens_list = self.collater_label( + # targets_by_label, audio_size, audio_starts + # ) + + net_input = {"source": collated_audios, "padding_mask": padding_mask, "cqt_labels": collated_cqt_labels} + + batch = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": net_input, + } + + if self.single_target: + batch["target_lengths"] = None + batch["ntokens"] = None + batch["target"] = None + else: + batch["target_lengths_list"] = None + batch["ntokens_list"] = None + batch["target_list"] = None + return batch + + def collater_audio(self, audios, audio_size): + collated_audios = audios[0].new_zeros(len(audios), audio_size) + padding_mask = ( + torch.BoolTensor(collated_audios.shape).fill_(False) + # if self.pad_audio else None + ) + audio_starts = [0 for _ in audios] + + for i, audio in enumerate(audios): + diff = len(audio) - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + assert self.pad_audio + collated_audios[i] = torch.cat([audio, audio.new_full((-diff,), 0.0)]) + padding_mask[i, diff:] = True + else: + collated_audios[i], audio_starts[i] = self.crop_to_max_size( + audio, audio_size + ) + + cqt_labels = None + if self.cqt_prediction_bin > 0: + cqt_labels = self.encoder_cqt_model(collated_audios.float(), forward_type='compute_cqt') + + for i, _ in enumerate(audios): + # compute cqt labels in advance + # cqt_labels + + # yizhilll: apply audio augmentation effects here + # the audio should be as the type torch.Tensor, in the shape [1, length] TODO? + if len(self.augmentation_effects) > 0: + with torch.no_grad(): + for effect, prob in zip(self.augmentation_effects, self.augmentation_probs): + if torch.rand(1).item() > prob: + if effect == 'composed_augmentation_v1': + # collated_audios[i] = self.composed_augment_v1(collated_audios[i]) + pass + elif effect == 'inbatch_noise_augment': + assert len(audios) > 1 + collated_audios[i] = self.inbatch_noise_augment( + target_audio = collated_audios[i], target_audio_idx = i, batch_audios = audios, + noise_len_min = self.inbatch_noise_augment_len_range[0], noise_len_max = self.inbatch_noise_augment_len_range[1], + n_noise_min = self.inbatch_noise_augment_number_range[0], n_noise_max = self.inbatch_noise_augment_number_range[1], + noise_vol = self.inbatch_noise_augment_volume) + else: + raise NotImplementedError() + + + return collated_audios, padding_mask, audio_starts, cqt_labels + + def collater_frm_label(self, targets, audio_size, audio_starts, label_rate, pad): + assert label_rate > 0 + s2f = label_rate / self.sample_rate # @yizhilll: 0.00625 for 100Hz and 16k sr + frm_starts = [int(round(s * s2f)) for s in audio_starts] # @yizhilll: should be all 0 if the audios are not croped + frm_size = int(round(audio_size * s2f)) # @yizhilll: this is the expected total number of given pseudo labels + if not self.pad_audio: + rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] # @yizhilll: what does this mean? + frm_size = min(frm_size, *rem_size) # @yizhilll: anyway, this should keep 3000 for 30s audio + targets = [t[s : s + frm_size] for t, s in zip(targets, frm_starts)] + logger.debug(f"audio_starts={audio_starts}") + logger.debug(f"frame_starts={frm_starts}") + logger.debug(f"frame_size={frm_size}") + + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) + return targets, lengths, ntokens + + def collater_seq_label(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens(targets, pad_idx=pad, left_pad=False) + return targets, lengths, ntokens + + def collater_label(self, targets_by_label, audio_size, audio_starts): + targets_list, lengths_list, ntokens_list = [], [], [] + itr = zip(targets_by_label, self.label_rates, self.pad_list) + for targets, label_rate, pad in itr: + if label_rate == -1.0: + targets, lengths, ntokens = self.collater_seq_label(targets, pad) + else: + targets, lengths, ntokens = self.collater_frm_label( + targets, audio_size, audio_starts, label_rate, pad + ) + targets_list.append(targets) + lengths_list.append(lengths) + ntokens_list.append(ntokens) + return targets_list, lengths_list, ntokens_list + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + if self.pad_audio: + return self.sizes[index] + return min(self.sizes[index], self.max_sample_size) + + # def ordered_indices(self): + # if self.shuffle: + # order = [np.random.permutation(len(self.sizes))] + # else: + # order = [np.arange(len(self.sizes))] + + # order.append(self.sizes) + # return np.lexsort(order)[::-1] + + def ordered_indices(self): + if self.shuffle: + try: + print("========Local rank :",torch.distributed.get_rank(),"========") + WORLD_SIZE = int(torch.distributed.get_world_size()) + WORLD_RANK = int(torch.distributed.get_rank()) + np.random.seed(self.epoch * WORLD_SIZE + WORLD_RANK) + order = np.random.permutation(len(self.sizes)) + print("==================multinode multigpu shuffle==================") + except: + print("==================singlenode shuffle==================") + order = np.random.permutation(len(self.sizes)) + else: + order = np.arange(len(self.sizes)) + + return order + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/_test.ipynb b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/_test.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..283e1e94ebba98991f9e3ed3f4f05d9956eaf585 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/_test.ipynb @@ -0,0 +1,143 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1034, (94, 11))" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from data_utils import compute_block_mask_2d\n", + "\n", + "\n", + "input_size = (752, 88) # 94 * 11 # (1024, 128) # 64 * 8\n", + "patch_size = 8\n", + "patches = (input_size[0] // patch_size ) * ( input_size[1] // patch_size )\n", + "img_shape = (input_size[0] // patch_size,input_size[1] // patch_size )\n", + "mask = compute_block_mask_2d( \n", + " shape=(16, patches),\n", + " mask_prob=0.8,\n", + " mask_length=2,\n", + " mask_prob_adjust=0.07,\n", + " inverse_mask=True,\n", + " require_same_masks=True,\n", + " expand_adjcent=False,\n", + " mask_dropout=0.0,\n", + " non_overlapping=False,\n", + " img_shape=img_shape,\n", + " flexible_mask=False\n", + ")\n", + "patches, img_shape" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([16, 1034])\n", + "tensor([828., 828., 828., 828., 828., 828., 828., 828., 828., 828., 828., 828.,\n", + " 828., 828., 828., 828.])\n" + ] + } + ], + "source": [ + "print(mask.shape)\n", + "print(mask.sum(dim=1))" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# library 导入库\n", + "import seaborn as sns\n", + "import pandas as pd\n", + "import numpy as np\n", + "# jupyter notebook显示多行输出\n", + "from IPython.core.interactiveshell import InteractiveShell \n", + "InteractiveShell.ast_node_interactivity = 'all'\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAABXgAAADLCAYAAADUbftLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA46klEQVR4nO3da3xU1b3/8e8kJJOAJ6CmSQjIzRsiCJVAGhBpaypaDxbbI6jUIFosChZJRQgoUTkaPFYup4IoFFFbBe1RpIVCMYVWJTXlJqVVLoKlBRLgpVzkkoSZ9X/gPykDk5nZM2smmeTz9rUfZM/M/q4985s1a5abNS5jjBEAAAAAAAAAIO4kNHQDAAAAAAAAAADhYYIXAAAAAAAAAOIUE7wAAAAAAAAAEKeY4AUAAAAAAACAOMUELwAAAAAAAADEKSZ4AQAAAAAAACBOMcELAAAAAAAAAHGKCV4AAAAAAAAAiFNM8AIAAAAAAABAnGKCFwAAAAAAAADiFBO8AAAAAAAAABChP/3pTxo8eLCys7Plcrm0dOnSoI9Zu3atrr76arndbl1yySVatGiR41wmeAEAAAAAAAAgQsePH1fPnj01Z86ckO6/e/du3XTTTfrWt76lzZs368EHH9SPfvQjrVq1ylGuyxhjwmkwAAAAAAAAAOBcLpdLb7/9toYMGVLvfSZOnKjly5dr69atdftuu+02HT58WCtXrgw5iyt4AQAAAAAAAMCPqqoqHT161GerqqqycuyysjLl5+f77Bs0aJDKysocHaeFldZY0CK5nePHnNz3XhRagniVmj3A8WOoIQDxKJz+Llz0k01XrOqIGgLOxfsPkaKGANQnKb1LQzch7tQc2hXw9pLnXtHjjz/us6+4uFiPPfZYxNkVFRXKzMz02ZeZmamjR4/q5MmTSk1NDek4jWaCFwAAAAAAAABiylMT8OaioiIVFhb67HO73dFskWNM8AIAAAAAAABonrzegDe73e6oTehmZWWpsrLSZ19lZaXS0tJCvnpXYoIXAAAAAAAAQDNlPKcbLDsvL08rVqzw2bd69Wrl5eU5Og4/sgYAAAAAAACgefLUBN4c+PLLL7V582Zt3rxZkrR7925t3rxZe/bskfTVcg8FBQV19x89erR27dqlhx9+WJ988onmzp2rN954Q+PHj3eU6/gK3kOHDmnhwoUqKytTRUWFpK8uJ+7Xr5/uuusufe1rX3N6SAAAAAAAAACIPRN4iQYn1q9fr29961t1f9eu3TtixAgtWrRI+/fvr5vslaTOnTtr+fLlGj9+vGbPnq327dtrwYIFGjRokKNcRxO8f/nLXzRo0CC1bNlS+fn5uuyyyyR9tTbE//7v/2r69OlatWqVcnJyAh6nqqpKVVVVPvuMMXK5XI4aDwAAAAAAAADhsrlEwze/+U0ZY+q9fdGiRX4fs2nTpohyHU3wPvDAA7r11ls1b968cyZjjTEaPXq0HnjgAZWVlQU8TklJiR5//HGffa6E8+RKTHPSHAAAAAAAAAAIXwOuwWuLozV4P/roI40fP97vlbYul0vjx4+vW2MikKKiIh05csRncyX8h5OmAAAAAAAAAEBkvJ7AWxxwdAVvVlaWysvL1bVrV7+3l5eXKzMzM+hx3G633G63zz6WZwAAAAAAAAAQU03gCl5HE7wPPfSQ7r33Xm3YsEHXXXdd3WRuZWWlSktLNX/+fP3sZz+LSkMBAAAAAAAAwCqLP7LWUBxN8I4ZM0bp6emaOXOm5s6dK4/nq8uUExMT1bt3by1atEhDhw6NSkMBAAAAAAAAwCbjqWnoJkTM0QSvJA0bNkzDhg1TTU2NDh06JElKT09XUlKS9cYBAAAAAAAAQNQ0tyUazpSUlKS2bdvabAsAAAAAAAAAxE4TWKLBZYwxDd0ISao5tKuhm9AopGYPcPyYk/vei0JL4k84z104Gvvz3RRrKFavbTga+3MXDmoofI39eQAAm+hbEammOOZoamI5Due1bZqoodhLSu/S0E2IO6fK3wx4e0rfW2PUkvCFfQUvAAAAAAAAAMS15rxEAwAAAAAAAADENW/8L9HABC8AAAAAAACAZsl4ahq6CRFjghcAAAAAAABA88QSDQAAAAAAAAAQp1iiAQAAAAAAAADiFFfwAgAAAAAAAECcYoIXAAAAAAAAAOIUSzQAAAAAAAAAQJziCl4AAAAAAAAAiFNM8AIAAAAAAABAnDIs0QAAAAAAAAAA8el0/F/Bm9DQDQAAAAAAAACABuHxBN7CMGfOHHXq1EkpKSnKzc1VeXl5wPvPmjVLl19+uVJTU3XRRRdp/PjxOnXqVMh5XMHbyJzc915DNyFuNcXnLjV7gOPHNMXnoSmeU6xQQ19piucExIum2A81xXNqzMJ5vsMR7msUq/aFo7HXXTjta8zPtxS757yxvy9ihf74K9RD+KghSJK8dpdoWLJkiQoLCzVv3jzl5uZq1qxZGjRokLZt26aMjIxz7v/aa69p0qRJWrhwofr166ft27frrrvuksvl0owZM0LK5ApeAAAAAAAAAM2T53TgzaEZM2Zo1KhRGjlypLp166Z58+apZcuWWrhwod/7r1u3Tv3799cdd9yhTp066frrr9ftt98e9KrfMzHBCwAAAAAAAKB5CrJEQ1VVlY4ePeqzVVVV+T1UdXW1NmzYoPz8/Lp9CQkJys/PV1lZmd/H9OvXTxs2bKib0N21a5dWrFih7373uyGfAhO8AAAAAAAAAJonrzfgVlJSotatW/tsJSUlfg916NAheTweZWZm+uzPzMxURUWF38fccccdeuKJJ3TNNdcoKSlJF198sb75zW9q8uTJIZ8CE7wAAAAAAAAAmiXj8QTcioqKdOTIEZ+tqKjIWv7atWv11FNPae7cudq4caPeeustLV++XNOmTQv5GPzIGgAAAAAAAIDmKcg6u263W263O6RDpaenKzExUZWVlT77KysrlZWV5fcxjz76qO6880796Ec/kiT16NFDx48f17333qspU6YoISH49bmOr+A9efKk3n//ff39738/57ZTp07plVdeCXoMJ2tXAAAAAAAAAEBUeE3gzYHk5GT17t1bpaWl/z6816vS0lLl5eX5fcyJEyfOmcRNTEyUJBkTWr6jCd7t27friiuu0LXXXqsePXpo4MCB2r9/f93tR44c0ciRI4Mex9/aFU/PnuekKQAAAAAAAAAQmdOnA28OFRYWav78+Xr55Zf18ccf67777tPx48fr5kwLCgp8lngYPHiwnn/+eS1evFi7d+/W6tWr9eijj2rw4MF1E73BOFqiYeLEierevbvWr1+vw4cP68EHH1T//v21du1adejQIeTjFBUVqbCw0GdfwrG9TpoCAAAAAAAAAJHxeKwebtiwYTp48KCmTp2qiooK9erVSytXrqz74bU9e/b4XLH7yCOPyOVy6ZFHHtHevXv1ta99TYMHD9aTTz4ZcqajCd5169bp3XffVXp6utLT0/Wb3/xG999/vwYMGKA1a9aoVatWIR3H39oVNdWHnDQFAAAAAAAAACLjcBmGUIwdO1Zjx471e9vatWt9/m7RooWKi4tVXFwcdp6jJRpOnjypFi3+PSfscrn0/PPPa/DgwRo4cKC2b98edkMAAAAAAAAAIKY8nsBbHHB0BW/Xrl21fv16XXHFFT77n3vuOUnSzTffbK9lAAAAAAAAABBF5nR8TOIG4ugK3ltuuUWvv/6639uee+453X777SH/uhsAAAAAAAAANCjjDbzFAUcTvEVFRVqxYkW9t8+dO1deb3ycOAAAAAAAAIBm7rQn8BYHXKaRXHJbc2iX48ekZg+IQkvOdXLfe44fE6u2hSucc0LT1ZjfS4gMr23T1Ng/Y8JBDcUe/UPTFM7rymuEM8XyM4baa/wa+5iDGmr8qKHYS0rv0tBNiDvHHx0a8PZW096IUUvC52gNXgAAAAAAAABoMryN4trXiDDBCwAAAAAAAKBZago/ssYELwAAAAAAAIDmycMELwAAAAAAAADEJ5ZoAAAAAAAAAID4ZE57G7oJEWOCFwAAAAAAAEDzxBq8AAAAAAAAABCnWKIBAAAAAAAAAOKT8bBEAwAAAAAAAADEJdbgBQAAAAAAAIB4xRINAAAAAAAAABCfzGkmeAEAAAAAAAAgPjHBCwAAAAAAAADxyTSBJRpcxphGcRY1h3Y1dBOaldTsATHJObnvvZjkIPbCqSHqITK8bxEpagjxglpFpBinAIB99K3hC3dsE87zl5TeJays5uzzWwYGvP2Ct/8Yo5aEjyt4AQAAAAAAADRL5nRDtyByCQ3dAAAAAAAAAABoEN4gWxjmzJmjTp06KSUlRbm5uSovLw94/8OHD2vMmDFq27at3G63LrvsMq1YsSLkPK7gBQAAAAAAANAs2b6Cd8mSJSosLNS8efOUm5urWbNmadCgQdq2bZsyMjLOuX91dbW+853vKCMjQ7/+9a/Vrl07/eMf/1CbNm1CzrQywWuMkcvlsnEoAAAAAAAAAIgJr+UJ3hkzZmjUqFEaOXKkJGnevHlavny5Fi5cqEmTJp1z/4ULF+rzzz/XunXrlJSUJEnq1KmTo0wrSzS43W59/PHHNg4FAAAAAAAAADFhvIG3qqoqHT161Gerqqrye6zq6mpt2LBB+fn5dfsSEhKUn5+vsrIyv49ZtmyZ8vLyNGbMGGVmZqp79+566qmn5PF4Qj4HR1fwFhYW+t3v8Xg0ffp0XXjhhZK+mqkOpKqq6pwnIqGqSm6320lzAAAAAAAAACBsxhN4VYKSkhI9/vjjPvuKi4v12GOPnXPfQ4cOyePxKDMz02d/ZmamPvnkE7/H37Vrl/7whz9o+PDhWrFihXbu3Kn7779fNTU1Ki4uDukcHE3wzpo1Sz179jxnDQhjjD7++GO1atUqpKUa/D0xj0z4iaY+PM5JcwAAAAAAAAAgbN7Tgecyi4qKzrno1eZFql6vVxkZGXrxxReVmJio3r17a+/evXrmmWeiM8H71FNP6cUXX9Szzz6rb3/723X7k5KStGjRInXr1i2k4/h7YhKO7XXSFAAAAAAAAACIiDGBJ3jdbnfIE7rp6elKTExUZWWlz/7KykplZWX5fUzbtm2VlJSkxMTEun1XXHGFKioqVF1dreTk5KC5jtbgnTRpkpYsWaL77rtPDz30kGpqapw8vI7b7VZaWprPxvIMAAAAAAAAAGLJe9oVcHMiOTlZvXv3Vmlp6b+P7/WqtLRUeXl5fh/Tv39/7dy5U16vt27f9u3b1bZt25Amd6UwfmStT58+2rBhgw4ePKicnBxt3bo1pGUZAAAAAAAAAKAx8XpcATenCgsLNX/+fL388sv6+OOPdd999+n48eMaOXKkJKmgoEBFRUV197/vvvv0+eefa9y4cdq+fbuWL1+up556SmPGjAk509ESDbXOO+88vfzyy1q8eLHy8/Md/aobAAAAAAAAADQGxmv3wtVhw4bp4MGDmjp1qioqKtSrVy+tXLmy7ofX9uzZo4SEf19ze9FFF2nVqlUaP368rrrqKrVr107jxo3TxIkTQ84Ma4K31m233aZrrrlGGzZsUMeOHSM5FAAAAAAAAADEVDhX6QYzduxYjR071u9ta9euPWdfXl6e/vznP4edF9EEryS1b99e7du3j/QwAAAAAAAAABBT0ZjgjTWXMcY0dCMkqebQroZuQr1Sswc4fszJfe9FoSX+hdO+cIRzTrFqmxTb57ypacw1FEs8D4gUNYSGwGctIhVuDVEPAAA0PknpXRq6CXFn+xU3BLz9so9Xxqgl4Yv4Cl4AAAAAAAAAiEdeT0LwOzVyTPACAAAAAAAAaJaawhINTPACAAAAAAAAaJa8hgleAAAAAAAAAIhLXi8TvAAAAAAAAAAQlzxe1uAFAAAAAAAAgLhkTEO3IHJM8AIAAAAAAABolriCFwAAAAAAAADilIcfWQMAAAAAAACA+ORlghcAAAAAAAAA4hNX8AIAAAAAAABAnGKCFwAAAAAAAADilFdM8AIAAAAAAABAXPIwwds8nNz3XkM3IaDG3L7G3Db8G6/TV3geEClqCJFKzR4Qs6ymVq88d1+J5fPQ1FBDAIBQ8HnR9DDBCwAAAAAAAABxytvQDbCACV4AAAAAAAAAzZLHxRW8AAAAAAAAABCXmsKPrCU0dAMAAAAAAAAAoCF4gmzhmDNnjjp16qSUlBTl5uaqvLw8pMctXrxYLpdLQ4YMcZTHBC8AAAAAAACAZsnjcgXcnFqyZIkKCwtVXFysjRs3qmfPnho0aJAOHDgQ8HGfffaZHnroIQ0Y4PyH/BxN8G7cuFG7d++u+/vVV19V//79ddFFF+maa67R4sWLQzpOVVWVjh496rNVVVU5azkAAAAAAAAARMAbZHNqxowZGjVqlEaOHKlu3bpp3rx5atmypRYuXFjvYzwej4YPH67HH39cXbp0cZzpaIJ35MiR+vTTTyVJCxYs0I9//GPl5ORoypQp6tOnj0aNGhWwsbVKSkrUunVrn+3p2fMcNx4AAAAAAAAAwnXa5Qq4OVFdXa0NGzYoPz+/bl9CQoLy8/NVVlZW7+OeeOIJZWRk6J577gnrHBz9yNqOHTt06aWXSpLmzp2r2bNna9SoUXW39+nTR08++aTuvvvugMcpKipSYWGhz76EY3udNAUAAAAAAAAAIuIJModbVVV1zsoDbrdbbrf7nPseOnRIHo9HmZmZPvszMzP1ySef+D3++++/r1/84hfavHmzo3afydEVvC1bttShQ4ckSXv37lXfvn19bs/NzfVZwqE+brdbaWlpPpu/JwUAAAAAAAAAoiXYEg3+ViIoKSmxkn3s2DHdeeedmj9/vtLT08M+jqMreG+88UY9//zzWrBggQYOHKhf//rX6tmzZ93tb7zxhi655JKwGwMAAAAAAAAAsRLsCl5/KxHUd6Fqenq6EhMTVVlZ6bO/srJSWVlZ59z/008/1WeffabBgwfX7fN6v1r5t0WLFtq2bZsuvvjioOfgaIL36aefVv/+/TVw4EDl5OTo2Wef1dq1a3XFFVdo27Zt+vOf/6y3337bySEBAAAAAAAAoEGcDnJ7fcsx+JOcnKzevXurtLRUQ4YMkfTVhG1paanGjh17zv27du2qv/71rz77HnnkER07dkyzZ8/WRRddFFKuowne7Oxsbdq0SdOnT9dvfvMbGWNUXl6uf/7zn+rfv78++OAD5eTkODkkAAAAAAAAADQI4+x31IIqLCzUiBEjlJOTo759+2rWrFk6fvy4Ro4cKUkqKChQu3btVFJSopSUFHXv3t3n8W3atJGkc/YH4miCtzZk+vTpmj59utOHAgAAAAAAAECjEewKXqeGDRumgwcPaurUqaqoqFCvXr20cuXKuh9e27NnjxISHP0sWlAuY4yxesQw1RzaFZOc1OwBjh9zct97UWhJ/AnnuWvswnltqaH4EKt6jeVrS+0hUk2tH6e+m66m2Ic3RY35dYplf0cdha+pfS6FixqKrcbcd4WrKZ4TwpeU3qWhmxB3Znf4YcDbx+35ZYxaEj7HV/ACAAAAAAAAQFPgbegGWMAELwAAAAAAAIBmydPQDbCACV4AAAAAAAAAzdJpyz+y1hCY4AUAAAAAAADQLDWKHyeLEBO8AAAAAAAAAJql001gipcJXgAAAAAAAADNEmvwAgAAAAAAAECc8rIGLwAAAAAAAADEJw9LNAAAAAAAAABAfGINXgAAAAAAAACIU/E/vcsELwAAAAAAAIBmiit4AQAAAAAAACBOeRq6ARYwwQsAAAAAAACgWTJN4ApelzGmUZxFi+R2Dd2ERuHkvvcauglxKzV7QEM3oVGghmIrlnXX1F5bnjs0lFjVHnXXdFFDAJoLxmvxoTF/LjXFGmrs55SU3iUKLWna7u80NODtcz97I0YtCR9X8AIAAAAAAABoljxN4ApeJngBAAAAAAAANEvehm6ABUzwAgAAAAAAAGiWuIIXAAAAAAAAAOIUE7wAAAAAAAAAEKe8Jv4neBOcPuC5555TQUGBFi9eLEl69dVX1a1bN3Xt2lWTJ0/W6dOngx6jqqpKR48e9dlME3gyAQAAAAAAAMQPj0zALR44muD97//+b02ePFknTpzQ+PHj9fTTT2v8+PEaPny4RowYoQULFmjatGlBj1NSUqLWrVv7bMZ7LOyTAAAAAAAAAACnojHBO2fOHHXq1EkpKSnKzc1VeXl5vfedP3++BgwYoPPPP1/nn3++8vPzA97fH0cTvIsWLdKiRYv061//WitXrtSUKVM0e/ZsTZkyRUVFRXrhhRf02muvBT1OUVGRjhw54rO5Ev7DUcMBAAAAAAAAIBJemYCbU0uWLFFhYaGKi4u1ceNG9ezZU4MGDdKBAwf83n/t2rW6/fbbtWbNGpWVlemiiy7S9ddfr71794ac6WiCd9++fcrJyZEk9ezZUwkJCerVq1fd7VdffbX27dsX9Dhut1tpaWk+m8vlctIUAAAAAAAAAIiI7St4Z8yYoVGjRmnkyJHq1q2b5s2bp5YtW2rhwoV+7/+rX/1K999/v3r16qWuXbtqwYIF8nq9Ki0tDTnT0QRvVlaW/v73v0uSduzYIY/HU/e3JP3tb39TRkaGk0MCAAAAAAAAQIPwGG/Azd9viVVVVfk9VnV1tTZs2KD8/Py6fQkJCcrPz1dZWVlI7Tlx4oRqamp0wQUXhHwOjiZ4hw8froKCAo0aNUqDBg3Sww8/rIceekjz5s3TCy+8oNGjR+uWW25xckgAAAAAAAAAaBDeIJu/3xIrKSnxe6xDhw7J4/EoMzPTZ39mZqYqKipCas/EiROVnZ3tM0kcTIuQ7ynp8ccfV2pqqsrKyjRq1ChNmjRJPXv21MMPP6wTJ05o8ODBIf3IGgAAAAAAAAA0NI+8AW8vKipSYWGhzz632x2VtkyfPl2LFy/W2rVrlZKSEvLjHE3wJiQkaPLkyT77brvtNt12221ODgMAAAAAAAAADc5jAq+z63a7Q57QTU9PV2JioiorK332V1ZWKisrK+Bjf/azn2n69Ol69913ddVVV4WUV8vRBG80ndz3nuPHpGYPiEkOvhLO8x2ucF6nWL22sXwemppwnzte26apsffHTbEeGvtzHis8D+FrzO+LWL6uTe1zKZbPHTUUW419/I7Gj9c1PjTm16kxty1cTfGcmjsTxg+p1Sc5OVm9e/dWaWmphgwZIkl1P5g2duzYeh/3P//zP3ryySe1atUq5eTkOM5tNBO8AAAAAAAAABBLHhN4iQanCgsLNWLECOXk5Khv376aNWuWjh8/rpEjR0qSCgoK1K5du7p1fJ9++mlNnTpVr732mjp16lS3Vu95552n8847L6RMJngBAAAAAAAANEvB1uB1atiwYTp48KCmTp2qiooK9erVSytXrqz74bU9e/YoISGh7v7PP/+8qqur9V//9V8+xykuLtZjjz0WUiYTvAAAAAAAAACaJW+QNXjDMXbs2HqXZFi7dq3P35999lnEeUzwAgAAAAAAAGiWPBbX4G0oTPACAAAAAAAAaJZsr8HbEJjgBQAAAAAAANAsebmCFwAAAAAAAADik5creAEAAAAAAAAgPrFEAwAAAAAAAADEKZZoAAAAAAAAAIA4xRW8AAAAAAAAABCnmOAFAAAAAAAAgDhlWKIBAAAAAAAAAOJTU7iC12WMaRTT1DWHdjV0E6xKzR4Qs6yT+96LWRZiK1Z1RA3FB+oBtRr7Z0w47aPuEE9i+R50ivdSfKCGmi7Ga4gUNYRIJaV3aegmxJ0u6V8PePuuQ5ti1JLwcQUvAAAAAAAAgGbJNIEreMOa4K2urtbSpUtVVlamiooKSVJWVpb69eun733ve0pOTrbaSAAAAAAAAACwrSks0ZDg9AE7d+7UFVdcoREjRmjTpk3yer3yer3atGmTCgoKdOWVV2rnzp3RaCsAAAAAAAAAWOMx3oBbPHB8Be99992nHj16aNOmTUpLS/O57ejRoyooKNCYMWO0atUqa40EAAAAAAAAANu8jePnySLieIL3gw8+UHl5+TmTu5KUlpamadOmKTc310rjAAAAAAAAACBavHFylW4gjid427Rpo88++0zdu3f3e/tnn32mNm3aBDxGVVWVqqqqfPYlVFXJ7XY7bQ4AAAAAAAAAhCVelmEIxPEavD/60Y9UUFCgmTNnasuWLaqsrFRlZaW2bNmimTNn6q677tK9994b8BglJSVq3bq1z/b07HlhnwQAAAAAAAAAOOU1JuAWDxxfwfvEE0+oVatWeuaZZ/TTn/5ULpdLkmSMUVZWliZOnKiHH3444DGKiopUWFjosy/h2F6nTQEAAAAAAACAsDWFK3gdT/BK0sSJEzVx4kTt3r1bFRUVkqSsrCx17tw5pMe73e5zlmOoqT4UTlMAAAAAAAAAICweb/xP8DpeouFMnTt3Vl5envLy8uomd//5z3/q7rvvttI4AAAAAAAAAIgWE+S/cMyZM0edOnVSSkqKcnNzVV5eHvD+b775prp27aqUlBT16NFDK1ascJQX0QSvP59//rlefvll24cFAAAAAAAAAKs8Xm/AzaklS5aosLBQxcXF2rhxo3r27KlBgwbpwIEDfu+/bt063X777brnnnu0adMmDRkyREOGDNHWrVtDznS8RMOyZcsC3r5r1y6nhwQAAAAAAACAmPNaXoN3xowZGjVqlEaOHClJmjdvnpYvX66FCxdq0qRJ59x/9uzZuuGGGzRhwgRJ0rRp07R69Wo999xzmjdvXkiZjid4hwwZIpfLJRPgV+Rqf3gNAAAAAAAAABqrQHOcklRVVaWqqiqfff5+X0ySqqurtWHDBhUVFdXtS0hIUH5+vsrKyvwev6ysTIWFhT77Bg0apKVLl4Z4BpKMQ9nZ2Wbp0qX13r5p0yaTkJDg9LB+nTp1yhQXF5tTp05ZOV5jyGpqObHM4pwaf04sszinxp8Ty6ymlhPLLM6p8efEMotzavw5scxqajmxzOKcGn9OLLOaWk4sszinxp8TyyzOCQ2luLjYSPLZiouL/d537969RpJZt26dz/4JEyaYvn37+n1MUlKSee2113z2zZkzx2RkZITcRscTvIMHDzaPPvpovbdv3rzZuFwup4f168iRI0aSOXLkiJXjNYasppYTyyzOqfHnxDKLc2r8ObHMamo5sczinBp/TiyzOKfGnxPLrKaWE8sszqnx58Qyq6nlxDKLc2r8ObHM4pzQUE6dOmWOHDnis9U3Kd9QE7yOl2iYMGGCjh8/Xu/tl1xyidasWeP0sAAAAAAAAADQqNS3HIM/6enpSkxMVGVlpc/+yspKZWVl+X1MVlaWo/v7kxDyPf+/AQMG6IYbbqj39latWmngwIFODwsAAAAAAAAAcSs5OVm9e/dWaWlp3T6v16vS0lLl5eX5fUxeXp7P/SVp9erV9d7fH8dX8AIAAAAAAAAAzlVYWKgRI0YoJydHffv21axZs3T8+HGNHDlSklRQUKB27dqppKREkjRu3DgNHDhQzz77rG666SYtXrxY69ev14svvhhyZqOe4HW73SouLg75Muh4yGpqObHM4pwaf04sszinxp8Ty6ymlhPLLM6p8efEMotzavw5scxqajmxzOKcGn9OLLOaWk4sszinxp8TyyzOCfFi2LBhOnjwoKZOnaqKigr16tVLK1euVGZmpiRpz549Skj496IK/fr102uvvaZHHnlEkydP1qWXXqqlS5eqe/fuIWe6jDHG+pkAAAAAAAAAAKLO8Rq8AAAAAAAAAIDGgQleAAAAAAAAAIhTTPACAAAAAAAAQJxighcAAAAAAAAA4lSjnuCdM2eOOnXqpJSUFOXm5qq8vNzq8UtKStSnTx/9x3/8hzIyMjRkyBBt27bNaoY/06dPl8vl0oMPPhiV4+/du1c//OEPdeGFFyo1NVU9evTQ+vXrrWZ4PB49+uij6ty5s1JTU3XxxRdr2rRpsvGbfX/60580ePBgZWdny+VyaenSpT63G2M0depUtW3bVqmpqcrPz9eOHTus5tTU1GjixInq0aOHWrVqpezsbBUUFGjfvn1ROaczjR49Wi6XS7NmzYpKzscff6ybb75ZrVu3VqtWrdSnTx/t2bPHas6XX36psWPHqn379kpNTVW3bt00b948x+cTynv01KlTGjNmjC688EKdd955+sEPfqDKykrrWZ9//rkeeOABXX755UpNTVWHDh30k5/8REeOHLF+TrWMMbrxxhuD1kwkOWVlZfr2t7+tVq1aKS0tTddee61OnjxpPauiokJ33nmnsrKy1KpVK1199dX6v//7P0c5zz//vK666iqlpaUpLS1NeXl5+t3vfld3u61aCJZlqxZCOadakdRCqDk2aiGULBu14I+/z1abNVFfjs16CJZ1pkhrIpQcWzURKMdWPTz22GNyuVw+W9euXetut1kLgbJs1kOwc6oVaS2EkmOrFoJl2ewfgo2HbY0nA+XYHk86GeNHMp4MJcfGeDJYjq3xZKdOnc6pO5fLpTFjxkiy1z8EyrH9WRHsnGpF2j+EkmOrfwiWZat/COV7rI3+IViOzf7B6XfzcPuHUHMi7R9CybHVP0jSsWPH9OCDD6pjx45KTU1Vv3799Je//KXudlufF4FybH9eoJkyjdTixYtNcnKyWbhwofnb3/5mRo0aZdq0aWMqKyutZQwaNMi89NJLZuvWrWbz5s3mu9/9runQoYP58ssvrWWcrby83HTq1MlcddVVZty4cdaP//nnn5uOHTuau+66y3z44Ydm165dZtWqVWbnzp1Wc5588klz4YUXmt/+9rdm9+7d5s033zTnnXeemT17dsTHXrFihZkyZYp56623jCTz9ttv+9w+ffp007p1a7N06VLz0UcfmZtvvtl07tzZnDx50lrO4cOHTX5+vlmyZIn55JNPTFlZmenbt6/p3bt3VM6p1ltvvWV69uxpsrOzzcyZM63n7Ny501xwwQVmwoQJZuPGjWbnzp3mnXfecfy+CpYzatQoc/HFF5s1a9aY3bt3mxdeeMEkJiaad955x1FOKO/R0aNHm4suusiUlpaa9evXm2984xumX79+jnJCyfrrX/9qvv/975tly5aZnTt3mtLSUnPppZeaH/zgB9bPqdaMGTPMjTfeGLBmIslZt26dSUtLMyUlJWbr1q3mk08+MUuWLDGnTp2ynvWd73zH9OnTx3z44Yfm008/NdOmTTMJCQlm48aNIecsW7bMLF++3Gzfvt1s27bNTJ482SQlJZmtW7caY+zVQrAsW7UQyjnViqQWQsmxVQuhZNmohbPV99lqsybqy7FZD6GcU61IayJYjs2aCJRjqx6Ki4vNlVdeafbv31+3HTx4sO52m7UQKMtmPQQ7p1qR1kKwHJu1ECzLVj2EMh62MZ4MlmNzPOlkjB/JeDKUHBvjyVBybI0nDxw44FNzq1evNpLMmjVrjDH2+odAObY/K4KdU61I+4dgOTb7h2BZtvqHUL7H2ugfguXY7B+cfDePpH8IJcdG/xBKjq3+wRhjhg4darp162b++Mc/mh07dpji4mKTlpZm/vWvfxlj7M0/BMqxPf+A5qnRTvD27dvXjBkzpu5vj8djsrOzTUlJSdQyDxw4YCSZP/7xj1E5/rFjx8yll15qVq9ebQYOHBiVCd6JEyeaa665xvpxz3bTTTeZu+++22ff97//fTN8+HCrOWcPRrxer8nKyjLPPPNM3b7Dhw8bt9ttXn/9dWs5/pSXlxtJ5h//+EfYOYGy/vWvf5l27dqZrVu3mo4dO4Y1wRssZ9iwYeaHP/xhRMcNJefKK680TzzxhM++q6++2kyZMiWirLPfo4cPHzZJSUnmzTffrLvPxx9/bCSZsrIyq1n+vPHGGyY5OdnU1NRYz9m0aZNp166d2b9/f8QTOPXl5ObmmkceeSSi44aa1apVK/PKK6/43O+CCy4w8+fPjyjr/PPPNwsWLIhqLZyd5Y+NWqgvx3Yt+MuJVi34y7JdC/V9ttquCSef4ZHWQ7AsWzURKMdmTQTKsVUPxcXFpmfPnn5vs10LgbL8CbceQsmxUQvBcmzWQrAsW/UQbDxsazwZzrg73PFkqFmRjidDybExngwlJ1rjyXHjxpmLL77YeL3eqI4fzszxx+bYwV9WNMYPZ+dEc/xwdpat/iHY91hb/UM435fD7R9CzYq0fwglx0b/EEqOrf7hxIkTJjEx0fz2t7/1eyxb9RAsxx9b8w9oPhrlEg3V1dXasGGD8vPz6/YlJCQoPz9fZWVlUcut/ScyF1xwQVSOP2bMGN10000+52XbsmXLlJOTo1tvvVUZGRn6+te/rvnz51vP6devn0pLS7V9+3ZJ0kcffaT3339fN954o/WsM+3evVsVFRU+z2Hr1q2Vm5sb1dqQvqoPl8ulNm3aWD+21+vVnXfeqQkTJujKK6+0fvzajOXLl+uyyy7ToEGDlJGRodzc3Ij/aa8//fr107Jly7R3714ZY7RmzRpt375d119/fUTHPfs9umHDBtXU1PjUQ9euXdWhQ4eI6yGU/uDIkSNKS0tTixYtrOacOHFCd9xxh+bMmaOsrKywjx0o58CBA/rwww+VkZGhfv36KTMzUwMHDtT7779vPUv6qiaWLFmizz//XF6vV4sXL9apU6f0zW9+M6wMj8ejxYsX6/jx48rLy4tqLZyd5Y+NWvCXE41aODsnmrXg75xs10J9n622a8LJZ3ik9RAoy2ZN1JdjuyYCnY/NetixY4eys7PVpUsXDR8+vO6fg0ajf6gvy59I6iFQjs1aqC8nGv1DoHOyVQ/BxsO2xpPhjLvDHU+GkmVjPBksx9Z4MpTzicZ4srq6Wr/85S919913y+VyRW38cHaOPzbGDvVlRWP8cHZONMcP/s7JVv8Q7Husrf4hnO/L4fYPoWTZ6B+C5djqH0I5H1v9w+nTp+XxeJSSkuKzPzU1Ve+//761egiW40805x/QRDXwBLNfe/fuNZLMunXrfPZPmDDB9O3bNyqZHo/H3HTTTaZ///5ROf7rr79uunfvXncZf7Su4HW73cbtdpuioiKzceNG88ILL5iUlBSzaNEiqzkej8dMnDjRuFwu06JFC+NyucxTTz1lNcOYc68O/eCDD4wks2/fPp/73XrrrWbo0KHWcs528uRJc/XVV5s77rgj7IxAWU899ZT5zne+U/d/qKNxBW/t/71v2bKlmTFjhtm0aZMpKSkxLpfLrF271lqOMcacOnXKFBQUGEmmRYsWJjk52bz88sthZxjj/z36q1/9yiQnJ59z3z59+piHH37YatbZDh48aDp06GAmT55sPefee+8199xzT93fweoznJyysjIjyVxwwQVm4cKFZuPGjebBBx80ycnJZvv27VazjDHmiy++MNdff31dTaSlpZlVq1Y5Pv6WLVtMq1atTGJiomndurVZvny5MSY6tVBf1tkirYVAOTZrob6caNRCoHOyVQvGBP5stVkTTj7DI62HYFm2aiJQjs2aCHY+tuphxYoV5o033jAfffSRWblypcnLyzMdOnQwR48etd4/BMo6WyT1ECzHVi0EyrHdPwQ7J1v1EGw8bGs86XTcHcl4MpQsG+PJYDm2xpOhnE80xpNLliwxiYmJZu/evcaY6I0lz845m41xZKAs22NJfznRGkv6yzLGXv8Q7Husrf7B6fflSPqHULJs9A/Bcmz1D6Gcj83+IS8vzwwcONDs3bvXnD592rz66qsmISHBXHbZZVbnHwLlnM3m/AOaDyZ4/7/Ro0ebjh07mn/+85/Wj71nzx6TkZFhPvroo7p90ZrgTUpKMnl5eT77HnjgAfONb3zDas7rr79u2rdvb15//XWzZcsW88orr5gLLrjA+kRyY5jgra6uNoMHDzZf//rXzZEjR8LOqC9r/fr1JjMz02cAE40J3tr31e233+5zv8GDB5vbbrvNWo4xxjzzzDPmsssuM8uWLTMfffSR+fnPf27OO+88s3r16rBz/L1HozUoD9YfHDlyxPTt29fccMMNprq62mrOO++8Yy655BJz7Nixun2RDsr95dS+l4qKinzu26NHDzNp0iSrWcYYM3bsWNO3b1/z7rvvms2bN5vHHnvMtG7d2mzZssXR8auqqsyOHTvM+vXrzaRJk0x6err529/+FpVaqC/rTDZqob4c27VQX040aiHQc2erFoJ9ttqqCSef4ZHWQ7AsWzURLMdWTYTy3Nmqh7N98cUXJi0tzSxYsCBqnxX+ss5k67PCX040Piv85UTrs8JfljH26iHYeNjWeNLJuDvS8WSwLFvjyWA5tsaToTx30RhPXn/99eY///M/6/6OVv9wds6ZbPcNZ2dFq384Oyea/YO/589W/xDse6yt/sHJ9+VI+4dgWbb6h2A5tvqHUJ47m/3Dzp07zbXXXmskmcTERNOnTx8zfPhw07VrV6vzD4FyzmR7/gHNR6Oc4K2qqjKJiYnnfAgVFBSYm2++2XremDFjTPv27c2uXbusH9sYY95+++26N3HtJsm4XC6TmJhoTp8+bS2rQ4cOPv+31hhj5s6da7Kzs61lGGNM+/btzXPPPeezb9q0aebyyy+3mnP2YOTTTz81ksymTZt87nfttdean/zkJ9ZyalVXV5shQ4aYq666yhw6dCjs4wfKmjlzZl0tnFkfCQkJpmPHjtZyqqqqTIsWLcy0adN87vfwww9H9KNDZ+ecOHHCJCUlnbO+0D333GMGDRoUVkZ979HS0lIjyXzxxRc++zt06GBmzJhhNavW0aNHTV5enrnuuuscL6wfSs64cePqrYeBAwday9m1a5eRZF599VWf/UOHDg37/xTXl7Vz504j6ZwfDrvuuuvMj3/847CyzjzGvffeG5VaqC+rlq1aqC/Hdi3UlxONWqgvy2YtBPtsfffdd63URKif4TbqIVjW2LFjrdREsJza1ynSmgg1Jxp9gzHG5OTkmEmTJsWkf6jNqhWt/qE2J9r9Q21OLPqH2iyb9RBsPGxrPBnquNvGeDJYlq3xZLAcW+PJYDnRGE9+9tlnJiEhwSxdurRuXzT6B385tWz3Df6yotE/+MuJVv/gL8tm/xDse6yt/iHU78s2+odgWbb6h2A5tvqHYDnR6B+MMebLL7+sm8gdOnSo+e53vxuV+Qd/ObWiMf+A5qNRrsGbnJys3r17q7S0tG6f1+tVaWlpvesfhsMYo7Fjx+rtt9/WH/7wB3Xu3Nnasc903XXX6a9//as2b95ct+Xk5Gj48OHavHmzEhMTrWX1799f27Zt89m3fft2dezY0VqG9NW6TgkJvuWTmJgor9drNedsnTt3VlZWlk9tHD16VB9++KHV2pCkmpoaDR06VDt27NC7776rCy+80Orxa915553asmWLT31kZ2drwoQJWrVqlbWc5ORk9enTJ+r1UVNTo5qaGiv1Eew92rt3byUlJfnUw7Zt27Rnzx7H9RBKf3D06FFdf/31Sk5O1rJly85ZQ8lGzqRJk86pB0maOXOmXnrpJWs5nTp1UnZ2tpV6CJZ14sQJSYpKn+H1elVVVWW1FoJlSXZqIViOrVoIlmOzFoJl2ayFYJ+tOTk5VmoilM9wW/UQLGvKlClWaiJYTpcuXazURLCcaPYNX375pT799FO1bds26v3DmVlS9PqHM3Oi2T+cmRPt/uHMLJv1EGw8bGs8Gcq429Z4MliWrfFksBxb48lgOTbHk7VeeuklZWRk6KabbqrbF43+wV+OFJ2+wV9WNPoHfznR6h/8ZdnsH4J9j7XVP4TyfdlW/xAsy1b/ECzHVv8QLCca/YMktWrVSm3bttUXX3yhVatW6Xvf+15U5h/85dSeVyzmH9CENeTsciCLFy82brfbLFq0yPz973839957r2nTpo2pqKiwlnHfffeZ1q1bm7Vr15r9+/fXbSdOnLCWUZ9oLdFQXl5uWrRoYZ588kmzY8cO86tf/cq0bNnS/PKXv7SaM2LECNOuXTvz29/+1uzevdu89dZbJj093co/dTx27JjZtGmT2bRpk5FUt35P7a9HTp8+3bRp08a88847ZsuWLeZ73/ue6dy5s+P/Cx4op7q62tx8882mffv2ZvPmzT71UVVVZf2czhbuEg3Bct566y2TlJRkXnzxRbNjxw7z85//3CQmJpr33nvPas7AgQPNlVdeadasWWN27dplXnrpJZOSkmLmzp3rKCeU9+jo0aNNhw4dzB/+8Aezfv16k5eXd84/+bORdeTIEZObm2t69Ohhdu7c6XMfJ1fhh9PvKIx/VhdKzsyZM01aWpp58803zY4dO8wjjzxiUlJSzM6dO61mVVdXm0suucQMGDDAfPjhh2bnzp3mZz/7mXG5XPWua+vPpEmTzB//+Eeze/dus2XLFjNp0iTjcrnM73//e2OMvVoIlmWrFkI5p7OFUwuh5NiqhWBZtmqhPmd/ttqsifpybNZDsCx/wq2JYDk2a6K+HJv18NOf/tSsXbvW7N6923zwwQcmPz/fpKenmwMHDhhj7NZCoCyb9RDsnM4Wbi0Ey7FZC4GybNZDKONhG+PJYDk2x5PhjPHDGU+GkmNjPBlKjq3xpDFfrenZoUMHM3HixHNus9k/1JcTjc+KQOd0tkg+KwLl2P6sqC/LZv8QyvdYG/1DsByb/UM4383D6R9CybHRP4SSY7N/WLlypfnd735ndu3aZX7/+9+bnj17mtzc3LolVGzNPwTKsT3/gOap0U7wGmPMz3/+c9OhQweTnJxs+vbta/785z9bPb4kv9tLL71kNcefaE3wGmPMb37zG9O9e3fjdrtN165dzYsvvmg94+jRo2bcuHGmQ4cOJiUlxXTp0sVMmTLFSuezZs0av6/LiBEjjDHGeL1e8+ijj5rMzEzjdrvNddddZ7Zt22Y1Z/fu3fXWx5o1a6yf09nCneANJecXv/iFueSSS0xKSorp2bOn338+FmnO/v37zV133WWys7NNSkqKufzyy82zzz5bt6h/qEJ5j548edLcf//95vzzzzctW7Y0t9xyi9m/f7/jcwqWVd85SzK7d++2ek7+HuN0UB5qTklJiWnfvr1p2bKlycvLczzZH2rW9u3bzfe//32TkZFhWrZsaa666irzyiuvOMq5++67TceOHU1ycrL52te+Zq677jqfiVBbtRAsy1YthHJOZwv3C1ooOTZqIZQsG7VQn7M/W23WRH05NushWJY/0ZrgNcZeTQTKsVUPw4YNM23btjXJycmmXbt2ZtiwYT4TDDZrIVCWzXoIdk5nC7cWQsmxVQvBsmz2D8HGw7bGk4FybI8nnY7xwx1PhpJjYzwZLMfWeNIYY1atWmUk+X2NbfYP9eVE47Mi0DmdLZLPimA5Nj8rAmXZ6h9C+R5ro38IlmOzfwjnu3k4/UOoOZH2D6Hk2OwflixZYrp06WKSk5NNVlaWGTNmjDl8+HDd7bY+LwLl2P68QPPkMsYYAQAAAAAAAADiTqNcgxcAAAAAAAAAEBwTvAAAAAAAAAAQp5jgBQAAAAAAAIA4xQQvAAAAAAAAAMQpJngBAAAAAAAAIE4xwQsAAAAAAAAAcYoJXgAAAAAAAACIU0zwAgAAAAAAAECcYoIXAAAAAAAAAOIUE7wAAAAAAAAAEKeY4AUAAAAAAACAOMUELwAAAAAAAADEqf8HB1LbLLNtYSkAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.figure(figsize=(20, 2))\n", + "# sns.heatmap(mask.numpy())\n", + "sns.heatmap(mask[2].reshape(11, -1).numpy())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fairseq", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.0" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/data_utils.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f0234692a83f93272e868452e8ba13743264ce6d --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/data_utils.py @@ -0,0 +1,535 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import numpy as np +import torch + +from typing import Optional, Tuple + + + +logger = logging.getLogger(__name__) + + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, + require_same_masks: bool = True, + mask_dropout: float = 0.0, + add_masks: bool = False, + seed: Optional[int] = None, + epoch: Optional[int] = None, + indices: Optional[torch.Tensor] = None, + idc_select_ver: int = 1, # 2 to reproduce mask_tokens_dataset + num_mask_ver: int = 2, # 2 to reproduce mask_tokens_dataset +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample + mask_dropout: randomly dropout this percentage of masks in each example + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + if num_mask_ver == 1: + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if seed is not None and epoch is not None and indices is not None: + seed_i = int(hash((seed, epoch, indices[i].item())) % 1e6) + else: + seed_i = None + + rng = np.random.default_rng(seed_i) + + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + assert sz >= 0, sz + else: + sz = all_sz + + if num_mask_ver == 1: + if padding_mask is not None: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + num_mask = all_num_mask + elif num_mask_ver == 2: + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + rng.random() + ) + num_mask = max(min_masks, num_mask) + else: + raise ValueError() + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = rng.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = rng.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = rng.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + if mask_type == "static": + raise ValueError(f"this should never happens") + else: + lengths = [min(mask_length, sz - 1)] + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = rng.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = rng.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + if idc_select_ver == 1: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + mask_idc = rng.choice(sz - min_len, num_mask, replace=False) + elif idc_select_ver == 2: + mask_idc = rng.choice(sz, num_mask, replace=False) + else: + raise ValueError() + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idc = np.unique(mask_idc[mask_idc < sz]) + if len(mask_idc) >= sz: + raise ValueError( + ( + f"the entire sequence is masked. " + f"sz={sz}; mask_idc[mask_idc]; " + f"index={indices[i] if indices is not None else None}" + ) + ) + mask_idcs.append(mask_idc) + + target_len = None + if require_same_masks: + if add_masks: + target_len = max([len(m) for m in mask_idcs]) + else: + target_len = min([len(m) for m in mask_idcs]) + + for i, mask_idc in enumerate(mask_idcs): + if target_len is not None and len(mask_idc) > target_len: + mask_idc = rng.choice(mask_idc, target_len, replace=False) + + mask[i, mask_idc] = True + + if target_len is not None and len(mask_idc) < target_len: + unmasked = np.flatnonzero(~mask[i]) + to_mask = rng.choice(unmasked, target_len - len(mask_idc), replace=False) + mask[i, to_mask] = True + + if mask_dropout > 0: + masked = np.flatnonzero(mask[i]) + num_holes = np.rint(len(masked) * mask_dropout).astype(int) + to_drop = rng.choice(masked, num_holes, replace=False) + mask[i, to_drop] = False + + return mask + + +def compute_block_mask_2d( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + mask_prob_adjust: float = 0, + inverse_mask: bool = False, + require_same_masks: bool = True, + expand_adjcent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, + img_shape: tuple = None, # For the situation when d[0] != d[1], especially in audio spce ways + flexible_mask: bool = False, +) -> torch.Tensor: + + assert mask_length > 1 + + B, L = shape + + d = (int(L**0.5),int(L**0.5)) + + if img_shape: + d = (img_shape[0],img_shape[1]) + + if flexible_mask: + index = np.random.randint(0,3) + block_size_options = np.array([(6, 4), (5, 5), (8, 3)]) + block_size = block_size_options[index] + + if inverse_mask: + mask_prob = 1 - mask_prob + + if flexible_mask: + mask = torch.zeros((B, d[0], d[1])) + mask_inds = torch.randint( + 0, + L, + size=( + B, + int( + L + * ((mask_prob + mask_prob_adjust) / (block_size[0]*block_size[1])) + * (1 + mask_dropout) + ), + ), + ) + mask.view(B, -1).scatter_(1, mask_inds, 1) + centers = mask.nonzero(as_tuple=True) + + inds = ([], [], []) + + offset = mask_length // 2 + for i in range(block_size[0]): + for j in range(block_size[1]): + k1 = i - offset + k2 = j - offset + inds[0].append(centers[0]) + inds[1].append(centers[1] + k1) + inds[2].append(centers[2] + k2) + + i0 = torch.cat(inds[0]) + i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1) + i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1) + + mask[(i0, i1, i2)] = 1 + + elif non_overlapping: + sz = math.ceil(d[0] / mask_length) + inp_len = sz * sz + + inp = torch.zeros((B, 1, sz, sz)) + w = torch.ones((1, 1, mask_length, mask_length)) + + mask_inds = torch.multinomial( + 1 - inp.view(B, -1), + int(inp_len * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)), + replacement=False, + ) + inp.view(B, -1).scatter_(1, mask_inds, 1) + + mask = torch.nn.functional.conv_transpose2d(inp, w, stride=mask_length).squeeze( + 1 + ) + if mask.size(-1) > d[0]: + mask = mask[..., :d, :d] + else: + mask = torch.zeros((B, d[0], d[1])) + mask_inds = torch.randint( + 0, + L, + size=( + B, + int( + L + * ((mask_prob + mask_prob_adjust) / mask_length**2) + * (1 + mask_dropout) + ), + ), + ) + mask.view(B, -1).scatter_(1, mask_inds, 1) + centers = mask.nonzero(as_tuple=True) + + inds = ([], [], []) + + offset = mask_length // 2 + for i in range(mask_length): + for j in range(mask_length): + k1 = i - offset + k2 = j - offset + inds[0].append(centers[0]) + inds[1].append(centers[1] + k1) + inds[2].append(centers[2] + k2) + + i0 = torch.cat(inds[0]) + i1 = torch.cat(inds[1]).clamp_(min=0, max=d[0] - 1) + i2 = torch.cat(inds[2]).clamp_(min=0, max=d[1] - 1) + + mask[(i0, i1, i2)] = 1 + + def get_nbs(b, m, w): + all_nbs = torch.nn.functional.conv2d(m.unsqueeze(1), w, padding="same") + all_nbs = all_nbs.clamp_max_(1).view(b, -1) + return all_nbs + + if require_same_masks and expand_adjcent: + w = torch.zeros((1, 1, 3, 3)) + w[..., 0, 1] = 1 + w[..., 2, 1] = 1 + w[..., 1, 0] = 1 + w[..., 1, 2] = 1 + + all_nbs = get_nbs(B, mask, w) + + mask = mask.reshape(B, -1) + + if require_same_masks: + n_masks = mask.sum(dim=-1) + final_target_len = int(L * (mask_prob)) + target_len = int(final_target_len * (1 + mask_dropout)) + + for i in range(len(mask)): + n = n_masks[i] + m = mask[i] + r = 0 + while expand_adjcent and n < target_len: + if r == 0: + nbs = all_nbs[i] + else: + nbs = get_nbs(1, m.view(1, d[0], d[1]), w).flatten() + + cands = (1 - m + nbs) > 1 + cand_sz = int(cands.sum().item()) + + assert cand_sz > 0, f"{nbs} {cand_sz}" + + to_mask = torch.multinomial( + cands.float(), min(cand_sz, int(target_len - n)), replacement=False + ) + m[to_mask] = 1 + assert to_mask.numel() > 0 + n += to_mask.numel() + r += 1 + + if n > final_target_len: + to_unmask = torch.multinomial( + m, int(n - final_target_len), replacement=False + ) + m[to_unmask] = 0 + elif n < final_target_len: + to_mask = torch.multinomial( + (1 - m), int(final_target_len - n), replacement=False + ) + m[to_mask] = 1 + + if inverse_mask: + mask = 1 - mask + + return mask + + +def compute_block_mask_1d( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + mask_prob_adjust: float = 0, + inverse_mask: bool = False, + require_same_masks: bool = True, + expand_adjcent: bool = False, + mask_dropout: float = 0, + non_overlapping: bool = False, +) -> torch.Tensor: + + B, L = shape + + if inverse_mask: + mask_prob = 1 - mask_prob + + if non_overlapping: + sz = math.ceil(L / mask_length) + + inp = torch.zeros((B, 1, sz)) + w = torch.ones((1, 1, mask_length)) + + mask_inds = torch.multinomial( + 1 - inp.view(B, -1), + int(sz * (mask_prob + mask_prob_adjust) * (1 + mask_dropout)), + replacement=False, + ) + inp.view(B, -1).scatter_(1, mask_inds, 1) + + mask = torch.nn.functional.conv_transpose1d(inp, w, stride=mask_length).squeeze( + 1 + ) + if mask.size(-1) > L: + mask = mask[..., :L] + + else: + mask = torch.zeros((B, L)) + mask_inds = torch.randint( + 0, + L, + size=( + B, + int( + L + * ((mask_prob + mask_prob_adjust) / mask_length) + * (1 + mask_dropout) + ), + ), + ) + + mask.view(B, -1).scatter_(1, mask_inds, 1) + centers = mask.nonzero(as_tuple=True) + + inds = ([], []) + + offset = mask_length // 2 + for i in range(mask_length): + k1 = i - offset + inds[0].append(centers[0]) + inds[1].append(centers[1] + k1) + + i0 = torch.cat(inds[0]) + i1 = torch.cat(inds[1]).clamp_(min=0, max=L - 1) + + mask[(i0, i1)] = 1 + + def get_nbs(b, m, w): + all_nbs = torch.nn.functional.conv1d(m.unsqueeze(1), w, padding="same") + all_nbs = all_nbs.clamp_max_(1).view(b, -1) + return all_nbs + + if require_same_masks and expand_adjcent: + w = torch.ones((1, 1, 3)) + w[..., 1] = 0 + all_nbs = get_nbs(B, mask, w) + + mask = mask.view(B, -1) + + if require_same_masks: + n_masks = mask.sum(dim=-1) + final_target_len = int(L * (mask_prob)) + target_len = int(final_target_len * (1 + mask_dropout)) + + for i in range(len(mask)): + n = n_masks[i] + m = mask[i] + r = 0 + while expand_adjcent and n < target_len: + if r == 0: + nbs = all_nbs[i] + else: + nbs = get_nbs(1, m.unsqueeze(0), w).squeeze(0) + + cands = (1 - m + nbs) > 1 + cand_sz = int(cands.sum().item()) + + assert cand_sz > 0, f"{nbs} {cand_sz}" + + to_mask = torch.multinomial( + cands.float(), min(cand_sz, int(target_len - n)), replacement=False + ) + m[to_mask] = 1 + assert to_mask.numel() > 0 + n += to_mask.numel() + r += 1 + + if n > final_target_len: + to_unmask = torch.multinomial( + m, int(n - final_target_len), replacement=False + ) + m[to_unmask] = 0 + elif n < final_target_len: + to_mask = torch.multinomial( + (1 - m), int(final_target_len - n), replacement=False + ) + m[to_mask] = 1 + + if inverse_mask: + mask = 1 - mask + + return mask + + +def get_buckets(sizes, num_buckets): + buckets = np.unique( + np.percentile( + sizes, + np.linspace(0, 100, num_buckets + 1), + interpolation="lower", + )[1:] + ) + return buckets + + +def get_bucketed_sizes(orig_sizes, buckets): + sizes = np.copy(orig_sizes) + assert np.min(sizes) >= 0 + start_val = -1 + for end_val in buckets: + mask = (sizes > start_val) & (sizes <= end_val) + sizes[mask] = end_val + start_val = end_val + return sizes + + diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/mixup.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/mixup.py new file mode 100644 index 0000000000000000000000000000000000000000..9cd0d2c333a8973e1c994bc686d935117854ef40 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/data/utils/mixup.py @@ -0,0 +1,220 @@ +""" Mixup and Cutmix + +Papers: +mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) + +CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) + +Code Reference: +CutMix: https://github.com/clovaai/CutMix-PyTorch + +Hacked together by / Copyright 2019, Ross Wightman +""" +import numpy as np +import torch + + +def one_hot(x, num_classes, on_value=1., off_value=0.): + x = x.long().view(-1, 1) + return torch.full((x.size()[0], num_classes), off_value, device=x.device).scatter_(1, x, on_value) + +# adapted from using one_hot to directly using target values +def mixup_target(target, num_classes, lam=1., smoothing=0.0): + # off_value = smoothing / num_classes + # on_value = 1. - smoothing + off_value + # y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value) + # y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value) + y1 = target + y2 = target.flip(0) + return y1 * lam + y2 * (1. - lam) + + +def rand_bbox(img_shape, lam, margin=0., count=None): + """ Standard CutMix bounding-box + Generates a random square bbox based on lambda value. This impl includes + support for enforcing a border margin as percent of bbox dimensions. + + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) + count (int): Number of bbox to generate + """ + ratio = np.sqrt(1 - lam) + img_h, img_w = img_shape[-2:] + cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) + margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) + cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) + cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) + yl = np.clip(cy - cut_h // 2, 0, img_h) + yh = np.clip(cy + cut_h // 2, 0, img_h) + xl = np.clip(cx - cut_w // 2, 0, img_w) + xh = np.clip(cx + cut_w // 2, 0, img_w) + return yl, yh, xl, xh + + +def rand_bbox_minmax(img_shape, minmax, count=None): + """ Min-Max CutMix bounding-box + Inspired by Darknet cutmix impl, generates a random rectangular bbox + based on min/max percent values applied to each dimension of the input image. + + Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. + + Args: + img_shape (tuple): Image shape as tuple + minmax (tuple or list): Min and max bbox ratios (as percent of image size) + count (int): Number of bbox to generate + """ + assert len(minmax) == 2 + img_h, img_w = img_shape[-2:] + cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) + cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) + yl = np.random.randint(0, img_h - cut_h, size=count) + xl = np.random.randint(0, img_w - cut_w, size=count) + yu = yl + cut_h + xu = xl + cut_w + return yl, yu, xl, xu + + +def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): + """ Generate bbox and apply lambda correction. + """ + if ratio_minmax is not None: + yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) + else: + yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) + if correct_lam or ratio_minmax is not None: + bbox_area = (yu - yl) * (xu - xl) + lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) + return (yl, yu, xl, xu), lam + + +class Mixup: + """ Mixup/Cutmix that applies different params to each element or whole batch + + Args: + mixup_alpha (float): mixup alpha value, mixup is active if > 0. + cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. + cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. + prob (float): probability of applying mixup or cutmix per batch or element + switch_prob (float): probability of switching to cutmix instead of mixup when both are active + mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) + correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders + label_smoothing (float): apply label smoothing to the mixed target tensor + num_classes (int): number of classes for target + """ + def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, + mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): + self.mixup_alpha = mixup_alpha + self.cutmix_alpha = cutmix_alpha + self.cutmix_minmax = cutmix_minmax + if self.cutmix_minmax is not None: + assert len(self.cutmix_minmax) == 2 + # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe + self.cutmix_alpha = 1.0 + self.mix_prob = prob + self.switch_prob = switch_prob + self.label_smoothing = label_smoothing + self.num_classes = num_classes + self.mode = mode + self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix + self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) + + def _params_per_elem(self, batch_size): + lam = np.ones(batch_size, dtype=np.float32) + use_cutmix = np.zeros(batch_size, dtype=bool) + if self.mixup_enabled: + if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: + use_cutmix = np.random.rand(batch_size) < self.switch_prob + lam_mix = np.where( + use_cutmix, + np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), + np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) + elif self.mixup_alpha > 0.: + lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) + elif self.cutmix_alpha > 0.: + use_cutmix = np.ones(batch_size, dtype=bool) + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) + else: + assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." + lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) + return lam, use_cutmix + + def _params_per_batch(self): + lam = 1. + use_cutmix = False + if self.mixup_enabled and np.random.rand() < self.mix_prob: + if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: + use_cutmix = np.random.rand() < self.switch_prob + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ + np.random.beta(self.mixup_alpha, self.mixup_alpha) + elif self.mixup_alpha > 0.: + lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) + elif self.cutmix_alpha > 0.: + use_cutmix = True + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) + else: + assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." + lam = float(lam_mix) + return lam, use_cutmix + + def _mix_elem(self, x): + batch_size = len(x) + lam_batch, use_cutmix = self._params_per_elem(batch_size) + x_orig = x.clone() # need to keep an unmodified original for mixing source + for i in range(batch_size): + j = batch_size - i - 1 + lam = lam_batch[i] + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + x[i] = x[i] * lam + x_orig[j] * (1 - lam) + return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + + def _mix_pair(self, x): + batch_size = len(x) + lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + x_orig = x.clone() # need to keep an unmodified original for mixing source + for i in range(batch_size // 2): + j = batch_size - i - 1 + lam = lam_batch[i] + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] + x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + x[i] = x[i] * lam + x_orig[j] * (1 - lam) + x[j] = x[j] * lam + x_orig[i] * (1 - lam) + lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) + return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + + def _mix_batch(self, x): + lam, use_cutmix = self._params_per_batch() + if lam == 1.: + return 1. + if use_cutmix: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] + else: + x_flipped = x.flip(0).mul_(1. - lam) + x.mul_(lam).add_(x_flipped) + return lam + + def __call__(self, x, target): + assert len(x) % 2 == 0, 'Batch size should be even when using this' + if self.mode == 'elem': + lam = self._mix_elem(x) + elif self.mode == 'pair': + lam = self._mix_pair(x) + else: + lam = self._mix_batch(x) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) + return x, target \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/EAT_audio_classification.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/EAT_audio_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..0164e9a996951c731b806c6ba91c2defc02959e2 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/EAT_audio_classification.py @@ -0,0 +1,487 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# The code in this file is adapted from the BeiT implementation which can be found here: +# https://github.com/microsoft/unilm/tree/master/beit +import sys, os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +import logging +import torch +import torchaudio +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from dataclasses import dataclass,field +from enum import Enum, auto +from typing import Any, Optional +from omegaconf import II, MISSING +from fairseq import checkpoint_utils, tasks +from omegaconf import open_dict +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.tasks import FairseqTask + +from mae import interpolate_pos_embed +from mae import get_2d_sincos_pos_embed_flexible + +logger = logging.getLogger(__name__) + + +# EAT utilize cls token for prediction in most downstream tasks +class PredictionMode(Enum): + MEAN_POOLING = auto() + CLS_TOKEN = auto() + LIN_SOFTMAX = auto() + +# we follow the work of data2vec 2.0 on image modality and Audio-MAE in EAT +@dataclass +class MaeImageClassificationConfig(FairseqDataclass): + model_path: str = MISSING + no_pretrained_weights: bool = False + linear_classifier: bool = False + num_classes: int = 1000 + mixup: float = 0.0 + cutmix: float = 0.0 + label_smoothing: float = 0.0 + + drop_path_rate: float = 0.1 + layer_decay: float = 0.65 + + mixup_prob: float = 1.0 + mixup_switch_prob: float = 0.0 + mixup_mode: str = "batch" + + pretrained_model_args: Any = None + data: str = II("task.data") + + norm_eps: Optional[float] = None + + remove_alibi: bool = False + + # regularization overwrites + encoder_dropout: float = 0 + post_mlp_drop: float = 0 + attention_dropout: float = 0 + activation_dropout: float = 0.0 + dropout_input: float = 0.0 + layerdrop: float = 0.0 + + prenet_layerdrop: float = 0 + prenet_dropout: float = 0 + + use_fc_norm: bool = True + prediction_mode: PredictionMode = PredictionMode.CLS_TOKEN + + no_decay_blocks: bool = True + + # settings for specific downstream task + audio_mae: bool = field(default=False, metadata={"help": "if true, the task is to realize audio classification"}) + esc50_eval: bool = field(default=False, metadata={"help": "if true, the task is to finetune model on esc50 dataset"}) + spcv2_eval: bool = field(default=False, metadata={"help": "if true, the task is to finetune model on speech command v2 dataset"}) + target_length: int = field(default=1024,metadata={"help": "This setting will pad the input sequence will zeros."}) + + # specaug for specific downstream task + specaug: bool = field(default=False, metadata={"help": "if true, use the specaug technique (frame and frequency masked 30%)"}) + freqm: int = field(default=25, metadata={"help": "the mask ratio of frequency dimension in audio spectrogram by default"}) + timem: int = field(default=200, metadata={"help": "the mask ratio of time dimension in audio spectrogram by default"}) + mask_ratio: float = field(default=0.0, metadata={"help": "the mask ratio of both time and freq "}) + + +def get_layer_id_for_vit(name, num_layers): + """ + Assign a parameter with its layer id + Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33 + """ + if name in ["cls_token", "pos_embed"]: + return 0 + elif name.startswith("patch_embed"): + return 0 + elif name.startswith("rel_pos_bias"): + return num_layers - 1 + elif name.startswith("blocks"): + return int(name.split(".")[1]) + 1 + else: + return num_layers + + +@register_model("mae_image_classification", dataclass=MaeImageClassificationConfig) +class MaeImageClassificationModel(BaseFairseqModel): + def __init__(self, cfg: MaeImageClassificationConfig): + super().__init__() + self.cfg = cfg + self.audio_mae = self.cfg.audio_mae + self.esc50_eval = self.cfg.esc50_eval + self.spcv2_eval = self.cfg.spcv2_eval + self.target_length = self.cfg.target_length + + # adjust pre-training config into fine-tuning + if cfg.pretrained_model_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu(cfg.model_path, {}) + pretrained_args = state.get("cfg", None) + + pretrained_args.criterion = None + pretrained_args.lr_scheduler = None + + logger.info(pretrained_args.model) + + with open_dict(pretrained_args.model): + pretrained_args.model.drop_path_rate = cfg.drop_path_rate + if cfg.norm_eps is not None: + pretrained_args.model.norm_eps = cfg.norm_eps + + cfg.pretrained_model_args = pretrained_args + + logger.info(pretrained_args) + else: + state = None + pretrained_args = cfg.pretrained_model_args + + if "data" in pretrained_args.task: + pretrained_args.task.data = cfg.data + elif "image" in pretrained_args.task: + pretrained_args.task.image.data = cfg.data + + if "modalities" in pretrained_args.model: + prenet_blocks = pretrained_args.model["modalities"]["image"]["prenet_depth"] + model_blocks = pretrained_args.model["depth"] + with open_dict(pretrained_args): + dpr = np.linspace(0, cfg.drop_path_rate, model_blocks).tolist() + pretrained_args.model["modalities"]["image"][ + "start_drop_path_rate" + ] = dpr[0] + pretrained_args.model["modalities"]["image"][ + "end_drop_path_rate" + ] = max(0, dpr[prenet_blocks - 1]) + pretrained_args.model["start_drop_path_rate"] = dpr[prenet_blocks] + pretrained_args.model["end_drop_path_rate"] = dpr[-1] + + if "mae_masking" in pretrained_args.model["modalities"]["image"]: + del pretrained_args.model["modalities"]["image"]["mae_masking"] + + if cfg.remove_alibi: + pretrained_args.model["modalities"]["image"][ + "use_alibi_encoder" + ] = False + if ( + state is not None + and "modality_encoders.IMAGE.alibi_bias" in state["model"] + ): + del state["model"]["modality_encoders.IMAGE.alibi_bias"] + + pretrained_args.model["encoder_dropout"] = cfg.encoder_dropout + pretrained_args.model["post_mlp_drop"] = cfg.post_mlp_drop + pretrained_args.model["attention_dropout"] = cfg.attention_dropout + pretrained_args.model["activation_dropout"] = cfg.activation_dropout + pretrained_args.model["dropout_input"] = cfg.dropout_input + pretrained_args.model["layerdrop"] = cfg.layerdrop + + pretrained_args.model["modalities"]["image"][ + "prenet_layerdrop" + ] = cfg.prenet_layerdrop + pretrained_args.model["modalities"]["image"][ + "prenet_dropout" + ] = cfg.prenet_dropout + + pretrained_args.model["modalities"]["image"]['target_length'] = cfg.target_length + else: + # not d2v multi + with open_dict(pretrained_args): + pretrained_args.model["drop_path_rate"] = cfg.drop_path_rate + pretrained_args.model["block_dropout"] = cfg.encoder_dropout + pretrained_args.model["attention_dropout"] = cfg.attention_dropout + pretrained_args.model["activation_dropout"] = cfg.activation_dropout + + task = tasks.setup_task(pretrained_args.task) + model = task.build_model(pretrained_args.model, from_checkpoint=True) + + self.d2v_multi = "data2vec_multi" in pretrained_args.model._name + self.linear_classifier = cfg.linear_classifier + + self.model = model + + # adjust position embedding for specific downstream task (due to different fixed clip length) + if state is not None and not cfg.no_pretrained_weights: + interpolate_pos_embed(model, state) + + if "modality_encoders.IMAGE.positional_encoder.pos_embed" in state["model"]: + state["model"][ + "modality_encoders.IMAGE.positional_encoder.positions" + ] = state["model"][ + "modality_encoders.IMAGE.positional_encoder.pos_embed" + ] + + + del state["model"][ + "modality_encoders.IMAGE.positional_encoder.pos_embed" + ] + if "modality_encoders.IMAGE.encoder_mask" in state["model"]: + del state["model"]["modality_encoders.IMAGE.encoder_mask"] + + if cfg.esc50_eval: + num_patches = 256 + embed_dim = 768 + pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) + emb = get_2d_sincos_pos_embed_flexible(pos_embed.shape[-1],(32,8),cls_token=False) + pos_embed.data.copy_(torch.from_numpy(emb[:num_patches,:]).float().unsqueeze(0)) + state['model']["modality_encoders.IMAGE.fixed_positional_encoder.positions"] = pos_embed + state['model']['_ema']["modality_encoders.IMAGE.fixed_positional_encoder.positions"] = pos_embed + + if cfg.spcv2_eval: + num_patches = 64 + embed_dim = 768 + pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim), requires_grad=False) + emb = get_2d_sincos_pos_embed_flexible(pos_embed.shape[-1],(8,8),cls_token=False) + pos_embed.data.copy_(torch.from_numpy(emb[:num_patches,:]).float().unsqueeze(0)) + state['model']["modality_encoders.IMAGE.fixed_positional_encoder.positions"] = pos_embed + state['model']['_ema']["modality_encoders.IMAGE.fixed_positional_encoder.positions"] = pos_embed + + + model.load_state_dict(state["model"], strict=True) + + if self.d2v_multi: + model.remove_pretraining_modules(modality="image") + else: + model.remove_pretraining_modules() + + if self.linear_classifier: + model.requires_grad_(False) + + self.fc_norm = None + if self.cfg.use_fc_norm: + self.fc_norm = nn.LayerNorm(pretrained_args.model.embed_dim, eps=1e-6) + nn.init.constant_(self.fc_norm.bias, 0) + nn.init.constant_(self.fc_norm.weight, 1.0) + + self.head = nn.Linear(pretrained_args.model.embed_dim, cfg.num_classes) + + nn.init.trunc_normal_(self.head.weight, std=0.02) + nn.init.constant_(self.head.bias, 0) + + self.mixup_fn = None + self.specaug = cfg.specaug + self.mask_ratio = cfg.mask_ratio + + # spectrogram mixup for fine-tuning + if cfg.mixup > 0 or cfg.cutmix > 0: + from ..utils.mixup import Mixup + + self.mixup_fn = Mixup( + mixup_alpha=cfg.mixup, + cutmix_alpha=cfg.cutmix, + cutmix_minmax=None, + prob=cfg.mixup_prob, + switch_prob=cfg.mixup_switch_prob, + mode=cfg.mixup_mode, + label_smoothing=cfg.label_smoothing, + num_classes=cfg.num_classes, + ) + + # specaug for fine-tuning, you could set mask_ratio = 0 to setup specific freqm and timem + if self.specaug: + self.freqm = cfg.freqm + self.timem = cfg.timem + + if self.mask_ratio != 0.0: + self.freqm = 128 * self.mask_ratio + self.timem = self.target_length * self.mask_ratio + + # group optimizer initialization with layer decay + if self.model.norm is not None: + for pn, p in self.model.norm.named_parameters(): + if len(p.shape) == 1 or pn.endswith(".bias"): + p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} + + if self.fc_norm is not None: + for pn, p in self.fc_norm.named_parameters(): + if len(p.shape) == 1 or pn.endswith(".bias"): + p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} + + for pn, p in self.head.named_parameters(): + if len(p.shape) == 1 or pn.endswith(".bias"): + p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} + + if self.d2v_multi: + mod_encs = list(model.modality_encoders.values()) + assert len(mod_encs) == 1, len(mod_encs) + blocks = list(mod_encs[0].context_encoder.blocks) + list(model.blocks) + else: + blocks = model.blocks + + num_layers = len(blocks) + 1 + layer_scales = list( + cfg.layer_decay ** (num_layers - i) for i in range(num_layers + 1) + ) + + if self.d2v_multi: + for n, p in self.model.named_parameters(): + optimizer_override_dict = {} + + if len(p.shape) == 1 or n.endswith(".bias"): + optimizer_override_dict["weight_decay_scale"] = 0 + + p.optim_overrides = {"optimizer": optimizer_override_dict} + + if cfg.layer_decay > 0: + for i, b in enumerate(blocks): + lid = i + 1 + if layer_scales[lid] == 1.0: + continue + + for n, p in b.named_parameters(): + optim_override = getattr(p, "optim_overrides", {}) + if "optimizer" not in optim_override: + optim_override["optimizer"] = {} + + if cfg.no_decay_blocks: + optim_override["optimizer"]["lr_scale"] = layer_scales[lid] + p.optim_overrides = optim_override + else: + optim_override["optimizer"] = { + "lr_scale": layer_scales[lid] + } + p.optim_overrides = optim_override + + else: + for n, p in self.model.named_parameters(): + optimizer_override_dict = {} + layer_id = get_layer_id_for_vit(n, num_layers) + + if len(p.shape) == 1 or n.endswith(".bias"): + optimizer_override_dict["weight_decay_scale"] = 0 + + if cfg.layer_decay > 0: + optimizer_override_dict["lr_scale"] = layer_scales[layer_id] + p.optim_overrides = {"optimizer": optimizer_override_dict} + + @classmethod + def build_model(cls, cfg: MaeImageClassificationConfig, task=FairseqTask): + """Build a new model instance.""" + assert hasattr(task, "labels"), f"Task {task} must have an attribute 'labels'" + + return cls(cfg) + + def forward( + self, + imgs, + label=None, + ): + labels = label + if self.training and self.mixup_fn is not None and labels is not None: + imgs, labels = self.mixup_fn(imgs, labels) + + if self.training and self.specaug: + imgs = self.spectrogram_augment(imgs) + + if self.linear_classifier: + with torch.no_grad(): + x = self.model_forward(imgs) + else: + x = self.model_forward(imgs) + + # different prediction mode + if self.cfg.prediction_mode == PredictionMode.MEAN_POOLING: + x = x.mean(dim=1) + elif self.cfg.prediction_mode == PredictionMode.CLS_TOKEN: + x = x[:, 0] + elif self.cfg.prediction_mode == PredictionMode.LIN_SOFTMAX: + dtype = x.dtype + x = F.logsigmoid(x.float()) + x = torch.logsumexp(x + x, dim=1) - torch.logsumexp(x + 1e-6, dim=1) + x = x.clamp(max=0) + x = x - torch.log(-(torch.expm1(x))) + x = torch.nan_to_num(x, nan=0, posinf=0, neginf=0) + x = x.to(dtype=dtype) + else: + raise Exception(f"unknown prediction mode {self.cfg.prediction_mode.name}") + + # layer norm and project + if self.fc_norm is not None: + x = self.fc_norm(x) + + x = self.head(x) + + if labels is None: + return x + + x = torch.nan_to_num(x) + + # logs for different downstream task ESC-50 && SPC-2 -> single label AS (AS2M,AS20K) -> multilabel + if not self.audio_mae or (self.audio_mae and (self.esc50_eval or self.spcv2_eval )): + if self.training and self.mixup_fn is not None and not self.spcv2_eval: + loss = -labels * F.log_softmax(x.float(), dim=-1) + + elif self.mixup_fn is not None and self.spcv2_eval: + loss = F.binary_cross_entropy_with_logits( + x, labels.float(), reduction="none" + ) + + else: + loss = F.cross_entropy( + x.float(), + labels, + label_smoothing=self.cfg.label_smoothing if self.training else 0, + reduction="none", + ) + + result = { + "losses": {"regression": loss}, + "sample_size": imgs.size(0), + } + + if not self.training: + with torch.no_grad(): + pred = x.argmax(-1) + labels = labels.argmax(-1) + correct = (pred == labels).sum() + result["correct"] = correct + + else: + loss = F.binary_cross_entropy_with_logits( + x, labels.float(), reduction="none" + ) + + result = { + "losses": { + "main": loss, + }, + "sample_size": labels.sum(), + } + + if not self.training: + result["_predictions"] = torch.sigmoid(x) + result["_targets"] = labels + + + return result + + def model_forward(self, imgs): + if self.d2v_multi: + x = self.model.extract_features( + imgs, + mode="IMAGE", + mask=False, + remove_extra_tokens=( + self.cfg.prediction_mode != PredictionMode.CLS_TOKEN + ), + )["x"] + else: + x = self.model(imgs, predictions_only=True) + if ( + "no_cls" not in self.model.cfg or not self.model.cfg.no_cls + ) and not self.cfg.prediction_mode == PredictionMode.CLS_TOKEN: + x = x[:, 1:] + return x + + # specaug + def spectrogram_augment(self,spec): + freq_masking = torchaudio.transforms.FrequencyMasking(self.freqm,iid_masks=True) + time_masking = torchaudio.transforms.TimeMasking(self.timem,iid_masks=True) + spec_ = spec.transpose(2,3) + input_with_freq_mask = freq_masking(spec_) + input_with_time_freq_mask = time_masking(input_with_freq_mask) + input_with_time_freq_mask = torch.transpose(input_with_time_freq_mask, 2, 3) + return input_with_time_freq_mask \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/EAT_pretraining.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/EAT_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..d44fa90bdc5d766d12af9a124c237b4a792cde41 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/EAT_pretraining.py @@ -0,0 +1,881 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist +import numpy as np + +from dataclasses import dataclass, field +from typing import Optional, Callable +from functools import partial +from omegaconf import II +from enum import Enum, auto +from fairseq.modules import EMAModule, EMAModuleConfig +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model + +import sys, os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from base import ( + MaskSeed, + D2vModalityConfig, + ModalitySpecificEncoder, + get_annealed_rate, +) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from modules import ( + D2vDecoderConfig, + AltBlock, + Decoder1d, +) +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from images import ( + D2vImageConfig, + ImageEncoder, +) + +logger = logging.getLogger(__name__) + +# we follow the work of data2vec 2.0 on image modality and Audio-MAE in EAT +class Modality(Enum): + AUDIO = auto() + IMAGE = auto() + TEXT = auto() + +@dataclass +class D2vModalitiesConfig(FairseqDataclass): + image: D2vImageConfig = D2vImageConfig() + +@dataclass +class Data2VecMultiConfig(FairseqDataclass): + + loss_beta: float = field( + default=0, metadata={"help": "beta for smooth l1 loss. 0 means use l2 loss"} + ) + + loss_scale: Optional[float] = field( + default=None, + metadata={ + "help": "scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)" + }, + ) + + depth: int = 12 + + # standard vision Transformer + start_drop_path_rate: float = 0 + end_drop_path_rate: float = 0 + num_heads: int = 12 + norm_eps: float = 1e-6 + norm_affine: bool = True + encoder_dropout: float = 0.1 + post_mlp_drop: float = 0.1 + attention_dropout: float = 0.1 + activation_dropout: float = 0.0 + dropout_input: float = 0.0 + layerdrop: float = 0.0 + embed_dim: int = 768 + mlp_ratio: float = 4 + layer_norm_first: bool = False + + # EAT averages all Transformer block output (12 layers in total) + average_top_k_layers: int = field( + default=12, metadata={"help": "how many layers to average"} + ) + + end_of_block_targets: bool = False + + # clone batch for multi-mask strategy + clone_batch: int = 16 + + # normalization for teacher Transformer layer output + layer_norm_target_layer: bool = False + batch_norm_target_layer: bool = False + instance_norm_target_layer: bool = False + instance_norm_targets: bool = False + layer_norm_targets: bool = False + + # EMA settings + ema_decay: float = field(default=0.999, metadata={"help": "initial ema decay rate"}) + ema_same_dtype: bool = True + log_norms: bool = True + ema_end_decay: float = field( + default=0.9999, metadata={"help": "final ema decay rate"} + ) + + ema_anneal_end_step: int = II("optimization.max_update") + + # In EAT, the Transformer encoder and the CNN encoder are both EMA updated + ema_encoder_only: bool = field( + default=True, + metadata={ + "help": "whether to momentum update only the shared transformer encoder" + }, + ) + + max_update: int = II("optimization.max_update") + + modalities: D2vModalitiesConfig = D2vModalitiesConfig() + + shared_decoder: Optional[D2vDecoderConfig] = None + + min_target_var: float = field( + default=0.1, metadata={"help": "stop training if target var falls below this"} + ) + min_pred_var: float = field( + default=0.01, + metadata={"help": "stop training if prediction var falls below this"}, + ) + + supported_modality: Optional[Modality] = None + mae_init: bool = False + + seed: int = II("common.seed") + + skip_ema: bool = False + + # d2v_loss is the frame-level loss while cls_loss is the utterance-level loss + cls_loss: float = 0 + recon_loss: float = 0 + d2v_loss: float = 1 + + decoder_group: bool = False + + # the experiment of using dino loss instead of direct utterance loss (not included in our paper) + utterance_level: bool = field(default=False, metadata={"help": "if true, we will add utterance-level loss to the total loss"}) + init_center_token_zero: bool = field(default=False, metadata={"help": "if true, we will initialize the centor token with zero vertors"}) + center_exp: float = field(default=0.9, metadata={"help": "this value control the exponent decay of center value's coefficient"}) + softmax_temperature_student: float = field(default=0.1, metadata={"help": "this value control the temperature of softmax function of student output in the dino loss"}) + softmax_temperature_teacher: float = field(default=0.05, metadata={"help": "this value control the temperature of softmax function in teacher output the dino loss"}) + + +@register_model("data2vec_multi", dataclass=Data2VecMultiConfig) +class Data2VecMultiModel(BaseFairseqModel): + def make_modality_encoder( + self, + cfg: D2vModalityConfig, + embed_dim: int, + make_block: Callable[[float], nn.ModuleList], + norm_layer: Callable[[int], nn.LayerNorm], + layer_norm_first: bool, + alibi_biases, + task, + ) -> ModalitySpecificEncoder: + if cfg.type.value == Modality.IMAGE.value: + enc_cls = ImageEncoder + else: + raise Exception(f"unsupported modality {cfg.type}") + + return enc_cls( + cfg, + embed_dim, + make_block, + norm_layer, + layer_norm_first, + alibi_biases, + task, + ) + + def __init__(self, cfg: Data2VecMultiConfig, modalities, skip_ema=False, task=None): + super().__init__() + self.cfg = cfg + self.modalities = modalities + self.task = task + + make_layer_norm = partial( + nn.LayerNorm, eps=cfg.norm_eps, elementwise_affine=cfg.norm_affine + ) + + def make_block(drop_path, dim=None, heads=None): + return AltBlock( + cfg.embed_dim if dim is None else dim, + cfg.num_heads if heads is None else heads, + cfg.mlp_ratio, + qkv_bias=True, + drop=cfg.encoder_dropout, + attn_drop=cfg.attention_dropout, + mlp_drop=cfg.activation_dropout, + post_mlp_drop=cfg.post_mlp_drop, + drop_path=drop_path, + norm_layer=make_layer_norm, + layer_norm_first=cfg.layer_norm_first, + ffn_targets=not cfg.end_of_block_targets, + ) + + self.alibi_biases = {} + self.modality_encoders = nn.ModuleDict() + + # extract CNN encoder and CNN decoder from modified data2vec image modality (see image.py) + for mod in self.modalities: + mod_cfg = getattr(cfg.modalities, mod.name.lower()) + enc = self.make_modality_encoder( + mod_cfg, + cfg.embed_dim, + make_block, + make_layer_norm, + cfg.layer_norm_first, + self.alibi_biases, + task, + ) + self.modality_encoders[mod.name] = enc + + self.ema = None + + self.average_top_k_layers = cfg.average_top_k_layers + self.loss_beta = cfg.loss_beta + self.loss_scale = cfg.loss_scale + self.utterance_level = cfg.utterance_level + + self.dropout_input = nn.Dropout(cfg.dropout_input) + + dpr = np.linspace(cfg.start_drop_path_rate, cfg.end_drop_path_rate, cfg.depth) + + self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) + + self.norm = None + if cfg.layer_norm_first: + self.norm = make_layer_norm(cfg.embed_dim) + + if self.cfg.mae_init: + self.apply(self._init_weights) + else: + from fairseq.modules.transformer_sentence_encoder import init_bert_params + + self.apply(init_bert_params) + + for mod_enc in self.modality_encoders.values(): + mod_enc.reset_parameters() + + # make teacher model + if not skip_ema: + self.ema = self.make_ema_teacher(cfg.ema_decay) + self.shared_decoder = ( + Decoder1d(cfg.shared_decoder, cfg.embed_dim) + if self.cfg.shared_decoder is not None + else None + ) + if self.shared_decoder is not None: + self.shared_decoder.apply(self._init_weights) + + self.recon_proj = None + if cfg.recon_loss > 0: + self.recon_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim//3) + + self.cls_proj = None + if cfg.utterance_level: + self.cls_proj = nn.Linear(cfg.embed_dim, cfg.embed_dim) + + for pn, p in self.named_parameters(): + if len(p.shape) == 1 or pn.endswith(".bias") or "alibi_scale" in pn: + p.optim_overrides = {"optimizer": {"weight_decay_scale": 0}} + if cfg.decoder_group and "decoder" in pn: + p.param_group = "decoder" + + # dino loss experiment + self.center = None + if self.utterance_level: + self.center_exp = cfg.center_exp + self.soft_tem_s = cfg.softmax_temperature_student + self.soft_tem_t = cfg.softmax_temperature_teacher + self.center = nn.Parameter( + torch.zeros(1, 1, cfg.embed_dim, requires_grad=False) + ) + if not cfg.init_center_token_zero: + nn.init.normal_(self.center) + elif self.center.size(1) > 1: + nn.init.normal_(self.center[:, 1:]) + + self.num_updates = 0 + + def _init_weights(self, m): + + try: + from apex.normalization import FusedLayerNorm + + fn = FusedLayerNorm + except: + fn = nn.LayerNorm + + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, fn): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + + @torch.no_grad() + def make_ema_teacher(self, ema_decay): + ema_config = EMAModuleConfig( + ema_decay=ema_decay, + ema_fp32=True, + log_norms=self.cfg.log_norms, + add_missing_params=False, + ) + + model_copy = self.make_target_model() + + return EMAModule( + model_copy, + ema_config, + copy_model=False, + ) + + # teacher model (with independent CNN encoder and Transformer encoder) + def make_target_model(self): + logger.info("making target model") + + model_copy = Data2VecMultiModel( + self.cfg, self.modalities, skip_ema=True, task=self.task + ) + + if self.cfg.ema_encoder_only: + model_copy = model_copy.blocks + for p_s, p_t in zip(self.blocks.parameters(), model_copy.parameters()): + p_t.data.copy_(p_s.data) + else: + for p_s, p_t in zip(self.parameters(), model_copy.parameters()): + p_t.data.copy_(p_s.data) + + for mod_enc in model_copy.modality_encoders.values(): + mod_enc.decoder = None + if not mod_enc.modality_cfg.ema_local_encoder: + mod_enc.local_encoder = None + mod_enc.project_features = None + + model_copy.requires_grad_(False) + return model_copy + + # teacher model updated with EMA + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + + if self.ema is not None and ( + (self.num_updates == 0 and num_updates > 1) + or self.num_updates >= num_updates + ): + pass + elif self.training and self.ema is not None: + ema_weight_decay = None + if self.cfg.ema_decay != self.cfg.ema_end_decay: + if num_updates >= self.cfg.ema_anneal_end_step: + decay = self.cfg.ema_end_decay + else: + decay = get_annealed_rate( + self.cfg.ema_decay, + self.cfg.ema_end_decay, + num_updates, + self.cfg.ema_anneal_end_step, + ) + self.ema.set_decay(decay, weight_decay=ema_weight_decay) + if self.ema.get_decay() < 1: + self.ema.step(self.blocks if self.cfg.ema_encoder_only else self) + + self.num_updates = num_updates + + def state_dict(self, destination=None, prefix="", keep_vars=False): + state = super().state_dict(destination, prefix, keep_vars) + + if self.ema is not None: + state[prefix + "_ema"] = self.ema.fp32_params + + return state + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + k = prefix + "_ema" + if self.ema is not None: + assert k in state_dict + self.ema.restore(state_dict[k], True) + del state_dict[k] + elif k in state_dict: + del state_dict[k] + + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + @classmethod + def build_model(cls, cfg: Data2VecMultiConfig, task=None): + """Build a new model instance.""" + if task is None or not hasattr(task, "supported_modalities"): + modalities = ( + [cfg.supported_modality] + if cfg.supported_modality is not None + else [ + Modality.AUDIO, + Modality.IMAGE, + Modality.TEXT, + ] + ) + else: + modalities = task.supported_modalities + return cls(cfg, modalities, task=task, skip_ema=cfg.skip_ema) + + def forward( + self, + source, + target=None, + id=None, + mode=None, + padding_mask=None, + mask=True, + features_only=False, + force_remove_masked=False, + remove_extra_tokens=True, + precomputed_mask=None, + ): + # print('source:', source.shape) + if mode is None: + assert self.cfg.supported_modality is not None + mode = self.cfg.supported_modality + + if isinstance(mode, Modality): + mode = mode.name + + feature_extractor = self.modality_encoders[mode] + + mask_seeds = None + if id is not None: + mask_seeds = MaskSeed(seed=self.cfg.seed, update=self.num_updates, ids=id) + + # extract (unmasked) features using CNN encoder + extractor_out = feature_extractor( + source, + padding_mask, + mask, + remove_masked=not features_only or force_remove_masked, + clone_batch=self.cfg.clone_batch if not features_only else 1, + mask_seeds=mask_seeds, + precomputed_mask=precomputed_mask, + ) + + # x in shape ( batch_size * clone batch, patch_frame(64) * patch_freqency(8) * unmask_ratio(0.2) + 1(cls_token), 768(feature dimension) ) + # EAT does not employ the ablibi mechanism in Transformer + x = extractor_out["x"] + encoder_mask = extractor_out["encoder_mask"] + masked_padding_mask = extractor_out["padding_mask"] + masked_alibi_bias = extractor_out.get("alibi_bias", None) + alibi_scale = extractor_out.get("alibi_scale", None) + + if self.dropout_input is not None: + x = self.dropout_input(x) + + # standard Transformer (for student encoder) + layer_results = [] + for i, blk in enumerate(self.blocks): + if ( + not self.training + or self.cfg.layerdrop == 0 + or (np.random.random() > self.cfg.layerdrop) + ): + ab = masked_alibi_bias + if ab is not None and alibi_scale is not None: + scale = ( + alibi_scale[i] + if alibi_scale.size(0) > 1 + else alibi_scale.squeeze(0) + ) + ab = ab * scale.type_as(ab) + + x, lr = blk( + x, + padding_mask=masked_padding_mask, + alibi_bias=ab, + ) + if features_only: + layer_results.append(lr) + + if self.norm is not None: + x = self.norm(x) + + # extract features for fine-tuning + if features_only: + if remove_extra_tokens: + x = x[:, feature_extractor.modality_cfg.num_extra_tokens :] + if masked_padding_mask is not None: + masked_padding_mask = masked_padding_mask[ + :, feature_extractor.modality_cfg.num_extra_tokens : + ] + + return { + "x": x, + "padding_mask": masked_padding_mask, + "layer_results": layer_results, + "mask": encoder_mask, + } + + # decode features merged with masked tokens, dx in shape (batch_size * clone_batch, patch, 768) + xs = [] + if self.shared_decoder is not None: + dx = self.forward_decoder( + x, + feature_extractor, + self.shared_decoder, + encoder_mask, + ) + xs.append(dx) + if feature_extractor.decoder is not None: + dx = self.forward_decoder( + x, + feature_extractor, + feature_extractor.decoder, + encoder_mask, + ) + xs.append(dx) + orig_x = x + + assert len(xs) > 0 + + p = next(self.ema.model.parameters()) + device = x.device + dtype = x.dtype + ema_device = p.device + ema_dtype = p.dtype + + if not self.cfg.ema_same_dtype: + dtype = ema_dtype + + if ema_device != device or ema_dtype != dtype: + logger.info(f"adjusting ema dtype to {dtype} and device to {device}") + self.ema.model = self.ema.model.to(dtype=dtype, device=device) + ema_dtype = dtype + + def to_device(d): + for k, p in d.items(): + if isinstance(d[k], dict): + to_device(d[k]) + else: + d[k] = p.to(device=device) + + to_device(self.ema.fp32_params) + tm = self.ema.model + + # encode audio spectrogram using teacher model + with torch.no_grad(): + tm.eval() + + if self.cfg.ema_encoder_only: + assert target is None + ema_input = extractor_out["local_features"] + ema_input = feature_extractor.contextualized_features( + ema_input.to(dtype=ema_dtype), + padding_mask, + mask=False, + remove_masked=False, + ) + ema_blocks = tm + else: + ema_blocks = tm.blocks + if feature_extractor.modality_cfg.ema_local_encoder: + inp = ( + target.to(dtype=ema_dtype) + if target is not None + else source.to(dtype=ema_dtype) + ) + ema_input = tm.modality_encoders[mode]( + inp, + padding_mask, + mask=False, + remove_masked=False, + ) + else: + assert target is None + ema_input = extractor_out["local_features"] + ema_feature_enc = tm.modality_encoders[mode] + ema_input = ema_feature_enc.contextualized_features( + ema_input.to(dtype=ema_dtype), + padding_mask, + mask=False, + remove_masked=False, + ) + + ema_padding_mask = ema_input["padding_mask"] + ema_alibi_bias = ema_input.get("alibi_bias", None) + ema_alibi_scale = ema_input.get("alibi_scale", None) + ema_input = ema_input["x"] + + # extract target features using teacher CNN encoder + # ema_input in shape (batch_size, patch + 1(cls_token), feature_dimension) + y = [] + ema_x = [] + extra_tokens = feature_extractor.modality_cfg.num_extra_tokens + for i, blk in enumerate(ema_blocks): + ab = ema_alibi_bias + if ab is not None and alibi_scale is not None: + scale = ( + ema_alibi_scale[i] + if ema_alibi_scale.size(0) > 1 + else ema_alibi_scale.squeeze(0) + ) + ab = ab * scale.type_as(ab) + + ema_input, lr = blk( + ema_input, + padding_mask=ema_padding_mask, + alibi_bias=ab, + ) + y.append(lr[:, extra_tokens:]) + ema_x.append(ema_input[:, extra_tokens:]) + + # EAT utilize total 12 Transformer block layer output average as target + y = self.make_targets(y, self.average_top_k_layers) + orig_targets = y + + # multiply the target value according to the number of clone batch + if self.cfg.clone_batch > 1: + y = y.repeat_interleave(self.cfg.clone_batch, 0) + + # extract values in masked position to make prediction + masked = encoder_mask.mask.unsqueeze(-1) + masked_b = encoder_mask.mask.bool() + y = y[masked_b] + + if xs[0].size(1) == masked_b.size(1): + xs = [x[masked_b] for x in xs] + else: + xs = [x.reshape(-1, x.size(-1)) for x in xs] + + + sample_size = masked.sum().long() + + result = { + "losses": {}, + "sample_size": sample_size, + } + + sample_size = result["sample_size"] + + # EAT employ utterance-level loss by using mean pooling in patch dimension + if self.cfg.cls_loss > 0 and not self.utterance_level: + assert extra_tokens > 0 + cls_target = orig_targets.mean(dim=1) + if self.cfg.clone_batch > 1: + cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0) + cls_pred = x[:, extra_tokens - 1] + + result["losses"]["cls"] = self.d2v_loss(cls_pred, cls_target) * ( + self.cfg.cls_loss * sample_size + ) + + # dino loss experiment + if self.cfg.cls_loss > 0 and self.utterance_level: + assert extra_tokens > 0 + cls_target = orig_targets.mean(dim=1) + if self.cfg.clone_batch > 1: + cls_target = cls_target.repeat_interleave(self.cfg.clone_batch, 0) #(btz*clone,1,768) + cls_pred = x[:, extra_tokens - 1] + cls_target = cls_target - self.center + + cls_pred = cls_pred.squeeze(dim=1) + cls_target = cls_target.squeeze(dim=1) + + result["losses"]["cls"] = self.dino_loss(cls_pred, cls_target) * ( + self.cfg.cls_loss * sample_size + ) + + self.center = self.center_exp * self.center + (1 - self.center_exp) * (cls_target.mean(dim=0)) + + if self.cfg.recon_loss > 0: + + with torch.no_grad(): + target = feature_extractor.patchify(source) #(btz,1,512,16*16) + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 #(btz,1,512,1) + + if self.cfg.clone_batch > 1: + target = target.repeat_interleave(self.cfg.clone_batch, 0) #(btz*clone_btz,1,512,1) + + if masked_b is not None: + target = target[masked_b] + + recon = xs[0] + if self.recon_proj is not None: + recon = self.recon_proj(recon) + + result["losses"]["recon"] = ( + self.d2v_loss(recon, target.float()) * self.cfg.recon_loss + ) + + if self.cfg.d2v_loss > 0: + for i, x in enumerate(xs): + reg_loss = self.d2v_loss(x, y) + n = f"{mode}_regression_{i}" if len(xs) > 1 else f"{mode}_regression" + result["losses"][n] = reg_loss * self.cfg.d2v_loss + + # compute state for logs + suffix = "" if len(self.modalities) == 1 else f"_{mode}" + with torch.no_grad(): + if encoder_mask is not None: + result["masked_pct"] = 1 - ( + encoder_mask.ids_keep.size(1) / encoder_mask.ids_restore.size(1) + ) + for i, x in enumerate(xs): + n = f"pred_var{suffix}_{i}" if len(xs) > 1 else f"pred_var{suffix}" + result[n] = self.compute_var(x.float()) + if self.ema is not None: + for k, v in self.ema.logs.items(): + result[k] = v + + y = y.float() + result[f"target_var{suffix}"] = self.compute_var(y) + + if self.num_updates > 5000: + if result[f"target_var{suffix}"] < self.cfg.min_target_var: + logger.error( + f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})" + ) + raise Exception( + f"target var is {result[f'target_var{suffix}'].item()} < {self.cfg.min_target_var}, exiting ({mode})" + ) + + for k in result.keys(): + if k.startswith("pred_var") and result[k] < self.cfg.min_pred_var: + logger.error( + f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})" + ) + raise Exception( + f"{k} is {result[k].item()} < {self.cfg.min_pred_var}, exiting ({mode})" + ) + + result["ema_decay"] = self.ema.get_decay() * 1000 + + return result + + def forward_decoder( + self, + x, + feature_extractor, + decoder, + mask_info, + ): + x = feature_extractor.decoder_input(x, mask_info) + x = decoder(*x) + + return x + + def d2v_loss(self, x, y): + x = x.view(-1, x.size(-1)).float() + y = y.view(-1, x.size(-1)) + + if self.loss_beta == 0: + loss = F.mse_loss(x, y, reduction="none") + else: + loss = F.smooth_l1_loss(x, y, reduction="none", beta=self.loss_beta) + + if self.loss_scale is not None: + scale = self.loss_scale + else: + scale = 1 / math.sqrt(x.size(-1)) + + reg_loss = loss * scale + + return reg_loss + + def dino_loss(self,s,t): + t = t.detach() + s = F.softmax(s/self.soft_tem_s,dim=1) + t = F.softmax((t-self.center)/self.soft_tem_t,dim=1) + return - (t * torch.log(s)).sum(dim=1).mean() + + # average top-k layers output from teacher model + def make_targets(self, y, num_layers): + + with torch.no_grad(): + target_layer_results = y[-num_layers:] + + permuted = False + if self.cfg.instance_norm_target_layer or self.cfg.batch_norm_target_layer: + target_layer_results = [ + tl.transpose(1, 2) for tl in target_layer_results # BTC -> BCT + ] + permuted = True + if self.cfg.batch_norm_target_layer: + target_layer_results = [ + F.batch_norm( + tl.float(), running_mean=None, running_var=None, training=True + ) + for tl in target_layer_results + ] + if self.cfg.instance_norm_target_layer: + target_layer_results = [ + F.instance_norm(tl.float()) for tl in target_layer_results + ] + if permuted: + target_layer_results = [ + tl.transpose(1, 2) for tl in target_layer_results # BCT -> BTC + ] + if self.cfg.layer_norm_target_layer: + target_layer_results = [ + F.layer_norm(tl.float(), tl.shape[-1:]) + for tl in target_layer_results + ] + + y = target_layer_results[0].float() + for tl in target_layer_results[1:]: + y.add_(tl.float()) + y = y.div_(len(target_layer_results)) + + if self.cfg.layer_norm_targets: + y = F.layer_norm(y, y.shape[-1:]) + + if self.cfg.instance_norm_targets: + y = F.instance_norm(y.transpose(1, 2)).transpose(1, 2) + + return y + + @staticmethod + def compute_var(y): + y = y.view(-1, y.size(-1)) + if dist.is_initialized(): + zc = torch.tensor(y.size(0)).cuda() + zs = y.sum(dim=0) + zss = (y**2).sum(dim=0) + + dist.all_reduce(zc) + dist.all_reduce(zs) + dist.all_reduce(zss) + + var = zss / (zc - 1) - (zs**2) / (zc * (zc - 1)) + return torch.sqrt(var + 1e-6).mean() + else: + return torch.sqrt(y.var(dim=0) + 1e-6).mean() + + def extract_features( + self, source, mode=None, padding_mask=None, mask=False, remove_extra_tokens=True + ): + res = self.forward( + source, + mode=mode, + padding_mask=padding_mask, + mask=mask, + features_only=True, + remove_extra_tokens=remove_extra_tokens, + ) + return res + + def remove_pretraining_modules(self, modality=None, keep_decoder=False): + self.ema = None + self.cfg.clone_batch = 1 + self.recon_proj = None + + if not keep_decoder: + self.shared_decoder = None + + modality = modality.lower() if modality is not None else None + for k in list(self.modality_encoders.keys()): + if modality is not None and k.lower() != modality: + del self.modality_encoders[k] + else: + self.modality_encoders[k].remove_pretraining_modules( + keep_decoder=keep_decoder + ) + if not keep_decoder: + self.modality_encoders[k].decoder = None diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/__init__.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9aa01ce7af8b3e2b9e9651b5ce384cca1bf0d64 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/__init__.py @@ -0,0 +1,6 @@ +try: + from .EAT_pretraining import * +except: + import sys, os + sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '.')) + from EAT_pretraining import * \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/base.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/base.py new file mode 100644 index 0000000000000000000000000000000000000000..4c70d7fc083f0763d7c4d1a089e45649fea87f85 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/base.py @@ -0,0 +1,693 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import sys, os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from modules import D2vDecoderConfig + +import logging +import math +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from collections import namedtuple +from dataclasses import dataclass +from functools import partial +from omegaconf import MISSING, II +from typing import Optional, Callable +from fairseq.data.data_utils import compute_mask_indices +from fairseq.modules import GradMultiply +from fairseq.utils import index_put +from enum import Enum, auto + + +logger = logging.getLogger(__name__) + +class Modality(Enum): + AUDIO = auto() + IMAGE = auto() + TEXT = auto() + +@dataclass +class D2vModalityConfig: + type: Modality = MISSING + prenet_depth: int = 4 + prenet_layerdrop: float = 0 + prenet_dropout: float = 0 + start_drop_path_rate: float = 0 + end_drop_path_rate: float = 0 + + num_extra_tokens: int = 0 + init_extra_token_zero: bool = True + + mask_noise_std: float = 0.01 + mask_prob_min: Optional[float] = None + mask_prob: float = 0.7 + inverse_mask: bool = False + mask_prob_adjust: float = 0 + keep_masked_pct: float = 0 + + mask_length: int = 5 + add_masks: bool = False + remove_masks: bool = False + mask_dropout: float = 0.0 + encoder_zero_mask: bool = True + + mask_channel_prob: float = 0.0 + mask_channel_length: int = 64 + + ema_local_encoder: bool = False # used in data2vec_multi + local_grad_mult: float = 1.0 + + use_alibi_encoder: bool = False + alibi_scale: float = 1.0 + learned_alibi: bool = False + alibi_max_pos: Optional[int] = None + learned_alibi_scale: bool = False + learned_alibi_scale_per_head: bool = False + learned_alibi_scale_per_layer: bool = False + + num_alibi_heads: int = II("model.num_heads") + model_depth: int = II("model.depth") + + decoder: Optional[D2vDecoderConfig] = D2vDecoderConfig() + + +MaskSeed = namedtuple("MaskSeed", ["seed", "update", "ids"]) +MaskInfo = namedtuple("MaskInfo", ["x_unmasked", "mask", "ids_restore", "ids_keep"]) + + +class ModalitySpecificEncoder(nn.Module): + def __init__( + self, + modality_cfg: D2vModalityConfig, + embed_dim: int, + local_encoder: nn.Module, + project_features: nn.Module, + fixed_positional_encoder: Optional[nn.Module], + relative_positional_encoder: Optional[nn.Module], + context_encoder: nn.Module, + decoder: nn.Module, + get_alibi_bias: Optional[Callable[[int, int, str, str], torch.Tensor]], + ): + super().__init__() + + self.modality_cfg = modality_cfg + self.local_encoder = local_encoder + self.project_features = project_features + self.fixed_positional_encoder = fixed_positional_encoder + self.relative_positional_encoder = relative_positional_encoder + self.context_encoder = context_encoder + + self.decoder = decoder + self.get_alibi_bias = get_alibi_bias if modality_cfg.use_alibi_encoder else None + + self.local_grad_mult = self.modality_cfg.local_grad_mult + + self.extra_tokens = None + if modality_cfg.num_extra_tokens > 0: + self.extra_tokens = nn.Parameter( + torch.zeros(1, modality_cfg.num_extra_tokens, embed_dim) + ) + if not modality_cfg.init_extra_token_zero: + nn.init.normal_(self.extra_tokens) + elif self.extra_tokens.size(1) > 1: + nn.init.normal_(self.extra_tokens[:, 1:]) + + self.alibi_scale = None + if self.get_alibi_bias is not None: + self.alibi_scale = nn.Parameter( + torch.full( + ( + (modality_cfg.prenet_depth + modality_cfg.model_depth) + if modality_cfg.learned_alibi_scale_per_layer + else 1, + 1, + self.modality_cfg.num_alibi_heads + if modality_cfg.learned_alibi_scale_per_head + else 1, + 1, + 1, + ), + modality_cfg.alibi_scale, + dtype=torch.float, + ), + requires_grad=modality_cfg.learned_alibi_scale, + ) + + if modality_cfg.learned_alibi and self.get_alibi_bias is not None: + assert modality_cfg.alibi_max_pos is not None + alibi_bias = self.get_alibi_bias( + batch_size=1, + time_steps=modality_cfg.alibi_max_pos, + heads=modality_cfg.num_alibi_heads, + scale=1.0, + dtype=torch.float, + device="cpu", + ) + self.alibi_bias = nn.Parameter(alibi_bias) + self.get_alibi_bias = partial( + _learned_alibi_bias, alibi_bias=self.alibi_bias + ) + + def upgrade_state_dict_named(self, state_dict, name): + k = f"{name}.alibi_scale" + if k in state_dict and state_dict[k].dim() == 4: + state_dict[k] = state_dict[k].unsqueeze(0) + + return state_dict + + def convert_padding_mask(self, x, padding_mask): + return padding_mask + + def decoder_input(self, x, mask_info: MaskInfo): + inp_drop = self.modality_cfg.decoder.input_dropout + if inp_drop > 0: + x = F.dropout(x, inp_drop, training=self.training, inplace=True) + + num_extra = self.modality_cfg.num_extra_tokens + + if mask_info is not None: + num_masked = mask_info.ids_restore.shape[1] - x.shape[1] + num_extra + + mask_tokens = x.new_empty( + x.size(0), + num_masked, + x.size(-1), + ).normal_(0, self.modality_cfg.mask_noise_std) + + x_ = torch.cat([x[:, num_extra:], mask_tokens], dim=1) + x = torch.gather(x_, dim=1, index=mask_info.ids_restore) + + if self.modality_cfg.decoder.add_positions_masked: + assert self.fixed_positional_encoder is not None + pos = self.fixed_positional_encoder(x, None) + x = x + (pos * mask_info.mask.unsqueeze(-1)) + else: + x = x[:, num_extra:] + + if self.modality_cfg.decoder.add_positions_all: + assert self.fixed_positional_encoder is not None + x = x + self.fixed_positional_encoder(x, None) + + return x, mask_info + + def local_features(self, features): + if self.local_grad_mult > 0: + if self.local_grad_mult == 1.0: + x = self.local_encoder(features) + else: + x = GradMultiply.apply( + self.local_encoder(features), self.local_grad_mult + ) + else: + with torch.no_grad(): + x = self.local_encoder(features) + + x = self.project_features(x) + return x + + def contextualized_features( + self, + x, + padding_mask, + mask, + remove_masked, + clone_batch: int = 1, + mask_seeds: Optional[torch.Tensor] = None, + precomputed_mask=None, + ): + + if padding_mask is not None: + padding_mask = self.convert_padding_mask(x, padding_mask) + + local_features = x + if mask and clone_batch == 1: + local_features = local_features.clone() + + orig_B, orig_T, _ = x.shape + pre_mask_B = orig_B + mask_info = None + + x_pos = None + if self.fixed_positional_encoder is not None: + x = x + self.fixed_positional_encoder(x, padding_mask) + + if mask: + if clone_batch > 1: + x = x.repeat_interleave(clone_batch, 0) + if mask_seeds is not None: + clone_hash = [ + int(hash((mask_seeds.seed, ind)) % 1e10) + for ind in range(clone_batch - 1) + ] + clone_hash = torch.tensor([0] + clone_hash).long().view(1, -1) + + id = mask_seeds.ids + id = id.repeat_interleave(clone_batch, 0) + id = id.view(-1, clone_batch) + clone_hash.to(id) + id = id.view(-1) + mask_seeds = MaskSeed( + seed=mask_seeds.seed, update=mask_seeds.update, ids=id + ) + if padding_mask is not None: + padding_mask = padding_mask.repeat_interleave(clone_batch, 0) + + x, mask_info = self.compute_mask( + x, + padding_mask, + mask_seed=mask_seeds, + apply=self.relative_positional_encoder is not None or not remove_masked, + precomputed_mask=precomputed_mask, + ) + + if self.relative_positional_encoder is not None: + x_pos = self.relative_positional_encoder(x) + + masked_padding_mask = padding_mask + if mask and remove_masked: + x = mask_info.x_unmasked + if x_pos is not None: + x = x + gather_unmasked(x_pos, mask_info) + + if padding_mask is not None and padding_mask.any(): + masked_padding_mask = gather_unmasked_mask(padding_mask, mask_info) + if not masked_padding_mask.any(): + masked_padding_mask = None + else: + masked_padding_mask = None + + elif x_pos is not None: + x = x + x_pos + + alibi_bias = None + alibi_scale = self.alibi_scale + + if self.get_alibi_bias is not None: + alibi_bias = self.get_alibi_bias( + batch_size=pre_mask_B, + time_steps=orig_T, + heads=self.modality_cfg.num_alibi_heads, + dtype=torch.float32, + device=x.device, + ) + + if alibi_scale is not None: + alibi_scale = alibi_scale.clamp_min(0) + if alibi_scale.size(0) == 1: + alibi_bias = alibi_bias * alibi_scale.squeeze(0).type_as(alibi_bias) + alibi_scale = None + + if clone_batch > 1: + alibi_bias = alibi_bias.repeat_interleave(clone_batch, 0) + + if mask_info is not None and remove_masked: + alibi_bias = masked_alibi(alibi_bias, mask_info) + + if self.extra_tokens is not None: + num = self.extra_tokens.size(1) + x = torch.cat([self.extra_tokens.expand(x.size(0), -1, -1), x], dim=1) + if masked_padding_mask is not None: + # B x T + masked_padding_mask = F.pad(masked_padding_mask, (num, 0)) + if alibi_bias is not None: + # B x H x T x T + alibi_bias = F.pad(alibi_bias, (num, 0, num, 0)) + + x = self.context_encoder( + x, + masked_padding_mask, + alibi_bias, + alibi_scale[: self.modality_cfg.prenet_depth] + if alibi_scale is not None + else None, + ) + + return { + "x": x, + "local_features": local_features, + "padding_mask": masked_padding_mask, + "alibi_bias": alibi_bias, + "alibi_scale": alibi_scale[self.modality_cfg.prenet_depth :] + if alibi_scale is not None and alibi_scale.size(0) > 1 + else alibi_scale, + "encoder_mask": mask_info, + } + + def forward( + self, + features, + padding_mask, + mask: bool, + remove_masked: bool, + clone_batch: int = 1, + mask_seeds: Optional[torch.Tensor] = None, + precomputed_mask=None, + ): + x = self.local_features(features) + return self.contextualized_features( + x, + padding_mask, + mask, + remove_masked, + clone_batch, + mask_seeds, + precomputed_mask, + ) + + def reset_parameters(self): + pass + + def compute_mask( + self, + x, + padding_mask, + mask_seed: Optional[MaskSeed], + apply, + precomputed_mask, + ): + if precomputed_mask is not None: + mask = precomputed_mask + mask_info = self.make_maskinfo(x, mask) + else: + B, T, C = x.shape + cfg = self.modality_cfg + + mask_prob = cfg.mask_prob + + if ( + cfg.mask_prob_min is not None + and cfg.mask_prob_min >= 0 + and cfg.mask_prob_min < mask_prob + ): + mask_prob = np.random.uniform(cfg.mask_prob_min, mask_prob) + + if mask_prob > 0: + if cfg.mask_length == 1: + mask_info = random_masking(x, mask_prob, mask_seed) + else: + if self.modality_cfg.inverse_mask: + mask_prob = 1 - mask_prob + + mask = compute_mask_indices( + (B, T), + padding_mask, + mask_prob, + cfg.mask_length, + min_masks=1, + require_same_masks=True, + mask_dropout=cfg.mask_dropout, + add_masks=cfg.add_masks, + seed=mask_seed.seed if mask_seed is not None else None, + epoch=mask_seed.update if mask_seed is not None else None, + indices=mask_seed.ids if mask_seed is not None else None, + ) + + mask = torch.from_numpy(mask).to(device=x.device) + if self.modality_cfg.inverse_mask: + mask = 1 - mask + mask_info = self.make_maskinfo(x, mask) + else: + mask_info = None + + if apply: + x = self.apply_mask(x, mask_info) + + return x, mask_info + + + def make_maskinfo(self, x, mask, shape=None): + if shape is None: + B, T, D = x.shape + else: + B, T, D = shape + + mask = mask.to(torch.uint8) + ids_shuffle = mask.argsort(dim=1) + ids_restore = ids_shuffle.argsort(dim=1).unsqueeze(-1).expand(-1, -1, D) + + len_keep = T - mask[0].sum() + if self.modality_cfg.keep_masked_pct > 0: + len_keep += round((T - int(len_keep)) * self.modality_cfg.keep_masked_pct) + + ids_keep = ids_shuffle[:, :len_keep] + + if shape is not None: + x_unmasked = None + else: + ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D) + x_unmasked = torch.gather(x, dim=1, index=ids_keep) + + mask_info = MaskInfo( + x_unmasked=x_unmasked, + mask=mask, + ids_restore=ids_restore, + ids_keep=ids_keep, + ) + return mask_info + + def apply_mask(self, x, mask_info): + cfg = self.modality_cfg + B, T, C = x.shape + + if mask_info is not None: + mask = mask_info.mask + if cfg.encoder_zero_mask: + x = x * (1 - mask.type_as(x).unsqueeze(-1)) + else: + num_masks = mask.sum().item() + masks = x.new_empty(num_masks, x.size(-1)).normal_( + 0, cfg.mask_noise_std + ) + x = index_put(x, mask, masks) + if cfg.mask_channel_prob > 0: + mask_channel = compute_mask_indices( + (B, C), + None, + cfg.mask_channel_prob, + cfg.mask_channel_length, + ) + mask_channel = ( + torch.from_numpy(mask_channel) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x = index_put(x, mask_channel, 0) + return x + + def remove_pretraining_modules(self, keep_decoder=False): + if not keep_decoder: + self.decoder = None + + +def get_annealed_rate(start, end, curr_step, total_steps): + if curr_step >= total_steps: + return end + r = end - start + pct_remaining = 1 - curr_step / total_steps + return end - r * pct_remaining + + +# adapted from MAE +def random_masking(x, mask_ratio, mask_seed: Optional[MaskSeed]): + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + generator = None + if mask_seed is not None: + seed = int( + hash((mask_seed.seed, mask_seed.update, mask_seed.ids.sum().item())) % 1e6 + ) + generator = torch.Generator(device=x.device) + generator.manual_seed(seed) + + noise = torch.rand(N, L, generator=generator, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = noise.argsort(dim=1) # ascend: small is keep, large is remove + ids_restore = ids_shuffle.argsort(dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + ids_keep = ids_keep.unsqueeze(-1).expand(-1, -1, D) + x_unmasked = torch.gather(x, dim=1, index=ids_keep) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], dtype=x.dtype, device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + ids_restore = ids_restore.unsqueeze(-1).expand(-1, -1, D) + + return MaskInfo( + x_unmasked=x_unmasked, mask=mask, ids_restore=ids_restore, ids_keep=ids_keep + ) + + +def gather_unmasked(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor: + return torch.gather( + x, + dim=1, + index=mask_info.ids_keep, + ) + + +def gather_unmasked_mask(x: torch.Tensor, mask_info: MaskInfo) -> torch.Tensor: + return torch.gather( + x, + dim=1, + index=mask_info.ids_keep[..., 0], # ignore the feature dimension + ) + + +def get_alibi( + max_positions: int, + attention_heads: int, + dims: int = 1, + distance: str = "manhattan", +): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + # In the paper, we only train models that have 2^a heads for some + # a. This function has some good properties that only occur when + # the input is a power of 2. To maintain that even when the number + # of heads is not a power of 2, we use this workaround. + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + maxpos = max_positions + attn_heads = attention_heads + slopes = torch.Tensor(get_slopes(attn_heads)) + + if dims == 1: + # prepare alibi position linear bias. Note that wav2vec2 is non + # autoregressive model so we want a symmetric mask with 0 on the + # diagonal and other wise linear decreasing valuees + pos_bias = ( + torch.abs( + torch.arange(maxpos).unsqueeze(0) - torch.arange(maxpos).unsqueeze(1) + ) + * -1 + ) + elif dims == 2: + if distance == "manhattan": + df = lambda x1, y1, x2, y2: abs(x1 - x2) + abs(y1 - y2) + elif distance == "euclidean": + df = lambda x1, y1, x2, y2: math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) + + n = math.sqrt(max_positions) + assert n.is_integer(), n + n = int(n) + + pos_bias = torch.zeros((max_positions, max_positions)) + + for i in range(n): + for j in range(n): + for k in range(n): + for l in range(n): + new_x = i * n + j + new_y = k * n + l + pos_bias[new_x, new_y] = -df(i, j, k, l) + + else: + raise Exception(f"unsupported number of alibi dims: {dims}") + + alibi_bias = slopes.unsqueeze(1).unsqueeze(1) * pos_bias.unsqueeze(0).expand( + attn_heads, -1, -1 + ) + + return alibi_bias + + +def get_alibi_bias( + alibi_biases, + batch_size, + time_steps, + heads, + dtype, + device, + dims=1, + distance="manhattan", +): + cache_key = f"{dims}_{heads}_{distance}" + + buffered = alibi_biases.get(cache_key, None) + + target_size = heads * batch_size + if ( + buffered is None + or buffered.size(0) < target_size + or buffered.size(1) < time_steps + or buffered.dtype != dtype + or buffered.device != device + ): + bt = max(time_steps, buffered.size(1) if buffered is not None else 0) + bn = max(target_size, buffered.size(0) if buffered is not None else 0) // heads + + buffered = ( + get_alibi(bt, heads, dims=dims, distance=distance) + .to(dtype=dtype, device=device) + .repeat(bn, 1, 1) + ) + + alibi_biases[cache_key] = buffered + + b = buffered[:target_size, :time_steps, :time_steps] + b = b.view(batch_size, heads, time_steps, time_steps) + return b + + +def _learned_alibi_bias( + alibi_bias, + batch_size, + time_steps, + heads, + scale, + dtype, + device, +): + assert alibi_bias.size(1) == heads, alibi_bias.shape + assert alibi_bias.dtype == dtype, alibi_bias.dtype + assert alibi_bias.device == device, alibi_bias.device + + if alibi_bias.size(-1) < time_steps: + psz = math.ceil((time_steps - alibi_bias.size(-1)) / 2) + alibi_bias = F.pad(alibi_bias, (psz, psz, psz, psz), mode="replicate") + + alibi_bias = alibi_bias.expand(batch_size, -1, -1, -1) * scale + return alibi_bias[..., :time_steps, :time_steps] + + +def masked_alibi(alibi_bias, mask_info): + H = alibi_bias.size(1) + + orig_bias = alibi_bias + + index = mask_info.ids_keep.unsqueeze(1)[..., 0].unsqueeze(-1) + alibi_bias = torch.gather( + orig_bias, + dim=-2, + index=index.expand(-1, H, -1, mask_info.ids_restore.size(1)), + ) + alibi_bias = torch.gather( + alibi_bias, + dim=-1, + index=index.transpose(-1, -2).expand(-1, H, alibi_bias.size(-2), -1), + ) + + return alibi_bias diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/images.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/images.py new file mode 100644 index 0000000000000000000000000000000000000000..e2892025b1b90d8d4dd65f61969c3abbf61f38be --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/images.py @@ -0,0 +1,293 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import sys, os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from functools import partial +from dataclasses import dataclass +from typing import Callable, Dict, Optional +try: + from timm.models.layers import to_2tuple +except: + to_2tuple = None +from fairseq.tasks import FairseqTask +from enum import Enum, auto + +from mae import PatchEmbed,get_2d_sincos_pos_embed_flexible + + +from base import ( + D2vModalityConfig, + ModalitySpecificEncoder, + get_alibi_bias, + MaskSeed, +) +from modules import ( + BlockEncoder, + Decoder2d, + FixedPositionalEncoder, + TransformerDecoder, + EncDecTransformerDecoder, +) + + +class Modality(Enum): + AUDIO = auto() + IMAGE = auto() + TEXT = auto() + + +@dataclass +class D2vImageConfig(D2vModalityConfig): + type: Modality = Modality.IMAGE + + input_size: int = 224 + in_chans: int = 3 + patch_size: int = 16 + embed_dim: int = 768 + + alibi_dims: int = 2 + alibi_distance: str = "manhattan" + + fixed_positions: bool = True + + transformer_decoder: bool = False + enc_dec_transformer: bool = False + target_length: int = 1024 + + +class ImageEncoder(ModalitySpecificEncoder): + + modality_cfg: D2vImageConfig + + def __init__( + self, + modality_cfg: D2vImageConfig, + embed_dim: int, + make_block: Callable[[float, Optional[int], Optional[int]], nn.ModuleList], + norm_layer: Callable[[int], nn.LayerNorm], + layer_norm_first: bool, + alibi_biases: Dict, + task: Optional[FairseqTask], + ): + + if modality_cfg.in_chans == 1 : + img_size = (modality_cfg.target_length,128) + else: + img_size = to_2tuple(modality_cfg.input_size) + + patch_size = to_2tuple(modality_cfg.patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) # number of patch -> 512 + self.H = img_size[0] // patch_size[0] # 64 + self.W = img_size[1] // patch_size[1] # 8 + self.hw = (self.H,self.W) + + # (B,512,768) + local_encoder = PatchEmbed( + img_size, + modality_cfg.patch_size, + modality_cfg.in_chans, + modality_cfg.embed_dim, + ) + + # CNN initialize + w = local_encoder.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + if modality_cfg.embed_dim != embed_dim: + local_encoder = nn.Sequential( + local_encoder, + nn.Linear(modality_cfg.embed_dim, embed_dim), + ) + + project_features = nn.Identity() + + pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim), requires_grad=False + ) + + # side_n = int(num_patches ** 0.5) + emb = get_2d_sincos_pos_embed_flexible( + pos_embed.shape[-1], + self.hw, + cls_token=False, + ) + + pos_embed.data.copy_(torch.from_numpy(emb[:num_patches,:]).float().unsqueeze(0)) + fixed_positional_encoder = ( + FixedPositionalEncoder(pos_embed) if modality_cfg.fixed_positions else None + ) + + dpr = np.linspace( + modality_cfg.start_drop_path_rate, + modality_cfg.end_drop_path_rate, + modality_cfg.prenet_depth, + ) + + context_encoder = BlockEncoder( + nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), + norm_layer(embed_dim) if not layer_norm_first else None, + layer_norm_first, + modality_cfg.prenet_layerdrop, + modality_cfg.prenet_dropout, + ) + + # EAT utilize the CNN decoder + if modality_cfg.transformer_decoder: + if modality_cfg.enc_dec_transformer: + decoder = EncDecTransformerDecoder(modality_cfg.decoder, embed_dim) + else: + dec_enc = BlockEncoder( + nn.ModuleList( + make_block(0, modality_cfg.decoder.decoder_dim, 8) + for _ in range(modality_cfg.decoder.decoder_layers) + ), + None, + layer_norm_first, + 0, + 0, + ) + decoder = TransformerDecoder(modality_cfg.decoder, embed_dim, dec_enc) + else: + decoder = ( + Decoder2d(modality_cfg.decoder, embed_dim, self.H, self.W) + if modality_cfg.decoder is not None + else None + ) + + alibi_bias_fn = partial( + get_alibi_bias, + alibi_biases=alibi_biases, + heads=modality_cfg.num_alibi_heads, + dims=modality_cfg.alibi_dims, + distance=modality_cfg.alibi_distance, + ) + + super().__init__( + modality_cfg=modality_cfg, + embed_dim=embed_dim, + local_encoder=local_encoder, + project_features=project_features, + fixed_positional_encoder=fixed_positional_encoder, + relative_positional_encoder=None, + context_encoder=context_encoder, + decoder=decoder, + get_alibi_bias=alibi_bias_fn, + ) + + def reset_parameters(self): + super().reset_parameters() + if self.decoder is not None: + self.decoder.reset_parameters() + + @torch.no_grad() + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) audio: (N,1,H,W) 1024/16 = 64 128/16 = 8 + x: (N, L, patch_size**2 *3) + """ + if self.modality_cfg.in_chans == 1: + p = self.modality_cfg.patch_size + h = imgs.shape[2] // p + w = imgs.shape[3] // p + #h,w = self.patch_embed.patch_hw + x = imgs.reshape(shape=(imgs.shape[0], 1, h, p, w, p)) + x = torch.einsum('nchpwq->nhwpqc', x) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 1)) + + else: + p = self.modality_cfg.patch_size + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) + + return x + + @torch.no_grad() + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.modality_cfg.patch_size + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def compute_mask( + self, + x, + padding_mask, + mask_seed: Optional[MaskSeed], + apply, + shape=None, + precomputed_mask=None, + ): + mlen = self.modality_cfg.mask_length + if mlen <= 1: + return super().compute_mask( + x, padding_mask, mask_seed, apply, precomputed_mask + ) + + if precomputed_mask is not None: + mask = precomputed_mask + else: + from ..utils.data_utils import compute_block_mask_2d + + if shape is not None: + B, L, D = shape + else: + B, L, D = x.shape + + mask = compute_block_mask_2d( + shape=(B, L), + mask_prob=self.modality_cfg.mask_prob, + mask_length=self.modality_cfg.mask_length, + mask_prob_adjust=self.modality_cfg.mask_prob_adjust, + inverse_mask=self.modality_cfg.inverse_mask, + require_same_masks=True, + mask_dropout=self.modality_cfg.mask_dropout, + img_shape=self.hw + ) + + + mask_info = self.make_maskinfo(x, mask, shape) + if apply: + x = self.apply_mask(x, mask_info) + + return x, mask_info + + def decoder_input(self, x, mask_info): + if ( + not self.modality_cfg.transformer_decoder + or not self.modality_cfg.enc_dec_transformer + ): + return super().decoder_input(x, mask_info) + + inp_drop = self.modality_cfg.decoder.input_dropout + if inp_drop > 0: + x = F.dropout(x, inp_drop, training=self.training, inplace=True) + + kv = x[:, self.modality_cfg.num_extra_tokens :] + + assert self.fixed_positional_encoder is not None + pos = self.fixed_positional_encoder(x, None).expand(x.size(0), -1, -1) + + mask = mask_info.mask.bool() + if self.modality_cfg.decoder.add_positions_all: + kv = kv + pos[~mask].view(kv.shape) + + q = pos[mask].view(x.size(0), -1, x.size(-1)) + + return q, kv diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/mae.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/mae.py new file mode 100644 index 0000000000000000000000000000000000000000..2d6ef21168d4d8db11111c7c363e33917a34119a --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/mae.py @@ -0,0 +1,846 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# The code in this file is adapted from the BeiT implementation which can be found here: +# https://github.com/microsoft/unilm/tree/master/beit + +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from dataclasses import dataclass +from functools import partial +try: + from timm.models.vision_transformer import PatchEmbed, Block +except: + PatchEmbed, Block = None, None +from fairseq.dataclass import FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.wav2vec.wav2vec2 import TransformerSentenceEncoderLayer + +try: + from apex.normalization import FusedLayerNorm +except: + FusedLayerNorm = nn.LayerNorm + + + +logger = logging.getLogger(__name__) + + +@dataclass +class MaeConfig(FairseqDataclass): + input_size: int = 224 + in_chans: int = 3 + patch_size: int = 16 + embed_dim: int = 768 + depth: int = 12 + num_heads: int = 12 + decoder_embed_dim: int = 512 + decoder_depth: int = 8 + decoder_num_heads: int = 16 + mlp_ratio: int = 4 + norm_eps: float = 1e-6 + + drop_path_rate: float = 0.0 + + mask_ratio: float = 0.75 + norm_pix_loss: bool = True + + w2v_block: bool = False + alt_block: bool = False + alt_block2: bool = False + alt_attention: bool = False + block_dropout: float = 0 + attention_dropout: float = 0 + activation_dropout: float = 0 + layer_norm_first: bool = False + + fused_ln: bool = True + end_of_block_targets: bool = True + + no_decoder_embed: bool = False + no_decoder_pos_embed: bool = False + mask_noise_std: float = 0 + + single_qkv: bool = False + use_rel_pos_bias: bool = False + no_cls: bool = False + + +def modify_relative_position_bias(orig_bias, bsz, mask): + if mask is None: + return orig_bias.unsqueeze(0).repeat( + bsz, 1, 1, 1 + ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len + heads, max_seq_len, max_seq_len = orig_bias.shape # includes CLS token + mask_for_rel_pos_bias = torch.cat( + (torch.zeros(bsz, 1, dtype=mask.dtype, device=mask.device), mask), dim=1 + ).bool() # bsz x seqlen (add CLS token) + unmasked_for_rel_pos_bias = ~mask_for_rel_pos_bias + unmasked_for_rel_pos_bias = unmasked_for_rel_pos_bias.unsqueeze(1).repeat( + 1, heads, 1 + ) # bsz x seq_len => bsz x heads x seq_len + b_t_t_rel_pos_bias = orig_bias.unsqueeze(0).repeat( + bsz, 1, 1, 1 + ) # heads x seq_len x seq_len => bsz x heads x seq_len x seq_len + b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select( + unmasked_for_rel_pos_bias.unsqueeze(-1) + ) + b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, -1, max_seq_len) + new_len = b_t_t_rel_pos_bias.size(-2) + b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.masked_select( + unmasked_for_rel_pos_bias.unsqueeze(-2) + ) + b_t_t_rel_pos_bias = b_t_t_rel_pos_bias.view(bsz, heads, new_len, new_len) + return b_t_t_rel_pos_bias + + +class AltBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + layer_norm_first=True, + ffn_targets=False, + use_rel_pos_bias=False, + window_size=None, + alt_attention=False, + ): + super().__init__() + + self.layer_norm_first = layer_norm_first + self.ffn_targets = ffn_targets + + from timm.models.vision_transformer import Attention, DropPath, Mlp + + self.norm1 = norm_layer(dim) + self.use_rel_pos_bias = use_rel_pos_bias + if use_rel_pos_bias: + self.attn = AltAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + window_size=window_size, + ) + else: + if alt_attention: + from .modules import AltAttention as AltAttention2 + self.attn = AltAttention2( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + else: + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + def forward(self, x, rel_pos_bias=None, pos_mask=None): + if self.layer_norm_first: + if self.use_rel_pos_bias: + x = x + self.drop_path( + self.attn( + self.norm1(x), rel_pos_bias=rel_pos_bias, pos_mask=pos_mask + ) + ) + else: + x = x + self.drop_path(self.attn(self.norm1(x))) + t = self.mlp(self.norm2(x)) + x = x + self.drop_path(t) + if not self.ffn_targets: + t = x + return x, t + else: + if self.use_rel_pos_bias: + x = x + self.drop_path( + self.attn(x, rel_pos_bias=rel_pos_bias, pos_mask=pos_mask) + ) + else: + x = x + self.drop_path(self.attn(x)) + r = x = self.norm1(x) + x = self.mlp(x) + t = x + x = self.norm2(r + self.drop_path(x)) + if not self.ffn_targets: + t = x + return x, t + + +class AltAttention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + window_size=None, + attn_head_dim=None, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + if window_size: + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1 + ) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + # cls to token & token 2 cls & cls to cls + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, + dtype=relative_coords.dtype, + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + else: + self.window_size = None + self.relative_position_bias_table = None + self.relative_position_index = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, rel_pos_bias=None, pos_mask=None): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + ( + self.q_bias, + torch.zeros_like(self.v_bias, requires_grad=False), + self.v_bias, + ) + ) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if self.relative_position_bias_table is not None: + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + modify_relative_position_bias( + relative_position_bias, x.size(0), pos_mask + ) + + if rel_pos_bias is not None: + attn = attn + rel_pos_bias + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class RelativePositionBias(nn.Module): + def __init__(self, window_size, num_heads): + super().__init__() + self.window_size = window_size + self.num_relative_distance = (2 * window_size[0] - 1) * ( + 2 * window_size[1] - 1 + ) + 3 + self.relative_position_bias_table = nn.Parameter( + torch.zeros(self.num_relative_distance, num_heads) + ) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(window_size[0]) + coords_w = torch.arange(window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * window_size[1] - 1 + relative_position_index = torch.zeros( + size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype + ) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = self.num_relative_distance - 3 + relative_position_index[0:, 0] = self.num_relative_distance - 2 + relative_position_index[0, 0] = self.num_relative_distance - 1 + + self.register_buffer("relative_position_index", relative_position_index) + + def forward(self): + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1] + 1, + self.window_size[0] * self.window_size[1] + 1, + -1, + ) # Wh*Ww,Wh*Ww,nH + return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + +def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size[0], dtype=np.float32) + grid_w = np.arange(grid_size[1], dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, grid_size[0], grid_size[1]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +def interpolate_pos_embed(model, checkpoint_model): + if "pos_embed" in checkpoint_model: + pos_embed_checkpoint = checkpoint_model["pos_embed"] + embedding_size = pos_embed_checkpoint.shape[-1] + num_patches = model.patch_embed.num_patches + num_extra_tokens = model.pos_embed.shape[-2] - num_patches + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + new_size = int(num_patches ** 0.5) + # class_token and dist_token are kept unchanged + if orig_size != new_size: + print( + "Position interpolate from %dx%d to %dx%d" + % (orig_size, orig_size, new_size, new_size) + ) + extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] + pos_tokens = pos_tokens.reshape( + -1, orig_size, orig_size, embedding_size + ).permute(0, 3, 1, 2) + pos_tokens = torch.nn.functional.interpolate( + pos_tokens, + size=(new_size, new_size), + mode="bicubic", + align_corners=False, + ) + pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) + new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) + checkpoint_model["pos_embed"] = new_pos_embed + + +@register_model("mae", dataclass=MaeConfig) +class MaeModel(BaseFairseqModel): + def __init__(self, cfg: MaeConfig): + super().__init__() + self.cfg = cfg + + self.mask_ratio = cfg.mask_ratio + + # -------------------------------------------------------------------------- + # MAE encoder specifics + self.patch_embed = PatchEmbed( + cfg.input_size, cfg.patch_size, cfg.in_chans, cfg.embed_dim + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, cfg.embed_dim)) if not cfg.no_cls else None + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + int(not cfg.no_cls), cfg.embed_dim), requires_grad=False + ) # fixed sin-cos embedding + + norm_layer = partial(nn.LayerNorm, eps=cfg.norm_eps) + + dpr = [ + x.item() for x in torch.linspace(0, cfg.drop_path_rate, cfg.depth) + ] # stochastic depth decay rule + + def make_block(drop_path): + if cfg.w2v_block: + return TransformerSentenceEncoderLayer( + embedding_dim=cfg.embed_dim, + ffn_embedding_dim=cfg.embed_dim * cfg.mlp_ratio, + num_attention_heads=cfg.num_heads, + dropout=cfg.block_dropout, + attention_dropout=cfg.attention_dropout, + activation_dropout=cfg.activation_dropout, + activation_fn="gelu", + layer_norm_first=cfg.layer_norm_first, + drop_path=drop_path, + norm_eps=1e-6, + single_qkv=cfg.single_qkv, + fused_ln=cfg.fused_ln, + ) + elif cfg.alt_block: + window_size = ( + cfg.input_size // self.patch_embed.patch_size[0], + cfg.input_size // self.patch_embed.patch_size[1], + ) + return AltBlock( + cfg.embed_dim, + cfg.num_heads, + cfg.mlp_ratio, + qkv_bias=True, + qk_scale=None, + norm_layer=norm_layer, + drop_path=drop_path, + layer_norm_first=cfg.layer_norm_first, + ffn_targets=not cfg.end_of_block_targets, + use_rel_pos_bias=cfg.use_rel_pos_bias, + window_size=window_size + if (self.cfg.use_rel_pos_bias and not self.cfg.shared_rel_pos_bias) + else None, + alt_attention=cfg.alt_attention, + ) + elif cfg.alt_block2: + from .modules import AltBlock as AltBlock2 + return AltBlock2( + cfg.embed_dim, + cfg.num_heads, + cfg.mlp_ratio, + qkv_bias=True, + qk_scale=None, + norm_layer=norm_layer, + drop_path=drop_path, + layer_norm_first=cfg.layer_norm_first, + ffn_targets=not cfg.end_of_block_targets, + ) + else: + return Block( + cfg.embed_dim, + cfg.num_heads, + cfg.mlp_ratio, + qkv_bias=True, + qk_scale=None, + norm_layer=norm_layer, + drop_path=drop_path, + ) + + self.blocks = nn.ModuleList([make_block(dpr[i]) for i in range(cfg.depth)]) + self.norm = norm_layer(cfg.embed_dim) + # -------------------------------------------------------------------------- + + # -------------------------------------------------------------------------- + # MAE decoder specifics + self.decoder_embed = ( + nn.Linear(cfg.embed_dim, cfg.decoder_embed_dim, bias=True) + if not cfg.no_decoder_embed + else None + ) + + self.mask_token = ( + nn.Parameter( + torch.zeros( + 1, + 1, + cfg.decoder_embed_dim + if not cfg.no_decoder_embed + else cfg.embed_dim, + ) + ) + if cfg.mask_noise_std <= 0 + else None + ) + + self.decoder_pos_embed = ( + nn.Parameter( + torch.zeros( + 1, + num_patches + 1, + cfg.decoder_embed_dim + if not cfg.no_decoder_embed + else cfg.embed_dim, + ), + requires_grad=False, + ) + if not cfg.no_decoder_pos_embed + else None + ) + + self.decoder_blocks = nn.ModuleList( + [ + Block( + cfg.decoder_embed_dim, + cfg.decoder_num_heads, + cfg.mlp_ratio, + qkv_bias=True, + qk_scale=None, + norm_layer=norm_layer, + ) + for _ in range(cfg.decoder_depth) + ] + ) + + self.decoder_norm = norm_layer(cfg.decoder_embed_dim) + self.decoder_pred = nn.Linear( + cfg.decoder_embed_dim, cfg.patch_size ** 2 * cfg.in_chans, bias=True + ) # decoder to patch + # -------------------------------------------------------------------------- + + self.norm_pix_loss = cfg.norm_pix_loss + + self.initialize_weights() + + for pn, p in self.named_parameters(): + if len(p.shape) == 1 or pn.endswith(".bias"): + p.param_group = "no_decay" + else: + p.param_group = "with_decay" + + def initialize_weights(self): + # initialization + # initialize (and freeze) pos_embed by sin-cos embedding + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + int(self.patch_embed.num_patches ** 0.5), + cls_token=not self.cfg.no_cls, + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + if self.decoder_pos_embed is not None: + decoder_pos_embed = get_2d_sincos_pos_embed( + self.decoder_pos_embed.shape[-1], + int(self.patch_embed.num_patches ** 0.5), + cls_token=not self.cfg.no_cls, + ) + self.decoder_pos_embed.data.copy_( + torch.from_numpy(decoder_pos_embed).float().unsqueeze(0) + ) + + # initialize patch_embed like nn.Linear (instead of nn.Conv2d) + w = self.patch_embed.proj.weight.data + torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + + # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) + if self.cls_token is not None: + torch.nn.init.normal_(self.cls_token, std=0.02) + + if self.mask_token is not None: + torch.nn.init.normal_(self.mask_token, std=0.02) + + # initialize nn.Linear and nn.LayerNorm + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def patchify(self, imgs): + """ + imgs: (N, 3, H, W) + x: (N, L, patch_size**2 *3) + """ + p = self.patch_embed.patch_size[0] + assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 + + h = w = imgs.shape[2] // p + x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = torch.einsum("nchpwq->nhwpqc", x) + x = x.reshape(shape=(imgs.shape[0], h * w, p ** 2 * 3)) + return x + + def unpatchify(self, x): + """ + x: (N, L, patch_size**2 *3) + imgs: (N, 3, H, W) + """ + p = self.patch_embed.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + return imgs + + def random_masking(self, x, mask_ratio): + """ + Perform per-sample random masking by per-sample shuffling. + Per-sample shuffling is done by argsort random noise. + x: [N, L, D], sequence + """ + N, L, D = x.shape # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + ids_shuffle = torch.argsort( + noise, dim=1 + ) # ascend: small is keep, large is remove + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) + + # generate the binary mask: 0 is keep, 1 is remove + mask = torch.ones([N, L], device=x.device) + mask[:, :len_keep] = 0 + # unshuffle to get the binary mask + mask = torch.gather(mask, dim=1, index=ids_restore) + + return x_masked, mask, ids_restore # x_masked is actually unmasked x + + @classmethod + def build_model(cls, cfg: MaeConfig, task=None): + """Build a new model instance.""" + + return cls(cfg) + + def forward_encoder(self, x, mask_ratio): + # embed patches + x = self.patch_embed(x) + + # add pos embed w/o cls token + # if self.cls_token is not None: + # x = x + self.pos_embed + # else: + x = x + self.pos_embed[:, 1:, :] + + # masking: length -> length * mask_ratio + if mask_ratio > 0: + x, mask, ids_restore = self.random_masking(x, mask_ratio) + else: + mask = ids_restore = None + + # append cls token + if self.cls_token is not None: + cls_token = self.cls_token + self.pos_embed[:, :1, :] + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # apply Transformer blocks + for blk in self.blocks: + x = blk(x) + + if self.norm is not None: + x = self.norm(x) + + return x, mask, ids_restore + + def forward_decoder(self, x, ids_restore): + # embed tokens + x = self.decoder_embed(x) + + # append mask tokens to sequence + mask_tokens = self.mask_token.repeat( + x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1 + ) + if self.cls_token is not None: + x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token + else: + x_ = torch.cat([x, mask_tokens], dim=1) # no cls token + + x_ = torch.gather( + x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]) + ) # unshuffle + + if self.cls_token is not None: + x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token + + # add pos embed + x = x + self.decoder_pos_embed + + # apply Transformer blocks + for blk in self.decoder_blocks: + x = blk(x) + x = self.decoder_norm(x) + + # predictor projection + x = self.decoder_pred(x) + + if self.cls_token is not None: + # remove cls token + x = x[:, 1:, :] + + return x + + def forward_loss(self, imgs, pred, mask): + """ + imgs: [N, 3, H, W] + pred: [N, L, p*p*3] + mask: [N, L], 0 is keep, 1 is remove, + """ + target = self.patchify(imgs) + if self.norm_pix_loss: + mean = target.mean(dim=-1, keepdim=True) + var = target.var(dim=-1, keepdim=True) + target = (target - mean) / (var + 1.0e-6) ** 0.5 + + loss = (pred - target) ** 2 + loss = loss.mean(dim=-1) # [N, L], mean loss per patch + + loss = (loss * mask).sum() + return loss, mask.sum() + + def forward(self, imgs, predictions_only=False): + latent, mask, ids_restore = self.forward_encoder( + imgs, self.mask_ratio if not predictions_only else 0 + ) + + if predictions_only: + return latent + + pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] + loss, sample_size = self.forward_loss(imgs, pred, mask) + + result = { + "losses": {"regression": loss}, + "sample_size": sample_size, + } + return result + + def remove_pretraining_modules(self): + self.decoder_embed = None + self.decoder_blocks = None + self.decoder_norm = None + self.decoder_pos_embed = None + self.decoder_pred = None + self.mask_token = None + if self.cfg.layer_norm_first: + self.norm = None diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/modules.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..db82e4ce9dec598b049939c16ed6bf97c5f0395e --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/eat/modules.py @@ -0,0 +1,616 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +from dataclasses import dataclass +from fairseq.modules import ( + LayerNorm, + SamePad, + # SamePad2d, + # TransposeLast, +) + +try: + from fairseq.modules import SamePad2d +except: + class SamePad2d(nn.Module): + def __init__(self, kernel_size): + super().__init__() + self.remove = 1 if kernel_size % 2 == 0 else 0 + + def forward(self, x): + assert len(x.size()) == 4 + if self.remove > 0: + x = x[:, :, : -self.remove, : -self.remove] + return x + +class TransposeLast(nn.Module): + def __init__(self, deconstruct_idx=None, tranpose_dim=-2): + super().__init__() + self.deconstruct_idx = deconstruct_idx + self.tranpose_dim = tranpose_dim + + def forward(self, x): + if self.deconstruct_idx is not None: + x = x[self.deconstruct_idx] + return x.transpose(self.tranpose_dim, -1) + + +@dataclass +class D2vDecoderConfig: + decoder_dim: int = 384 + decoder_groups: int = 16 + decoder_kernel: int = 5 + decoder_layers: int = 5 + input_dropout: float = 0.1 + + add_positions_masked: bool = False + add_positions_all: bool = False + + decoder_residual: bool = True + projection_layers: int = 1 + projection_ratio: float = 2.0 + + +class FixedPositionalEncoder(nn.Module): + def __init__(self, pos_embed): + super().__init__() + self.positions = pos_embed + + def forward(self, x, padding_mask): + return self.positions + + +class TextFeatPositionalEncoder(nn.Module): + """ + Original encoder expects (B, T) long input. This module wraps it to take + local_encoder output which are (B, T, D) float tensors + """ + + def __init__(self, pos_encoder): + super().__init__() + self.pos_encoder = pos_encoder + + def forward(self, x, padding_mask): + # assume padded token embeddings are 0s + # TODO: consider using padding_mask as input + return self.pos_encoder(x[..., 0]) + + +class BlockEncoder(nn.Module): + def __init__(self, blocks, norm_layer, layer_norm_first, layerdrop, dropout): + super().__init__() + self.blocks = blocks + self.norm = norm_layer + self.layer_norm_first = layer_norm_first + self.layerdrop = layerdrop + self.dropout = nn.Dropout(dropout, inplace=True) + + def forward(self, x, padding_mask, alibi_bias, alibi_scale): + if self.norm is not None and not self.layer_norm_first: + x = self.norm(x) + + x = self.dropout(x) + + for i, blk in enumerate(self.blocks): + if ( + not self.training + or self.layerdrop == 0 + or (np.random.random() > self.layerdrop) + ): + ab = alibi_bias + if ab is not None and alibi_scale is not None: + scale = ( + alibi_scale[i] + if alibi_scale.size(0) > 1 + else alibi_scale.squeeze(0) + ) + ab = ab * scale.type_as(ab) + x, _ = blk(x, padding_mask, ab) + + if self.norm is not None and self.layer_norm_first: + x = self.norm(x) + + return x + + +class DecoderBase(nn.Module): + decoder_cfg: D2vDecoderConfig + + def __init__(self, cfg: D2vDecoderConfig): + super().__init__() + + self.decoder_cfg = cfg + + def reset_parameters(self): + for mod in self.proj.modules(): + if isinstance(mod, nn.Linear): + mod.reset_parameters() + + def add_residual(self, x, residual, i, mask_info): + if ( + residual is None + or not self.decoder_cfg.decoder_residual + or residual.size(1) != x.size(1) + ): + return x + + ret = x + residual + + return ret + + +class Decoder1d(DecoderBase): + def __init__(self, cfg: D2vDecoderConfig, input_dim): + super().__init__(cfg) + + def make_block(in_dim): + block = [ + nn.Conv1d( + in_dim, + cfg.decoder_dim, + kernel_size=cfg.decoder_kernel, + padding=cfg.decoder_kernel // 2, + groups=cfg.decoder_groups, + ), + SamePad(cfg.decoder_kernel), + TransposeLast(), + LayerNorm(cfg.decoder_dim, elementwise_affine=False), + TransposeLast(), + nn.GELU(), + ] + + return nn.Sequential(*block) + + self.blocks = nn.Sequential( + *[ + make_block(input_dim if i == 0 else cfg.decoder_dim) + for i in range(cfg.decoder_layers) + ] + ) + + projs = [] + curr_dim = cfg.decoder_dim + for i in range(cfg.projection_layers - 1): + next_dim = int(curr_dim * cfg.projection_ratio) if i == 0 else curr_dim + projs.append(nn.Linear(curr_dim, next_dim)) + projs.append(nn.GELU()) + curr_dim = next_dim + projs.append(nn.Linear(curr_dim, input_dim)) + if len(projs) == 1: + self.proj = projs[0] + else: + self.proj = nn.Sequential(*projs) + + def forward(self, x, mask_info): + + x = x.transpose(1, 2) + + residual = x + + for i, layer in enumerate(self.blocks): + x = layer(x) + x = self.add_residual(x, residual, i, mask_info) + residual = x + + x = x.transpose(1, 2) + x = self.proj(x) + return x + + +class Decoder2d(DecoderBase): + def __init__(self, cfg: D2vDecoderConfig, input_dim, h_size, w_size): + super().__init__(cfg) + + self.h_size = h_size + self.w_size = w_size + + def make_block(in_dim): + block = [ + nn.Conv2d( + in_dim, + cfg.decoder_dim, + kernel_size=cfg.decoder_kernel, + padding=cfg.decoder_kernel // 2, + groups=cfg.decoder_groups, + ), + SamePad2d(cfg.decoder_kernel), + TransposeLast(tranpose_dim=-3), + LayerNorm(cfg.decoder_dim, elementwise_affine=False), + TransposeLast(tranpose_dim=-3), + nn.GELU(), + ] + + return nn.Sequential(*block) + + self.blocks = nn.Sequential( + *[ + make_block(input_dim if i == 0 else cfg.decoder_dim) + for i in range(cfg.decoder_layers) + ] + ) + + self.proj = nn.Linear(cfg.decoder_dim, input_dim) + + def forward(self, x, mask_info): + B, T, C = x.shape + + x = x.transpose(1, 2).reshape(B, C, self.h_size, self.w_size) + + residual = x + + for i, layer in enumerate(self.blocks): + x = layer(x) + x = self.add_residual(x, residual, i, mask_info) + residual = x + + x = x.reshape(B, -1, T).transpose(1, 2) + x = self.proj(x) + return x + + +class TransformerDecoder(nn.Module): + decoder_cfg: D2vDecoderConfig + + def __init__(self, cfg: D2vDecoderConfig, input_dim, encoder): + super().__init__() + + self.decoder_cfg = cfg + + self.input_proj = nn.Linear(input_dim, cfg.decoder_dim) + + self.encoder = encoder + + self.proj = nn.Linear(cfg.decoder_dim, input_dim) + + def reset_parameters(self): + from fairseq.modules.transformer_sentence_encoder import init_bert_params + + self.apply(init_bert_params) + + def forward(self, x, mask_info): + x = self.input_proj(x) + x = self.encoder(x, None, None, 1) + x = self.proj(x) + return x + + +class AltBlock(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + mlp_drop=0.0, + post_mlp_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + layer_norm_first=True, + ffn_targets=False, + cosine_attention=False, + ): + super().__init__() + + self.layer_norm_first = layer_norm_first + self.ffn_targets = ffn_targets + + from timm.models.vision_transformer import DropPath, Mlp + + self.norm1 = norm_layer(dim) + self.attn = AltAttention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + cosine_attention=cosine_attention, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop, + ) + self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False) + + def forward(self, x, padding_mask=None, alibi_bias=None): + if self.layer_norm_first: + x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias)) + r = x = self.mlp(self.norm2(x)) + t = x + x = r + self.drop_path(self.post_mlp_dropout(x)) + if not self.ffn_targets: + t = x + else: + x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias)) + r = x = self.norm1(x) + x = self.mlp(x) + t = x + x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x))) + if not self.ffn_targets: + t = x + + return x, t + + +class AltAttention(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + cosine_attention=False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.cosine_attention = cosine_attention + + if cosine_attention: + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True + ) + + def forward(self, x, padding_mask=None, alibi_bias=None): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + dtype = q.dtype + + if self.cosine_attention: + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp( + self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01)) + ).exp() + attn = attn * logit_scale + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if alibi_bias is not None: + attn = attn.type_as(alibi_bias) + attn[:, : alibi_bias.size(1)] += alibi_bias + + if padding_mask is not None and padding_mask.any(): + attn = attn.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + + attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) # + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class EncDecAttention(nn.Module): + def __init__( + self, + q_dim, + kv_dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + cosine_attention=False, + ): + super().__init__() + self.num_heads = num_heads + head_dim = q_dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.q_proj = nn.Linear(q_dim, q_dim, bias=qkv_bias) + self.kv_proj = nn.Linear(kv_dim, 2 * q_dim, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(q_dim, q_dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.cosine_attention = cosine_attention + + if cosine_attention: + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True + ) + + def forward(self, q, kv, padding_mask=None, alibi_bias=None): + B, N, C = q.shape + + q = ( + self.q_proj(q) + .reshape(B, N, self.num_heads, C // self.num_heads) + .permute(0, 2, 1, 3) + ) # B x H x L x D + kv = ( + self.kv_proj(kv) + .reshape(B, -1, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) # kv x B x H x L x D + k, v = ( + kv[0], + kv[1], + ) # make torchscript happy (cannot use tensor as tuple) + + dtype = q.dtype + + if self.cosine_attention: + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp( + self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01)) + ).exp() + attn = attn * logit_scale + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + if alibi_bias is not None: + attn = attn.type_as(alibi_bias) + attn[:, : alibi_bias.size(1)] += alibi_bias + + if padding_mask is not None and padding_mask.any(): + attn = attn.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + float("-inf"), + ) + + attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) # + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class EncDecBlock(nn.Module): + def __init__( + self, + q_dim, + kv_dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + mlp_drop=0.0, + post_mlp_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + layer_norm_first=True, + cosine_attention=False, + first_residual=True, + ): + super().__init__() + + self.layer_norm_first = layer_norm_first + + from timm.models.vision_transformer import DropPath, Mlp + + self.norm1 = norm_layer(q_dim) + self.attn = EncDecAttention( + q_dim, + kv_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + cosine_attention=cosine_attention, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(q_dim) + mlp_hidden_dim = int(q_dim * mlp_ratio) + self.mlp = Mlp( + in_features=q_dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=mlp_drop, + ) + self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False) + self.first_residual = first_residual + + def forward(self, q, kv, padding_mask=None, alibi_bias=None): + r = q if self.first_residual else 0 + if self.layer_norm_first: + x = r + self.drop_path( + self.attn(self.norm1(q), kv, padding_mask, alibi_bias) + ) + r = x = self.mlp(self.norm2(x)) + x = r + self.drop_path(self.post_mlp_dropout(x)) + else: + x = r + self.drop_path(self.attn(q, kv, padding_mask, alibi_bias)) + r = x = self.norm1(x) + x = self.mlp(x) + x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x))) + + return x + + +class EncDecTransformerDecoder(nn.Module): + def __init__(self, cfg: D2vDecoderConfig, input_dim): + super().__init__() + + self.input_proj = nn.Linear(input_dim, cfg.decoder_dim) + + self.blocks = nn.Sequential( + *[ + EncDecBlock( + q_dim=cfg.decoder_dim, + kv_dim=input_dim, + num_heads=8, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + mlp_drop=0.0, + post_mlp_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + layer_norm_first=False, + cosine_attention=False, + first_residual=i > 0, + ) + for i in range(cfg.decoder_layers) + ] + ) + + self.proj = nn.Linear(cfg.decoder_dim, input_dim) + + def reset_parameters(self): + from fairseq.modules.transformer_sentence_encoder import init_bert_params + + self.apply(init_bert_params) + + def forward(self, x, kv): + x = self.input_proj(x) + for i, layer in enumerate(self.blocks): + x = layer(x, kv) + + x = self.proj(x) + return x diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/README.md b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/README.md new file mode 100644 index 0000000000000000000000000000000000000000..08aeb44af99358a78e0d9d3f7e841aea738c7034 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/README.md @@ -0,0 +1,5 @@ +add cauchy extension from https://github.com/HazyResearch/state-spaces +```shell +cd state-spaces/extensions/cauchy +python setup.py install +``` diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/__init__.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef6629345e3b792abc2414a4b07403ac1afc88d2 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .mert_model import * # noqa \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/chroma_torch.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/chroma_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..d29cdd422cf85917c0dae9b3862d6985faa3368c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/chroma_torch.py @@ -0,0 +1,261 @@ +from typing import Callable, Optional +import torch +from torchaudio.transforms import Spectrogram +from nnAudio.features.cqt import CQT + +def _hz_to_octs(freqs, tuning=0.0, bins_per_octave=12): + a440 = 440.0 * 2.0 ** (tuning / bins_per_octave) + return torch.log2(freqs / (a440 / 16)) + +def chroma_filterbank( + sample_rate: int, + n_freqs: int, + n_chroma: int, + *, + tuning: float = 0.0, + ctroct: float = 5.0, + octwidth: Optional[float] = 2.0, + norm: int = 2, + base_c: bool = True, +): + """Create a frequency-to-chroma conversion matrix. Implementation adapted from librosa. + + Args: + sample_rate (int): Sample rate. + n_freqs (int): Number of input frequencies. + n_chroma (int): Number of output chroma. + tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0) + ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0) + octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves. + If ``None``, then disable weighting altogether. (Default: 2.0) + norm (int, optional): order of norm to normalize filter bank by. (Default: 2) + base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True) + + Returns: + torch.Tensor: Chroma filter bank, with shape `(n_freqs, n_chroma)`. + """ + # Skip redundant upper half of frequency range. + freqs = torch.linspace(0, sample_rate // 2, n_freqs)[1:] # ->[n_freqs - 1], 均分sample_rate//2 ; 对哪些频率感兴趣 + freq_bins = n_chroma * _hz_to_octs(freqs, bins_per_octave=n_chroma, tuning=tuning) # 这些频率对应的octave坐标下的值(类似于MIDI的音高序号) + freq_bins = torch.cat((torch.tensor([freq_bins[0] - 1.5 * n_chroma]), freq_bins)) + freq_bin_widths = torch.cat( + ( + torch.maximum(freq_bins[1:] - freq_bins[:-1], torch.tensor(1.0)), + torch.tensor([1]), + ) + ) #每个波带对应的octave的宽度(至少为1) + + # (n_freqs, n_chroma) + D = freq_bins.unsqueeze(1) - torch.arange(0, n_chroma) + + n_chroma2 = round(n_chroma / 2) + + # Project to range [-n_chroma/2, n_chroma/2 - 1] #D:[1025, 12] + D = torch.remainder(D + n_chroma2, n_chroma) - n_chroma2 + + fb = torch.exp(-0.5 * (2 * D / torch.tile(freq_bin_widths.unsqueeze(1), (1, n_chroma))) ** 2) + fb = torch.nn.functional.normalize(fb, p=norm, dim=1) #->[1025, 12] + + if octwidth is not None: + fb *= torch.tile( + torch.exp(-0.5 * (((freq_bins.unsqueeze(1) / n_chroma - ctroct) / octwidth) ** 2)), + (1, n_chroma), + ) #->[1025, 12] + + if base_c: + fb = torch.roll(fb, -3 * (n_chroma // 12), dims=1) + + return fb + +class ChromaScale(torch.nn.Module): + r"""Converts spectrogram to chromagram. + + .. devices:: CPU CUDA + + .. properties:: Autograd + + Args: + sample_rate (int): Sample rate of audio signal. + n_freqs (int): Number of frequency bins in STFT. See ``n_fft`` in :class:`Spectrogram`. + n_chroma (int, optional): Number of chroma. (Default: ``12``) + tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0) + ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0) + octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves. + If ``None``, then disable weighting altogether. (Default: 2.0) + norm (int, optional): order of norm to normalize filter bank by. (Default: 2) + base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True) + + Example + >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) + >>> spectrogram_transform = transforms.Spectrogram(n_fft=1024) + >>> spectrogram = spectrogram_transform(waveform) + >>> chroma_transform = transforms.ChromaScale(sample_rate=sample_rate, n_freqs=1024 // 2 + 1) + >>> chroma_spectrogram = chroma_transform(spectrogram) + + See also: + :py:func:`torchaudio.prototype.functional.chroma_filterbank` — function used to + generate the filter bank. + """ + + def __init__( + self, + sample_rate: int, + n_freqs: int, + *, + n_chroma: int = 12, + tuning: float = 0.0, + ctroct: float = 5.0, + octwidth: Optional[float] = 2.0, + norm: int = 2, + base_c: bool = True, + ): + super().__init__() + fb = chroma_filterbank( + sample_rate, n_freqs, n_chroma, tuning=tuning, ctroct=ctroct, octwidth=octwidth, norm=norm, base_c=base_c + ) + self.register_buffer("fb", fb) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + r""" + Args: + specgram (torch.Tensor): Spectrogram of dimension (..., ``n_freqs``, time). + + Returns: + torch.Tensor: Chroma spectrogram of size (..., ``n_chroma``, time). + """ + return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2) #[376, 1025] @ [1025, 12]-> [12, 376] + + +class ChromaSpectrogram(torch.nn.Module): + r"""Generates chromagram for audio signal. + + .. devices:: CPU CUDA + + .. properties:: Autograd + + Composes :py:func:`torchaudio.transforms.Spectrogram` and + and :py:func:`torchaudio.prototype.transforms.ChromaScale`. + + Args: + sample_rate (int): Sample rate of audio signal. + n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. + win_length (int or None, optional): Window size. (Default: ``n_fft``) + hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``) + pad (int, optional): Two sided padding of signal. (Default: ``0``) + window_fn (Callable[..., torch.Tensor], optional): A function to create a window tensor + that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``) + power (float, optional): Exponent for the magnitude spectrogram, + (must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``) + normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``) + wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``) + center (bool, optional): whether to pad :attr:`waveform` on both sides so + that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. + (Default: ``True``) + pad_mode (string, optional): controls the padding method used when + :attr:`center` is ``True``. (Default: ``"reflect"``) + n_chroma (int, optional): Number of chroma. (Default: ``12``) + tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0) + ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0) + octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves. + If ``None``, then disable weighting altogether. (Default: 2.0) + norm (int, optional): order of norm to normalize filter bank by. (Default: 2) + base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True) + + Example + >>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True) + >>> transform = transforms.ChromaSpectrogram(sample_rate=sample_rate, n_fft=400) + >>> chromagram = transform(waveform) # (channel, n_chroma, time) + """ + + def __init__( + self, + sample_rate: int, + n_fft: int, + *, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + pad: int = 0, + window_fn: Callable[..., torch.Tensor] = torch.hann_window, + power: float = 2.0, + normalized: bool = False, + wkwargs: Optional[dict] = None, + center: bool = True, + pad_mode: str = "reflect", + n_chroma: int = 12, + n_bins: Optional[int] = None, + tuning: float = 0.0, + ctroct: float = 5.0, + octwidth: Optional[float] = 2.0, + norm: int = 2, + base_c: bool = True, + use_cqt: bool = False, + ): + super().__init__() + if n_bins is None: + n_bins = n_fft // 2 + 1 + if use_cqt: + self.spectrogram = CQT( + sr=sample_rate, + hop_length=hop_length, + n_bins=n_bins, + bins_per_octave=n_bins//7, + filter_scale=1, + norm=1, + window='hann', + center=True, + pad_mode='constant', + trainable=False, + output_format='Magnitude', + verbose=True, + ) + else: + self.spectrogram = Spectrogram( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + pad=pad, + window_fn=window_fn, + power=power, + normalized=normalized, + wkwargs=wkwargs, + center=center, + pad_mode=pad_mode, + onesided=True, + ) + self.chroma_scale = ChromaScale( + sample_rate, + n_bins, + n_chroma=n_chroma, + tuning=tuning, + base_c=base_c, + ctroct=ctroct, + octwidth=octwidth, + norm=norm, + ) + + def forward(self, waveform: torch.Tensor, normalize=True) -> torch.Tensor: + r""" + Args: + waveform (Tensor): Tensor of audio of dimension (..., time). + + Returns: + Tensor: Chromagram of size (..., ``n_chroma``, time). + """ + spectrogram = self.spectrogram(waveform) #[1025, 376] + chroma_spectrogram = self.chroma_scale(spectrogram) + if normalize: + chroma_spectrogram[chroma_spectrogram < 0] = 0.0 + chroma_spectrogram = torch.nn.functional.normalize(chroma_spectrogram, p=2, dim=-2) + return chroma_spectrogram + +if __name__ == '__main__': + import numpy as np + import librosa + audio_path = 'speech_data/pretrain/music_42/226849998.flac' + sr = 24000 + freq = 75 + hop = int(sr // freq) + y, _sr = librosa.load(audio_path, duration=5, sr=sr) + + chroma_extractor = ChromaSpectrogram(sample_rate=sr, hop_length=hop, n_fft=2048, use_cqt=True) + chroma_tr = chroma_extractor(torch.from_numpy(y)).numpy() \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/mert_model.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/mert_model.py new file mode 100644 index 0000000000000000000000000000000000000000..f60183d71a262b919d43d2efa1298945d7eb8cab --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/mert_model.py @@ -0,0 +1,2470 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# from lib2to3.pytree import _Results +import logging +from dataclasses import dataclass, field +from re import L +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter + +from torchaudio import transforms +# from torch import autocast +from omegaconf import II + +from fairseq import utils +from fairseq.data.data_utils import compute_mask_indices +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.wav2vec.wav2vec2 import ( + EXTRACTOR_MODE_CHOICES, + MASKING_DISTRIBUTION_CHOICES, + LAYER_TYPE_CHOICES, + ConvFeatureExtractionModel, + TransformerEncoder, + TransformerSentenceEncoderLayer, + ConformerWav2Vec2EncoderLayer, +) +from fairseq.modules import GradMultiply, LayerNorm, MultiheadAttention +from fairseq.tasks.hubert_pretraining import ( + HubertPretrainingConfig, + HubertPretrainingTask, +) + +import os +import math + +from nnAudio import features as nnAudioFeatures +from torchaudio.transforms import MelSpectrogram +# from chroma_torch import ChromaSpectrogram + +logger = logging.getLogger(__name__) + +MASK_REPLACE_TYPE_CHOICES = ChoiceEnum(["in_batch", "in_sample"]) +AUDIO_FEAT_EXTRACTOR_TYPE_CHOICES = ChoiceEnum(["w2v_conv", "hstft_conv", "melspec"]) + +@dataclass +class MERTConfig(FairseqDataclass): + label_rate: float = II("task.label_rate") + + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group " + "norm with d groups in the first conv block, whereas layer_norm " + "has layer norms in every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + layer_type: LAYER_TYPE_CHOICES = field( + default="transformer", metadata={"help": "layer type in encoder"} + ) + + # dropouts + dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for the transformer"}, + ) + attention_dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for attention weights"}, + ) + activation_dropout: float = field( + default=0.0, + metadata={"help": "dropout probability after activation in FFN"}, + ) + encoder_layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a tarnsformer layer"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={"help": "dropout to apply to the features (after feat extr)"}, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many " + "dimensions. set to encoder_embed_dim is <= 0" + }, + ) + untie_final_proj: bool = field( + default=False, + metadata={"help": "use separate projection for each target"}, + ) + layer_norm_first: bool = field( + default=False, + metadata={"help": "apply layernorm first in the transformer"}, + ) + + # parameters for feature extractors + # cr: yinghao implementation + audio_extract_type: AUDIO_FEAT_EXTRACTOR_TYPE_CHOICES = field( + default="w2v_conv", + metadata={"help": "the type of audio feature extractor used to extract audio features"}, + ) + melspec_n_bins: int = field( + default=84, metadata={'help':'the papra meter for mel bins in the input feature extractor'} + ) + music_conv_nmel: int = field( + default=80, metadata={'help':'the papra meter for logmel transformation in the input feature extractor'} + ) + music_conv_hoplen: int = field( + default=40, metadata={'help':'the papra meter for STFT hop length in the input feature extractor, default set to 40 for 16k audio to align with HuBERT'} + ) + conv_feature_layers: str = field( + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + metadata={ + "help": "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + + + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, + metadata={"help": "multiply feature extractor var grads by this"}, + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + pos_conv_depth: int = field( + default=1, + metadata={"help": "depth of positional encoder network"}, + ) + + # masking + mask_length: int = field(default=10, metadata={"help": "mask length"}) + mask_prob: float = field( + default=0.65, + metadata={"help": "probability of replacing a token with mask"}, + ) + + # dynamic mask prob + mask_dynamic_prob_step: str = field( + default="[]", + metadata={ + "help": "string describing steps to update mask prob strategy " + "eg, [10000, 40000, 80000, 150000,] " + "the last step should be less than max_updates " + }, + ) + mask_dynamic_prob: str = field( + default="[]", + metadata={ + "help": "string describing mask prob strategy " + "the len() should be len(mask_dynamic_prob_step) + 1" + "eg, [0.1, 0.2, 0.4, 0.6, 0.8]" + }, + ) + + # dynamic mask length + mask_dynamic_len_step: str = field( + default="[]", + metadata={ + "help": "string describing steps to update mask prob strategy " + "eg, [20000, 80000, 150000,] " + "the last step should be less than max_updates " + }, + ) + mask_dynamic_len: str = field( + default="[]", + metadata={ + "help": "string describing mask prob strategy " + "the len() should be len(mask_dynamic_prob_step) + 1" + "eg, [2, 5, 10, 15]" + }, + ) + + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + # replacement in mask + mask_replace: float = field( + default=0.0, + metadata={"help": "probability of replacing the mask embeddings with other embeddings"}, + ) + mask_replace_type: MASK_REPLACE_TYPE_CHOICES = field( + default="in_sample", + metadata={"help": "the strategy of mask replacement; in_sampe or in_batch"}, + ) + mask_origin: float = field( + default=0.0, + metadata={"help": "probability of keeping the original embeddings at the mask position"}, + ) + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + mask_channel_min_space: int = field( + default=1, + metadata={"help": "min space between spans (if no overlap is enabled)"}, + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={"help": "number of filters for convolutional positional embeddings"}, + ) + conv_pos_groups: int = field( + default=16, + metadata={"help": "number of groups for convolutional positional embedding"}, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={"help": "legacy (to be removed)"}, + ) + + # loss computation + skip_masked: bool = field( + default=False, + metadata={"help": "skip computing losses over masked frames"}, + ) + skip_nomask: bool = field( + default=False, + metadata={"help": "skip computing losses over unmasked frames"}, + ) + + checkpoint_activations: bool = field( + default=False, + metadata={"help": "recompute activations and save memory for extra compute"}, + ) + + # FP16 optimization + required_seq_len_multiple: int = field( + default=2, + metadata={ + "help": "pad the input to encoder such that the sequence length is divisible by multiple" + }, + ) + + # Conformer + depthwise_conv_kernel_size: int = field( + default=31, + metadata={ + "help": "depthwise-conv-kernel-size for convolution in conformer layer" + }, + ) + attn_type: str = field( + default="", + metadata={"help": "if espnet use ESPNET MHA"}, + ) + pos_enc_type: str = field( + default="abs", + metadata={"help": "Positional encoding type to use in conformer"}, + ) + fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) + + # codec loss + audio_codec_type: str = field( + default='encodec', + metadata={"help": "audio codec type, candidates: encodec/dac/rvq"}, + ) + audio_codec_dac_model_path: Optional[str] = field( + default=None, + metadata={'help': 'The path to DAC model file. By default download from internet.'} + ) + audio_codec_ckpt_path: Optional[str] = field( + default=None, + metadata={'help': 'The path to the ckpt file of codec model'} + ) + + # cqt loss + audio_cqt_loss_m: bool = field( + default=False, + metadata={"help": "whether to predict the CQT of the audio of masked pard"}, + ) + audio_cqt_bins: int = field( + default=84, + metadata={"help": "the bins of CQT feature"}, + ) + # mel loss + audio_mel_loss_m: bool = field( + default=False, + metadata={"help": "whether to predict the CQT of the audio of masked pard"}, + ) + audio_mel_bins: int = field( + default=84, + metadata={"help": "the bins of CQT feature"}, + ) + + # random quantizer(BEST-RQ) loss + audio_rq_loss_m: bool = field( + default=False, + metadata={"help": "whether to predict the random quantizer code (BEST-RQ) of the audio of masked pard"}, + ) + audio_rq_loss_embed_dim: int = field( + default=16, + metadata={"help": "the dimension of the embeddings of random quantizer"}, + ) + audio_rq_loss_num_codebooks: int = field( + default=1, + metadata={"help": "the codebooks number of random quantizer"}, + ) + audio_rq_loss_num_embeds: int = field( + default=8192, + metadata={"help": "the embeddings number of random quantizer"}, + ) + audio_rq_loss_seed: int = field( + default=-1, + metadata={"help": "random seed for BEST RQ"}, + ) + audio_rq_loss_use_norm: bool = field( + default=False, + metadata={"help": "whether to use norm of random quantizer; default to False, but it is not recommended"}, + ) + audio_rq_loss_use_chroma: bool = field( + default=False, + metadata={"help": "whether to use chroma feature of random quantizer"}, + ) + audio_rq_loss_seed_chroma: int = field( + default=-1, + metadata={"help": "random seed for chromagram in BEST RQ"}, + ) + + + # cqt extractor + feature_extractor_cqt: bool = field( + default=False, + metadata={"help": "whether to use CQT feature as extra input of transformer"}, + ) + feature_extractor_cqt_bins: int = field( + default=84, + metadata={"help": "the number of bins of CQT feature as extra input of transformer"}, + ) + + mixture_prob: float = field( + default=-1.0, + metadata={"help": "whether to do in-batch noise mixture during training"}, + ) + + inbatch_noise_augment_len_range: Optional[str] = field( + default = "[8000, 24000]", + metadata={ + "help": ( + "the range of length of the mix-up noise augmentation, unit in smaples" + ) + }, + ) + + inbatch_noise_augment_number_range: Optional[str] = field( + default = "[1, 3]", + metadata={ + "help": ( + "the range of numbers of the mix-up noise augmentation" + ) + }, + ) + inbatch_noise_augment_volume: float = field( + default = 1.0, + metadata={ + "help": ( + "the coefficient used to modify the volume of the noise audios wavs" + ) + }, + ) + + learnable_temp: bool = field( + default=False, + metadata={"help": "whether to learn nce temperature duing training"}, + ) + learnable_temp_init: float = field( + default = 0.1, + metadata={ + "help": ( + "initial value for the learnable tempatures" + ) + }, + ) + learnable_temp_max: float = field( + default = 100.0, + metadata={ + "help": ( + "maximum scale value of the exp(learnable tempatures)" + ) + }, + ) + + chunk_nce_cal: int = field( + default = -1, + metadata={ + "help": ( + "maximum scale value of the exp(learnable tempatures)" + ) + }, + ) + + pretrained_weights: str = field( + default="", + metadata={"help": "a path of model checkpoint to initialize the weights of the model"}, + ) + + random_codebook: int = field( + default=-1, + metadata={"help": "whether to randomly select n of the codebooks during training"}, + ) + + deepnorm: bool = field( + default=False, + metadata={"help": "whether to use deepnorm from DeepNet"}, + ) + + subln: bool = field( + default=False, + metadata={"help": "whether to use deepnorm from SubLN"}, + ) + + + emb_grad_mult: float = field( + default=1.0, + metadata={"help": "multiply word embedding var grads by this"}, + ) + + attention_relax: float = field( + default=-1.0, + metadata={"help": "whether to use additional relaxing scale for attention module"}, + ) + + do_cnn_feat_stable_layernorm: bool = field( + default=False, + metadata={"help": "whether to modify and add additional non-affine layer_norm after feature and proj(feature)"}, + ) + + wav_normalize: bool = field( + default=False, + metadata={"help": "whether to do layernorm on waveform before fed to CNN"}, + ) + + +class model_mel_pred(torch.nn.Module): + def __init__(self, input_dim, n_bins=84, sr=16000, freq=50, use_as_target=True): + super().__init__() + # self.epsilon=1e-10 + # Getting Mel Spectrogram on the fly + self.spec_layer = transforms.MelSpectrogram(sample_rate=sr, n_fft=2048, hop_length=sr//freq, f_min=32.7, # win_length=None + f_max=None, n_mels=n_bins, window_fn=torch.hann_window, center=True, # normalized=False, + pad_mode='constant', # pad=0, + mel_scale='htk', normalized=True) # norm=None, nrom on slaney mel_scale power: float = 2.0, + + if use_as_target: + self.fc = nn.Linear(input_dim, n_bins) + + self.criterion = nn.MSELoss() + self.forward_dict = { + 'masked_transformer_output': self.plain_forward + } + + def compute_mel(self, x): + ''' + convert waveform to CQT -> [batch, bins, len] -> transpose + finally turns out [batch, len, bins] + ''' + # align with the padding of HuBERT model, + # the truncation is calculated by bruteforce search since the nnAudio padding strategy and fairseq models are different + # x = x[..., :-560] + mels = torch.transpose(self.spec_layer(x), -1, -2) + 1e-5 # [batch, len, bins] + # compute log mel + logmel = torch.log(mels) + # return logmel + # # normalize + S = (logmel - logmel.mean()) / (logmel.std() + 1e-5) + return S + + def plain_forward(self, x): + ''' + take input from transformer hidden states: [batch * len_seq, channel] + output: [batch * len_seq, n_bins] + ''' + # x = self.fc1(x) + # x = self.bn(self.relu(x)) + # x = self.fc2(x) + + x = self.fc(x) + + return x + + def forward(self, x, forward_type='masked_transformer_output'): + ''' + take input from transformer hidden states: [batch, len_seq, channel] + output: [batch, len_seq, n_bins] + ''' + + return self.forward_dict[forward_type](x) + +class model_cqt_pred(torch.nn.Module): + def __init__(self, input_dim, n_bins=84, sr=16000, freq=50): + super().__init__() + self.epsilon=1e-10 + # Getting Mel Spectrogram on the fly + self.spec_layer = nnAudioFeatures.cqt.CQT(sr=sr, hop_length=sr//freq, fmin=32.7, + fmax=None, n_bins=n_bins, bins_per_octave=n_bins//7, + filter_scale=1, norm=1, window='hann', center=True, + pad_mode='constant', trainable=False, + output_format='Magnitude', verbose=True) + + + # Initializing a Hubert facebook/hubert-base-ls960 style configuration + # configuration = HubertConfig() + + # Initializing a model from the facebook/hubert-base-ls960 style configuration + # self.hubert = HubertModel(configuration) + # self.encoder = NewConvFeatureExtractionModel(n_fft=n_fft, hop_len=sr//(freq*8)) + + # 2-layer & non-linear version TODO: 增加一个参数,可以使用两种不同的模型结构 + # self.fc1 = nn.Linear(input_dim, 1024) + # self.relu = nn.ReLU(inplace=True) + # self.bn = nn.BatchNorm1d(1024) + # self.fc2 = nn.Linear(1024, n_bins) + + # 1-layer version + + self.fc = nn.Linear(input_dim, n_bins) + + self.criterion = nn.MSELoss() + self.forward_dict = { + 'masked_transformer_output': self.plain_forward + } + def compute_cqt(self, x): + ''' + convert waveform to CQT -> [batch, bins, len] -> transpose + ''' + # align with the padding of HuBERT model, + # the truncation is calculated by bruteforce search since the nnAudio padding strategy and fairseq models are different + # x = x[..., :-560] + return torch.transpose(self.spec_layer(x), -1, -2) + + def keep_dim_forward(self, x): + # if the input is conv output: [batch, channel, len_seq] + # the output will be [batch, n_bins, len_seq] + # z = self.spec_layer(x) + # z = self.hubert(x) + # x = self.encoder(x) + # print(x.shape) + x = self.fc1(torch.transpose(x,1,2)) + # print(x.shape) + x = self.bn(self.relu(torch.transpose(x,1,2))) + # print(x.shape) + x = self.fc2(torch.transpose(x,1,2)) + # print(x.shape) + return torch.transpose(x,1,2) + + def plain_forward(self, x): + ''' + take input from transformer hidden states: [batch * len_seq, channel] + output: [batch * len_seq, n_bins] + ''' + # x = self.fc1(x) + # x = self.bn(self.relu(x)) + # x = self.fc2(x) + + x = self.fc(x) + + return x + + def forward(self, x, forward_type='masked_transformer_output'): + ''' + take input from transformer hidden states: [batch, len_seq, channel] + output: [batch, len_seq, n_bins] + ''' + + return self.forward_dict[forward_type](x) + +def quantize_vector(latent: torch.Tensor, codebook: torch.Tensor): + """ + Symbols in comments: + B: batch_size. + D: latent_dim. + C: num_latent_classes per group + G: num of codebook groups. + + Args: + latent: [B, D] + codebook: [C, G, D // G] + + Returns: + (quantized, codes, onehot). + - quantized: [B, D] + - codes: [B, G] + - onehot: [B, G, C] + """ + + assert len(codebook.size()) == 3 + b, d = latent.size() + c, g, _ = codebook.size() + assert d % g == 0 + + device_l, device_c = latent.device, codebook.device + + latent = latent.reshape(b, g, d // g).to(device_c) + + # [B, G, C] + # torch.transpose(codebook, [2,1,0]) + distance = ( + # [b, g, 1] + torch.sum(latent**2, -1, keepdim=True) - + # [b, g, c] + 2 * torch.einsum('bgd,cgd->bgc', latent, codebook) + + # [1, g, c] + torch.sum(codebook.permute([2, 1, 0])**2, 0, keepdim=True)) + + # [B, G] + codes = torch.argmin(distance, dim=-1) + + # [B, G, C] + one_hot = torch.nn.functional.one_hot(codes, c).type(codebook.dtype) + quantized = torch.einsum('bgc,cgd->bgd', one_hot, codebook) + quantized = torch.reshape(quantized, [b, d]) + return quantized, codes.to(device_l), one_hot + +def l2norm(x, dim=-1): + return x / x.norm(p=2, dim=dim, keepdim=True) + +import librosa +CHROMA_N_BINS = 12 +class WavToChroma: + def __init__(self, sr, freq): + self.sr = sr + self.freq = freq + + def __call__(self, wav:torch.Tensor): #[Batch, Time] -> [Batch, N_Bins=12, SeqLen] + device = wav.device + chroma = librosa.feature.chroma_cqt(y=wav.cpu().numpy(), sr=self.sr, hop_length=int(self.sr//self.freq)) + chroma = torch.from_numpy(chroma).to(device) + return chroma + +from contextlib import contextmanager +@contextmanager +def torch_rand_seed(seed): + if seed is not None and seed > 0: + rng_state = torch.random.get_rng_state() + torch.manual_seed(seed) + yield rng_state + torch.random.set_rng_state(rng_state) + else: + yield None + + +@register_model("mert", dataclass=MERTConfig) +class MERTModel(BaseFairseqModel): + def __init__( + self, + cfg: MERTConfig, + task_cfg: HubertPretrainingConfig, + dictionaries: List[Dictionary], + ) -> None: + super().__init__() + logger.info(f"MERTModel Config: {cfg}") + self.cfg = cfg + self.task_cfg = task_cfg + + + if self.use_encodec_target: + from encodec import EncodecModel + from encodec.utils import convert_audio + # TODO: add encodec module + # Done + def get_encodec(cfg: MERTConfig): + if cfg.audio_codec_type == 'encodec' or cfg.audio_codec_type is None: + model = EncodecModel.encodec_model_24khz() + model.set_target_bandwidth(6.0) + elif cfg.audio_codec_type == 'dac': + import dac + if cfg.audio_codec_dac_model_path is None: + model_path = dac.utils.download(model_type="44khz") + else: + model_path = cfg.audio_codec_dac_model_path + model = dac.DAC.load(model_path) + elif cfg.audio_codec_type == 'rvq': + from .rvq import ResidualVectorQuantize + model = ResidualVectorQuantize( + input_dim = 120, + n_codebooks = 8, + codebook_size = 1024, + codebook_dim = 16, + quantizer_dropout = 0.0, + ) + if cfg.audio_codec_ckpt_path is None or not os.path.exists(cfg.audio_codec_ckpt_path): + logger.warning('No checkpoint path for rvq model. Assume it in inference mode.') + else: + state_dict = torch.load(cfg.audio_codec_ckpt_path, map_location="cpu") + model.load_state_dict(state_dict) + import torchaudio + self.rvq_preprocess = torchaudio.transforms.MelSpectrogram( + sample_rate = 24000, + n_mels = 120, + n_fft=2048, + win_length = int(24000//75), + hop_length = int(24000//75), + center = True, + pad_mode='constant', # pad=0, + mel_scale='htk', + normalized=True, + ) + else: + raise ValueError(f"Unknown audio encoder type: {cfg.codec_type}.") + + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # model.to(device) + return model + + self.encodec = get_encodec(cfg).eval() #这里添加了self.encodec + self.rsp_src2encodec = nn.Identity() + for param in self.encodec.parameters(): + param.requires_grad = False + if self.use_rq_target: + total_n_codebooks = 0 + if cfg.audio_rq_loss_use_norm: + self.rq_norm = l2norm + else: + self.rq_norm = lambda x, *args, **kw: x + with torch_rand_seed(cfg.audio_rq_loss_seed): + # random projectoin + self.rq_projection = torch.nn.parameter.Parameter( + torch.empty(cfg.melspec_n_bins, cfg.audio_rq_loss_embed_dim * cfg.audio_rq_loss_num_codebooks), + requires_grad=False, + ) + torch.nn.init.xavier_uniform_(self.rq_projection) + + # codebooks + # [num_embeddings, num_codebooks, num_embeddings] means + # [C, G, D] see quantize_vector + self.rq_embeddings = torch.nn.parameter.Parameter( + torch.empty(cfg.audio_rq_loss_num_embeds, cfg.audio_rq_loss_num_codebooks, cfg.audio_rq_loss_embed_dim), + requires_grad=False, + ) + torch.nn.init.normal_(self.rq_embeddings) + self.rq_embeddings.data = self.rq_norm(self.rq_embeddings.data, dim=-1) + total_n_codebooks += cfg.audio_rq_loss_num_codebooks + + if cfg.audio_rq_loss_use_chroma: + self.rq_wav2chroma = WavToChroma(sr=task_cfg.sample_rate, freq=cfg.label_rate) + with torch_rand_seed(cfg.audio_rq_loss_seed_chroma): + # random projectoin + self.rq_projection_chroma = torch.nn.parameter.Parameter( + torch.empty(CHROMA_N_BINS, cfg.audio_rq_loss_embed_dim * cfg.audio_rq_loss_num_codebooks), + requires_grad=False, + ) + torch.nn.init.xavier_uniform_(self.rq_projection_chroma) + + # codebooks + # [num_embeddings, num_codebooks, num_embeddings] means + # [C, G, D] see quantize_vector + self.rq_embeddings_chroma = torch.nn.parameter.Parameter( + torch.empty(cfg.audio_rq_loss_num_embeds, cfg.audio_rq_loss_num_codebooks, cfg.audio_rq_loss_embed_dim), + requires_grad=False, + ) + torch.nn.init.normal_(self.rq_embeddings_chroma) + self.rq_embeddings_chroma.data = self.rq_norm(self.rq_embeddings_chroma.data, dim=-1) + total_n_codebooks += cfg.audio_rq_loss_num_codebooks + + assert len(dictionaries) == 1, "BEST-RQ target requires a single dictionary" + + dictionaries = [dictionaries[0] for _ in range(total_n_codebooks)] + + feature_enc_layers = eval(cfg.conv_feature_layers) # noqa + + if self.cfg.feature_extractor_cqt: + self.feature_extractor_cqt = nnAudioFeatures.cqt.CQT(sr=task_cfg.sample_rate, hop_length=task_cfg.sample_rate//50, fmin=32.7, + fmax=None, n_bins=cfg.feature_extractor_cqt_bins, bins_per_octave=cfg.feature_extractor_cqt_bins//7, + filter_scale=1, norm=1, window='hann', center=True, + pad_mode='constant', trainable=False, + output_format='Magnitude', verbose=True) + + if cfg.audio_extract_type == 'w2v_conv': + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, + ) + self.embed = feature_enc_layers[-1][0] + elif cfg.audio_extract_type == 'melspec': + self.feature_extractor_melspec = model_mel_pred( + input_dim=None, + n_bins=cfg.melspec_n_bins, + sr=int(task_cfg.sample_rate), + freq=int(cfg.label_rate), + use_as_target=False, + ) + self.embed = cfg.melspec_n_bins + else: + raise NotImplementedError + + if self.cfg.feature_extractor_cqt: + self.embed = feature_enc_layers[-1][0] + cfg.feature_extractor_cqt_bins + + feature_ds_rate = np.prod([s for _, _, s in feature_enc_layers]) + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate + + self.do_cnn_feat_stable_layernorm = cfg.do_cnn_feat_stable_layernorm + + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + if self.post_extract_proj is not None and self.do_cnn_feat_stable_layernorm: + self.post_proj_layer_norm = LayerNorm(cfg.encoder_embed_dim, elementwise_affine=False) + else: + self.post_proj_layer_norm = None + + + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + self.mask_replace = cfg.mask_replace + self.mask_replace_type = cfg.mask_replace_type + self.mask_origin = cfg.mask_origin + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.emb_grad_mult = cfg.emb_grad_mult + + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + + self.learnable_temp = cfg.learnable_temp + + + self.wav_normalize = cfg.wav_normalize + + if not self.learnable_temp: + self.logit_temp = cfg.logit_temp + else: + self.logit_temp_list = nn.Parameter(torch.FloatTensor(len(dictionaries))) + nn.init.constant_(self.logit_temp_list, np.log(1/cfg.learnable_temp_init)) + # self.logit_temp_list = [ + # nn.Parameter(torch.tensor([np.log(1/cfg.learnable_temp_init)])) for _ in range(len(dictionaries)) + # ] + + self.learnable_temp_max = cfg.learnable_temp_max + + self.chunk_nce_cal = cfg.chunk_nce_cal + + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + if ( + cfg.attention_relax > 0 or \ + cfg.deepnorm or \ + cfg.subln + ): + if cfg.subln: + assert cfg.layer_norm_first + if cfg.deepnorm: + assert not cfg.layer_norm_first + + self.encoder = TransformerEncoder_extend(cfg) + + else: + self.encoder = TransformerEncoder(cfg) + + if self.do_cnn_feat_stable_layernorm: + self.layer_norm = LayerNorm(self.embed, elementwise_affine=False) + else: + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.untie_final_proj = cfg.untie_final_proj + self.random_codebook = cfg.random_codebook + + # use all codebook + if self.random_codebook <=0: + if self.untie_final_proj: + self.final_proj = nn.Linear( + cfg.encoder_embed_dim, final_dim * len(dictionaries) + ) + else: + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + else: + assert self.random_codebook <= len(dictionaries) + if self.untie_final_proj: + self.final_projs = nn.ModuleList([nn.Linear(cfg.encoder_embed_dim, final_dim) for _ in range(len(dictionaries))]) + else: + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + # modules below are not needed during fine-tuning + if any([d is None for d in dictionaries]): + logger.info("cannot find dictionary. assume will be used for fine-tuning") + else: + self.num_classes = [len(d) for d in dictionaries] + self.label_embs_concat = nn.Parameter( + torch.FloatTensor(sum(self.num_classes), final_dim), + # requires_grad=False, + ) + nn.init.uniform_(self.label_embs_concat) + + if cfg.audio_cqt_loss_m: + logger.info("train the model with extra task: reconstruct cqt from transformer output") + self.encoder_cqt_model = model_cqt_pred( + input_dim=cfg.encoder_embed_dim, + n_bins=cfg.audio_cqt_bins, + sr=int(task_cfg.sample_rate), + freq=int(cfg.label_rate) + ) + if cfg.audio_mel_loss_m: + logger.info("train the model with extra task: reconstruct mel from transformer output") + self.encoder_mel_model = model_mel_pred( + input_dim=cfg.encoder_embed_dim, + n_bins=cfg.audio_mel_bins, + sr=int(task_cfg.sample_rate), + freq=int(cfg.label_rate)) + + self.num_updates = 0 + + self.mask_dynamic_prob_step = eval(cfg.mask_dynamic_prob_step) + self.mask_dynamic_prob = eval(cfg.mask_dynamic_prob) + + if len(self.mask_dynamic_prob_step) > 0 and len(self.mask_dynamic_prob) > 0: + self.initialize_dynamic_mask_prob() + else: + self.mask_dynamic_prob_stage = -1 + + self.mask_dynamic_len_step = eval(cfg.mask_dynamic_len_step) + self.mask_dynamic_len = eval(cfg.mask_dynamic_len) + if len(self.mask_dynamic_len_step) > 0 and len(self.mask_dynamic_len) > 0: + self.initialize_dynamic_mask_len() + else: + self.mask_dynamic_len_stage = -1 + + self.mixture_prob = cfg.mixture_prob + self.inbatch_noise_augment_len_range = eval(cfg.inbatch_noise_augment_len_range) + self.inbatch_noise_augment_number_range = eval(cfg.inbatch_noise_augment_number_range) + self.inbatch_noise_augment_volume = cfg.inbatch_noise_augment_volume + + if os.path.isfile(cfg.pretrained_weights): + load_patterns = ['feature_extractor.'] # ['feature_extractor.', 'encoder.'] + logger.info(f"initialize {load_patterns} weights with given checkpoint") + pretrained_dict = torch.load(cfg.pretrained_weights)['model'] + + def filter_keys(patterns, keys): + toloads = [] + for k in keys: + for pattern in patterns: + if pattern in k: + toloads.append(k) + return toloads + modules_to_load = filter_keys(load_patterns, pretrained_dict.keys()) + logger.info(f"found modules to load: {modules_to_load}") + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in modules_to_load} + logger.info(f"extractor sample weitghts before loading:\n{self.feature_extractor.conv_layers[0][2].weight}") + self.load_state_dict(pretrained_dict, strict=False) + logger.info(f"extractor sample weitghts after loading:\n{self.feature_extractor.conv_layers[0][2].weight}") + + if cfg.feature_grad_mult <= 0: + self.feature_extractor.eval() + # for param in self.feature_extractor.parameters(): + # param.requires_grad = False + + + # from collections import defaultdict + # param_dict = defaultdict(lambda:0) + # for name, param in self.named_parameters(): + # name = str(name) + # if '.' in name: + # name = name[:name.index('.')] + # param_dict[name] += param.nelement() + + # print(param_dict) + + @property + def use_encodec_target(self): + return not self.cfg.audio_rq_loss_m + + @property + def use_rq_target(self): + return self.cfg.audio_rq_loss_m + + def inbatch_noise_augment(self, + target_audio: torch.Tensor, target_audio_idx: int , + batch_audios: torch.Tensor, # [bsz, audio_lengths] + noise_len_min: int, noise_len_max: int, + n_noise_min: int, n_noise_max: int, + noise_vol: float = 1.0): + ''' + augmenation that leverages in-batch noise audios. + noise_len_min and noise_len_max are the range of the lengths of noises (counted as samples) + n_noise_min and n_noise_max are the range of number of noises, + ''' + # assert noise_len_max <= target_audio.shape[0] and noise_len_min >= 1 # should assert this outside? + + augmented_audio = torch.clone(target_audio) + + # exclude the target audio and use the rest as noise candidates + noise_pool = torch.flatten(torch.cat((batch_audios[:target_audio_idx,:], batch_audios[target_audio_idx+1:,:]), dim=0)) + + # n_noise = np.random.randint(n_noise_min, n_noise_max+1) + n_noise = torch.randint(n_noise_min, n_noise_max+1, size=(1,)) + + # random_start_idxs = np.random.randint(0, noise_pool.shape[0] - noise_len_max + 1, size=(n_noise,)) + # random_durations = np.random.randint(noise_len_min, noise_len_max+1, size=(n_noise,)) + random_start_idxs = torch.randint(0, noise_pool.shape[0] - noise_len_max + 1, size=(n_noise,)) + random_durations = torch.randint(noise_len_min, noise_len_max+1, size=(n_noise,)) + + + for noise_idx in range(n_noise): + # augmentation_position = np.random.randint(0, target_audio.shape[0] - random_durations[noise_idx]+1, size=None) + augmentation_position = torch.randint(0, target_audio.shape[0] - random_durations[noise_idx]+1, size=(1,)) + + # assign noise to the original audio + augmented_audio[augmentation_position:augmentation_position+random_durations[noise_idx]] += \ + noise_vol * noise_pool[random_start_idxs[noise_idx]: random_start_idxs[noise_idx]+random_durations[noise_idx]] + + return augmented_audio + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: MERTConfig, task: HubertPretrainingTask): + """Build a new model instance.""" + + model = MERTModel(cfg, task.cfg, task.dictionaries) + return model + + def compute_replace_mask(self, padding_mask, mask_indices): + ''' + all variables are numpy array + ''' + original_prob = np.random.rand(*mask_indices.shape)<=self.mask_origin + original_indices = np.all([mask_indices, original_prob], axis=0) + replace_prob = np.random.rand(*mask_indices.shape)<=self.mask_replace + replace_indices = np.all([mask_indices, replace_prob], axis=0) + mask_emb_indices = np.all([mask_indices, ~original_indices, ~replace_indices], axis=0) + + replace_target_indices = np.zeros(mask_indices.shape,dtype=bool) + all_indices = np.ones(mask_indices.shape,dtype=bool) + all_indices = np.all([~padding_mask, all_indices], axis=0) # exclude the padding part + if self.mask_replace_type == 'in_batch': + # replaces with anyone within the batch, no duplicated + n_replace = np.sum(replace_indices) + all_indices = np.where(all_indices) # turn into tuple indices + + replace_target = np.random.choice(len(all_indices[0]), n_replace, replace=False) + replace_target_indices[(all_indices[0][replace_target], all_indices[1][replace_target])] = True + + elif self.mask_replace_type == 'in_sample': + # replaces with anyone within the same sample, no duplicated + for i in range(mask_indices.shape[0]): + # find replacement for each sample + n_replace_insample = np.sum(replace_indices[i]) + all_indices_insample = np.where(all_indices[i]) # (T - padding,) + replace_target_insample = np.random.choice(len(all_indices_insample[0]), n_replace_insample, replace=False) + replace_target_indices[i][all_indices_insample[0][replace_target_insample]] = True + + return mask_emb_indices, replace_indices, replace_target_indices + + def apply_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + if self.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + if self.mask_replace > 0: + mask_emb_indices, replace_indices, replace_target_indices = self.compute_replace_mask(padding_mask, mask_indices) + + mask_indices = torch.from_numpy(mask_indices).to(x.device) # tokens involved in mask prediction task + mask_emb_indices = torch.from_numpy(mask_emb_indices).to(x.device) # tokens replaced with [MASK] + # origin_indices = torch.from_numpy(origin_indices).to(x.device) # tokens remains the same, no need to do assignment + replace_indices = torch.from_numpy(replace_indices).to(x.device) # tokens that are replaced with other tokens + replace_target_indices = torch.from_numpy(replace_target_indices).to(x.device) # tokens that are used to replace + + x[mask_emb_indices] = self.mask_emb + x[replace_indices] = x[replace_target_indices] + + else: + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + + + else: + mask_indices = None + + + if self.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def compute_nce(self, x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + if self.chunk_nce_cal > 0: + logits = [] + + for start in range(0, x.shape[0], self.chunk_nce_cal): + end = start + self.chunk_nce_cal + a = x[start:end] + b = targets[start:end] + # assert a.shape[0] == b.shape[0], f'mismatch shape of a {a.shape} and b {b.shape}, x {x.shape} and targets {targets.shape}' + logits.append(torch.cosine_similarity(a.float(), b.float(), dim=-1).type_as(a)) + logits = torch.cat(logits,dim=0) + else: + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") # 忽略正样本对自身的相似性 + logits = logits.transpose(0, 1) # (num_x, num_cls+1) #相当于变成了 [Batch, Num_Code + 1, Num_Coodebook] 的相似度矩阵,其中第0个是目标 + return logits + + def compute_nce_learned_temp(self, x, pos, negs, logit_temp): + + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + # logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + if self.chunk_nce_cal > 0: + logits = [] + for start in range(0, x.shape[0], self.chunk_nce_cal): + logits.append(torch.cosine_similarity(x[start:start+self.chunk_nce_cal].float(), targets[start:start+self.chunk_nce_cal].float(), dim=-1).type_as(x)) + logits = torch.cat(logits,dim=0) + else: + logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) + + logit_scale = torch.clamp(logit_temp.exp(), max=self.learnable_temp_max) + logits *= logit_scale + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose(0, 1) # (num_x, num_cls+1) + return logits + + + def forward_features(self, source: torch.Tensor) -> torch.Tensor: + ''' + features: BxCxT + ''' + if self.wav_normalize: + assert source.dim() == 2 + with torch.no_grad(): + source = torch.nn.functional.layer_norm(source, source.shape) + + if self.cfg.audio_extract_type == 'w2v_conv': + if self.feature_grad_mult > 0: + features = self.feature_extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + self.feature_extractor.eval() + features = self.feature_extractor(source) + return features + elif self.cfg.audio_extract_type == 'melspec': + with torch.no_grad(): + features = self.feature_extractor_melspec.compute_mel(source).transpose(1, 2) + return features + + def forward_targets( + self, + features: torch.Tensor, + target_list: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + # @yizhilll: if feature * 2 > 3000, then crop the features + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + # @yizhilll: select only the first pseoudo label if there are multiple labels + target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio + target_list = [t[:, target_inds.long()] for t in target_list] + return features, target_list + + def forward_padding_mask( + self, + features: torch.Tensor, + padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1) + padding_mask = padding_mask.all(-1) + return padding_mask + + def set_num_updates(self, num_updates): + super().set_num_updates(num_updates) + + if self.mask_dynamic_prob_stage >=0: + if num_updates == self.mask_dynamic_prob_step[self.mask_dynamic_prob_stage]: + logger.info(f'updating mask_prob from {self.mask_dynamic_prob[self.mask_dynamic_prob_stage]} to {self.mask_dynamic_prob[self.mask_dynamic_prob_stage+1]} at step {num_updates}') + self.mask_prob = self.mask_dynamic_prob[self.mask_dynamic_prob_stage+1] + # stop updating since it gets to the last stage + self.mask_dynamic_prob_stage = self.mask_dynamic_prob_stage + 1 if self.mask_dynamic_prob_stage < len(self.mask_dynamic_prob_step)-1 else -1 + + if self.mask_dynamic_len_stage >=0: + if num_updates == self.mask_dynamic_len_step[self.mask_dynamic_len_stage]: + logger.info(f'updating mask_length from {self.mask_dynamic_len[self.mask_dynamic_len_stage]} to {self.mask_dynamic_len[self.mask_dynamic_len_stage+1]} at step {num_updates}') + self.mask_length = self.mask_dynamic_len[self.mask_dynamic_len_stage+1] + + # stop updating since it gets to the last stage + self.mask_dynamic_len_stage = self.mask_dynamic_len_stage + 1 if self.mask_dynamic_len_stage < len(self.mask_dynamic_len_step)-1 else -1 + + self.num_updates = num_updates + + def encodec_encode(self, inp): + if self.cfg.audio_codec_type == 'encodec': + encoded_frames = self.encodec.encode(inp) #list, B,[ 8,T ] + codes = torch.cat([encoded[0].detach() for encoded in encoded_frames], dim=-1) + return codes + elif self.cfg.audio_codec_type == 'dac': + self.encodec.eval() + x = self.encodec.preprocess(inp, self.task_cfg.sample_rate) + z, codes, latents, _, _ = self.encodec.encode(x) + return codes #[Batch, NumCode=32, Freq*Secs] + elif self.cfg.audio_codec_type == 'rvq': + if len(inp.shape) == 3: + inp = inp.squeeze(1) + x = self.rvq_preprocess(inp) + self.encodec.eval() + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.encodec(x) + return codes + else: + raise NotImplementedError + + #MERTforward + def forward( + self, + source: torch.Tensor, # B,L + target_list: Optional[List[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + cqt_labels: Optional[torch.Tensor] = None, + mel_labels: Optional[torch.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None, + ) -> Dict[str, torch.Tensor]: + """output layer is 1-based""" + # with autocast(device_type=source.device.type, dtype=torch.float32): + # import pdb + # pdb.set_trace() + # print('source_shape:',str(source.shape)) + # TODO: extrace target_list + #这里修改过 + # print('running') + if not features_only and self.use_encodec_target: + with torch.no_grad(): + encodec_inp = self.rsp_src2encodec(source) + # if(self.num_updates==2):print(encodec_inp);exit() + encodec_inp = encodec_inp.unsqueeze(1) # B,1,T + # print(encodec_inp.shape) + codes = self.encodec_encode(encodec_inp) + # print(codes.shape) + # print(target_list.shape) + target_list = [codes[:,codebook_inx,:].detach() for codebook_inx in range(8)] + + if self.mixture_prob > 0: + # compute cqt before mixture + if self.cfg.audio_cqt_loss_m: + cqt_targets = self.encoder_cqt_model.compute_cqt(source) + if self.cfg.audio_mel_loss_m: + mel_targets = self.encoder_mel_model.compute_mel(source) + with torch.no_grad(): + batch_audios = torch.clone(source) + for i in range(source.shape[0]): + if torch.rand(1).item() > self.mixture_prob: + try: + #这里加入批内噪声 + source[i] = self.inbatch_noise_augment( + target_audio = batch_audios[i], target_audio_idx = i, batch_audios = batch_audios, + noise_len_min = self.inbatch_noise_augment_len_range[0], noise_len_max = self.inbatch_noise_augment_len_range[1], + n_noise_min = self.inbatch_noise_augment_number_range[0], n_noise_max = self.inbatch_noise_augment_number_range[1], + noise_vol = self.inbatch_noise_augment_volume) + except: + source[i] = batch_audios[i] + + features = self.forward_features(source) + + if not features_only and self.use_rq_target: + target_list = [] + with torch.no_grad(): + features_projected = torch.matmul(features.transpose(1, 2), self.rq_projection.to(features.device)) # input_dim -> embedding_dim * num_codebooks + + B, T, C = features_projected.size() + features_flatten = features_projected.view(B * T, C) + _, codes, _ = quantize_vector(self.rq_norm(features_flatten, dim=-1), self.rq_embeddings) + codes = codes.reshape(B, T, -1) # [B, T, num_codebooks] + target_list += [codes[:, :, codebook_inx] for codebook_inx in range(codes.shape[-1])] + + if self.cfg.audio_rq_loss_use_chroma: + chroma_features = self.rq_wav2chroma(source) + features_projected = torch.matmul(chroma_features.transpose(1, 2), self.rq_projection_chroma.to(chroma_features.device)) # input_dim -> embedding_dim * num_codebooks + + B, T, C = features_projected.size() + features_flatten = features_projected.view(B * T, C) + _, codes, _ = quantize_vector(self.rq_norm(features_flatten, dim=-1), self.rq_embeddings_chroma) + codes = codes.reshape(B, T, -1) # [B, T, num_codebooks] + target_list += [codes[:, :, codebook_inx] for codebook_inx in range(codes.shape[-1])] + + + if target_list is not None: + features, target_list = self.forward_targets(features, target_list) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) # BxTxC + + if self.cfg.feature_extractor_cqt: + features_cqt = self.feature_extractor_cqt(source).transpose(1, 2) + features_cqt = features_cqt[:,:features.shape[1],:] # align shape + # version 1 + # features = features + features_cqt + # features = self.layer_norm(features) + # version 2 + # features_cqt = self.post_cqt_feature_proj(features_cqt) # v2 + # features = self.layer_norm(features) + self.layer_norm(features_cqt) + # version 3 + features = torch.cat([features,features_cqt], 2) + features = self.layer_norm(features) # BxTxC + else: + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + if self.post_proj_layer_norm is not None: + features = self.post_proj_layer_norm(features) + + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + + if not features_only and mask: + x, mask_indices = self.apply_mask(features, padding_mask, target_list) + else: + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, layer_results = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1, + ) + # 很美妙的设定,取第layer层的特征,如果超出范围则取最后一层 + + if features_only: + #这里有所修改,增加了一个所有层的结果 + return {"x": x, "padding_mask": padding_mask, "features": features,"layer_results": layer_results} + + def compute_pred(proj_x, target, label_embs, logit_temp=None): + # skip the codebook that is not selected + if proj_x is None: + return None + # compute logits for the i-th label set + y = torch.index_select(label_embs, 0, target.long()) # 选择目标对应的嵌入向量 + negs = label_embs.unsqueeze(1).expand(-1, proj_x.size(0), -1) #将嵌入表重复batch次 + if self.target_glu: + y = self.target_glu(y) + negs = self.target_glu(negs) + # proj_x: (S, D) + # y: (S, D) + # negs: (Neg, S, D) + if logit_temp is not None: + return self.compute_nce_learned_temp(proj_x, y, negs, logit_temp) + else: + return self.compute_nce(proj_x, y, negs) + + label_embs_list = self.label_embs_concat.split(self.num_classes, 0) + + if not self.skip_masked: + masked_indices = torch.logical_and(~padding_mask, mask_indices) + + # @yizhilll: TODO merge the codes heredui + if self.random_codebook <= 0: + proj_x_m = self.final_proj(x[masked_indices]) #将特征投射到一个更低维的空间 + if self.untie_final_proj: + proj_x_m_list = proj_x_m.chunk(len(target_list), dim=-1) #按照不同的码本切成多个向量 + else: + proj_x_m_list = [proj_x_m for _ in range(len(target_list))] # no extra RAM taken here + else: + # pass + selected_books = np.random.choice(len(target_list),self.random_codebook) + proj_x_m_list = [] + for i in range(len(target_list)): + if i in selected_books: + if self.untie_final_proj: + proj_x_m_list.append(self.final_projs[i](x[masked_indices])) + else: + proj_x_m_list.append(self.final_proj(x[masked_indices])) + else: + proj_x_m_list.append(None) + + + if self.learnable_temp: + logit_m_list = [ + compute_pred(proj_x_m, t[masked_indices], label_embs_list[i], logit_temp) + for i, (proj_x_m, t, logit_temp), in enumerate(zip(proj_x_m_list, target_list, self.logit_temp_list)) + ] + else: + logit_m_list = [ + compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) + for i, (proj_x_m, t) in enumerate(zip(proj_x_m_list, target_list)) + ] + + # else: + # # # mute to optimize the codes + # proj_x_m_list = [proj_x_m for _ in range(len(target_list))] + # if self.learnable_temp: + # logit_m_list = [ + # compute_pred(proj_x_m, t[masked_indices], label_embs_list[i], logit_temp) + # for i, (t, logit_temp), in enumerate(zip(target_list, self.logit_temp_list)) + # ] + + # else: + # logit_m_list = [ + # compute_pred(proj_x_m, t[masked_indices], label_embs_list[i]) + # for i, t in enumerate(target_list) + # ] + else: + logit_m_list = [None for _ in target_list] + + if not self.skip_nomask: + nomask_indices = torch.logical_and(~padding_mask, ~mask_indices) + proj_x_u = self.final_proj(x[nomask_indices]) + if self.untie_final_proj: + proj_x_u_list = proj_x_u.chunk(len(target_list), dim=-1) + else: + proj_x_u_list = [proj_x_u for _ in range(len(target_list))] + + logit_u_list = [ + compute_pred(proj_x_u, t[nomask_indices], label_embs_list[i]) + for i, (proj_x_u, t) in enumerate(zip(proj_x_u_list, target_list)) + ] + else: + logit_u_list = [None for _ in target_list] + + # if self.emb_grad_mult > 0 and self.emb_grad_mult !=1.0: + # self.label_embs_concat = GradMultiply.apply(self.label_embs_concat, self.emb_grad_mult) + + # word_embeddin = ∗α+word_embedding .detach()∗(1−α). + # self.label_embs_concat = self.label_embs_concat * self.emb_grad_mult + self.label_embs_concat.detach()*(1-self.emb_grad_mult) + + result = { + "logit_m_list": logit_m_list, + "logit_u_list": logit_u_list, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + + if self.cfg.audio_cqt_loss_m: + if cqt_labels is not None: + cqt_targets = cqt_labels[:masked_indices.shape[0],:masked_indices.shape[1]] # dump the last + else: + if self.mixture_prob > 0: + # no need to compute again + assert cqt_targets is not None + else: + cqt_targets = self.encoder_cqt_model.compute_cqt(source) + cqt_targets = cqt_targets[:masked_indices.shape[0],:masked_indices.shape[1]] # dump the last + + cqt_pred_m = self.encoder_cqt_model(x[masked_indices]) + # logger.info(x[masked_indices].shape, cqt_pred_m.shape, cqt_targets.shape) + cqt_loss_m = self.encoder_cqt_model.criterion(cqt_pred_m, cqt_targets[masked_indices]) + if self.use_rq_target and self.cfg.audio_rq_loss_num_codebooks > 16: + cqt_loss_m = cqt_loss_m * 10 + result["cqt_pred_m"] = cqt_loss_m + + if self.cfg.audio_mel_loss_m: + if mel_labels is not None: + mel_targets = mel_labels[:masked_indices.shape[0],:masked_indices.shape[1]] # dump the last + else: + if self.mixture_prob > 0: + # no need to compute again + assert mel_targets is not None + else: + mel_targets = self.encoder_mel_model.compute_mel(source) + mel_targets = mel_targets[:masked_indices.shape[0],:masked_indices.shape[1]] # dump the last + + mel_pred_m = self.encoder_mel_model(x[masked_indices]) + # logger.info(x[masked_indices].shape, cqt_pred_m.shape, cqt_targets.shape) + mel_loss_m = self.encoder_mel_model.criterion(mel_pred_m, mel_targets[masked_indices]) + result["mel_pred_m"] = mel_loss_m + + if self.learnable_temp: + # for i in range(len(self.logit_temp_list)): + for i in range(self.logit_temp_list.shape[0]): + result[f"logit_temp_{i}"] = self.logit_temp_list[i].item() + + return result + # def extract_all_feature( + # self, + # source: torch.Tensor, # B,L + # target_list: Optional[List[torch.Tensor]] = None, + # padding_mask: Optional[torch.Tensor] = None, + # cqt_labels: Optional[torch.Tensor] = None, + # mel_labels: Optional[torch.Tensor] = None, + # mask: bool = True, + # features_only: bool = False, + # output_layer: Optional[int] = None, + # ) -> Dict[str, torch.Tensor]: + # """output layer is 1-based""" + # # with autocast(device_type=source.device.type, dtype=torch.float32): + # # import pdb + # # pdb.set_trace() + # # print('source_shape:',str(source.shape)) + # # TODO: extrace target_list + # #这里修改过 + # # print('running') + # with torch.no_grad(): + # encodec_inp = self.rsp_src2encodec(source) + # encodec_inp = encodec_inp.unsqueeze(1) # B,1,T + # # print(encodec_inp.shape) + # encoded_frames = self.encodec.encode(encodec_inp) #list, B,[ 8,T ] + # codes = torch.cat([encoded[0].detach() for encoded in encoded_frames], dim=-1) + # # print(codes.shape) + # # print(target_list.shape) + # target_list = [codes[:,codebook_inx,:].detach() for codebook_inx in range(8)] + + # if self.mixture_prob > 0: + # # compute cqt before mixture + # if self.cfg.audio_cqt_loss_m: + # cqt_targets = self.encoder_cqt_model.compute_cqt(source) + # if self.cfg.audio_mel_loss_m: + # mel_targets = self.encoder_mel_model.compute_mel(source) + # with torch.no_grad(): + # batch_audios = torch.clone(source) + # for i in range(source.shape[0]): + # if torch.rand(1).item() > self.mixture_prob: + # try: + # #这里加入批内噪声 + # source[i] = self.inbatch_noise_augment( + # target_audio = batch_audios[i], target_audio_idx = i, batch_audios = batch_audios, + # noise_len_min = self.inbatch_noise_augment_len_range[0], noise_len_max = self.inbatch_noise_augment_len_range[1], + # n_noise_min = self.inbatch_noise_augment_number_range[0], n_noise_max = self.inbatch_noise_augment_number_range[1], + # noise_vol = self.inbatch_noise_augment_volume) + # except: + # source[i] = batch_audios[i] + + # features = self.forward_features(source) + # if target_list is not None: + # features, target_list = self.forward_targets(features, target_list) + + # features_pen = features.float().pow(2).mean() + + # features = features.transpose(1, 2) # BxTxC + + # if self.cfg.feature_extractor_cqt: + # features_cqt = self.feature_extractor_cqt(source).transpose(1, 2) + # features_cqt = features_cqt[:,:features.shape[1],:] # align shape + # # version 1 + # # features = features + features_cqt + # # features = self.layer_norm(features) + # # version 2 + # # features_cqt = self.post_cqt_feature_proj(features_cqt) # v2 + # # features = self.layer_norm(features) + self.layer_norm(features_cqt) + # # version 3 + # features = torch.cat([features,features_cqt], 2) + # features = self.layer_norm(features) # BxTxC + # else: + # features = self.layer_norm(features) + # unmasked_features = features.clone() + + # if padding_mask is not None: + # padding_mask = self.forward_padding_mask(features, padding_mask) + + # if self.post_extract_proj is not None: + # features = self.post_extract_proj(features) + # if self.post_proj_layer_norm is not None: + # features = self.post_proj_layer_norm(features) + + # features = self.dropout_input(features) + # unmasked_features = self.dropout_features(unmasked_features) + + # if mask: + # x, mask_indices = self.apply_mask(features, padding_mask, target_list) + # else: + # x = features + # mask_indices = None + + # # feature: (B, T, D), float + # # target: (B, T), long + # # x: (B, T, D), float + # # padding_mask: (B, T), bool + # # mask_indices: (B, T), bool + + # all_layer_outputs = [] + # all_layer_outputs.append(x.unsqueeze(0)) + # with torch.no_grad(): + # for layer in self.encoder.layers: + # output = layer(x) + # all_layer_outputs.append(output[0].unsqueeze(0)) + # concatenated_output = torch.cat(all_layer_outputs,dim=0) + # return concatenated_output + + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + return feature, res["padding_mask"] + + def extract_all_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + feature_num = 24 + ): + + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + x = res['x'] # [bsz,T,1024] + all_features = [] + all_features.append(x.unsqueeze(0)) + # print('x0 shape :',x.shape) + for i in range(feature_num): + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=i, + ) + + feature = res["features"] + # print(feature.shape) + all_features.append(feature.unsqueeze(0)) # [bsz,T,1024] + concatenated_output = torch.cat(all_features,dim=0) + return concatenated_output #[channels,bsz,T,1024] + + def get_logits(self, net_output, is_masked=True): + if is_masked: + logits_list = net_output["logit_m_list"] + else: + logits_list = net_output["logit_u_list"] + logits_list = [x.float() for x in logits_list if x is not None] + return logits_list + + def get_targets(self, net_output, is_masked=True): + logits_list = self.get_logits(net_output, is_masked) + targets_list = [x.new_zeros(x.size(0), dtype=torch.long) for x in logits_list] + return targets_list + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + if "cqt_pred_m" in net_output: + extra_losses.append(net_output["cqt_pred_m"]) + names.append("cqt_pred_m") + if "mel_pred_m" in net_output: + extra_losses.append(net_output["mel_pred_m"]) + names.append("mel_pred_m") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.target_glu = None + self.final_proj = None + + def initialize_dynamic_mask_prob(self): + # if len(self.mask_dynamic_prob_step) > 0 and len(self.mask_dynamic_prob) > 0: + if self.num_updates == 0: + logger.info(f'setting masking prob...') + else: + logger.info(f"loading checkpoint at step {self.num_updates}, resuming the mask_prob is set to trained with dynamic schedule") + assert len(self.mask_dynamic_prob_step) + 1 == len(self.mask_dynamic_prob), ("the len(step) is the step of updating mask_prob") + self.mask_dynamic_prob_stage = 0 + for i in self.mask_dynamic_prob_step: + if self.num_updates >= i: + self.mask_dynamic_prob_stage += 1 + self.mask_prob = self.mask_dynamic_prob[self.mask_dynamic_prob_stage] + logger.info(f'set the masking prob as {self.mask_prob}, stage {self.mask_dynamic_prob_stage}') + if self.num_updates >= self.mask_dynamic_prob_step[-1]: + self.mask_dynamic_prob_stage = -1 # no need for further updating + + def initialize_dynamic_mask_len(self): + if self.num_updates == 0: + logger.info(f'setting masking prob...') + else: + logger.info(f"loading checkpoint at step {self.num_updates}, resuming the mask_length is set to trained with dynamic schedule") + assert len(self.mask_dynamic_len_step) + 1 == len(self.mask_dynamic_len), ("the len(step) is the step of updating mask_len") + self.mask_dynamic_len_stage = 0 + for i in self.mask_dynamic_len_step: + if self.num_updates >= i: + self.mask_dynamic_len_stage += 1 + self.mask_length = self.mask_dynamic_len[self.mask_dynamic_len_stage] + logger.info(f'set the masking length as {self.mask_length}, stage {self.mask_dynamic_len_stage}') + if self.num_updates >= self.mask_dynamic_len_step[-1]: + self.mask_dynamic_len_stage = -1 # no need for further updating + + def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + if len(self.mask_dynamic_prob_step) > 0 and len(self.mask_dynamic_prob) > 0: + self.initialize_dynamic_mask_prob() + + if len(self.mask_dynamic_len_step) > 0 and len(self.mask_dynamic_len) > 0: + self.initialize_dynamic_mask_len() + + return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class TransformerEncoder_extend(TransformerEncoder): + def build_encoder_layer(self, args: MERTConfig): + + if args.layer_type == "transformer": + + if (args.deepnorm or args.subln or args.attention_relax > 0.0 ): + residual_alpha = 1.0 + if args.deepnorm: + residual_alpha = math.pow(2.0 * args.encoder_layers, 0.25) + + layer = TransformerSentenceEncoderLayerExtend( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + residual_alpha=residual_alpha, + attention_relax=args.attention_relax, + ) + else: + layer = TransformerSentenceEncoderLayer( + embedding_dim=self.embedding_dim, + ffn_embedding_dim=args.encoder_ffn_embed_dim, + num_attention_heads=args.encoder_attention_heads, + dropout=self.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + activation_fn=args.activation_fn, + layer_norm_first=args.layer_norm_first, + ) + + + elif args.layer_type == "conformer": + layer = ConformerWav2Vec2EncoderLayer( + embed_dim=self.embedding_dim, + ffn_embed_dim=args.encoder_ffn_embed_dim, + attention_heads=args.encoder_attention_heads, + dropout=args.dropout, + depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, + activation_fn="swish", + attn_type=args.attn_type, + use_fp16=args.fp16, + pos_enc_type="abs", + ) + from fairseq.distributed import fsdp_wrap + from fairseq.modules.checkpoint_activations import checkpoint_wrapper + + layer = fsdp_wrap(layer) + if args.checkpoint_activations: + layer = checkpoint_wrapper(layer) + return layer + + def __init__(self, args: MERTConfig): + super().__init__(args) + if args.deepnorm: + # if is_encoder_decoder: + # init_scale = ( + # math.pow( + # math.pow(args.encoder_layers, 4) * args.decoder_layers, 0.0625 + # ) + # / 1.15 + # ) + # else: + init_scale = math.pow(8.0 * args.encoder_layers, 0.25) + for name, p in self.named_parameters(): + if ( + "fc1" in name + or "fc2" in name + or "out_proj" in name + or "v_proj" in name + ): + p.data.div_(init_scale) + +class TransformerSentenceEncoderLayerExtend(TransformerSentenceEncoderLayer): + """ + Extend the Transformer Encoder Layer to support DeepNorm. + """ + + def __init__( + self, + embedding_dim: float = 768, + ffn_embedding_dim: float = 3072, + num_attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.1, + activation_fn: str = "relu", + layer_norm_first: bool = False, + residual_alpha: float = 1.0, + subln: bool = False, + attention_relax: float = -1.0, + ) -> None: + + super().__init__() + # nn.Module().__init__(self) + + self.residual_alpha = residual_alpha + # Initialize parameters + self.embedding_dim = embedding_dim + self.dropout = dropout + self.activation_dropout = activation_dropout + + # Initialize blocks + self.activation_fn = utils.get_activation_fn(activation_fn) + + + if attention_relax > 0: + # self.attention_relax = attention_relax + logger.info(f"creating custom attention layer with relaxation scale: {attention_relax}") + self.self_attn = MultiheadAttention_extend( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + attention_relax=attention_relax, + ) + + else: + self.self_attn = MultiheadAttention( + self.embedding_dim, + num_attention_heads, + dropout=attention_dropout, + self_attention=True, + ) + + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(self.activation_dropout) + self.dropout3 = nn.Dropout(dropout) + + self.layer_norm_first = layer_norm_first + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embedding_dim) + self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) + self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) + + self.ffn_layernorm = LayerNorm(ffn_embedding_dim) if subln else None + + # layer norm associated with the position wise feed-forward NN + self.final_layer_norm = LayerNorm(self.embedding_dim) + + def residual_connection(self, x, residual): + return residual * self.residual_alpha + x + + def forward( + self, + x: torch.Tensor, + self_attn_mask: torch.Tensor = None, + self_attn_padding_mask: torch.Tensor = None, + need_weights: bool = False, + att_args=None, + ): + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer imlementation. + """ + residual = x + + if self.layer_norm_first: + x = self.self_attn_layer_norm(x) + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + attn_mask=self_attn_mask, + need_weights=False, + ) + x = self.dropout1(x) + # x = residual + x + x = self.residual_connection(x, residual) + + residual = x + x = self.final_layer_norm(x) + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + + # for subln + if self.ffn_layernorm is not None: + x = self.ffn_layernorm(x) + + x = self.fc2(x) + + layer_result = x + + x = self.dropout3(x) + # x = residual + x + x = self.residual_connection(x, residual) + + else: + x, attn = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + need_weights=False, + ) + + x = self.dropout1(x) + # x = residual + x + x = self.residual_connection(x, residual) + + x = self.self_attn_layer_norm(x) + + residual = x + x = self.activation_fn(self.fc1(x)) + x = self.dropout2(x) + x = self.fc2(x) + + layer_result = x + + x = self.dropout3(x) + # x = residual + x + x = self.residual_connection(x, residual) + x = self.final_layer_norm(x) + + return x, (attn, layer_result) + + +class MultiheadAttention_extend(MultiheadAttention): + def __init__( + self, + embed_dim, + num_heads, + kdim=None, + vdim=None, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + self_attention=False, + encoder_decoder_attention=False, + # dictionary=None, + q_noise=0.0, + qn_block_size=8, + attention_relax = -1.0, + # TODO: pass in config rather than string. + # config defined in xformers.components.attention.AttentionConfig + xformers_att_config: Optional[str] = None, + xformers_blocksparse_layout: Optional[ + torch.Tensor + ] = None, # This should be part of the config + xformers_blocksparse_blocksize: Optional[ + int + ] = 16, # This should be part of the config + ): + # nn.Module.__init__(self) + # super().__init__() + # initialize the instance with the father class method + # MultiheadAttention.__init__(self, + # super(MultiheadAttention_extend, self).__init__( + # super(self).__init__( + super().__init__( + embed_dim, + num_heads, + kdim=kdim, + vdim=vdim, + dropout=dropout, + bias=bias, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + self_attention=self_attention, + encoder_decoder_attention=encoder_decoder_attention, + # dictionary=dictionary, + q_noise=q_noise, + qn_block_size=qn_block_size, + xformers_att_config=xformers_att_config, + xformers_blocksparse_layout=xformers_blocksparse_layout, + xformers_blocksparse_blocksize=xformers_blocksparse_blocksize, + ) + + self.attention_relax = attention_relax + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor], + value: Optional[torch.Tensor], + key_padding_mask: Optional[torch.Tensor] = None, + incremental_state: Optional[Dict[str, Dict[str, Optional[torch.Tensor]]]] = None, + need_weights: bool = True, + static_kv: bool = False, + attn_mask: Optional[torch.Tensor] = None, + before_softmax: bool = False, + need_head_weights: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Input shape: Time x Batch x Channel + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + is_tpu = query.device.type == "xla" + + tgt_len, bsz, embed_dim = query.size() + src_len = tgt_len + if not self.skip_embed_dim_check: + assert ( + embed_dim == self.embed_dim + ), f"query dim {embed_dim} != {self.embed_dim}" + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if key is not None: + src_len, key_bsz, _ = key.size() + if not torch.jit.is_scripting(): + assert value is not None + assert src_len, key_bsz == value.shape[:2] + + if ( + not self.onnx_trace + and not is_tpu # don't use PyTorch version on TPUs + and incremental_state is None + and not static_kv + # A workaround for quantization to work. Otherwise JIT compilation + # treats bias in linear module as method. + and not torch.jit.is_scripting() + # The Multihead attention implemented in pytorch forces strong dimension check + # for input embedding dimention and K,Q,V projection dimension. + # Since pruning will break the dimension check and it is not easy to modify the pytorch API, + # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check + and not self.skip_embed_dim_check + ): + assert key is not None and value is not None + + if self.use_xformers: + return self._xformers_attn_forward( + query, key, value, key_padding_mask, need_weights, attn_mask + ) + + else: + return F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + torch.empty([0]), + torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout_module.p, + self.out_proj.weight, + self.out_proj.bias, + self.training or self.dropout_module.apply_during_inference, + key_padding_mask.bool() if key_padding_mask is not None else None, + need_weights, + attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + ) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if saved_state is not None and "prev_key" in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + q = self.q_proj(query) + k = self.k_proj(query) + v = self.v_proj(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.q_proj(query) + if key is None: + assert value is None + k = v = None + else: + if self.beam_size > 1 and bsz == key.size(1): + # key is [T, bsz*beam_size, C], reduce to [T, bsz, C] + key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ + :, :, 0, : + ] + if key_padding_mask is not None: + key_padding_mask = key_padding_mask.view( + -1, self.beam_size, key_padding_mask.size(1) + )[:, 0, :] + k = self.k_proj(key) + v = self.v_proj(key) + + else: + assert key is not None and value is not None + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k, v, attn_mask, key_padding_mask = self._add_bias( + k, v, attn_mask, key_padding_mask, bsz + ) + + q = ( + q.contiguous() + .view(tgt_len, bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + kv_bsz = bsz # need default value for scripting + if k is not None: + kv_bsz = k.size(1) + k = ( + k.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + if v is not None: + v = ( + v.contiguous() + .view(-1, kv_bsz * self.num_heads, self.head_dim) + .transpose(0, 1) + ) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if "prev_key" in saved_state: + _prev_key = saved_state["prev_key"] + assert _prev_key is not None + kv_bsz = _prev_key.size(0) + prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + assert k is not None + k = torch.cat([prev_key, k], dim=1) + src_len = k.size(1) + if "prev_value" in saved_state: + _prev_value = saved_state["prev_value"] + assert _prev_value is not None + assert kv_bsz == _prev_value.size(0) + prev_value = _prev_value.view( + kv_bsz * self.num_heads, -1, self.head_dim + ) + if static_kv: + v = prev_value + else: + assert v is not None + v = torch.cat([prev_value, v], dim=1) + prev_key_padding_mask: Optional[torch.Tensor] = None + if "prev_key_padding_mask" in saved_state: + prev_key_padding_mask = saved_state["prev_key_padding_mask"] + assert k is not None and v is not None + key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( + key_padding_mask=key_padding_mask, + prev_key_padding_mask=prev_key_padding_mask, + batch_size=kv_bsz, + src_len=k.size(1), + static_kv=static_kv, + ) + + saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) + saved_state["prev_value"] = v.view( + kv_bsz, self.num_heads, -1, self.head_dim + ) + saved_state["prev_key_padding_mask"] = key_padding_mask + # In this branch incremental_state is never None + assert incremental_state is not None + incremental_state = self._set_input_buffer(incremental_state, saved_state) + assert k is not None + assert k.size(1) == src_len + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.dim() == 0: + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == kv_bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + assert v is not None + src_len += 1 + k, v, key_padding_mask, attn_mask = self._append_zero_attn( + k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask + ) + + if self.encoder_decoder_attention and bsz != kv_bsz: + attn_weights = torch.einsum( + "bxhtd,bhsd->bxhts", + q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), + k.view((kv_bsz, self.num_heads) + k.size()[1:]), + ) + attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) + else: + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + attn_mask = attn_mask.unsqueeze(0) + if self.onnx_trace: + attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) + attn_weights += attn_mask + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + if not is_tpu: + attn_weights = attn_weights.view( + kv_bsz, -1, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .to(torch.bool), + float("-inf"), + ) + else: + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) + attn_weights = attn_weights.transpose(0, 2) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + if self.attention_relax > 0 : + # tgt_len == src_len + + # => (bsz, self.num_heads, tgt_len, src_len) + # attn_weights_relax = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)/self.attention_relax + + # => (bsz * self.num_heads, tgt_len, src_len) + attn_weights_relax = attn_weights / self.attention_relax + + # # => (bsz, self.num_heads, 1, src_len) + # attn_max_relax = torch.max(attn_weights_relax, dim=-2, keepdim=False).unsqueeze(2) + + + # find max according to K_j' => (bsz* self.num_heads, tgt_len, 1) + attn_max_relax = torch.max(attn_weights_relax, dim=-1, keepdim=False).unsqueeze(2) + + # => (bsz * self.num_heads, tgt_len, src_len) + attn_weights = (attn_weights_relax - attn_max_relax) * self.attention_relax + # attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_weights_float = utils.softmax( + attn_weights, dim=-1, onnx_trace=self.onnx_trace + ) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = self.dropout_module(attn_weights) + + assert v is not None + attn: Optional[torch.Tensor] = None + if self.encoder_decoder_attention and bsz != kv_bsz: + attn = torch.einsum( + "bxhts,bhsd->bxhtd", + attn_probs.view( + ( + kv_bsz, + -1, + self.num_heads, + ) + + attn_probs.size()[1:] + ), + v.view( + ( + kv_bsz, + self.num_heads, + ) + + v.size()[1:] + ), + ) + attn = attn.reshape((-1,) + attn.size()[-2:]) + else: + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + if self.onnx_trace and attn.size(1) == 1: + # when ONNX tracing a single decoder step (sequence length == 1) + # the transpose is a no-op copy before view, thus unnecessary + attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) + else: + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) + attn = self.out_proj(attn) + attn_weights: Optional[torch.Tensor] = None + if need_weights: + attn_weights = attn_weights_float.view( + bsz, self.num_heads, tgt_len, src_len + ).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + + return attn, attn_weights \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/rvq.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..9c36e3895f059dc21b750b1e4bfd8cc72c39b2fa --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/mert/rvq.py @@ -0,0 +1,584 @@ +# compared with `descript_quantize2`, we use rvq & random_dropout +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int, stale_tolerance: int = 100): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + self.register_buffer("stale_counter", torch.zeros(self.codebook_size,)) + self.stale_tolerance = stale_tolerance + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + if(self.training): + onehots = torch.nn.functional.one_hot(indices, self.codebook_size).float() # B, T, codebook_size + stale_codes = (onehots.sum(0).sum(0) == 0).float() + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # codebook_size + if replace_code.sum(-1) > 0: + print("Replace {} codes".format(replace_code.sum(-1))) + random_input_idx = torch.randperm(encodings.shape[0]) + random_input = encodings[random_input_idx].view(encodings.shape) + if random_input.shape[0] < self.codebook_size: + random_input = torch.cat([random_input]*(self.codebook_size // random_input.shape[0] + 1), 0) + random_input = random_input[:self.codebook_size,:].contiguous() # codebook_size, dim + + self.codebook.weight.data = self.codebook.weight.data * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + stale_tolerance: int = 100, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i], stale_tolerance=stale_tolerance) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + else: + n_quantizers = torch.ones((z.shape[0],)) * n_quantizers + 1 + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + # if self.training is False and i >= n_quantizers: + # break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + encodings = F.one_hot(codes, self.codebook_size).float() # B N T 1024 + # for n in range(encodings.shape[1]): + # print("Lyaer {}, Ratio of unused vector : {:.1f}".format(n, + # (encodings[:,n,:,:].sum(0).sum(0) < 1.0).sum()/torch.numel(encodings[:,n,:,:].sum(0).sum(0) < 1.0) * 100. + # )) + + return z_q, codes, latents, commitment_loss, codebook_loss, n_quantizers.clamp(max=self.n_codebooks).long() - 1 + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + +from torch.utils.data import Dataset, DataLoader +import json, traceback +import torchaudio +import math + +from typing import List, Tuple, Dict, Any + +CLIPSECS = 5 +def load_audio_by_json(json_path, max_keep, min_keep, tgt_sample_rate): + # read json file + print(json_path) + datas = [] + inds = [] + sizes = [] + with open(json_path) as fp: + for ind,line in enumerate(fp): + data = json.loads(line) + datas.append(data) + inds.append(ind) + # sz = int(data['duration'] * data['sample_rate']) + sz = int(tgt_sample_rate * CLIPSECS) + sizes.append(sz) + tot = ind + 1 + return datas,inds,tot,sizes + +class Read_and_PadCrop_Normalized_T(torch.nn.Module): + def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): + + super().__init__() + + self.n_samples = n_samples + self.sample_rate = sample_rate + self.randomize = randomize + + + def __call__(self, filename: str, duration: float, cur_sample_rate: int) -> Tuple[torch.Tensor, float, float, int, int]: + if(duration<(float(self.n_samples)/self.sample_rate+1)): + # print(duration,(float(self.n_samples)/self.sample_rate+1)) + chunk, _ = torchaudio.load(filename, frame_offset=0, num_frames=-1) + t_start = 0. + t_end = min(1.0, float(self.n_samples) / float(self.sample_rate) / duration) + offset = 0 + # print('c1:',chunk.shape) + else: + offset = np.random.randint(0,int(duration*cur_sample_rate)-int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + t_start = offset / float(cur_sample_rate) / duration + t_end = t_start + float(self.n_samples) / float(self.sample_rate) / duration + chunk, _ = torchaudio.load(filename, frame_offset=offset, num_frames=int(float(self.n_samples)/self.sample_rate*cur_sample_rate)) + # print('offset:',offset) + # print('c0:',chunk.shape) + # Pad with silence if necessary. + if(chunk.shape[0]>1): + chunk = chunk[torch.randint(chunk.shape[0], size=(1,)),:].float() + else: + chunk = chunk[[0],:].float() + if(cur_sample_rate!=self.sample_rate): + # print('a:',cur_sample_rate,chunk.shape) + chunk = torchaudio.functional.resample(chunk, cur_sample_rate, self.sample_rate) + # print('b:',self.sample_rate,chunk.shape) + if chunk.shape[-1] < self.n_samples: + chunk = torch.cat([chunk, torch.zeros((1, self.n_samples - chunk.shape[-1],))],-1) + else: + chunk = chunk[:,0:self.n_samples] + seconds_start = math.floor(offset / cur_sample_rate) + seconds_total = math.floor(duration) + + return ( + chunk, + t_start, + t_end, + seconds_start, + seconds_total + ) + +class RVQDataset(Dataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + normalize: bool = False, + ): + self.sample_rate = sample_rate + self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate) + self.dataset_len = len(self.datas) + + self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate) + self.normalize = normalize + + + def __getitem__(self, i): + # WORLD_SIZE = int(torch.distributed.get_world_size()) + # WORLD_RANK = int(torch.distributed.get_rank()) + # np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i) + # index = random.randint(0,len(self.sizes) - 1) + index = i + item = None + while item is None: + try: + wav = self.get_audio_by_slice(index) + # labels = self.get_labels(index) #这个得改 + # labels = None + # item = {"id": index, "source": wav, "label_list": labels} + item = {"id": index, "source": wav} + except Exception as e: + # print(e) + traceback.print_exc() + print(f'skip damaged data {index}') + index = np.random.randint(0,len(self.sizes)-1) + return item + + def __len__(self): + return self.dataset_len + + def get_audio_by_slice(self,index): + wav_path = self.datas[index]['path'] + # print(wav_path) + audio_info = torchaudio.info(wav_path) + origin_sample_rate = audio_info.sample_rate + origin_duration = audio_info.num_frames / origin_sample_rate + + wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate) + wav = wav.float() + + # _path, slice_ptr = parse_path(wav_path) #这个应该也要改 + # original way + # if len(slice_ptr) == 0: + # wav, cur_sample_rate = sf.read(_path) + # else: + # assert _path.endswith(".zip") + # data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) + # f = io.BytesIO(data) + # wav, cur_sample_rate = sf.read(f) + # wav = torch.from_numpy(wav).float() + # print(wav.shape) + wav = wav.permute(1,0) + wav = self.postprocess(wav, self.sample_rate) #降至单个声道,确认采样率,归一化 + # print(wav.shape) + + # wav = wav.squeeze(0) + return wav + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav + + +if __name__ == "__main__": + config = dict( + train_dataset = dict( + manifest_path = 'music4all_sh/train.json', + sample_rate = 24000, + normalize = False, + ), + valid_dataset = dict( + manifest_path = None, + sample_rate = 24000, + normalize = False, + ), + model = dict( + input_dim = 120, + n_codebooks = 8, + codebook_size = 1024, + codebook_dim = 16, + quantizer_dropout = 0.0, + ), + train = dict( + batch_size = 96, + num_workers = 6, + valid_interval = 10, + save_interval = 100, + max_updates = 5000, + lr = 1e-4, + device = 'cuda:1', + # loss = 'commitment_loss * 0.25 + codebook_loss * 1.0', + loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()', + preprocess = torchaudio.transforms.MelSpectrogram( + sample_rate = 24000, + n_mels = 120, + n_fft=2048, + win_length = int(24000//75), + hop_length = int(24000//75), + center = True, + pad_mode='constant', # pad=0, + mel_scale='htk', + normalized=True, + ) + ) + ) + train_dataset = RVQDataset(**config['train_dataset']) + if config['valid_dataset']['manifest_path'] is None: + # split train and valid dataset + from torch.utils.data import random_split + train_dataset, valid_dataset = random_split( + train_dataset, lengths=[len(train_dataset) - 500, 500] + ) + else: + valid_dataset = RVQDataset(**config['valid_dataset']) + train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers']) + valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers']) + model = ResidualVectorQuantize(**config['model']) + + device = config['train']['device'] + preprocess = config['train']['preprocess'].to(device) + model = model.to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr']) + cur_updates = 0 + is_running = True + result = {} + from tqdm import tqdm + from tensorboardX import SummaryWriter + writer = SummaryWriter() + from collections import defaultdict + import os + from logging import getLogger + logger = getLogger() + + while is_running: + results = defaultdict(lambda:0) + for item in tqdm(train_dataloader, desc='train'): + wavs = item['source'] + optimizer.zero_grad() + wavs = wavs.to(device) + x = preprocess(wavs) + model.train() + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x) + loss = eval(config['train']['loss']) + loss.backward() + optimizer.step() + + results['loss/train'] += loss.item() + results['commitment_loss/train'] += commitment_loss.item() + results['codebook_loss/train'] += codebook_loss.item() + results['rvq_usage/train'] += rvq_usage.float().mean().item() + + if cur_updates % config['train']['valid_interval'] == 0: + model.eval() + with torch.no_grad(): + for item in tqdm(valid_dataloader, desc='valid'): + wavs = item['source'] + wavs = wavs.to(device) + x = preprocess(wavs) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x) + valid_loss = eval(config['train']['loss']) + + results['loss/valid'] += valid_loss.item() + results['commitment_loss/valid'] += commitment_loss.item() + results['codebook_loss/valid'] += codebook_loss.item() + results['rvq_usage/valid'] += rvq_usage.float().mean().item() + + results['cur_updates'] = cur_updates + results['loss/train'] /= config['train']['valid_interval'] + results['commitment_loss/train'] /= config['train']['valid_interval'] + results['codebook_loss/train'] /= config['train']['valid_interval'] + results['rvq_usage/train'] /= config['train']['valid_interval'] + + results['loss/valid'] /= len(valid_dataloader) + results['commitment_loss/valid'] /= len(valid_dataloader) + results['codebook_loss/valid'] /= len(valid_dataloader) + results['rvq_usage/valid'] /= len(valid_dataloader) + + print('') + logger.info(str(results)) + for k,v in results.items(): + writer.add_scalar(k, v, cur_updates) + + results.clear() + + if cur_updates % config['train']['save_interval'] == 0: + os.makedirs(f'{writer.logdir}/ckpt/', exist_ok=True) + logger.info(f'saving checkpoint to {writer.logdir}/ckpt/RVQ_{cur_updates}.pth') + torch.save(model.state_dict(), f'{writer.logdir}/ckpt/RVQ_{cur_updates}.pth') + + + if cur_updates < config['train']['max_updates']: + cur_updates += 1 + else: + is_running = False + break + + # x = torch.randn(32, 120, 375) + # quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x) + # print(quantized_prompt_embeds.shape) + # print(codes.shape) + # # w/o reconstruction + # loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # # w/ reconstruction + # loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/__init__.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fef38a3fd3766c766ad9526ca991f82c412ff1f8 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/__init__.py @@ -0,0 +1 @@ +from .musicfm_model import * \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/__init__.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/musicfm_25hz.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/musicfm_25hz.py new file mode 100644 index 0000000000000000000000000000000000000000..738d7a595ef92d6e652a49fddd51386d0bbfae96 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/musicfm_25hz.py @@ -0,0 +1,318 @@ +# MIT License +# +# Copyright 2023 ByteDance Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), +# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +import json +import random +import torch +from torch import nn +from einops import rearrange +import os + +try: + from ..modules.random_quantizer import RandomProjectionQuantizer + from ..modules.features import MelSTFT + from ..modules.conv import Conv2dSubsampling +except: + import sys, os + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + from modules.random_quantizer import RandomProjectionQuantizer + from modules.features import MelSTFT + from modules.conv import Conv2dSubsampling + + +class MusicFM25Hz(nn.Module): + """ + MusicFM + + Input: 128-band mel spectrogram + Frontend: 2-layer Residual convolution + Backend: 12-layer Conformer + Quantizer: a codebook for mel spectrogram + """ + + def __init__( + self, + num_codebooks=1, + codebook_dim=16, + codebook_size=4096, + features=["melspec_2048"], + hop_length=240, + n_mels=128, + conv_dim=512, + encoder_dim=1024, + encoder_depth=12, + mask_hop=0.4, + mask_prob=0.6, + is_flash=False, + stat_path="./data/fma_stats.json", + model_path="./data/pretrained_fma.pt", + w2v2_config_path="facebook/wav2vec2-conformer-rope-large-960h-ft", + use_rvq_target=False, + rvq_ckpt_path=None, + ): + super(MusicFM25Hz, self).__init__() + + # global variables + self.hop_length = hop_length + self.mask_hop = mask_hop + self.mask_prob = mask_prob + self.num_codebooks = num_codebooks + self.codebook_size = codebook_size + self.features = features + + # load feature mean / std stats + import os + if stat_path is not None and os.path.exists(stat_path): + with open(stat_path, "r") as f: + self.stat = json.load(f) + else: + print("No stats file found at `{}`, use default from msd.".format(stat_path)) + self.stat = {"spec_256_cnt": 14394344256, "spec_256_mean": -23.34296658431829, "spec_256_std": 26.189295587132637, "spec_512_cnt": 28677104448, "spec_512_mean": -21.31267396860235, "spec_512_std": 26.52644536245769, "spec_1024_cnt": 57242624832, "spec_1024_mean": -18.852271129208273, "spec_1024_std": 26.443154583585663, "spec_2048_cnt": 114373665600, "spec_2048_mean": -15.638743433896792, "spec_2048_std": 26.115825961611545, "spec_4096_cnt": 228635747136, "spec_4096_mean": -11.715532502794836, "spec_4096_std": 25.763972210234062, "melspec_256_cnt": 14282760192, "melspec_256_mean": -26.962600400166156, "melspec_256_std": 36.13614100912126, "melspec_512_cnt": 14282760192, "melspec_512_mean": -9.108344167718862, "melspec_512_std": 24.71910937988429, "melspec_1024_cnt": 14282760192, "melspec_1024_mean": 0.37302579246531126, "melspec_1024_std": 18.684082325919388, "melspec_2048_cnt": 14282760192, "melspec_2048_mean": 6.768444971712967, "melspec_2048_std": 18.417922652295623, "melspec_4096_cnt": 14282760192, "melspec_4096_mean": 13.617164614990036, "melspec_4096_std": 18.08552130124525, "cqt_cnt": 9373061376, "cqt_mean": 0.46341379757927165, "cqt_std": 0.9543998080910191, "mfcc_256_cnt": 1339008768, "mfcc_256_mean": -11.681755459447485, "mfcc_256_std": 29.183186444668316, "mfcc_512_cnt": 1339008768, "mfcc_512_mean": -2.540581461792183, "mfcc_512_std": 31.93752185832081, "mfcc_1024_cnt": 1339008768, "mfcc_1024_mean": 6.606636263169779, "mfcc_1024_std": 34.151644801729624, "mfcc_2048_cnt": 1339008768, "mfcc_2048_mean": 5.281600844245184, "mfcc_2048_std": 33.12784541220003, "mfcc_4096_cnt": 1339008768, "mfcc_4096_mean": 4.7616569480166095, "mfcc_4096_std": 32.61458906894133, "chromagram_256_cnt": 1339008768, "chromagram_256_mean": 55.15596556703181, "chromagram_256_std": 73.91858278719991, "chromagram_512_cnt": 1339008768, "chromagram_512_mean": 175.73092252759895, "chromagram_512_std": 248.48485148525953, "chromagram_1024_cnt": 1339008768, "chromagram_1024_mean": 589.2947481634608, "chromagram_1024_std": 913.857929063196, "chromagram_2048_cnt": 1339008768, "chromagram_2048_mean": 2062.286388327397, "chromagram_2048_std": 3458.92657915397, "chromagram_4096_cnt": 1339008768, "chromagram_4096_mean": 7673.039107997085, "chromagram_4096_std": 13009.883158267234} + + # feature extractor + self.preprocessor_melspec_2048 = MelSTFT( + n_fft=2048, hop_length=hop_length, is_db=True + ) + + # random quantizer + self.use_rvq_target = use_rvq_target + + seed = 142 + if use_rvq_target: + try: + from .rvq_musicfm import ResidualVectorQuantize + + except: + import sys, os + sys.path.append(os.path.dirname(os.path.abspath(__file__))) + from rvq_musicfm import ResidualVectorQuantize + + self.rvq = ResidualVectorQuantize( + input_dim = 128*4, + n_codebooks = 8, + codebook_size = 1024, + codebook_dim = 16, + quantizer_dropout = 0.0, + ) + import os + if rvq_ckpt_path is not None and os.path.exists(rvq_ckpt_path): + state_dict = torch.load(rvq_ckpt_path, map_location="cpu") + self.rvq.load_state_dict(state_dict) + else: + print(f'Checkpoint for rvq `{rvq_ckpt_path}` not found. Using random initialization.') + + else: + for feature in self.features: + for i in range(num_codebooks): + setattr( + self, + f"quantizer_{feature}", # _{i} + RandomProjectionQuantizer( + n_mels * 4, codebook_dim, codebook_size, seed=seed + i + ), + ) + + # two residual convolution layers + one projection layer + self.conv = Conv2dSubsampling( + 1, conv_dim, encoder_dim, strides=[2, 2], n_bands=n_mels + ) + + # Conformer + if is_flash: + from modules.flash_conformer import ( + Wav2Vec2ConformerEncoder, + Wav2Vec2ConformerConfig, + ) + else: + from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( + Wav2Vec2ConformerEncoder, + Wav2Vec2ConformerConfig, + ) + import os + if w2v2_config_path is None or not os.path.exists(w2v2_config_path): + w2v2_config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "w2v2_config.json") + config = Wav2Vec2ConformerConfig.from_pretrained( + w2v2_config_path + ) + config.num_hidden_layers = encoder_depth + config.hidden_size = encoder_dim + + self.conformer = Wav2Vec2ConformerEncoder(config) + + # projection + self.linear = nn.Linear(encoder_dim, codebook_size) + + # loss function + self.loss = nn.CrossEntropyLoss() + + # cls token (used for sequence classification) + random.seed(seed) + self.cls_token = nn.Parameter(torch.randn(encoder_dim)) + + # load model + if model_path: + S = torch.load(model_path)["state_dict"] + SS = {k[6:]: v for k, v in S.items()} + SS['quantizer_melspec_2048.random_projection'] = SS['quantizer_melspec_2048_0.random_projection'] + SS['quantizer_melspec_2048.codebook'] = SS['quantizer_melspec_2048_0.codebook'] + del SS['quantizer_melspec_2048_0.random_projection'] + del SS['quantizer_melspec_2048_0.codebook'] + unmatch = self.load_state_dict(SS, strict=False) + if len(unmatch.missing_keys) > 0: + print(f'Missing keys: {unmatch.missing_keys}') + + def masking(self, x): + """random masking of 400ms with given probability""" + mx = x.clone() + b, t = mx.shape + len_masking_raw = int(24000 * self.mask_hop) # 9600 = 24000 * 0.4 + len_masking_token = int(24000 / self.hop_length / 2 / 2 * self.mask_hop) # 10 = 25Hz * 0.4 + + # get random mask indices + start_indices = torch.rand(b, t // len_masking_raw) < self.mask_prob # Tensor{Size([3, 75]) cpu bol} + time_domain_masked_indices = torch.nonzero( + start_indices.repeat_interleave(len_masking_raw, dim=1) + ) # Tensor{Size([1286400, 2]) cpu i64} + token_domain_masked_indices = torch.nonzero( + start_indices.repeat_interleave(len_masking_token, dim=1) + ) # Tensor{Size([1340, 2]) cpu i64} + + # mask with random values + masking_noise = ( + torch.randn(time_domain_masked_indices.shape[0], dtype=x.dtype) * 0.1 + ) # 0 mean 0.1 std + mx[tuple(time_domain_masked_indices.t())] = masking_noise.to(x.device) + + return mx, token_domain_masked_indices + + @torch.no_grad() + def preprocessing(self, x, features): + """extract classic audio features""" + # check precision + if x.dtype == torch.float16: + precision = 16 + else: + precision = 32 + + out = {} + for key in features: + layer = getattr(self, "preprocessor_%s" % key) + out[key] = layer.float()(x.float())[..., :-1] + if precision == 16: + out[key] = out[key].half() + return out + + def encoder(self, x): + """2-layer conv + w2v-conformer""" + x = self.conv(x) # [3, 128, 3000] -> [3, 750, 1024] + out = self.conformer(x, output_hidden_states=True) + hidden_emb = out["hidden_states"] + last_emb = out["last_hidden_state"] + logits = self.linear(last_emb) + logits = { + key: logits[:, :, i * self.codebook_size : (i + 1) * self.codebook_size] + for i, key in enumerate(self.features) + } + return logits, hidden_emb + + @torch.no_grad() + def normalize(self, x): + """normalize the input audio to have zero mean unit variance""" + for key in x.keys(): + x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967} + return x + + @torch.no_grad() + def rearrange(self, x): + """rearrange the batch to flatten every 4 steps""" + for key in x.keys(): + if key == "chromagram": + x[key] = rearrange(x[key], "b f t -> b t f") + else: + x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4) + return x + + @torch.no_grad() + def tokenize(self, x): + out = {} + for key in x.keys(): + if self.use_rvq_target: + self.rvq.eval() + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = self.rvq(x[key].permute((0, 2, 1))) + out[key] = torch.cat([codes[:, idx, :] for idx in range(int(self.codebook_size//1024))], dim=-1) + else: + layer = getattr(self, "quantizer_%s" % key) + out[key] = layer(x[key]) + return out + + def get_targets(self, x): + x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}} + x = self.normalize(x) + x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}} + target_tokens = self.tokenize(x) # -> {'melspec_2048': Tensor{Size([3, 750]) cuda:0 i64}} + return target_tokens + + def get_predictions(self, x): + # preprocessing + x = self.preprocessing(x, features=["melspec_2048"]) + x = self.normalize(x) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}} + + # encoding + logits, hidden_emb = self.encoder(x["melspec_2048"]) + + return logits, hidden_emb + + def get_latent(self, x, layer_ix=12): + _, hidden_states = self.get_predictions(x) + emb = hidden_states[layer_ix] + return emb + + def get_loss(self, logits, target_tokens, masked_indices): + losses = {} + accuracies = {} + for key in logits.keys(): + masked_logits = logits[key][tuple(masked_indices.t())] + masked_tokens = target_tokens[key][tuple(masked_indices.t())] + losses[key] = self.loss(masked_logits, masked_tokens) + accuracies[key] = ( + torch.sum(masked_logits.argmax(-1) == masked_tokens) + / masked_tokens.numel() + ) + return losses, accuracies + + def forward(self, x): + # get target feature tokens + target_tokens = self.get_targets(x) # {'melspec_2048': Tensor{Size([3, 750]) cuda:0 i64}} + + # masking + x, masked_indices = self.masking(x) # (Tensor{Size([3, 720000]) cuda:0 f32}, Tensor{Size([1340, 2]) cpu i64}) + + # forward + logits, hidden_emb = self.get_predictions(x) + + # get loss + losses, accuracies = self.get_loss(logits, target_tokens, masked_indices) + + return logits, hidden_emb, losses, accuracies + +if __name__ == "__main__": + device = 'cuda' + model = MusicFM25Hz( + stat_path='msd_stats.json', + w2v2_config_path='msd_stats.json', + model_path=None, + ).to(device) + wavs = torch.randn(3, 24000*30).to(device) + logits, hidden_emb, losses, accuracies = model(wavs) diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/rvq.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/rvq.py new file mode 120000 index 0000000000000000000000000000000000000000..521462aec8f84cad20376c6a48267240fafc4818 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/rvq.py @@ -0,0 +1 @@ +../../mert/rvq.py \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/rvq_musicfm.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/rvq_musicfm.py new file mode 100644 index 0000000000000000000000000000000000000000..5161c295c761c7d9b686fcfca95eda09f04439b8 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/rvq_musicfm.py @@ -0,0 +1,301 @@ +try: + from .rvq import * +except: + import sys, os + sys.path.append(os.path.dirname(os.path.abspath(__file__))) + from rvq import * + +try: + from ..modules.random_quantizer import RandomProjectionQuantizer + from ..modules.features import MelSTFT + from ..modules.conv import Conv2dSubsampling +except: + import sys, os + sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + from modules.random_quantizer import RandomProjectionQuantizer + from modules.features import MelSTFT + from modules.conv import Conv2dSubsampling + + +class RVQDataset(Dataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + normalize: bool = False, + ): + self.sample_rate = sample_rate + self.datas,inds,tot,self.sizes = load_audio_by_json(manifest_path, None, None, self.sample_rate) + self.dataset_len = len(self.datas) + + self.reader = Read_and_PadCrop_Normalized_T(n_samples=CLIPSECS*sample_rate,sample_rate = self.sample_rate) + self.normalize = normalize + + + def __getitem__(self, i): + # WORLD_SIZE = int(torch.distributed.get_world_size()) + # WORLD_RANK = int(torch.distributed.get_rank()) + # np.random.seed(1337 + self.epoch * WORLD_SIZE + WORLD_RANK + i) + # index = random.randint(0,len(self.sizes) - 1) + index = i + item = None + while item is None: + try: + wav = self.get_audio_by_slice(index) + # labels = self.get_labels(index) #这个得改 + # labels = None + # item = {"id": index, "source": wav, "label_list": labels} + item = {"id": index, "source": wav} + except Exception as e: + # print(e) + traceback.print_exc() + print(f'skip damaged data {index}') + index = np.random.randint(0,len(self.sizes)-1) + return item + + def __len__(self): + return self.dataset_len + + def get_audio_by_slice(self,index): + + wav_path = self.datas[index]['path'] + audio_info = torchaudio.info(wav_path) + origin_sample_rate = audio_info.sample_rate + origin_duration = audio_info.num_frames / origin_sample_rate + + wav, *ignored = self.reader(wav_path, origin_duration,origin_sample_rate) + wav = wav.float() + + # _path, slice_ptr = parse_path(wav_path) #这个应该也要改 + # original way + # if len(slice_ptr) == 0: + # wav, cur_sample_rate = sf.read(_path) + # else: + # assert _path.endswith(".zip") + # data = read_from_stored_zip(_path, slice_ptr[0], slice_ptr[1]) + # f = io.BytesIO(data) + # wav, cur_sample_rate = sf.read(f) + # wav = torch.from_numpy(wav).float() + # print(wav.shape) + wav = wav.permute(1,0) + wav = self.postprocess(wav, self.sample_rate) #降至单个声道,确认采样率,归一化 + # print(wav.shape) + + # wav = wav.squeeze(0) + return wav + + def postprocess(self, wav, cur_sample_rate): + if wav.dim() == 2: + wav = wav.mean(-1) + assert wav.dim() == 1, wav.dim() + + if cur_sample_rate != self.sample_rate: + raise Exception(f"sr {cur_sample_rate} != {self.sample_rate}") + + if self.normalize: + with torch.no_grad(): + wav = F.layer_norm(wav, wav.shape) + return wav + +class Preprocessor(nn.Module): + def __init__(self, + codebook_dim=16, + codebook_size=4096, + hop_length=240, + n_mels=128, + stat_path='msd_stats.json', + ) -> None: + super().__init__() + + self.features=["melspec_2048"] + + # load feature mean / std stats + with open(stat_path, "r") as f: + self.stat = json.load(f) + + # feature extractor + self.preprocessor_melspec_2048 = MelSTFT( + n_fft=2048, hop_length=hop_length, is_db=True + ) + + + @torch.no_grad() + def normalize(self, x): + """normalize the input audio to have zero mean unit variance""" + for key in x.keys(): + x[key] = (x[key] - self.stat["%s_mean" % key]) / self.stat["%s_std" % key] # {'melspec_2048_cnt': 14282760192, 'melspec_2048_mean': 6.768444971712967} + return x + + @torch.no_grad() + def rearrange(self, x): + """rearrange the batch to flatten every 4 steps""" + for key in x.keys(): + if key == "chromagram": + x[key] = rearrange(x[key], "b f t -> b t f") + else: + x[key] = rearrange(x[key], "b f (t s) -> b t (s f)", s=4) + return x + + @torch.no_grad() + def preprocessing(self, x, features): + """extract classic audio features""" + # check precision + if x.dtype == torch.float16: + precision = 16 + else: + precision = 32 + + out = {} + for key in features: + layer = getattr(self, "preprocessor_%s" % key) + out[key] = layer.float()(x.float())[..., :-1] + if precision == 16: + out[key] = out[key].half() + return out + + @torch.no_grad() + def tokenize(self, x): + out = {} + for key in x.keys(): + layer = getattr(self, "quantizer_%s" % key) + out[key] = layer(x[key]) + return out + + @torch.no_grad() + def __call__(self, x): + x = self.preprocessing(x, features=self.features) # -> {'melspec_2048': Tensor{Size([3, 128, 3000]) cuda:0 f32}} + x = self.normalize(x) + x = self.rearrange(x) # -> {'melspec_2048': Tensor{Size([3, 750, 512]) cuda:0 f32}} + return x['melspec_2048'].permute((0, 2, 1)) + +if __name__ == "__main__": + config = dict( + train_dataset = dict( + manifest_path = 'music4all_sh/train.json', + sample_rate = 24000, + normalize = False, + ), + valid_dataset = dict( + manifest_path = None, + sample_rate = 24000, + normalize = False, + ), + model = dict( + input_dim = 128*4, + n_codebooks = 8, + codebook_size = 1024, + codebook_dim = 16, + quantizer_dropout = 0.0, + ), + train = dict( + batch_size = 96, + num_workers = 6, + valid_interval = 10, + save_interval = 100, + max_updates = 5000, + lr = 1e-4, + device = 'cuda:1', + # loss = 'commitment_loss * 0.25 + codebook_loss * 1.0', + loss = 'commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean()', + preprocess = Preprocessor(), + ) + ) + train_dataset = RVQDataset(**config['train_dataset']) + if config['valid_dataset']['manifest_path'] is None: + # split train and valid dataset + from torch.utils.data import random_split + train_dataset, valid_dataset = random_split( + train_dataset, lengths=[len(train_dataset) - 500, 500] + ) + else: + valid_dataset = RVQDataset(**config['valid_dataset']) + train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers']) + valid_dataloader = DataLoader(valid_dataset, shuffle=False, batch_size=config['train']['batch_size'], drop_last=True, num_workers=config['train']['num_workers']) + model = ResidualVectorQuantize(**config['model']) + + device = config['train']['device'] + preprocess = config['train']['preprocess'].to(device) + model = model.to(device) + optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr']) + cur_updates = 0 + is_running = True + result = {} + from tqdm import tqdm + from tensorboardX import SummaryWriter + writer = SummaryWriter() + from collections import defaultdict + import os + from logging import getLogger + logger = getLogger() + + while is_running: + results = defaultdict(lambda:0) + for item in tqdm(train_dataloader, desc='train'): + wavs = item['source'] + optimizer.zero_grad() + wavs = wavs.to(device) + x = preprocess(wavs) + model.train() + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x) + loss = eval(config['train']['loss']) + loss.backward() + optimizer.step() + + results['loss/train'] += loss.item() + results['commitment_loss/train'] += commitment_loss.item() + results['codebook_loss/train'] += codebook_loss.item() + results['rvq_usage/train'] += rvq_usage.float().mean().item() + + if cur_updates % config['train']['valid_interval'] == 0: + model.eval() + with torch.no_grad(): + for item in tqdm(valid_dataloader, desc='valid'): + wavs = item['source'] + wavs = wavs.to(device) + x = preprocess(wavs) + quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x) + valid_loss = eval(config['train']['loss']) + + results['loss/valid'] += valid_loss.item() + results['commitment_loss/valid'] += commitment_loss.item() + results['codebook_loss/valid'] += codebook_loss.item() + results['rvq_usage/valid'] += rvq_usage.float().mean().item() + + results['cur_updates'] = cur_updates + results['loss/train'] /= config['train']['valid_interval'] + results['commitment_loss/train'] /= config['train']['valid_interval'] + results['codebook_loss/train'] /= config['train']['valid_interval'] + results['rvq_usage/train'] /= config['train']['valid_interval'] + + results['loss/valid'] /= len(valid_dataloader) + results['commitment_loss/valid'] /= len(valid_dataloader) + results['codebook_loss/valid'] /= len(valid_dataloader) + results['rvq_usage/valid'] /= len(valid_dataloader) + + print('') + logger.info(str(results)) + for k,v in results.items(): + writer.add_scalar(k, v, cur_updates) + + results.clear() + + if cur_updates % config['train']['save_interval'] == 0: + os.makedirs(f'{writer.logdir}/ckpt/', exist_ok=True) + logger.info(f'saving checkpoint to {writer.logdir}/ckpt/RVQ_{cur_updates}.pth') + torch.save(model.state_dict(), f'{writer.logdir}/ckpt/RVQ_{cur_updates}.pth') + + + if cur_updates < config['train']['max_updates']: + cur_updates += 1 + else: + is_running = False + break + + # x = torch.randn(32, 120, 375) + # quantized_prompt_embeds, codes, _, commitment_loss, codebook_loss, rvq_usage = model(x) + # print(quantized_prompt_embeds.shape) + # print(codes.shape) + # # w/o reconstruction + # loss = commitment_loss * 0.25 + codebook_loss * 1.0 + # # w/ reconstruction + # loss = commitment_loss * 0.25 + codebook_loss * 1.0 + (x - quantized_prompt_embeds).abs().mean() diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/w2v2_config.json b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/w2v2_config.json new file mode 100644 index 0000000000000000000000000000000000000000..f74dbbf6fe96728cceda4888cf841b39a579e66e --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/model/w2v2_config.json @@ -0,0 +1,113 @@ +{ + "activation_dropout": 0.1, + "adapter_kernel_size": 3, + "adapter_stride": 2, + "add_adapter": false, + "apply_spec_augment": true, + "architectures": [ + "Wav2Vec2ConformerForCTC" + ], + "attention_dropout": 0.1, + "bos_token_id": 1, + "classifier_proj_size": 256, + "codevector_dim": 768, + "conformer_conv_dropout": 0.1, + "contrastive_logits_temperature": 0.1, + "conv_bias": true, + "conv_depthwise_kernel_size": 31, + "conv_dim": [ + 512, + 512, + 512, + 512, + 512, + 512, + 512 + ], + "conv_kernel": [ + 10, + 3, + 3, + 3, + 3, + 2, + 2 + ], + "conv_stride": [ + 5, + 2, + 2, + 2, + 2, + 2, + 2 + ], + "ctc_loss_reduction": "sum", + "ctc_zero_infinity": false, + "diversity_loss_weight": 0.1, + "do_stable_layer_norm": true, + "eos_token_id": 2, + "feat_extract_activation": "gelu", + "feat_extract_dropout": 0.0, + "feat_extract_norm": "layer", + "feat_proj_dropout": 0.1, + "feat_quantizer_dropout": 0.0, + "final_dropout": 0.1, + "gradient_checkpointing": false, + "hidden_act": "swish", + "hidden_dropout": 0.1, + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-05, + "layerdrop": 0.0, + "mask_feature_length": 10, + "mask_feature_min_masks": 0, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_masks": 2, + "mask_time_prob": 0.05, + "max_source_positions": 5000, + "model_type": "wav2vec2-conformer", + "num_adapter_layers": 3, + "num_attention_heads": 16, + "num_codevector_groups": 2, + "num_codevectors_per_group": 320, + "num_conv_pos_embedding_groups": 16, + "num_conv_pos_embeddings": 128, + "num_feat_extract_layers": 7, + "num_hidden_layers": 24, + "num_negatives": 100, + "output_hidden_size": 1024, + "pad_token_id": 0, + "position_embeddings_type": "rotary", + "proj_codevector_dim": 768, + "rotary_embedding_base": 10000, + "tdnn_dilation": [ + 1, + 2, + 3, + 1, + 1 + ], + "tdnn_dim": [ + 512, + 512, + 512, + 512, + 1500 + ], + "tdnn_kernel": [ + 5, + 3, + 3, + 1, + 1 + ], + "torch_dtype": "float32", + "transformers_version": "4.19.0.dev0", + "use_weighted_layer_sum": false, + "vocab_size": 32, + "xvector_output_dim": 512 +} diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/__init__.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/conv.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc1a8f16cd103d09c86ace5b7c48b0583134e02 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/conv.py @@ -0,0 +1,82 @@ +# MIT License +# +# Copyright 2023 ByteDance Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), +# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +from torch import nn +from einops import rearrange + + +class Res2dModule(nn.Module): + def __init__(self, idim, odim, stride=(2, 2)): + super(Res2dModule, self).__init__() + self.conv1 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) + self.bn1 = nn.BatchNorm2d(odim) + self.conv2 = nn.Conv2d(odim, odim, 3, padding=1) + self.bn2 = nn.BatchNorm2d(odim) + self.relu = nn.ReLU() + + # residual + self.diff = False + if (idim != odim) or (stride[0] > 1): + self.conv3 = nn.Conv2d(idim, odim, 3, padding=1, stride=stride) + self.bn3 = nn.BatchNorm2d(odim) + self.diff = True + + def forward(self, x): + out = self.bn2(self.conv2(self.relu(self.bn1(self.conv1(x))))) + if self.diff: + x = self.bn3(self.conv3(x)) + out = x + out + out = self.relu(out) + return out + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Args: + idim (int): Input dimension. + hdim (int): Hidden dimension. + odim (int): Output dimension. + strides (list): Sizes of strides. + n_bands (int): Number of frequency bands. + """ + + def __init__(self, idim, hdim, odim, strides=[2, 2], n_bands=64): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling, self).__init__() + + self.conv = nn.Sequential( + Res2dModule(idim, hdim, (2, strides[0])), + Res2dModule(hdim, hdim, (2, strides[1])), + ) + self.linear = nn.Linear(hdim * n_bands // 2 // 2, odim) + + def forward(self, x): + """Subsample x. + + Args: + x (torch.Tensor): Input tensor (#batch, idim, time). + + Returns: + torch.Tensor: Subsampled tensor (#batch, time', odim), + where time' = time // 4. + """ + + if x.dim() == 3: + x = x.unsqueeze(1) # (b, c, f, t) + x = self.conv(x) + x = rearrange(x, "b c f t -> b t (c f)") + x = self.linear(x) + return x diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/features.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/features.py new file mode 100644 index 0000000000000000000000000000000000000000..c38f525e856eeefffeb2580c7bf61058ed228e0e --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/features.py @@ -0,0 +1,45 @@ +# MIT License +# +# Copyright 2023 ByteDance Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), +# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +import torchaudio +from torch import nn + + +class MelSTFT(nn.Module): + def __init__( + self, + sample_rate=24000, + n_fft=2048, + hop_length=240, + n_mels=128, + is_db=False, + ): + super(MelSTFT, self).__init__() + + # spectrogram + self.mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels + ) + + # amplitude to decibel + self.is_db = is_db + if is_db: + self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() + + def forward(self, waveform): + if self.is_db: + return self.amplitude_to_db(self.mel_stft(waveform)) + else: + return self.mel_stft(waveform) diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/flash_conformer.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/flash_conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..d4c7e58d2a23ea985e22e6e7ec0adc0def7811fa --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/flash_conformer.py @@ -0,0 +1,2115 @@ +# coding=utf-8 +# Copyright 2022 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Wav2Vec2-Conformer model.""" + +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.nn import functional as F + +from transformers.activations import ACT2FN +# from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.modeling_outputs import ( + BaseModelOutput, + CausalLMOutput, + SequenceClassifierOutput, + TokenClassifierOutput, + Wav2Vec2BaseModelOutput, + XVectorOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, + replace_return_docstrings, +) +from transformers.models.wav2vec2_conformer.configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig + + +logger = logging.get_logger(__name__) + + +_HIDDEN_STATES_START_POSITION = 2 + +# General docstring +_CONFIG_FOR_DOC = "Wav2Vec2ConformerConfig" + +# Base docstring +_CHECKPOINT_FOR_DOC = "facebook/wav2vec2-conformer-rope-large-960h-ft" +_EXPECTED_OUTPUT_SHAPE = [1, 292, 1024] + +# CTC docstring +_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" +_CTC_EXPECTED_LOSS = 64.21 + + +WAV2VEC2_CONFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "facebook/wav2vec2-conformer-rel-pos-large", + # See all Wav2Vec2Conformer models at https://huggingface.co/models?filter=wav2vec2-conformer +] + + +@dataclass +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTrainingOutput with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerForPreTrainingOutput(ModelOutput): + """ + Output type of [`Wav2Vec2ConformerForPreTraining`], with potential hidden states and attentions. + + Args: + loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + Total loss as the sum of the contrastive loss (L_m) and the diversity loss (L_d) as stated in the [official + paper](https://arxiv.org/pdf/2006.11477.pdf) . (classification) loss. + projected_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Hidden-states of the model projected to *config.proj_codevector_dim* that can be used to predict the masked + projected quantized states. + projected_quantized_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.proj_codevector_dim)`): + Quantized extracted feature vectors projected to *config.proj_codevector_dim* representing the positive + target vectors for contrastive loss. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + contrastive_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + The contrastive loss (L_m) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . + diversity_loss (*optional*, returned when `sample_negative_indices` are passed, `torch.FloatTensor` of shape `(1,)`): + The diversity loss (L_d) as stated in the [official paper](https://arxiv.org/pdf/2006.11477.pdf) . + """ + + loss: Optional[torch.FloatTensor] = None + projected_states: torch.FloatTensor = None + projected_quantized_states: torch.FloatTensor = None + codevector_perplexity: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + contrastive_loss: Optional[torch.FloatTensor] = None + diversity_loss: Optional[torch.FloatTensor] = None + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.sum(-1).detach().tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._sample_negative_indices +def _sample_negative_indices( + features_shape: Tuple, num_negatives: int, mask_time_indices: Optional[np.ndarray] = None +): + """ + Sample `num_negatives` vectors from feature vectors. + """ + batch_size, sequence_length = features_shape + + # generate indices of the positive vectors themselves, repeat them `num_negatives` times + sequence_length_range = np.arange(sequence_length) + + # get `num_negatives` random vector indices from the same utterance + sampled_negative_indices = np.zeros(shape=(batch_size, sequence_length, num_negatives), dtype=np.int32) + + mask_time_indices = ( + mask_time_indices.astype(bool) if mask_time_indices is not None else np.ones(features_shape, dtype=bool) + ) + + for batch_idx in range(batch_size): + high = mask_time_indices[batch_idx].sum() - 1 + mapped_masked_indices = sequence_length_range[mask_time_indices[batch_idx]] + + feature_indices = np.broadcast_to(np.arange(high + 1)[:, None], (high + 1, num_negatives)) + sampled_indices = np.random.randint(0, high, size=(high + 1, num_negatives)) + # avoid sampling the same positive vector, but keep the distribution uniform + sampled_indices[sampled_indices >= feature_indices] += 1 + + # remap to actual indices + sampled_negative_indices[batch_idx][mask_time_indices[batch_idx]] = mapped_masked_indices[sampled_indices] + + # correct for batch size + sampled_negative_indices[batch_idx] += batch_idx * sequence_length + + return sampled_negative_indices + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerNoLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerLayerNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + + hidden_states = hidden_states.transpose(-2, -1) + hidden_states = self.layer_norm(hidden_states) + hidden_states = hidden_states.transpose(-2, -1) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerGroupNormConvLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.conv_dim[layer_id - 1] if layer_id > 0 else 1 + self.out_conv_dim = config.conv_dim[layer_id] + + self.conv = nn.Conv1d( + self.in_conv_dim, + self.out_conv_dim, + kernel_size=config.conv_kernel[layer_id], + stride=config.conv_stride[layer_id], + bias=config.conv_bias, + ) + self.activation = ACT2FN[config.feat_extract_activation] + + self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PositionalConvEmbedding with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + ) + + # if is_deepspeed_zero3_enabled(): + # import deepspeed + + # with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + # self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + # deepspeed.zero.register_external_parameter(self, self.conv.weight_v) + # deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + # else: + self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) + + self.padding = Wav2Vec2ConformerSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf + """ + + def __init__(self, config): + super().__init__() + dim = config.hidden_size // config.num_attention_heads + base = config.rotary_embedding_base + + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + def forward(self, hidden_states): + sequence_length = hidden_states.shape[1] + + if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: + return self.cached_rotary_positional_embedding + + self.cached_sequence_length = sequence_length + time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + + cos_embeddings = embeddings.cos()[:, None, None, :] + sin_embeddings = embeddings.sin()[:, None, None, :] + self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]) + return self.cached_rotary_positional_embedding + + +class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module): + """Relative positional encoding module.""" + + def __init__(self, config): + super().__init__() + self.max_len = config.max_source_positions + self.d_model = config.hidden_size + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) + + def extend_pe(self, x): + # Reset the positional encodings + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` is the position of query vector and `j` is the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (iWav2Vec2Conformer +class Wav2Vec2ConformerSamePadLayer(nn.Module): + def __init__(self, num_conv_pos_embeddings): + super().__init__() + self.num_pad_remove = 1 if num_conv_pos_embeddings % 2 == 0 else 0 + + def forward(self, hidden_states): + if self.num_pad_remove > 0: + hidden_states = hidden_states[:, :, : -self.num_pad_remove] + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerFeatureEncoder(nn.Module): + """Construct the features from raw audio waveform""" + + def __init__(self, config): + super().__init__() + + if config.feat_extract_norm == "group": + conv_layers = [Wav2Vec2ConformerGroupNormConvLayer(config, layer_id=0)] + [ + Wav2Vec2ConformerNoLayerNormConvLayer(config, layer_id=i + 1) + for i in range(config.num_feat_extract_layers - 1) + ] + elif config.feat_extract_norm == "layer": + conv_layers = [ + Wav2Vec2ConformerLayerNormConvLayer(config, layer_id=i) for i in range(config.num_feat_extract_layers) + ] + else: + raise ValueError( + f"`config.feat_extract_norm` is {config.feat_extract_norm}, but has to be one of ['group', 'layer']" + ) + self.conv_layers = nn.ModuleList(conv_layers) + self.gradient_checkpointing = False + self._requires_grad = True + + def _freeze_parameters(self): + for param in self.parameters(): + param.requires_grad = False + self._requires_grad = False + + def forward(self, input_values): + hidden_states = input_values[:, None] + + # make sure hidden_states require grad for gradient_checkpointing + if self._requires_grad and self.training: + hidden_states.requires_grad = True + + for conv_layer in self.conv_layers: + if self._requires_grad and self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(conv_layer), + hidden_states, + ) + else: + hidden_states = conv_layer(hidden_states) + + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureProjection with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerFeatureProjection(nn.Module): + def __init__(self, config): + super().__init__() + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.dropout = nn.Dropout(config.feat_proj_dropout) + + def forward(self, hidden_states): + # non-projected hidden states are needed for quantization + norm_hidden_states = self.layer_norm(hidden_states) + hidden_states = self.projection(norm_hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states, norm_hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.activation_dropout) + + self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.output_dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class Wav2Vec2ConformerConvolutionModule(nn.Module): + """Convolution block used in the conformer block""" + + def __init__(self, config): + super().__init__() + if (config.conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") + self.layer_norm = nn.LayerNorm(config.hidden_size) + self.pointwise_conv1 = torch.nn.Conv1d( + config.hidden_size, + 2 * config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.glu = torch.nn.GLU(dim=1) + self.depthwise_conv = torch.nn.Conv1d( + config.hidden_size, + config.hidden_size, + config.conv_depthwise_kernel_size, + stride=1, + padding=(config.conv_depthwise_kernel_size - 1) // 2, + groups=config.hidden_size, + bias=False, + ) + self.batch_norm = torch.nn.BatchNorm1d(config.hidden_size) + self.activation = ACT2FN[config.hidden_act] + self.pointwise_conv2 = torch.nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.dropout = torch.nn.Dropout(config.conformer_conv_dropout) + + def forward(self, hidden_states): + hidden_states = self.layer_norm(hidden_states) + # exchange the temporal dimension and the feature dimension + hidden_states = hidden_states.transpose(1, 2) + + # GLU mechanism + # => (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Wav2Vec2ConformerSelfAttention(nn.Module): + """Construct an Wav2Vec2ConformerSelfAttention object. + Can be enhanced with rotary or relative position embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.head_size = config.hidden_size // config.num_attention_heads + self.num_heads = config.num_attention_heads + self.position_embeddings_type = config.position_embeddings_type + + self.linear_q = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_k = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_v = nn.Linear(config.hidden_size, config.hidden_size) + self.linear_out = nn.Linear(config.hidden_size, config.hidden_size) + + self.dropout = nn.Dropout(p=config.attention_dropout) + self.dropout_p = config.attention_dropout + + self.is_causal = config.is_causal + + if self.position_embeddings_type == "relative": + # linear transformation for positional encoding + self.linear_pos = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self-attention mechanism + batch_size, sequence_length, hidden_size = hidden_states.size() + + # make sure query/key states can be != value states + query_key_states = hidden_states + value_states = hidden_states + + if self.position_embeddings_type == "rotary": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_embeddings_type == 'rotary'" + ) + query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) + + # project query_key_states and value_states + query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) + + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False): + hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=self.dropout_p, is_causal=self.is_causal) + probs = None + + # # apply attention_mask if necessary + # if attention_mask is not None: + # scores = scores + attention_mask + + # # => (batch, head, time1, time2) + # probs = torch.softmax(scores, dim=-1) + # probs = self.dropout(probs) + + # # => (batch, head, time1, d_k) + # hidden_states = torch.matmul(probs, value) + + # => (batch, time1, hidden_size) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, probs + + def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings): + batch_size, sequence_length, hidden_size = hidden_states.size() + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size) + + cos = relative_position_embeddings[0, :sequence_length, ...] + sin = relative_position_embeddings[1, :sequence_length, ...] + + # rotate hidden_states with rotary embeddings + hidden_states = hidden_states.transpose(0, 1) + rotated_states_begin = hidden_states[..., : self.head_size // 2] + rotated_states_end = hidden_states[..., self.head_size // 2 :] + rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1) + hidden_states = (hidden_states * cos) + (rotated_states * sin) + hidden_states = hidden_states.transpose(0, 1) + + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size) + + return hidden_states + + def _apply_relative_embeddings(self, query, key, relative_position_embeddings): + # 1. project positional embeddings + # => (batch, head, 2*time1-1, d_k) + proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings) + proj_relative_position_embeddings = proj_relative_position_embeddings.view( + relative_position_embeddings.size(0), -1, self.num_heads, self.head_size + ) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) + + # 2. Add bias to query + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + # 3. attention score: first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # => (batch, head, time1, time2) + scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + # 4. then compute matrix b and matrix d + # => (batch, head, time1, 2*time1-1) + scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings) + + # 5. shift matrix b and matrix d + zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype) + scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1) + scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2]) + scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape) + scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd) + scores_bd = scores_bd[:, :, :, : scores_bd.size(-1) // 2 + 1] + + # 6. sum matrices + # => (batch, head, time1, time2) + scores = (scores_ac + scores_bd) / math.sqrt(self.head_size) + + return scores + + +class Wav2Vec2ConformerEncoderLayer(nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100.""" + + def __init__(self, config): + super().__init__() + embed_dim = config.hidden_size + dropout = config.attention_dropout + + # Feed-forward 1 + self.ffn1_layer_norm = nn.LayerNorm(embed_dim) + self.ffn1 = Wav2Vec2ConformerFeedForward(config) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_dropout = torch.nn.Dropout(dropout) + self.self_attn = Wav2Vec2ConformerSelfAttention(config) + + # Conformer Convolution + self.conv_module = Wav2Vec2ConformerConvolutionModule(config) + + # Feed-forward 2 + self.ffn2_layer_norm = nn.LayerNorm(embed_dim) + self.ffn2 = Wav2Vec2ConformerFeedForward(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + hidden_states = hidden_states + + # 1. Feed-Forward 1 layer + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + residual = hidden_states + + # 2. Self-Attention layer + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weigts = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + # 3. Convolutional Layer + residual = hidden_states + hidden_states = self.conv_module(hidden_states) + hidden_states = residual + hidden_states + + # 4. Feed-Forward 2 Layer + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, attn_weigts + + +class Wav2Vec2ConformerEncoder(nn.Module): + def __init__(self, config, is_causal=False): + super().__init__() + config.is_causal = is_causal + self.config = config + + if config.position_embeddings_type == "relative": + self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config) + elif config.position_embeddings_type == "rotary": + self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config) + else: + self.embed_positions = None + + self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([Wav2Vec2ConformerEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states[~attention_mask] = 0.0 + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + # deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = np.random.uniform(0, 1) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer: + # if not skip_the_layer or deepspeed_zero3_is_enabled: + # under deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + # create gradient checkpointing function + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + relative_position_embeddings, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GumbelVectorQuantizer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerGumbelVectorQuantizer(nn.Module): + """ + Vector quantization using gumbel softmax. See `[CATEGORICAL REPARAMETERIZATION WITH + GUMBEL-SOFTMAX](https://arxiv.org/pdf/1611.01144.pdf) for more information. + """ + + def __init__(self, config): + super().__init__() + self.num_groups = config.num_codevector_groups + self.num_vars = config.num_codevectors_per_group + + if config.codevector_dim % self.num_groups != 0: + raise ValueError( + f"`config.codevector_dim {config.codevector_dim} must be divisible " + f"by `config.num_codevector_groups` {self.num_groups} for concatenation" + ) + + # storage for codebook variables (codewords) + self.codevectors = nn.Parameter( + torch.FloatTensor(1, self.num_groups * self.num_vars, config.codevector_dim // self.num_groups) + ) + self.weight_proj = nn.Linear(config.conv_dim[-1], self.num_groups * self.num_vars) + + # can be decayed for training + self.temperature = 2 + + @staticmethod + def _compute_perplexity(probs, mask=None): + if mask is not None: + mask_extended = mask.flatten()[:, None, None].expand(probs.shape) + probs = torch.where(mask_extended, probs, torch.zeros_like(probs)) + marginal_probs = probs.sum(dim=0) / mask.sum() + else: + marginal_probs = probs.mean(dim=0) + + perplexity = torch.exp(-torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)).sum() + return perplexity + + def forward(self, hidden_states, mask_time_indices=None): + batch_size, sequence_length, hidden_size = hidden_states.shape + + # project to codevector dim + hidden_states = self.weight_proj(hidden_states) + hidden_states = hidden_states.view(batch_size * sequence_length * self.num_groups, -1) + + if self.training: + # sample code vector probs via gumbel in differentiateable way + codevector_probs = nn.functional.gumbel_softmax( + hidden_states.float(), tau=self.temperature, hard=True + ).type_as(hidden_states) + + # compute perplexity + codevector_soft_dist = torch.softmax( + hidden_states.view(batch_size * sequence_length, self.num_groups, -1).float(), dim=-1 + ) + perplexity = self._compute_perplexity(codevector_soft_dist, mask_time_indices) + else: + # take argmax in non-differentiable way + # comptute hard codevector distribution (one hot) + codevector_idx = hidden_states.argmax(dim=-1) + codevector_probs = hidden_states.new_zeros(hidden_states.shape).scatter_( + -1, codevector_idx.view(-1, 1), 1.0 + ) + codevector_probs = codevector_probs.view(batch_size * sequence_length, self.num_groups, -1) + + perplexity = self._compute_perplexity(codevector_probs, mask_time_indices) + + codevector_probs = codevector_probs.view(batch_size * sequence_length, -1) + # use probs to retrieve codevectors + codevectors_per_group = codevector_probs.unsqueeze(-1) * self.codevectors + codevectors = codevectors_per_group.view(batch_size * sequence_length, self.num_groups, self.num_vars, -1) + codevectors = codevectors.sum(-2).view(batch_size, sequence_length, -1) + + return codevectors, perplexity + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Adapter with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerAdapter(nn.Module): + def __init__(self, config): + super().__init__() + + # feature dim might need to be down-projected + if config.output_hidden_size != config.hidden_size: + self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) + self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size) + else: + self.proj = self.proj_layer_norm = None + + self.layers = nn.ModuleList(Wav2Vec2ConformerAdapterLayer(config) for _ in range(config.num_adapter_layers)) + self.layerdrop = config.layerdrop + + def forward(self, hidden_states): + # down project hidden_states if necessary + if self.proj is not None and self.proj_layer_norm is not None: + hidden_states = self.proj(hidden_states) + hidden_states = self.proj_layer_norm(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + + for layer in self.layers: + layerdrop_prob = np.random.random() + if not self.training or (layerdrop_prob > self.layerdrop): + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2AdapterLayer with Wav2Vec2->Wav2Vec2Conformer +class Wav2Vec2ConformerAdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.output_hidden_size, + 2 * config.output_hidden_size, + config.adapter_kernel_size, + stride=config.adapter_stride, + padding=1, + ) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=1) + + return hidden_states + + +class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = Wav2Vec2ConformerConfig + base_model_prefix = "wav2vec2_conformer" + main_input_name = "input_values" + _keys_to_ignore_on_load_missing = [r"position_ids"] + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + # Wav2Vec2ForPreTraining last 2 linear layers need standard Linear init. + if isinstance(module, Wav2Vec2ConformerForPreTraining): + module.project_hid.reset_parameters() + module.project_q.reset_parameters() + module.project_hid._is_hf_initialized = True + module.project_q._is_hf_initialized = True + # gumbel softmax requires special init + elif isinstance(module, Wav2Vec2ConformerGumbelVectorQuantizer): + module.weight_proj.weight.data.normal_(mean=0.0, std=1) + module.weight_proj.bias.data.zero_() + nn.init.uniform_(module.codevectors) + elif isinstance(module, Wav2Vec2ConformerSelfAttention): + if hasattr(module, "pos_bias_u"): + nn.init.xavier_uniform_(module.pos_bias_u) + if hasattr(module, "pos_bias_v"): + nn.init.xavier_uniform_(module.pos_bias_v) + elif isinstance(module, Wav2Vec2ConformerPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, Wav2Vec2ConformerFeatureProjection): + k = math.sqrt(1 / module.projection.in_features) + nn.init.uniform_(module.projection.weight, a=-k, b=k) + nn.init.uniform_(module.projection.bias, a=-k, b=k) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _get_feat_extract_output_lengths( + self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None + ): + """ + Computes the output length of the convolutional layers + """ + + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + + return input_lengths + + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None + ): + # Effectively attention_mask.sum(-1), but not inplace to be able to run + # on inference mode. + non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] + + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + output_lengths = output_lengths.to(torch.long) + + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)): + module.gradient_checkpointing = value + + +WAV2VEC2_CONFORMER_START_DOCSTRING = r""" + Wav2Vec2Conformer was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech + Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael + Auli. + + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving etc.). + + This model is a PyTorch [nn.Module](https://pytorch.org/docs/stable/nn.html#nn.Module) sub-class. Use it as a + regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. + + Parameters: + config ([`Wav2Vec2ConformerConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +WAV2VEC2_CONFORMER_INPUTS_DOCSTRING = r""" + Args: + input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`): + Float values of input raw speech waveform. Values can be obtained by loading a `.flac` or `.wav` audio file + into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via the soundfile library (`pip install + soundfile`). To prepare the array into `input_values`, the [`AutoProcessor`] should be used for padding and + conversion into a tensor of type `torch.FloatTensor`. See [`Wav2Vec2Processor.__call__`] for details. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, + 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + + + `attention_mask` should only be passed if the corresponding processor has `config.return_attention_mask == + True`. For all models whose processor has `config.return_attention_mask == False`, such as + [wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large), + `attention_mask` should **not** be passed to avoid degraded performance when doing batched inference. For + such models `input_values` should simply be padded with 0 and passed without `attention_mask`. Be aware + that these models also yield slightly different results depending on whether `input_values` is padded or + not. + + + + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Wav2Vec2Conformer Model transformer outputting raw hidden-states without any specific head on top.", + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerModel(Wav2Vec2ConformerPreTrainedModel): + def __init__(self, config: Wav2Vec2ConformerConfig): + super().__init__(config) + self.config = config + self.feature_extractor = Wav2Vec2ConformerFeatureEncoder(config) + self.feature_projection = Wav2Vec2ConformerFeatureProjection(config) + + # model only needs masking vector if mask prob is > 0.0 + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) + + self.encoder = Wav2Vec2ConformerEncoder(config) + + self.adapter = Wav2Vec2ConformerAdapter(config) if config.add_adapter else None + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.freeze_feature_encoder + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.feature_extractor._freeze_parameters() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=Wav2Vec2BaseModelOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + expected_output=_EXPECTED_OUTPUT_SHAPE, + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model.forward with wav2vec2->wav2vec2_conformer + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Wav2Vec2BaseModelOutput]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + hidden_states, extract_features = self.feature_projection(extract_features) + hidden_states = self._mask_hidden_states( + hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask + ) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + + if not return_dict: + return (hidden_states, extract_features) + encoder_outputs[1:] + + return Wav2Vec2BaseModelOutput( + last_hidden_state=hidden_states, + extract_features=extract_features, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """Wav2Vec2Conformer Model with a quantizer and `VQ` head on top.""", WAV2VEC2_CONFORMER_START_DOCSTRING +) +class Wav2Vec2ConformerForPreTraining(Wav2Vec2ConformerPreTrainedModel): + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer + def __init__(self, config: Wav2Vec2ConformerConfig): + super().__init__(config) + self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) + self.dropout_features = nn.Dropout(config.feat_quantizer_dropout) + + self.quantizer = Wav2Vec2ConformerGumbelVectorQuantizer(config) + + self.project_hid = nn.Linear(config.hidden_size, config.proj_codevector_dim) + self.project_q = nn.Linear(config.codevector_dim, config.proj_codevector_dim) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.set_gumbel_temperature + def set_gumbel_temperature(self, temperature: int): + """ + Set the Gumbel softmax temperature to a given value. Only necessary for training + """ + self.quantizer.temperature = temperature + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.freeze_feature_encoder with wav2vec2->wav2vec2_conformer + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2_conformer.feature_extractor._freeze_parameters() + + @staticmethod + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.compute_contrastive_logits + def compute_contrastive_logits( + target_features: torch.FloatTensor, + negative_features: torch.FloatTensor, + predicted_features: torch.FloatTensor, + temperature: int = 0.1, + ): + """ + Compute logits for contrastive loss based using cosine similarity as the distance measure between + `[positive_feature, negative_features]` and `[predicted_features]`. Additionally, temperature can be applied. + """ + target_features = torch.cat([target_features, negative_features], dim=0) + + logits = torch.cosine_similarity(predicted_features.float(), target_features.float(), dim=-1).type_as( + target_features + ) + + # apply temperature + logits = logits / temperature + return logits + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=Wav2Vec2ConformerForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForPreTraining.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,wav2vec2_conformer-base->wav2vec2-conformer-rel-pos-large + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.BoolTensor] = None, + sampled_negative_indices: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Wav2Vec2ConformerForPreTrainingOutput]: + r""" + mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + sampled_negative_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length, num_negatives)`, *optional*): + Indices indicating which quantized target vectors are used as negative sampled vectors in contrastive loss. + Required input for pre-training. + + Returns: + + Example: + + ```python + >>> import torch + >>> from transformers import AutoFeatureExtractor, Wav2Vec2ConformerForPreTraining + >>> from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( + ... _compute_mask_indices, + ... _sample_negative_indices, + ... ) + >>> from datasets import load_dataset + + >>> feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") + >>> model = Wav2Vec2ConformerForPreTraining.from_pretrained("facebook/wav2vec2-conformer-rel-pos-large") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1 + + >>> # compute masked indices + >>> batch_size, raw_sequence_length = input_values.shape + >>> sequence_length = model._get_feat_extract_output_lengths(raw_sequence_length).item() + >>> mask_time_indices = _compute_mask_indices( + ... shape=(batch_size, sequence_length), mask_prob=0.2, mask_length=2 + ... ) + >>> sampled_negative_indices = _sample_negative_indices( + ... features_shape=(batch_size, sequence_length), + ... num_negatives=model.config.num_negatives, + ... mask_time_indices=mask_time_indices, + ... ) + >>> mask_time_indices = torch.tensor(data=mask_time_indices, device=input_values.device, dtype=torch.long) + >>> sampled_negative_indices = torch.tensor( + ... data=sampled_negative_indices, device=input_values.device, dtype=torch.long + ... ) + + >>> with torch.no_grad(): + ... outputs = model(input_values, mask_time_indices=mask_time_indices) + + >>> # compute cosine similarity between predicted (=projected_states) and target (=projected_quantized_states) + >>> cosine_sim = torch.cosine_similarity(outputs.projected_states, outputs.projected_quantized_states, dim=-1) + + >>> # show that cosine similarity is much higher than random + >>> cosine_sim[mask_time_indices.to(torch.bool)].mean() > 0.5 + tensor(True) + + >>> # for contrastive loss training model should be put into train mode + >>> model = model.train() + >>> loss = model( + ... input_values, mask_time_indices=mask_time_indices, sampled_negative_indices=sampled_negative_indices + ... ).loss + ```""" + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if mask_time_indices is not None: + mask_time_indices = mask_time_indices.to(torch.bool) + + outputs = self.wav2vec2_conformer( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + mask_time_indices=mask_time_indices, + return_dict=return_dict, + ) + + # 1. project all transformed features (including masked) to final vq dim + transformer_features = self.project_hid(outputs[0]) + + # 2. quantize all (unmasked) extracted features and project to final vq dim + extract_features = self.dropout_features(outputs[1]) + + if attention_mask is not None: + # compute reduced attention_mask correponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) + + quantized_features, codevector_perplexity = self.quantizer( + extract_features, mask_time_indices=mask_time_indices + ) + quantized_features = self.project_q(quantized_features) + + loss = contrastive_loss = diversity_loss = None + if sampled_negative_indices is not None: + batch_size, sequence_length, hidden_size = quantized_features.shape + + # for training, we sample negatives + # 3. sample K negatives (distractors) quantized states for contrastive loss + # if attention_mask is passed, make sure that padded feature vectors cannot be sampled + # sample negative quantized vectors BTC => (BxT)C + negative_quantized_features = quantized_features.view(-1, hidden_size)[ + sampled_negative_indices.long().view(-1) + ] + negative_quantized_features = negative_quantized_features.view( + batch_size, sequence_length, -1, hidden_size + ).permute(2, 0, 1, 3) + + # 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa` + # of equation (3) in https://arxiv.org/pdf/2006.11477.pdf + logits = self.compute_contrastive_logits( + quantized_features[None, :], + negative_quantized_features, + transformer_features, + self.config.contrastive_logits_temperature, + ) + + # 5. if a negative vector is identical to the positive (i.e. when codebook utilization is low), + # its cosine similarity will be masked + neg_is_pos = (quantized_features == negative_quantized_features).all(-1) + + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + + # 6. compute contrastive loss \mathbf{L}_m = cross_entropy(logs) = + # -log(exp(sim(c_t, q_t)/\kappa) / \sum_{\sim{q}} exp(sim(c_t, \sim{q})/\kappa)) + logits = logits.transpose(0, 2).reshape(-1, logits.size(0)) + target = ((1 - mask_time_indices.long()) * -100).transpose(0, 1).flatten() + + contrastive_loss = nn.functional.cross_entropy(logits.float(), target, reduction="sum") + # 7. compute diversity loss: \mathbf{L}_d + num_codevectors = self.config.num_codevectors_per_group * self.config.num_codevector_groups + diversity_loss = ((num_codevectors - codevector_perplexity) / num_codevectors) * mask_time_indices.sum() + + # 8. \mathbf{L} = \mathbf{L}_m + \alpha * \mathbf{L}_d + loss = contrastive_loss + self.config.diversity_loss_weight * diversity_loss + + if not return_dict: + if loss is not None: + return (loss, transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + return (transformer_features, quantized_features, codevector_perplexity) + outputs[2:] + + return Wav2Vec2ConformerForPreTrainingOutput( + loss=loss, + projected_states=transformer_features, + projected_quantized_states=quantized_features, + codevector_perplexity=codevector_perplexity, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + contrastive_loss=contrastive_loss, + diversity_loss=diversity_loss, + ) + + +@add_start_docstrings( + """Wav2Vec2Conformer Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).""", + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForCTC(Wav2Vec2ConformerPreTrainedModel): + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer + def __init__(self, config): + super().__init__(config) + + self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) + self.dropout = nn.Dropout(config.final_dropout) + + if config.vocab_size is None: + raise ValueError( + f"You are trying to instantiate {self.__class__} with a configuration that " + "does not define the vocabulary size of the language model head. Please " + "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. " + "or define `vocab_size` of your model's configuration." + ) + output_hidden_size = ( + config.output_hidden_size if hasattr(config, "add_adapter") and config.add_adapter else config.hidden_size + ) + self.lm_head = nn.Linear(output_hidden_size, config.vocab_size) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.freeze_feature_encoder with wav2vec2->wav2vec2_conformer + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2_conformer.feature_extractor._freeze_parameters() + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=CausalLMOutput, + config_class=_CONFIG_FOR_DOC, + expected_output=_CTC_EXPECTED_OUTPUT, + expected_loss=_CTC_EXPECTED_LOSS, + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, CausalLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*): + Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to + the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`. + All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., + config.vocab_size - 1]`. + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.wav2vec2_conformer( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states) + + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + if labels.max() >= self.config.vocab_size: + raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}") + + # retrieve loss input_lengths from attention_mask + attention_mask = ( + attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long) + ) + input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + + # assuming that padded tokens are filled with -100 + # when not being attended to + labels_mask = labels >= 0 + target_lengths = labels_mask.sum(-1) + flattened_targets = labels.masked_select(labels_mask) + + # ctc_loss doesn't support fp16 + log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1) + + with torch.backends.cudnn.flags(enabled=False): + loss = nn.functional.ctc_loss( + log_probs, + flattened_targets, + input_lengths, + target_lengths, + blank=self.config.pad_token_id, + reduction=self.config.ctc_loss_reduction, + zero_infinity=self.config.ctc_zero_infinity, + ) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutput( + loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions + ) + + +@add_start_docstrings( + """ + Wav2Vec2Conformer Model with a sequence classification head on top (a linear layer over the pooled output) for + tasks like SUPERB Keyword Spotting. + """, + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForSequenceClassification(Wav2Vec2ConformerPreTrainedModel): + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Sequence classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)" + ) + self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2_conformer.feature_extractor._freeze_parameters() + + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2_conformer.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=SequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, SequenceClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2_conformer( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + if attention_mask is None: + pooled_output = hidden_states.mean(dim=1) + else: + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + hidden_states[~padding_mask] = 0.0 + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) + + logits = self.classifier(pooled_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings( + """ + Wav2Vec2Conformer Model with a frame classification head on top for tasks like Speaker Diarization. + """, + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForAudioFrameClassification(Wav2Vec2ConformerPreTrainedModel): + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.__init__ with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER + def __init__(self, config): + super().__init__(config) + + if hasattr(config, "add_adapter") and config.add_adapter: + raise ValueError( + "Audio frame classification does not support the use of Wav2Vec2Conformer adapters (config.add_adapter=True)" + ) + self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self.num_labels = config.num_labels + + self.init_weights() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_feature_encoder with wav2vec2->wav2vec2_conformer + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2_conformer.feature_extractor._freeze_parameters() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.freeze_base_model with wav2vec2->wav2vec2_conformer + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2_conformer.parameters(): + param.requires_grad = False + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForAudioFrameClassification.forward with wav2vec2->wav2vec2_conformer + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2_conformer( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), torch.argmax(labels.view(-1, self.num_labels), axis=1)) + + if not return_dict: + output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] + return output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.AMSoftmaxLoss +class AMSoftmaxLoss(nn.Module): + def __init__(self, input_dim, num_labels, scale=30.0, margin=0.4): + super(AMSoftmaxLoss, self).__init__() + self.scale = scale + self.margin = margin + self.num_labels = num_labels + self.weight = nn.Parameter(torch.randn(input_dim, num_labels), requires_grad=True) + self.loss = nn.CrossEntropyLoss() + + def forward(self, hidden_states, labels): + labels = labels.flatten() + weight = nn.functional.normalize(self.weight, dim=0) + hidden_states = nn.functional.normalize(hidden_states, dim=1) + cos_theta = torch.mm(hidden_states, weight) + psi = cos_theta - self.margin + + onehot = nn.functional.one_hot(labels, self.num_labels) + logits = self.scale * torch.where(onehot.bool(), psi, cos_theta) + loss = self.loss(logits, labels) + + return loss + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.TDNNLayer +class TDNNLayer(nn.Module): + def __init__(self, config, layer_id=0): + super().__init__() + self.in_conv_dim = config.tdnn_dim[layer_id - 1] if layer_id > 0 else config.tdnn_dim[layer_id] + self.out_conv_dim = config.tdnn_dim[layer_id] + self.kernel_size = config.tdnn_kernel[layer_id] + self.dilation = config.tdnn_dilation[layer_id] + + self.kernel = nn.Linear(self.in_conv_dim * self.kernel_size, self.out_conv_dim) + self.activation = nn.ReLU() + + def forward(self, hidden_states): + hidden_states = hidden_states.unsqueeze(1) + hidden_states = nn.functional.unfold( + hidden_states, + (self.kernel_size, self.in_conv_dim), + stride=(1, self.in_conv_dim), + dilation=(self.dilation, 1), + ) + hidden_states = hidden_states.transpose(1, 2) + hidden_states = self.kernel(hidden_states) + + hidden_states = self.activation(hidden_states) + return hidden_states + + +@add_start_docstrings( + """ + Wav2Vec2Conformer Model with an XVector feature extraction head on top for tasks like Speaker Verification. + """, + WAV2VEC2_CONFORMER_START_DOCSTRING, +) +class Wav2Vec2ConformerForXVector(Wav2Vec2ConformerPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings + if config.use_weighted_layer_sum: + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) + self.projector = nn.Linear(config.hidden_size, config.tdnn_dim[0]) + + tdnn_layers = [TDNNLayer(config, i) for i in range(len(config.tdnn_dim))] + self.tdnn = nn.ModuleList(tdnn_layers) + + self.feature_extractor = nn.Linear(config.tdnn_dim[-1] * 2, config.xvector_output_dim) + self.classifier = nn.Linear(config.xvector_output_dim, config.xvector_output_dim) + + self.objective = AMSoftmaxLoss(config.xvector_output_dim, config.num_labels) + + self.init_weights() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_feature_encoder with wav2vec2->wav2vec2_conformer + def freeze_feature_encoder(self): + """ + Calling this function will disable the gradient computation for the feature encoder so that its parameter will + not be updated during training. + """ + self.wav2vec2_conformer.feature_extractor._freeze_parameters() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.freeze_base_model with wav2vec2->wav2vec2_conformer + def freeze_base_model(self): + """ + Calling this function will disable the gradient computation for the base model so that its parameters will not + be updated during training. Only the classification head will be updated. + """ + for param in self.wav2vec2_conformer.parameters(): + param.requires_grad = False + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector._get_tdnn_output_lengths with wav2vec2->wav2vec2_conformer + def _get_tdnn_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the TDNN layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return (input_length - kernel_size) // stride + 1 + + for kernel_size in self.config.tdnn_kernel: + input_lengths = _conv_out_length(input_lengths, kernel_size, 1) + + return input_lengths + + @add_start_docstrings_to_model_forward(WAV2VEC2_CONFORMER_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=XVectorOutput, + config_class=_CONFIG_FOR_DOC, + modality="audio", + ) + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForXVector.forward with Wav2Vec2->Wav2Vec2Conformer,wav2vec2->wav2vec2_conformer,WAV_2_VEC_2->WAV2VEC2_CONFORMER + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[torch.Tensor] = None, + ) -> Union[Tuple, XVectorOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states + + outputs = self.wav2vec2_conformer( + input_values, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if self.config.use_weighted_layer_sum: + hidden_states = outputs[_HIDDEN_STATES_START_POSITION] + hidden_states = torch.stack(hidden_states, dim=1) + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) + else: + hidden_states = outputs[0] + + hidden_states = self.projector(hidden_states) + + for tdnn_layer in self.tdnn: + hidden_states = tdnn_layer(hidden_states) + + # Statistic Pooling + if attention_mask is None: + mean_features = hidden_states.mean(dim=1) + std_features = hidden_states.std(dim=1) + else: + feat_extract_output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(dim=1)) + tdnn_output_lengths = self._get_tdnn_output_lengths(feat_extract_output_lengths) + mean_features = [] + std_features = [] + for i, length in enumerate(tdnn_output_lengths): + mean_features.append(hidden_states[i, :length].mean(dim=0)) + std_features.append(hidden_states[i, :length].std(dim=0)) + mean_features = torch.stack(mean_features) + std_features = torch.stack(std_features) + statistic_pooling = torch.cat([mean_features, std_features], dim=-1) + + output_embeddings = self.feature_extractor(statistic_pooling) + logits = self.classifier(output_embeddings) + + loss = None + if labels is not None: + loss = self.objective(logits, labels) + + if not return_dict: + output = (logits, output_embeddings) + outputs[_HIDDEN_STATES_START_POSITION:] + return ((loss,) + output) if loss is not None else output + + return XVectorOutput( + loss=loss, + logits=logits, + embeddings=output_embeddings, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/random_quantizer.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/random_quantizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1257014658a24e4557814ccbb746de455ec111fa --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/modules/random_quantizer.py @@ -0,0 +1,83 @@ +# MIT License +# +# Copyright 2023 ByteDance Inc. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), +# to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, +# and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. + +import torch +from torch import nn, einsum +from einops import rearrange + + +class RandomProjectionQuantizer(nn.Module): + """ + Random projection and codebook lookup module + + Some code is borrowed from: + https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/random_projection_quantizer.py + But I did normalization using pre-computed global mean & variance instead of using layer norm. + """ + + def __init__( + self, + input_dim, + codebook_dim, + codebook_size, + seed=142, + ): + super().__init__() + + # random seed + torch.manual_seed(seed) + + # randomly initialized projection + random_projection = torch.empty(input_dim, codebook_dim) + nn.init.xavier_normal_(random_projection) + self.register_buffer("random_projection", random_projection) + + # randomly initialized codebook + codebook = torch.empty(codebook_size, codebook_dim) + nn.init.normal_(codebook) + self.register_buffer("codebook", codebook) + + def codebook_lookup(self, x): + # reshape + b = x.shape[0] + x = rearrange(x, "b n e -> (b n) e") + + # L2 normalization + normalized_x = nn.functional.normalize(x, dim=1, p=2) + normalized_codebook = nn.functional.normalize(self.codebook, dim=1, p=2) + + # compute distances + distances = torch.cdist(normalized_codebook, normalized_x) + + # get nearest + nearest_indices = torch.argmin(distances, dim=0) + + # reshape + xq = rearrange(nearest_indices, "(b n) -> b n", b=b) + + return xq + + @torch.no_grad() + def forward(self, x): + # always eval + self.eval() + + # random projection [batch, length, input_dim] -> [batch, length, codebook_dim] + x = einsum("b n d, d e -> b n e", x, self.random_projection) + + # codebook lookup + xq = self.codebook_lookup(x) + + return xq diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/musicfm_model.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/musicfm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..376f1c8051eebab17c4875fd88924e8b7ef88eed --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/models/musicfm/musicfm_model.py @@ -0,0 +1,109 @@ +try: + from .model.musicfm_25hz import MusicFM25Hz +except: + import sys, os + sys.path.append(os.path.dirname(os.path.abspath(__file__))) + from model.musicfm_25hz import MusicFM25Hz +try: + from fairseq.fairseq.dataclass import FairseqDataclass + from fairseq.fairseq.models import BaseFairseqModel, register_model + from fairseq.fairseq.tasks.fairseq_task import FairseqTask +except: + from fairseq.dataclass import FairseqDataclass + from fairseq.models import BaseFairseqModel, register_model + from fairseq.tasks.fairseq_task import FairseqTask + +from dataclasses import dataclass, field +from typing import List, Tuple, Optional +import torch + +from logging import getLogger + +logger = getLogger(__name__) + +@dataclass +class MusicFMConfig(FairseqDataclass): + label_rate:int = field(default=25) + num_codebooks:int = field(default=1) + codebook_dim:int = field(default=16) + codebook_size:int = field(default=4096) + features:List[str] = field(default_factory=lambda:["melspec_2048"]) + hop_length:int = field(default=240) + n_mels:int = field(default=128) + conv_dim:int = field(default=512) + encoder_dim:int = field(default=1024) + encoder_depth:int = field(default=12) + mask_hop:float = field(default=0.4) + mask_prob:float = field(default=0.6) + is_flash:bool = field(default=False) + stat_path:Optional[str] = field(default=None) + model_path:Optional[str] = field(default=None) + w2v2_config_path:Optional[str] = field(default=None) + use_rvq_target:bool = field(default=False) + rvq_ckpt_path: Optional[str] = field(default=None) + + +SAMPLE_RATE = 24_000 + +@register_model("musicfm", dataclass=MusicFMConfig) +class MusicFMModel(BaseFairseqModel): + def __init__(self, cfg: MusicFMConfig, task_cfg: FairseqTask): + super().__init__() + self.cfg = cfg + self.model = MusicFM25Hz( + num_codebooks=cfg.num_codebooks, + codebook_dim=cfg.codebook_dim, + codebook_size=cfg.codebook_size, + features=cfg.features, + n_mels=cfg.n_mels, + conv_dim=cfg.conv_dim, + encoder_dim=cfg.encoder_dim, + encoder_depth=cfg.encoder_depth, + mask_hop=cfg.mask_hop, + mask_prob=cfg.mask_prob, + is_flash=cfg.is_flash, + stat_path=cfg.stat_path, + model_path=cfg.model_path, + w2v2_config_path=cfg.w2v2_config_path, + use_rvq_target=cfg.use_rvq_target, + rvq_ckpt_path=cfg.rvq_ckpt_path, + ) + + def forward( + self, + source: torch.Tensor, # B,L + features_only: bool = False, + **kwargs, + ): + source = source[..., :int((source.shape[-1]//(SAMPLE_RATE//self.cfg.label_rate))*(SAMPLE_RATE//self.cfg.label_rate)) ] + # logger.info("source shape: "+str(source.shape)) + if features_only: + _, hidden_states = self.model.get_predictions(source) + result = { + "layer_results": hidden_states + } + return result + else: + result = {} + logits, hidden_emb, losses, accuracies = self.model(source) + result["losses"] = losses + result["accuracies"] = accuracies + result["logits"] = logits + result["hidden_emb"] = hidden_emb + return result + + @classmethod + def build_model(cls, cfg: MusicFMConfig, task: FairseqTask): + """Build a new model instance.""" + + model = MusicFMModel(cfg, task.cfg) + import numpy as np + s = 0 + for param in model.parameters(): + s += np.product(param.size()) + print('# of parameters: '+str(s/1024.0/1024.0)) + return model + + def get_losses(self, result, batch): + return result['losses'] + \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/mert_pretraining.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/mert_pretraining.py new file mode 100644 index 0000000000000000000000000000000000000000..79a6cad1b6be7d902c9dd06af5ca3334682cfc20 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/mert_pretraining.py @@ -0,0 +1,452 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import os +import sys +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch + +from dataclasses import dataclass, field +from fairseq.data import Dictionary, HubertDataset +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING + +# from ..data.mert_dataset import MERTDataset +from ..data.mert_dataset import MERTDataset #这么做感觉有大问题,得换个办法 +from ..data.ark_dataset import ArkDataset + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary: Dictionary) -> None: + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + # @yizhilll: https://fairseq.readthedocs.io/en/latest/_modules/fairseq/data/dictionary.html \ + # encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT + return self.dictionary.encode_line( + label, + append_eos=False, + add_if_not_exist=False, + ) +class PaddedNumpyLabelEncoder(object): + def __init__(self): + # self.dictionary = dictionary + pass + + def __call__(self, label): + # @yizhilll: https://fairseq.readthedocs.io/en/latest/_modules/fairseq/data/dictionary.html \ + # encode_line return a torch.IntTensor, should be all 1 for vanila HuBERT + # return self.dictionary.encode_line( + # label, + # append_eos=False, + # add_if_not_exist=False, + # ) + # if isisntance(label, np.memmap): + + # assert isisntance(label, np.memmap) + # t = torch.IntTensor(np.asarray(label).copy()) + t = torch.IntTensor(np.asarray(label)) + t = t[t>=0] # remove padded -1 values at the end + return t + +@dataclass +class MERTPretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + sharding_data: int = field( + default=-1, + metadata={ + "help": "set this para >1 to use sharding dataset to prevent OOM" + "prepare data tsv and label files by adding postfix for sharding 64 like:" + "train_28_64.tsv and train_28_64.encodec_6" + }, + ) + load_random_data_shard: bool = field( + default=True, + metadata={ + "help": "whether to laod shards randomly or in order when use sharding_data" + }, + ) + fine_tuning: bool = field( + default=False, metadata={"help": "set to true if fine-tuning Hubert"} + ) + labels: List[str] = field( + default_factory=lambda: ["ltr"], + metadata={ + "help": ( + "extension of the label files to load, frame-level labels for" + " pre-training, and sequence-level label for fine-tuning" + ) + }, + ) + label_dir: Optional[str] = field( + default=None, + metadata={ + "help": "if set, looks for labels in this directory instead", + }, + ) + label_rate: float = field( + default=-1.0, + metadata={"help": "label frame rate. -1.0 for sequence label"}, + ) + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_keep_size: Optional[int] = field( + default=None, + metadata={"help": "exclude sample longer than this"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to crop to for batching"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to crop to for batching"}, + ) + single_target: Optional[bool] = field( + default=False, + metadata={ + "help": "if set, AddTargetDatasets outputs same keys " "as AddTargetDataset" + }, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + + store_labels: Optional[bool] = field( + default=False, + metadata={"help": "whether to load all of the label into memory"}, + ) + + numpy_memmap_label: Optional[bool] = field( + default=False, + metadata={"help": "whether the label file is saved as a numpy file, each line is ended with padding -1"}, + ) + + augmentation_effects: Optional[str] = field( + default="[]", + metadata={ + "help": ( + "a list of effects that might apply to the audios" + "example: \"['random_mute', 'random_Gaussian', 'reverse_polarity']\" " + "supported: random_mute," + "todo: " + ) + }, + ) + augmentation_probs: Optional[str] = field( + default="[]", + metadata={ + "help": ( + "the corresponding probabilities for the data augmentation effects" + "example: \"[0.1, 0.5, 0.8]\" " + "the sum is not necessarily need to be 1.0, and multiple effects can be applied to the same audio" + ) + }, + ) + + # inbatch_noise_augment_len_range: Optional[List[int]] = field( + # default_factory=lambda: [8000, 24000], + # default = [8000, 24000], + inbatch_noise_augment_len_range: Optional[str] = field( + default = "[8000, 24000]", + metadata={ + "help": ( + "the range of length of the mix-up noise augmentation, unit in smaples" + ) + }, + ) + # inbatch_noise_augment_number_range: Optional[List[int]] = field( + # default_factory=lambda: [1, 3], + # default = [1, 3], + inbatch_noise_augment_number_range: Optional[str] = field( + default = "[1, 3]", + metadata={ + "help": ( + "the range of numbers of the mix-up noise augmentation" + ) + }, + ) + inbatch_noise_augment_volume: float = field( + default = 1.0, + metadata={ + "help": ( + "the coefficient used to modify the volume of the noise audios wavs" + ) + }, + ) + dynamic_crops: Optional[str] = field( + default="[]", + metadata={ + "help": ( + "used to set the maximum audio length setting, for training" + "example: \"[1, 2, 3, 4, 5, 10]\" " + ) + }, + ) + dynamic_crops_epoches: Optional[str] = field( + default="[]", + metadata={ + "help": ( + "used to set training epoches of changing the maximum audio length" + "example: \"[1, 10, 20, 40, 80, 160,]\" " + "then len need to be equal to len(dynamic_crops)" + ) + }, + ) + + cqt_loss_bin_dataloader: Optional[int] = field( + default=-1, + metadata={ + "help": ( + "use this parameter to prepare cqt prediction objective in dataloader" + ) + }, + ) + + clip_secs: int = field( + default=5, + metadata={ + "help": "clip secs for each audio" + } + ) + + +@register_task("mert_pretraining", dataclass=MERTPretrainingConfig) +class MERTPretrainingTask(FairseqTask): + + cfg: MERTPretrainingConfig + + def __init__( + self, + cfg: MERTPretrainingConfig, + ) -> None: + super().__init__(cfg) + + logger.info(f"current directory is {os.getcwd()}") + logger.info(f"MERTPretrainingTask Config {cfg}") + + self.cfg = cfg + self.fine_tuning = cfg.fine_tuning + + if cfg.fine_tuning: + self.state.add_factory("target_dictionary", self.load_dictionaries) + else: + self.state.add_factory("dictionaries", self.load_dictionaries) + + self.blank_symbol = "" + + # @yizhilll: use eval() to pass list parameters, skirt the fairseq/torch error: Can't pickle : attribute lookup Choices on fairseq.dataclass.constants failed + self.augmentation_effects = eval(self.cfg.augmentation_effects) + self.augmentation_probs = eval(self.cfg.augmentation_probs) + if len(self.augmentation_effects) > 0: + assert len(self.augmentation_effects) == len(self.augmentation_probs) + logger.info(f"Applying audio augmentation {self.augmentation_effects}, probabilities: {self.augmentation_probs}") + + self.inbatch_noise_augment_number_range = eval(self.cfg.inbatch_noise_augment_number_range) + self.inbatch_noise_augment_len_range = eval(self.cfg.inbatch_noise_augment_len_range) + + self.max_sample_size = self.cfg.max_sample_size + + self.dynamic_crops = eval(self.cfg.dynamic_crops) + self.dynamic_crops_epoches = eval(self.cfg.dynamic_crops_epoches) + assert len(self.dynamic_crops) == len(self.dynamic_crops_epoches) + if len(self.dynamic_crops) > 0: + assert self.dynamic_crops_epoches[0] == 1 + + self.cqt_loss_bin_dataloader = self.cfg.cqt_loss_bin_dataloader + + self.numpy_memmap_label = self.cfg.numpy_memmap_label + self.store_labels = self.cfg.store_labels + if self.numpy_memmap_label: + assert self.store_labels + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return None + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self.state.target_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return self.state.dictionaries + + @classmethod + def setup_task( + cls, cfg: MERTPretrainingConfig, **kwargs + ) -> "MERTPretrainingTask": + return cls(cfg) + + def load_dictionaries(self): + label_dir = self.cfg.data if (self.cfg.label_dir is None or self.cfg.label_dir == '') else self.cfg.label_dir + print(label_dir) + dictionaries = [ + Dictionary.load(f"{label_dir}/dict.{label}.txt") + for label in self.cfg.labels + ] + return dictionaries[0] if self.cfg.fine_tuning else dictionaries + + def get_label_dir(self) -> str: + if self.cfg.label_dir is None or self.cfg.label_dir=='': + return self.cfg.data + return self.cfg.label_dir + + # def has_sharded_data(self, split): + # """overwrite this function for let the trainier do dataset reload for changing the the dynamic croppings""" + # logger.info(f"check whether to re-load dataset for epoch {epoch} by overwritting task.has_sharded_data()") + # # find the threshold that holds epoch \in [threshold, next_threshold) + # is_reload_dataset = epoch in self.dynamic_crops_epoches + + # return os.pathsep in getattr(self.cfg, "data", "") or is_reload_dataset + # def is_force_load_dataset(self, epoch): + def is_force_load_dataset(self, epoch, training_restore=False): + # find the threshold that holds epoch \in [threshold, next_threshold) + return (epoch in self.dynamic_crops_epoches) or training_restore or (self.cfg.sharding_data > 1) + # for idx in range(len(self.dynamic_crops_epoches)): + # if (idx == len(self.dynamic_crops_epoches)-1) or \ + # (epoch >= self.dynamic_crops_epoches[idx] and epoch < self.dynamic_crops_epoches[idx+1]): + # return True + # return False + + def set_dynamic_crop_max_sample(self, epoch): + """ force to set the max_sample_size config for the dynamic cropping function""" + # pass + # @yizhilll: the parameter "epoch" is passed into this funciton in trainer.py#688, + # containing in "**kwargs" + # if 'train' in split: + # epoch = kwargs['epoch'] + + # find the threshold that holds epoch \in [threshold, next_threshold) + # is_reload_dataset = epoch in self.dynamic_crops_epoches # test again + # if is_reload_dataset: + if epoch in self.dynamic_crops_epoches: + for idx in range(len(self.dynamic_crops_epoches)): + if (idx == len(self.dynamic_crops_epoches)-1) or \ + (epoch >= self.dynamic_crops_epoches[idx] and epoch < self.dynamic_crops_epoches[idx+1]): + # set new cropping parameters and end loop + self.max_sample_size = self.dynamic_crops[idx]*self.cfg.sample_rate + self.cfg.max_sample_size = self.dynamic_crops[idx]*self.cfg.sample_rate + logger.info(f"epoch {epoch} forcely set new maximum audio length as {self.dynamic_crops[idx]}s == {self.max_sample_size} samples") + break + # logger.info(f'reloading dataset for changing the sequence length') + # self.load_dataset('train') + #TODO : 修改数据地址 + def load_dataset(self, split: str, **kwargs) -> None: + if len(list(filter(lambda s: s.endswith('.scp'), os.listdir(self.cfg.data)))) > 0: + return self.load_dataset_ark(split, **kwargs) + else: + return self.load_dataset_mert(split, **kwargs) + + def load_dataset_ark(self, split, **kwargs): + if 'train' not in split: + logger.info(f'split {split} is only used for training') + # raise ValueError(f"No support for split: {split}") + else: + self.datasets[split] = ArkDataset( + wav_scp=os.path.join(self.cfg.data, f"wav_ark.scp"), + dur_scp=os.path.join(self.cfg.data, f"dur_ark.scp"), + sr=self.cfg.sample_rate, + ) + + def load_dataset_mert(self, split: str, **kwargs) -> None: + if 'train' in split: + epoch = kwargs['epoch'] + # the epoch to change crops + if self.is_force_load_dataset(epoch): + self.set_dynamic_crop_max_sample(epoch) + + # load all training data + if self.cfg.sharding_data <= 1: + # manifest = f"{self.cfg.data}/{split}.tsv" + manifest = f"{self.cfg.data}/{split}.json" + + paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels] + # load part of the training data + else: + if self.cfg.load_random_data_shard: + data_shard_idx = np.random.randint(self.cfg.sharding_data) + else: + data_shard_idx = (epoch-1) % self.cfg.sharding_data # epoch start from 1 + assert data_shard_idx < self.cfg.sharding_data + logger.info(f'loading shard {data_shard_idx} of {self.cfg.sharding_data} training data for ecpoh {epoch}') + + # manifest = f"{self.cfg.data}/{split}_{data_shard_idx}_{self.cfg.sharding_data}.tsv" + manifest = f"{self.cfg.data}/{split}_{data_shard_idx}_{self.cfg.sharding_data}.json" + + paths = [f"{self.get_label_dir()}/{split}_{data_shard_idx}_{self.cfg.sharding_data}.{l}" for l in self.cfg.labels] + else: + # manifest = f"{self.cfg.data}/{split}.tsv" + manifest = f"{self.cfg.data}/{split}.json" + + paths = [f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels] + + dicts = [self.target_dictionary] if self.cfg.fine_tuning else self.dictionaries + pad_list = [dict.pad() for dict in dicts] + eos_list = [dict.eos() for dict in dicts] + + if self.numpy_memmap_label: + procs = [PaddedNumpyLabelEncoder() for dict in dicts] + else: + procs = [LabelEncoder(dict) for dict in dicts] + + self.datasets[split] = MERTDataset( + manifest, + sample_rate=self.cfg.sample_rate, + label_paths=paths, # this containes the ensemble label sequence names + label_rates=self.cfg.label_rate, + pad_list=pad_list, + eos_list=eos_list, + label_processors=procs, + max_keep_sample_size=self.cfg.max_keep_size, + min_keep_sample_size=self.cfg.min_sample_size, + max_sample_size=self.max_sample_size, + pad_audio=self.cfg.pad_audio, + normalize=self.cfg.normalize, + store_labels=self.store_labels, + npmemmap=self.numpy_memmap_label, + random_crop=self.cfg.random_crop, + single_target=self.cfg.single_target, + augmentation_effects=self.augmentation_effects, + augmentation_probs=self.augmentation_probs, + inbatch_noise_augment_len_range=self.inbatch_noise_augment_len_range, + inbatch_noise_augment_number_range=self.inbatch_noise_augment_number_range, + inbatch_noise_augment_volume=self.cfg.inbatch_noise_augment_volume, + cqt_prediction_bin=self.cqt_loss_bin_dataloader, + clip_secs=self.cfg.clip_secs, + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size(self, indices: np.array, *args, **kwargs) -> np.array: + return indices diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/pretraining_AS2M.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/pretraining_AS2M.py new file mode 100644 index 0000000000000000000000000000000000000000..4073a364ca1e361f088a53d7301089ce70d1dfea --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq/tasks/pretraining_AS2M.py @@ -0,0 +1,141 @@ +# Copyright (c) 2017-present, Facebook, Inc. +# All rights reserved. +# +# This source code is licensed under the license found in the LICENSE file in +# the root directory of this source tree. An additional grant of patent rights +# can be found in the PATENTS file in the same directory. + +import logging +import sys + +from typing import Optional, List +from dataclasses import dataclass, field +from omegaconf import MISSING, II + +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task + +try: + from ..data.eat_data import MaeImageDataset +except: + import sys, os + sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), '..')) + from data.eat_data.mae_image_dataset import MaeImageDataset + +logger = logging.getLogger(__name__) + + +@dataclass +class ImageMaskingConfig: + patch_size: int = II("model.modalities.image.patch_size") + mask_prob: float = II("model.modalities.image.mask_prob") + mask_prob_adjust: float = II("model.modalities.image.mask_prob_adjust") + mask_length: int = II("model.modalities.image.mask_length") + inverse_mask: bool = II("model.modalities.image.inverse_mask") + mask_dropout: float = II("model.modalities.image.mask_dropout") + clone_batch: int = II("model.clone_batch") + expand_adjacent: bool = False + non_overlapping: bool = False + + +@dataclass +class MaeImagePretrainingConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + multi_data: Optional[List[str]] = None + input_size: int = 224 + local_cache_path: Optional[str] = None + key: str = "imgs" + beit_transforms: bool = False + target_transform: bool = False + no_transform: bool = False + + rebuild_batches: bool = True + precompute_mask_config: Optional[ImageMaskingConfig] = None + subsample: float = 1 + seed: int = II("common.seed") + dataset_type: str = "imagefolder" + + audio_mae: bool = field(default=False,metadata={"help": "if set, we use image_mae way to deal with audio files."}) + h5_format: bool = field(default=False,metadata={"help": "if set, dataset will read data file in h5df format."}) + downsr_16hz: bool = field(default=False,metadata={"help": "if set, wav file's sample rate will be reduced to 16kHz."}) + target_length: int = field(default=1024,metadata={"help": "This setting will pad the audio spectrogram with zeros."}) + flexible_mask: bool = field(default=False, metadata={"help": "if true, we will using flexible inverse block mask method."}) + + esc50_eval: bool = field(default=False, metadata={"help": "if true, the task is to finetune model on esc50 dataset."}) + spcv2_eval: bool = field(default=False, metadata={"help": "if true, the task is to finetune model on speech command v2 dataset."}) + AS2M_finetune: bool = field(default=False, metadata={"help": "if true, the task is to finetune model on Audioset 2M with weighted sample."}) + spcv1_finetune: bool = field(default=False, metadata={"help": "if true, the task is to finetune model on speech commands v1 with weighted sample."}) + roll_aug: bool = field(default=False, metadata={"help": "if true, we will use roll aug in fine-tuning."}) + noise: bool = field(default=False, metadata={"help": "if true, we will add gaussian noise as augmentation during fine-tuning."}) + weights_file : str = field(default="", metadata={"help": "the path of weighted sample file"}) + num_samples: int = field(default=200000, metadata={"help": "this setting will determine the number of samples in each epoch, usually used in unbalanced training."}) + is_finetuning: bool = field(default=False, metadata={"help": "this property has been deprecated"}) + + sample_rate: int = field(default=24000) + fixed_duration: float = field(default=30.0) + + + + +@register_task("mae_image_pretraining", dataclass=MaeImagePretrainingConfig) +class MaeImagePretrainingTask(FairseqTask): + """ """ + + cfg: MaeImagePretrainingConfig + + @classmethod + def setup_task(cls, cfg: MaeImagePretrainingConfig, **kwargs): + """Setup the task (e.g., load dictionaries). + + Args: + cfg (AudioPretrainingConfig): configuration of this task + """ + + return cls(cfg) + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + cfg = task_cfg or self.cfg + + + compute_mask = cfg.precompute_mask_config is not None + mask_args = {} + if compute_mask: + mask_args = cfg.precompute_mask_config + + self.datasets[split] = MaeImageDataset( + root=data_path if cfg.multi_data is None else cfg.multi_data, + split=split, + input_size=cfg.input_size, + key=cfg.key, + compute_mask=compute_mask, + dataset_type=cfg.dataset_type, + audio_mae=cfg.audio_mae, + downsr_16hz=cfg.downsr_16hz, + h5_format=cfg.h5_format, + esc50_eval=cfg.esc50_eval, + spcv2_eval=cfg.spcv2_eval, + roll_aug=cfg.roll_aug and split == 'train', + target_length=cfg.target_length, + noise=cfg.noise, + AS2M_finetune=cfg.AS2M_finetune, + spcv1_finetune=cfg.spcv1_finetune, + num_samples=cfg.num_samples, + weights_file=cfg.weights_file, + flexible_mask=cfg.flexible_mask, + sample_rate=cfg.sample_rate, + fixed_duration=cfg.fixed_duration, + **mask_args, + ) + + @property + def source_dictionary(self): + return None + + @property + def target_dictionary(self): + return None + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return sys.maxsize, sys.maxsize diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/modify_env.md b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/modify_env.md new file mode 100644 index 0000000000000000000000000000000000000000..417c6e0cab4f072e297d371950f2534105ddedec --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/modify_env.md @@ -0,0 +1,4 @@ +cp -r fairseq/fairseq/model_parallel/megatron /opt/conda/envs/map/lib/python3.8/site-packages/fairseq/model_parallel/ +vi /opt/conda/envs/map/lib/python3.8/site-packages/apex/amp/_initialize.py # string_classes = str +vi /opt/conda/envs/map/lib/python3.8/site-packages/fairseq/modules/layer_norm.py +vi /opt/conda/envs/map/lib/python3.8/site-packages/fairseq/distributed/utils.py # import datetime; timeout=datetime.timedelta(seconds=51200); logger.info("add nccl time to 51200") diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_eat.sh b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_eat.sh new file mode 100644 index 0000000000000000000000000000000000000000..d7cf006b58dec68fa9c93b83684e03393f231d87 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_eat.sh @@ -0,0 +1,72 @@ +WORKER_RANK=${1:-$INDEX} +PLATFORM=${2:-'shef'} +YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'} +TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'} +MASTER_PROC_ADD=${5:-$CHIEF_IP} +DIST_PORT=${6:-'25520'} +# echo $PATH +# export PATH=$PATH:./ +echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}" + +MAP_PROJ_DIR=$(pwd) +echo $MAP_PROJ_DIR + +NNODS=1 +BATCH_SIZE=12 +NUM_WOKERS=6 + +run_command_prefix=' ' +# Loading folders +# 1. tsv files for audio paths +# DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv +DATA_DIR=${MAP_PROJ_DIR}/data/music4all_sh #audio_manifest +# 2. working folder for saving checkpoints and loading config files +CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain +# 3. clustering labels for training data +LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/encodec_labels/custom_audio_dataset + +FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq; +SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/ + +case $YAML_NAME_WITHOUT_EXT in + EAT_pretraining_music_multinodes) + NNODS=4 + NPROCES_PER_NODE=8 + LABEL_RATE=25 + BATCH_SIZE=12 + ;; + *) + echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT}" + exit 1 + ;; + esac + +echo running $YAML_NAME_WITHOUT_EXT .. + +mkdir -p ${SAVE_DIR} +echo "checkpoint save at: ${SAVE_DIR}" +cd ${SAVE_DIR} + +DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}` +ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}` +echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}" + +DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"` + +OMP_NUM_THREADS=6 ${run_command_prefix} \ +python -u ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \ +--config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT} \ +common.user_dir=${MAP_PROJ_DIR}/mert_fairseq \ +common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \ +checkpoint.save_dir=${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT} \ +distributed_training.distributed_rank=${ACTUAL_WORKER_RANK} \ +distributed_training.distributed_world_size=${DISTRIBUTED_WORLD_SIZE} \ +distributed_training.distributed_num_procs=${DISTRIBUTED_WORLD_SIZE} \ +distributed_training.nprocs_per_node=${NPROCES_PER_NODE} \ +distributed_training.distributed_init_method="tcp://${CHIEF_IP}:${DIST_PORT}" \ +task.data=${DATA_DIR} \ +dataset.num_workers=${NUM_WOKERS} \ +dataset.batch_size=${BATCH_SIZE} \ +dataset.disable_validation=true \ + +# pip install h5py timm -i https://mirrors.tencent.com/pypi/simple/ \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_mulNodes_wotorchdist_womodelparsize.sh b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_mulNodes_wotorchdist_womodelparsize.sh new file mode 100644 index 0000000000000000000000000000000000000000..238d6a28678f4b272c097ce7b91f8161e22543ff --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_mulNodes_wotorchdist_womodelparsize.sh @@ -0,0 +1,177 @@ +# bash run_training_mulNodes_wotorchdist.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes +# bash run_training_mulNodes_wotorchdist.sh 1 dummy MERT_RVQ-VAE_CQT_330M_multinodes +# bash run_training_mulNodes_wotorchdist.sh 2 dummy MERT_RVQ-VAE_CQT_330M_multinodes +# bash run_training_mulNodes_wotorchdist.sh 3 dummy MERT_RVQ-VAE_CQT_330M_multinodes + +# the rank of distributed node worker +# If I use two nodes, 4 gpus per each, then WORKER_RANK for the two node should be 0, 4, i.e. the starting indice of the GPU. +WORKER_RANK=${1:-$INDEX} +PLATFORM=${2:-'shef'} +YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'} +TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'} +MASTER_PROC_ADD=${5:-$CHIEF_IP} +DIST_PORT=${6:-'25520'} +DATASET_NAME=${7:-'dataindex'} +# echo $PATH +# export PATH=$PATH:./ +echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}" + +MAP_PROJ_DIR=$(pwd) +echo $MAP_PROJ_DIR + +NNODS=1 +# MAX_TOKENS=1000000 # set for 80GB A100 batchsize +NUM_WOKERS=6 + +run_command_prefix=' ' +# Loading folders +# 1. tsv files for audio paths +# DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv +DATA_DIR=${MAP_PROJ_DIR}/data/${DATASET_NAME} #audio_manifest +# 2. working folder for saving checkpoints and loading config files +CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain +# 3. clustering labels for training data +LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/encodec_labels/custom_audio_dataset + +FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq; +SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/ + +# set 75 for the RVQ-VAE model +LABEL_RATE=75 + +case $YAML_NAME_WITHOUT_EXT in + MERT_RVQ-VAE_CQT_95M) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=1 + LABEL_RATE=75 + MAX_TOKENS=1800000 + ;; + MERT_RVQ-VAE_CQT_95M_mel_multinodes) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=4 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=1200000 + ;; + MERT_RVQ-VAE_CQT_95M_bestrq_multinodes) + TASK_LABELS_POSTFIX='["rq_0"]' + NNODS=4 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=1200000 + ;; + MERT_RVQ-VAE_CQT_95M_bestrq_chroma_multinodes) + TASK_LABELS_POSTFIX='["rq_0"]' + NNODS=4 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=1600000 + ;; + MERT_RVQ-VAE_CQT_95M_bestrq_norm_multinodes) + TASK_LABELS_POSTFIX='["rq_0"]' + NNODS=4 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=1600000 + ;; + MERT_RVQ-VAE_CQT_95M_dac_multinodes) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=4 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=1600000 + ;; + MERT_RVQ-VAE_CQT_95M_groupbestrq_multinodes) + TASK_LABELS_POSTFIX='["grq_0"]' + NNODS=4 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=1600000 + ;; + MusicFM_95M_multinodes) + TASK_LABELS_POSTFIX='["grq_0"]' + NNODS=4 + LABEL_RATE=25 + NPROCES_PER_NODE=8 + MAX_TOKENS=4800000 + ;; + MusicFM_95M_bestrvq_multinodes) + TASK_LABELS_POSTFIX='["grq_0"]' + NNODS=4 + LABEL_RATE=25 + NPROCES_PER_NODE=8 + MAX_TOKENS=4800000 + ;; + MusicFM_95M_speech_multinodes) + TASK_LABELS_POSTFIX='[]' + NNODS=4 + LABEL_RATE=25 + NPROCES_PER_NODE=8 + MAX_TOKENS=1200000 + ;; + MERT_RVQ-VAE_CQT_330M) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=1 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=720000 + ;; + MERT_RVQ-VAE_CQT_330M_multinodes) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=4 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=600000 + ;; + MERT_RVQ-VAE_CQT_330M_multinodes_debug2node) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=2 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=600000 + ;; + MERT_RVQ-VAE_CQT_330M_multinodes_debug1node) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=1 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=600000 + ;; + *) + echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT} = ${YAML_NAME_WITHOUT_EXT}" + exit 1 + ;; + esac + + echo running $YAML_NAME_WITHOUT_EXT .. + + mkdir -p ${SAVE_DIR} + echo "checkpoint save at: ${SAVE_DIR}" + cd ${SAVE_DIR} + + echo "NPROCES_PER_NODE is ${NPROCES_PER_NODE}" + + DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}` + ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}` + echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}" + + DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"` + + OMP_NUM_THREADS=6 ${run_command_prefix} \ + python -u ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \ + --config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT} \ + common.user_dir=${MAP_PROJ_DIR}/mert_fairseq \ + common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \ + checkpoint.save_dir=${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT} \ + distributed_training.distributed_rank=${ACTUAL_WORKER_RANK} \ + distributed_training.distributed_world_size=${DISTRIBUTED_WORLD_SIZE} \ + distributed_training.distributed_num_procs=${DISTRIBUTED_WORLD_SIZE} \ + distributed_training.nprocs_per_node=${NPROCES_PER_NODE} \ + distributed_training.distributed_init_method="tcp://${CHIEF_IP}:${DIST_PORT}" \ + task.data=${DATA_DIR} \ + task.label_dir=${LABEL_DIR} \ + task.labels=${TASK_LABELS_POSTFIX} \ + dataset.num_workers=${NUM_WOKERS} \ + dataset.max_tokens=${MAX_TOKENS} \ + dataset.disable_validation=true \ + model.label_rate=${LABEL_RATE} \ diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_orig.sh b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_orig.sh new file mode 100644 index 0000000000000000000000000000000000000000..889ae0efcd041defce188823f2d822ce84417f55 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_orig.sh @@ -0,0 +1,79 @@ +# the rank of distributed node worker +# If I use two nodes, 4 gpus per each, then WORKER_RANK for the two node should be 0, 4, i.e. the starting indice of the GPU. +WORKER_RANK=${1:-'0'} +PLATFORM=${2:-'shef'} +YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'} +TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'} +MASTER_PROC_ADD=${5:-'127.0.0.1'} +DIST_PORT=${6:-'39683'} + +echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}" + +MAP_PROJ_DIR=$HOME/MERT + +DISTRIBUTED_WORLD_SIZE=2 +NPROCES_PER_NODE=2 +MAX_TOKENS=1000000 # set for 80GB A100 +NUM_WOKERS=6 + +run_command_prefix=' ' +# Loading folders +# 1. tsv files for audio paths +DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv +# 2. working folder for saving checkpoints and loading config files +CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain +# 3. clustering labels for training data +LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/labels + + +FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq; +SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/ + +# set 75 for the RVQ-VAE model +LABEL_RATE=75 + +case $YAML_NAME_WITHOUT_EXT in + MERT_RVQ-VAE_CQT_95M) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + DISTRIBUTED_WORLD_SIZE=8 + NPROCES_PER_NODE=1 + LABEL_RATE=75 + MAX_TOKENS=1800000 + ;; + MERT_RVQ-VAE_CQT_330M) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + DISTRIBUTED_WORLD_SIZE=64 + NPROCES_PER_NODE=8 + LABEL_RATE=75 + MAX_TOKENS=920000 + ;; + *) + echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT}" + exit 1 + ;; + esac + + echo running $YAML_NAME_WITHOUT_EXT .. + + mkdir -p ${SAVE_DIR} + echo "checkpoint save at: ${SAVE_DIR}" + cd ${SAVE_DIR} + + ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}` + echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}" + + OMP_NUM_THREADS=6 ${run_command_prefix} python -u ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \ + --config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT} \ + common.user_dir=${MAP_PROJ_DIR}/mert_faiseq \ + common.wandb_project=pretrain_${TRAINING_SETTING} \ + checkpoint.save_dir=${SAVE_DIR}/ckpt_${TRAINING_SETTING}/${YAML_NAME_WITHOUT_EXT} \ + distributed_training.distributed_rank=${ACTUAL_WORKER_RANK} \ + distributed_training.distributed_world_size=${DISTRIBUTED_WORLD_SIZE} \ + distributed_training.nprocs_per_node=${NPROCES_PER_NODE} \ + distributed_training.distributed_init_method="tcp://${MASTER_PROC_ADD}:${DIST_PORT}" \ + task.data=${DATA_DIR} task.label_dir=${LABEL_DIR} \ + task.labels=${TASK_LABELS_POSTFIX} \ + dataset.num_workers=${NUM_WOKERS} \ + dataset.max_tokens=${MAX_TOKENS} \ + dataset.disable_validation=true \ + model.label_rate=${LABEL_RATE} \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_sglNodes.sh b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_sglNodes.sh new file mode 100644 index 0000000000000000000000000000000000000000..6a9f6d873f888e942474abb141c9ad56e01e5b7f --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/run_training_sglNodes.sh @@ -0,0 +1,115 @@ +# bash run_training_sglNodes.sh 0 dummy MERT_RVQ-VAE_CQT_330M_multinodes_debug1node + +# the rank of distributed node worker +# If I use two nodes, 4 gpus per each, then WORKER_RANK for the two node should be 0, 4, i.e. the starting indice of the GPU. +WORKER_RANK=${1:-'0'} +PLATFORM=${2:-'shef'} +YAML_NAME_WITHOUT_EXT=${3:-'MERT_RVQ-VAE_CQT_95M'} +TRAINING_SETTING=${4:-'MERT_RVQ-VAE_CQT'} +MASTER_PROC_ADD=${5:-'127.0.0.1'} +DIST_PORT=${6:-'39685'} +# echo $PATH +# export PATH=$PATH:./ +echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}" + +MAP_PROJ_DIR=$(pwd) +echo $MAP_PROJ_DIR + +NNODS=1 +MAX_TOKENS=1000000 # set for 80GB A100 batchsize +NUM_WOKERS=0 + +run_command_prefix=' ' +# Loading folders +# 1. tsv files for audio paths +# DATA_DIR=${MAP_PROJ_DIR}/data/audio_tsv +DATA_DIR=${MAP_PROJ_DIR}/data/music4all_sh #audio_manifest +# 2. working folder for saving checkpoints and loading config files +CONFIG_DIR=/${MAP_PROJ_DIR}/mert_fairseq/config/pretrain +# 3. clustering labels for training data +LABEL_ROOT_DIR=${MAP_PROJ_DIR}/data/encodec_labels/custom_audio_dataset + +FAIRSEQ_PATH=${MAP_PROJ_DIR}/src/fairseq; +SAVE_DIR=${MAP_PROJ_DIR}/data/fairseq_savedir/ + +# set 75 for the RVQ-VAE model +LABEL_RATE=75 + +case $YAML_NAME_WITHOUT_EXT in + MERT_RVQ-VAE_CQT_95M) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=1 + LABEL_RATE=75 + MAX_TOKENS=1800000 + ;; + MERT_RVQ-VAE_CQT_95M_bestrq) + TASK_LABELS_POSTFIX='["rq_0"]' + NNODS=1 + LABEL_RATE=75 + MAX_TOKENS=1200000 + ;; + MERT_RVQ-VAE_CQT_330M) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=1 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=720000 + ;; + MERT_RVQ-VAE_CQT_330M_multinodes) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=4 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=600000 + ;; + MERT_RVQ-VAE_CQT_330M_multinodes_debug2node) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=2 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=600000 + ;; + MERT_RVQ-VAE_CQT_330M_multinodes_debug1node) + TASK_LABELS_POSTFIX='["encodec_0","encodec_1","encodec_2","encodec_3","encodec_4","encodec_5","encodec_6","encodec_7"]' + NNODS=1 + LABEL_RATE=75 + NPROCES_PER_NODE=8 + MAX_TOKENS=600000 + ;; + *) + echo "Unknown running config: ${$YAML_NAME_WITHOUT_EXT}" + exit 1 + ;; + esac + + echo running $YAML_NAME_WITHOUT_EXT .. + + mkdir -p ${SAVE_DIR} + echo "checkpoint save at: ${SAVE_DIR}" + cd ${SAVE_DIR} + + DISTRIBUTED_WORLD_SIZE=`expr ${NNODS} \* ${NPROCES_PER_NODE}` + ACTUAL_WORKER_RANK=`expr ${WORKER_RANK} \* ${NPROCES_PER_NODE}` + echo "worker rank ${WORKER_RANK}, master address ${MASTER_PROC_ADD}:${DIST_PORT}, actual rank ${ACTUAL_WORKER_RANK}" + + DATE_SUFFIX=`date +"%Y-%m-%d_%H-%M"` + CKPT_SAVE_DIR="${SAVE_DIR}/ckpt_${TRAINING_SETTING}_multinodes${NNODS}_${DATE_SUFFIX}/${YAML_NAME_WITHOUT_EXT}" + + OMP_NUM_THREADS=6 ${run_command_prefix} \ + python -u -m torch.distributed.launch --use_env \ + --nproc_per_node=8 --nnodes=${NNODS} --node_rank=${INDEX} \ + --master_addr=${CHIEF_IP} --master_port=25521 \ + ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py -m \ + --config-dir ${CONFIG_DIR} --config-name ${YAML_NAME_WITHOUT_EXT}\ + common.user_dir=${MAP_PROJ_DIR}/mert_fairseq \ + common.tensorboard_logdir=${MAP_PROJ_DIR}/logs/pretrain_tb_${TRAINING_SETTING}_${YAML_NAME_WITHOUT_EXT}_multinodes${NNODS} \ + task.data=${DATA_DIR}\ + task.label_dir=${LABEL_DIR} \ + task.labels=${TASK_LABELS_POSTFIX} \ + dataset.num_workers=${NUM_WOKERS} \ + dataset.max_tokens=${MAX_TOKENS} \ + dataset.disable_validation=true \ + model.label_rate=${LABEL_RATE}\ + checkpoint.save_dir=${CKPT_SAVE_DIR} \ + checkpoint.restore_file="checkpoint_last.pt" + \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/test.py b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/test.py new file mode 100644 index 0000000000000000000000000000000000000000..993b0597e8e2e8bcffafe90ad88f742b82550739 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/test.py @@ -0,0 +1,25 @@ +import torch +from dataclasses import dataclass +from logging import getLogger +import torch.nn.functional as F +import fairseq.utils +from fairseq.checkpoint_utils import load_model_ensemble_and_task + +logger = getLogger(__name__) + +@dataclass +class UserDirModule: + user_dir: str + +def load_model(model_dir, checkpoint_dir): + '''Load Fairseq SSL model''' + + #导入模型所在的代码模块 + model_path = UserDirModule(model_dir) + fairseq.utils.import_user_module(model_path) + + #载入模型的checkpoint + model, cfg, task = load_model_ensemble_and_task([checkpoint_dir], strict=False) + model = model[0] + + return model diff --git a/codeclm/tokenizer/Flow1dVAE/tools/__init__.py b/codeclm/tokenizer/Flow1dVAE/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/codeclm/tokenizer/Flow1dVAE/tools/check_stereo.py b/codeclm/tokenizer/Flow1dVAE/tools/check_stereo.py new file mode 100644 index 0000000000000000000000000000000000000000..be7c2ff166a0f2c655c0b4cbbd110e8c469f2a37 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/check_stereo.py @@ -0,0 +1,59 @@ +''' +TAMPLEATE = { + "path": "" + "duration": "" + "sample_rate": "" + "amplitude": null, + "weight": null, + "info_path": null +} +''' +import torchaudio +import json +from tqdm import tqdm + +import torchaudio +import numpy as np +import torch, torch.nn as nn, random +from torchaudio import transforms +import os +import argparse +from tqdm import tqdm +import torchaudio +from torchaudio.transforms import Resample +from multiprocessing import Pool + +def preprocess(args, wav_json, thread_id): + # f = open("pretrain_tme_20230927.scp").readlines() + f = open("out.{}".format(thread_id), 'w') + for line in tqdm(wav_json): + try: + # import pdb; pdb.set_trace() + line = line.strip() + wav_info = json.loads(line) + meta = torchaudio.info(wav_info["path"]) + + wav_info["num_channels"] = meta.num_channels + json_string = json.dumps(wav_info) + # print(json_string) + f.write("{}\n".format(json_string)) + except: + print(line) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Deep Speaker Embedding Inference') + parser.add_argument('--wav_json', type=str) + parser.add_argument('--num_thread', default=10, type=int, help='random seed') + args = parser.parse_args() + + wav_json_total = open(args.wav_json).readlines() + args.num_thread = min(len(wav_json_total), args.num_thread) + wav_json_list = np.array_split(wav_json_total, args.num_thread) + + p = Pool(args.num_thread) + for thread_id, wav_json in enumerate(wav_json_list): + r = p.apply_async(preprocess, (args, wav_json, thread_id)) + p.close() + p.join() + r.get() diff --git a/codeclm/tokenizer/Flow1dVAE/tools/compare_2models.py b/codeclm/tokenizer/Flow1dVAE/tools/compare_2models.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd0385c8cb8027a99311ff9a0f816bc672ff6a0 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/compare_2models.py @@ -0,0 +1,20 @@ +import torch +import sys + +if __name__=="__main__": + m1, m2 = sys.argv[1:3] + m1 = torch.load(m1, map_location = 'cpu') + m2 = torch.load(m2, map_location = 'cpu') + m1_keys = set(m1.keys()) + m2_keys = set(m2.keys()) + + m1_uniq_keys = m1_keys - m2_keys + m2_uniq_keys = m2_keys - m1_keys + m12_shared_keys = m1_keys & m2_keys + + print("m1_uniq_keys: ", m1_uniq_keys) + print("m2_uniq_keys: ", m2_uniq_keys) + print("m12_shared_keys but different: ") + for k in m12_shared_keys: + if(m1[k].numel() != m2[k].numel()): + print(k,m1[k].shape,m2[k].shape) diff --git a/codeclm/tokenizer/Flow1dVAE/tools/creat_jsonl.py b/codeclm/tokenizer/Flow1dVAE/tools/creat_jsonl.py new file mode 100644 index 0000000000000000000000000000000000000000..7f8bb4ae22c431955a5be6e98edef64a1844ba4b --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/creat_jsonl.py @@ -0,0 +1,71 @@ +''' +TAMPLEATE = { + "path": "" + "duration": "" + "sample_rate": "" + "amplitude": null, + "weight": null, + "info_path": null +} +''' +import torchaudio +import json +from tqdm import tqdm + +import torchaudio +import numpy as np +import torch, torch.nn as nn, random +from torchaudio import transforms +import os +import argparse +from tqdm import tqdm +import torchaudio +from torchaudio.transforms import Resample +from multiprocessing import Pool + +def preprocess(args, wav_scp, thread_id): + # f = open("pretrain_tme_20230927.scp").readlines() + f = open("out.{}".format(thread_id), 'w') + for line in tqdm(wav_scp): + try: + # import pdb; pdb.set_trace() + line = line.strip() + meta = torchaudio.info(line) + duration = meta.num_frames / float(meta.sample_rate) + sr = meta.sample_rate + + # json_path = line.replace(".flac", ".json") + # with open(json_path, encoding='utf-8') as fh: + # data = json.load(fh) + # duration = data['duration'] + wav_info = { + "path": line, + "duration": duration, + "sample_rate": sr, + "amplitude": None, + "weight": None, + "info_path": None + } + json_string = json.dumps(wav_info) + # print(json_string) + f.write("{}\n".format(json_string)) + except: + print(line) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Deep Speaker Embedding Inference') + parser.add_argument('--wav_scp', type=str) + parser.add_argument('--num_thread', default=10, type=int, help='random seed') + args = parser.parse_args() + + wav_scp_total = open(args.wav_scp).readlines() + args.num_thread = min(len(wav_scp_total), args.num_thread) + wav_scp_list = np.array_split(wav_scp_total, args.num_thread) + + p = Pool(args.num_thread) + for thread_id, wav_scp in enumerate(wav_scp_list): + r = p.apply_async(preprocess, (args, wav_scp, thread_id)) + p.close() + p.join() + r.get() diff --git a/codeclm/tokenizer/Flow1dVAE/tools/extract_rvq.py b/codeclm/tokenizer/Flow1dVAE/tools/extract_rvq.py new file mode 100644 index 0000000000000000000000000000000000000000..6a58f76ae5b7c8a3a51751a499d115b1caa40214 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/extract_rvq.py @@ -0,0 +1,15 @@ +import torch +import sys + +if __name__=="__main__": + p = sys.argv[1] + bd = '/'.join(p.split('/')[:-1]) + bn = p.split('/')[-1] + + d = {} + m = torch.load(p, map_location='cpu') + for k in m.keys(): + if('rvq' in k): + d[k] = m[k] + + torch.save(d, '{}/rvq.bin'.format(bd)) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae.py b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7af1b8470bc18050db3e0091cea93214bafbbb --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae.py @@ -0,0 +1,15 @@ +import torch +from tqdm import tqdm +import torchaudio +from third_party.stable_audio_tools.stable_audio_tools.models.autoencoders import create_autoencoder_from_config +import numpy as np +import os +import json + +def get_model(model_config, path): + with open(model_config) as f: + model_config = json.load(f) + state_dict = torch.load(path) + model = create_autoencoder_from_config(model_config) + model.load_state_dict(state_dict['state_dict']) + return model \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_1920.py b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_1920.py new file mode 100644 index 0000000000000000000000000000000000000000..ef7af1b8470bc18050db3e0091cea93214bafbbb --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_1920.py @@ -0,0 +1,15 @@ +import torch +from tqdm import tqdm +import torchaudio +from third_party.stable_audio_tools.stable_audio_tools.models.autoencoders import create_autoencoder_from_config +import numpy as np +import os +import json + +def get_model(model_config, path): + with open(model_config) as f: + model_config = json.load(f) + state_dict = torch.load(path) + model = create_autoencoder_from_config(model_config) + model.load_state_dict(state_dict['state_dict']) + return model \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large.py b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large.py new file mode 100644 index 0000000000000000000000000000000000000000..a80fcec2c9e579f415d274dc46a0447e4c6477ee --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large.py @@ -0,0 +1,15 @@ +import torch +from tqdm import tqdm +import torchaudio +from third_party.stable_audio_tools.stable_audio_tools.models.autoencoders import create_autoencoder_from_config +import numpy as np +import os +import json + +def get_model(model_config, path): + with open(model_config) as f: + model_config = json.load(f) + state_dict = torch.load(path, map_location='cpu') + model = create_autoencoder_from_config(model_config) + model.load_state_dict(state_dict['state_dict'], strict=False) + return model diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large_melvae.py b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large_melvae.py new file mode 100644 index 0000000000000000000000000000000000000000..86e701e994151e18a5c824e9c9feb0d575837f20 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_1dvae_large_melvae.py @@ -0,0 +1,15 @@ +import torch +from tqdm import tqdm +import torchaudio +from third_party.stable_audio_tools.stable_audio_tools.models.autoencoders import create_autoencoder_from_config +import numpy as np +import os +import json + +def get_model(model_config, path): + with open(model_config) as f: + model_config = json.load(f) + state_dict = torch.load(path, map_location='cpu') + model = create_autoencoder_from_config(model_config) + model.load_state_dict(state_dict['state_dict'], strict=False) + return model \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae.py b/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae.py new file mode 100644 index 0000000000000000000000000000000000000000..a850487b5219187a5335c760e1ed0220570798a6 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae.py @@ -0,0 +1,395 @@ +"""! +@author Yi Luo (oulyluo) +@copyright Tencent AI Lab +""" + +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch.utils.checkpoint import checkpoint_sequential +from thop import profile, clever_format + +class RMVN(nn.Module): + """ + Rescaled MVN. + """ + def __init__(self, dimension, groups=1): + super(RMVN, self).__init__() + + self.mean = nn.Parameter(torch.zeros(dimension)) + self.std = nn.Parameter(torch.ones(dimension)) + self.groups = groups + self.eps = torch.finfo(torch.float32).eps + + def forward(self, input): + # input size: (B, N, T) + B, N, T = input.shape + assert N % self.groups == 0 + + input = input.view(B, self.groups, -1, T) + input_norm = (input - input.mean(2).unsqueeze(2)) / (input.var(2).unsqueeze(2) + self.eps).sqrt() + input_norm = input_norm.view(B, N, T) * self.std.view(1, -1, 1) + self.mean.view(1, -1, 1) + + return input_norm + +class ConvActNorm1d(nn.Module): + def __init__(self, in_channel, hidden_channel, kernel=7, causal=False): + super(ConvActNorm1d, self).__init__() + + self.in_channel = in_channel + self.kernel = kernel + self.causal = causal + if not causal: + self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2), + RMVN(in_channel), + nn.Conv1d(in_channel, hidden_channel*2, 1), + nn.GLU(dim=1), + nn.Conv1d(hidden_channel, in_channel, 1) + ) + else: + self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1), + RMVN(in_channel), + nn.Conv1d(in_channel, hidden_channel*2, 1), + nn.GLU(dim=1), + nn.Conv1d(hidden_channel, in_channel, 1) + ) + + def forward(self, input): + + output = self.conv(input) + if self.causal: + output = output[...,:-self.kernel+1].contiguous() + return input + output + +class ICB(nn.Module): + def __init__(self, in_channel, kernel=7, causal=False): + super(ICB, self).__init__() + + self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), + ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), + ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal) + ) + + def forward(self, input): + + return self.blocks(input) + +class ResRNN(nn.Module): + def __init__(self, input_size, hidden_size, bidirectional=False): + super(ResRNN, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.eps = torch.finfo(torch.float32).eps + + self.norm = RMVN(input_size) + self.rnn = nn.LSTM(input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional) + + self.proj = nn.Linear(hidden_size*(int(bidirectional)+1), input_size) + + def forward(self, input, use_head=1): + # input shape: batch, dim, seq + + B, N, T = input.shape + + rnn_output, _ = self.rnn(self.norm(input).transpose(1,2).contiguous()) + + output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])) + output = output.view(B, T, -1).transpose(1,2).contiguous() + + return input + output + +class BSNet(nn.Module): + def __init__(self, feature_dim, kernel=7, causal=False): + super(BSNet, self).__init__() + + self.feature_dim = feature_dim + + self.seq_net = ICB(self.feature_dim, kernel=kernel, causal=causal) + self.band_net = ResRNN(self.feature_dim, self.feature_dim*2, bidirectional=True) + + def forward(self, input): + # input shape: B, nband, N, T + + B, nband, N, T = input.shape + + band_output = self.seq_net(input.view(B*nband, N, T)).view(B, nband, -1, T) + + # band comm + band_output = band_output.permute(0,3,2,1).contiguous().view(B*T, -1, nband) + output = self.band_net(band_output).view(B, T, -1, nband).permute(0,3,2,1).contiguous() + + return output.view(B, nband, N, T) + +# https://github.com/bshall/VectorQuantizedVAE/blob/master/model.py +class VQEmbeddingEMA(nn.Module): + def __init__(self, num_code, code_dim, decay=0.99, layer=0): + super(VQEmbeddingEMA, self).__init__() + + self.num_code = num_code + self.code_dim = code_dim + self.decay = decay + self.layer = layer + self.stale_tolerance = 100 + self.eps = torch.finfo(torch.float32).eps + + embedding = torch.empty(num_code, code_dim).normal_() / ((layer+1) * code_dim) + self.register_buffer("embedding", embedding) + self.register_buffer("ema_weight", self.embedding.clone()) + self.register_buffer("ema_count", torch.zeros(self.num_code)) + self.register_buffer("stale_counter", torch.zeros(self.num_code)) + + def forward(self, input): + + B, N, T = input.shape + assert N == self.code_dim + + input_detach = input.detach().mT.contiguous().view(B*T, N) # B*T, dim + + # distance + eu_dis = input_detach.pow(2).sum(-1).unsqueeze(-1) + self.embedding.pow(2).sum(-1).unsqueeze(0) # B*T, num_code + eu_dis = eu_dis - 2 * input_detach.mm(self.embedding.T) # B*T, num_code + + # best codes + indices = torch.argmin(eu_dis, dim=-1) # B*T + quantized = torch.gather(self.embedding, 0, indices.unsqueeze(-1).expand(-1, self.code_dim)) # B*T, dim + quantized = quantized.view(B, T, N).mT.contiguous() # B, N, T + + # calculate perplexity + encodings = F.one_hot(indices, self.num_code).float() # B*T, num_code + avg_probs = encodings.mean(0) # num_code + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self.eps), -1)).mean() + indices = indices.view(B, T) + + if self.training: + # EMA update for codebook + + self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0) # num_code + + update_direction = encodings.T.mm(input_detach) # num_code, dim + self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * update_direction # num_code, dim + + # Laplace smoothing on the counters + # make sure the denominator will never be zero + n = torch.sum(self.ema_count, dim=-1, keepdim=True) # 1 + self.ema_count = (self.ema_count + self.eps) / (n + self.num_code * self.eps) * n # num_code + + self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1) + + # calculate code usage + stale_codes = (encodings.sum(0) == 0).float() # num_code + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # num_code + if replace_code.sum(-1).max() > 0: + random_input_idx = torch.randperm(input_detach.shape[0]) + random_input = input_detach[random_input_idx].view(input_detach.shape) + if random_input.shape[0] < self.num_code: + random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0) + random_input = random_input[:self.num_code].contiguous() # num_code, dim + + self.embedding = self.embedding * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.ema_weight = self.ema_weight * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.ema_count = self.ema_count * (1 - replace_code) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return quantized, indices, perplexity + +class RVQEmbedding(nn.Module): + def __init__(self, code_dim, decay=0.99, bit=[10]): + super(RVQEmbedding, self).__init__() + + self.code_dim = code_dim + self.decay = decay + self.eps = torch.finfo(torch.float32).eps + + self.VQEmbedding = nn.ModuleList([]) + for i in range(len(bit)): + self.VQEmbedding.append(VQEmbeddingEMA(2**bit[i], code_dim, decay, layer=i)) + + def forward(self, input): + quantized = [] + indices = [] + ppl = [] + + residual_input = input + for i in range(len(self.VQEmbedding)): + this_quantized, this_indices, this_perplexity = self.VQEmbedding[i](residual_input) + indices.append(this_indices) + ppl.append(this_perplexity) + residual_input = residual_input - this_quantized + if i == 0: + quantized.append(this_quantized) + else: + quantized.append(quantized[-1] + this_quantized) + + quantized = torch.stack(quantized, -1) + indices = torch.stack(indices, -1) + ppl = torch.stack(ppl, -1) + latent_loss = 0 + for i in range(quantized.shape[-1]): + latent_loss = latent_loss + F.mse_loss(input, quantized.detach()[...,i]) + + return quantized, indices, ppl, latent_loss + +class Codec(nn.Module): + def __init__(self, nch=1, sr=44100, win=100, feature_dim=128, vae_dim=2, enc_layer=12, dec_layer=12, bit=[8]*5, causal=True): + super(Codec, self).__init__() + + self.nch = nch + self.sr = sr + self.win = int(sr / 1000 * win) + self.stride = self.win // 2 + self.enc_dim = self.win // 2 + 1 + self.feature_dim = feature_dim + self.vae_dim = vae_dim + self.bit = bit + self.eps = torch.finfo(torch.float32).eps + + # 0-1k (50 hop), 1k-4k (100 hop), 4k-8k (200 hop), 8k-12k (400 hop), 12k-22k (500 hop) + # 100 bands + bandwidth_50 = int(np.floor(50 / (sr / 2.) * self.enc_dim)) + bandwidth_100 = int(np.floor(100 / (sr / 2.) * self.enc_dim)) + bandwidth_200 = int(np.floor(200 / (sr / 2.) * self.enc_dim)) + bandwidth_400 = int(np.floor(400 / (sr / 2.) * self.enc_dim)) + bandwidth_500 = int(np.floor(500 / (sr / 2.) * self.enc_dim)) + self.band_width = [bandwidth_50]*20 + self.band_width += [bandwidth_100]*30 + self.band_width += [bandwidth_200]*20 + self.band_width += [bandwidth_400]*10 + self.band_width += [bandwidth_500]*19 + self.band_width.append(self.enc_dim - np.sum(self.band_width)) + self.nband = len(self.band_width) + print(self.band_width, self.nband) + + self.VAE_BN = nn.ModuleList([]) + for i in range(self.nband): + self.VAE_BN.append(nn.Sequential(RMVN((self.band_width[i]*2+1)*self.nch), + nn.Conv1d(((self.band_width[i]*2+1)*self.nch), self.feature_dim, 1)) + ) + + self.VAE_encoder = [] + for _ in range(enc_layer): + self.VAE_encoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) + self.VAE_encoder = nn.Sequential(*self.VAE_encoder) + + self.vae_FC = nn.Sequential(RMVN(self.nband*self.feature_dim, groups=self.nband), + nn.Conv1d(self.nband*self.feature_dim, self.nband*self.vae_dim*2, 1, groups=self.nband) + ) + self.codebook = RVQEmbedding(self.nband*self.vae_dim*2, bit=bit) + self.vae_reshape = nn.Conv1d(self.nband*self.vae_dim, self.nband*self.feature_dim, 1, groups=self.nband) + + self.VAE_decoder = [] + for _ in range(dec_layer): + self.VAE_decoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) + self.VAE_decoder = nn.Sequential(*self.VAE_decoder) + + self.VAE_output = nn.ModuleList([]) + for i in range(self.nband): + self.VAE_output.append(nn.Sequential(RMVN(self.feature_dim), + nn.Conv1d(self.feature_dim, self.band_width[i]*4*self.nch, 1), + nn.GLU(dim=1)) + ) + + def spec_band_split(self, input): + + B, nch, nsample = input.shape + + spec = torch.stft(input.view(B*nch, nsample).float(), n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(input.device), return_complex=True) + + subband_spec = [] + subband_spec_norm = [] + subband_power = [] + band_idx = 0 + for i in range(self.nband): + this_spec = spec[:,band_idx:band_idx+self.band_width[i]] + subband_spec.append(this_spec) # B, BW, T + subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) # B, 1, T + subband_spec_norm.append([this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1]]) # B, BW, T + band_idx += self.band_width[i] + subband_power = torch.cat(subband_power, 1) # B, nband, T + + return subband_spec, subband_spec_norm, subband_power + + def feature_extractor(self, input): + + _, subband_spec_norm, subband_power = self.spec_band_split(input) + + # normalization and bottleneck + subband_feature = [] + for i in range(self.nband): + concat_spec = torch.cat([subband_spec_norm[i][0], subband_spec_norm[i][1], torch.log(subband_power[:,i].unsqueeze(1))], 1) + concat_spec = concat_spec.view(-1, (self.band_width[i]*2+1)*self.nch, concat_spec.shape[-1]) + subband_feature.append(self.VAE_BN[i](concat_spec.type(input.type()))) + subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T + + return subband_feature + + def vae_sample(self, input): + + B, nch, _ = input.shape + + subband_feature = self.feature_extractor(input) + + # encode + enc_output = checkpoint_sequential(self.VAE_encoder, len(self.VAE_encoder), subband_feature) + enc_output = self.vae_FC(enc_output.view(B, self.nband*self.feature_dim, -1)).view(B, self.nband, 2, self.vae_dim, -1) + mu = enc_output[:,:,0].contiguous() + logvar = enc_output[:,:,1].contiguous() + + # vae + reparam_feature = mu + torch.randn_like(logvar) * torch.exp(0.5 * logvar) + vae_loss = (-0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(2)).mean() + + # quantization + mu_var = torch.stack([mu, logvar], 1).view(B, self.nband*self.vae_dim*2, -1) + quantized_emb, indices, ppl, latent_loss = self.codebook(mu_var.detach()) + + return reparam_feature, quantized_emb, mu_var, indices, ppl, latent_loss, vae_loss + + def vae_decode(self, vae_feature, nsample=None): + B = vae_feature.shape[0] + dec_input = self.vae_reshape(vae_feature.contiguous().view(B, self.nband*self.vae_dim, -1)) + output = checkpoint_sequential(self.VAE_decoder, len(self.VAE_decoder), dec_input.view(B, self.nband, self.feature_dim, -1)) + + est_spec = [] + for i in range(self.nband): + this_RI = self.VAE_output[i](output[:,i]).view(B*self.nch, 2, self.band_width[i], -1) + est_spec.append(torch.complex(this_RI[:,0].float(), this_RI[:,1].float())) + est_spec = torch.cat(est_spec, 1) + if nsample is not None: + output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(vae_feature.device), length=nsample).view(B, self.nch, -1) + else: + output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(vae_feature.device)).view(B, self.nch, -1) + + return output.type(vae_feature.type()) + + def forward(self, input): + + B, nch, nsample = input.shape + assert nch == self.nch + + vae_feature, quantized_emb, mu_var, indices, ppl, latent_loss, vae_loss = self.vae_sample(input) + output = self.vae_decode(vae_feature, nsample=nsample).view(input.shape) + + + return output # , vae_feature, quantized_emb, mu_var, indices, ppl, latent_loss, vae_loss + +def get_bsrnnvae(ckpt): + nch = 1 + model = Codec(nch = nch, \ + win = 100, \ + feature_dim = 128, \ + vae_dim = 8, \ + bit = [14]*5, \ + causal = True) + weight = torch.load(ckpt, map_location='cpu') + model.load_state_dict(weight) + return model.eval() diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae_old.py b/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae_old.py new file mode 100644 index 0000000000000000000000000000000000000000..010606804f85edcc3fc43850c0ee3e943080218d --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_bsrnnvae_old.py @@ -0,0 +1,427 @@ +"""! +@author Yi Luo (oulyluo) +@copyright Tencent AI Lab +""" + +from __future__ import print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch.utils.checkpoint import checkpoint_sequential +from thop import profile, clever_format + +class RMVN(nn.Module): + """ + Rescaled MVN. + """ + def __init__(self, dimension, groups=1): + super(RMVN, self).__init__() + + self.mean = nn.Parameter(torch.zeros(dimension)) + self.std = nn.Parameter(torch.ones(dimension)) + self.groups = groups + self.eps = torch.finfo(torch.float32).eps + + def forward(self, input): + # input size: (B, N, T) + B, N, T = input.shape + assert N % self.groups == 0 + + input = input.view(B, self.groups, -1, T) + input_norm = (input - input.mean(2).unsqueeze(2)) / (input.var(2).unsqueeze(2) + self.eps).sqrt() + input_norm = input_norm.view(B, N, T) * self.std.view(1, -1, 1) + self.mean.view(1, -1, 1) + + return input_norm + +class ConvActNorm1d(nn.Module): + def __init__(self, in_channel, hidden_channel, kernel=7, causal=False): + super(ConvActNorm1d, self).__init__() + + self.in_channel = in_channel + self.kernel = kernel + self.causal = causal + if not causal: + self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=(kernel-1)//2), + RMVN(in_channel), + nn.Conv1d(in_channel, hidden_channel*2, 1), + nn.GLU(dim=1), + nn.Conv1d(hidden_channel, in_channel, 1) + ) + else: + self.conv = nn.Sequential(nn.Conv1d(in_channel, in_channel, kernel, padding=kernel-1), + RMVN(in_channel), + nn.Conv1d(in_channel, hidden_channel*2, 1), + nn.GLU(dim=1), + nn.Conv1d(hidden_channel, in_channel, 1) + ) + + def forward(self, input): + + output = self.conv(input) + if self.causal: + output = output[...,:-self.kernel+1].contiguous() + return input + output + +class ICB(nn.Module): + def __init__(self, in_channel, kernel=7, causal=False): + super(ICB, self).__init__() + + self.blocks = nn.Sequential(ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), + ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal), + ConvActNorm1d(in_channel, in_channel*4, kernel, causal=causal) + ) + + def forward(self, input): + + return self.blocks(input) + +class ResRNN(nn.Module): + def __init__(self, input_size, hidden_size, bidirectional=False): + super(ResRNN, self).__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.eps = torch.finfo(torch.float32).eps + + self.norm = RMVN(input_size) + self.rnn = nn.LSTM(input_size, hidden_size, 1, batch_first=True, bidirectional=bidirectional) + + self.proj = nn.Linear(hidden_size*(int(bidirectional)+1), input_size) + + def forward(self, input, use_head=1): + # input shape: batch, dim, seq + + B, N, T = input.shape + + rnn_output, _ = self.rnn(self.norm(input).transpose(1,2).contiguous()) + + output = self.proj(rnn_output.contiguous().view(-1, rnn_output.shape[2])) + output = output.view(B, T, -1).transpose(1,2).contiguous() + + return input + output + +class BSNet(nn.Module): + def __init__(self, feature_dim, kernel=7, causal=False): + super(BSNet, self).__init__() + + self.feature_dim = feature_dim + + self.seq_net = ICB(self.feature_dim, kernel=kernel, causal=causal) + self.band_net = ResRNN(self.feature_dim, self.feature_dim*2, bidirectional=True) + + def forward(self, input): + # input shape: B, nband, N, T + + B, nband, N, T = input.shape + + band_output = self.seq_net(input.view(B*nband, N, T)).view(B, nband, -1, T) + + # band comm + band_output = band_output.permute(0,3,2,1).contiguous().view(B*T, -1, nband) + output = self.band_net(band_output).view(B, T, -1, nband).permute(0,3,2,1).contiguous() + + return output.view(B, nband, N, T) + +# https://github.com/bshall/VectorQuantizedVAE/blob/master/model.py +class VQEmbeddingEMA(nn.Module): + def __init__(self, num_code, code_dim, decay=0.99, layer=0): + super(VQEmbeddingEMA, self).__init__() + + self.num_code = num_code + self.code_dim = code_dim + self.decay = decay + self.layer = layer + self.stale_tolerance = 100 + self.eps = torch.finfo(torch.float32).eps + + embedding = torch.empty(num_code, code_dim).normal_() / ((layer+1) * code_dim) + self.register_buffer("embedding", embedding) + self.register_buffer("ema_weight", self.embedding.clone()) + self.register_buffer("ema_count", torch.zeros(self.num_code)) + self.register_buffer("stale_counter", torch.zeros(self.num_code)) + + def forward(self, input): + + B, N, T = input.shape + assert N == self.code_dim + + input_detach = input.detach().mT.contiguous().view(B*T, N) # B*T, dim + + # distance + eu_dis = input_detach.pow(2).sum(-1).unsqueeze(-1) + self.embedding.pow(2).sum(-1).unsqueeze(0) # B*T, num_code + eu_dis = eu_dis - 2 * input_detach.mm(self.embedding.T) # B*T, num_code + + # best codes + indices = torch.argmin(eu_dis, dim=-1) # B*T + quantized = torch.gather(self.embedding, 0, indices.unsqueeze(-1).expand(-1, self.code_dim)) # B*T, dim + quantized = quantized.view(B, T, N).mT.contiguous() # B, N, T + + # calculate perplexity + encodings = F.one_hot(indices, self.num_code).float() # B*T, num_code + avg_probs = encodings.mean(0) # num_code + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + self.eps), -1)).mean() + indices = indices.view(B, T) + + if self.training: + # EMA update for codebook + + self.ema_count = self.decay * self.ema_count + (1 - self.decay) * torch.sum(encodings, dim=0) # num_code + + update_direction = encodings.T.mm(input_detach) # num_code, dim + self.ema_weight = self.decay * self.ema_weight + (1 - self.decay) * update_direction # num_code, dim + + # Laplace smoothing on the counters + # make sure the denominator will never be zero + n = torch.sum(self.ema_count, dim=-1, keepdim=True) # 1 + self.ema_count = (self.ema_count + self.eps) / (n + self.num_code * self.eps) * n # num_code + + self.embedding = self.ema_weight / self.ema_count.unsqueeze(-1) + + # calculate code usage + stale_codes = (encodings.sum(0) == 0).float() # num_code + self.stale_counter = self.stale_counter * stale_codes + stale_codes + + # random replace codes that haven't been used for a while + replace_code = (self.stale_counter == self.stale_tolerance).float() # num_code + if replace_code.sum(-1).max() > 0: + random_input_idx = torch.randperm(input_detach.shape[0]) + random_input = input_detach[random_input_idx].view(input_detach.shape) + if random_input.shape[0] < self.num_code: + random_input = torch.cat([random_input]*(self.num_code // random_input.shape[0] + 1), 0) + random_input = random_input[:self.num_code].contiguous() # num_code, dim + + self.embedding = self.embedding * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.ema_weight = self.ema_weight * (1 - replace_code).unsqueeze(-1) + random_input * replace_code.unsqueeze(-1) + self.ema_count = self.ema_count * (1 - replace_code) + self.stale_counter = self.stale_counter * (1 - replace_code) + + return quantized, indices, perplexity + +class RVQEmbedding(nn.Module): + def __init__(self, code_dim, decay=0.99, bit=[10]): + super(RVQEmbedding, self).__init__() + + self.code_dim = code_dim + self.decay = decay + self.eps = torch.finfo(torch.float32).eps + + self.VQEmbedding = nn.ModuleList([]) + for i in range(len(bit)): + self.VQEmbedding.append(VQEmbeddingEMA(2**bit[i], code_dim, decay, layer=i)) + + def forward(self, input): + quantized = [] + indices = [] + ppl = [] + + residual_input = input + for i in range(len(self.VQEmbedding)): + this_quantized, this_indices, this_perplexity = self.VQEmbedding[i](residual_input) + indices.append(this_indices) + ppl.append(this_perplexity) + residual_input = residual_input - this_quantized + if i == 0: + quantized.append(this_quantized) + else: + quantized.append(quantized[-1] + this_quantized) + + quantized = torch.stack(quantized, -1) + indices = torch.stack(indices, -1) + ppl = torch.stack(ppl, -1) + latent_loss = 0 + for i in range(quantized.shape[-1]): + latent_loss = latent_loss + F.mse_loss(input, quantized.detach()[...,i]) + + return quantized, indices, ppl, latent_loss + +class Codec(nn.Module): + def __init__(self, nch=1, sr=44100, win=80, feature_dim=128, vae_dim=2, enc_layer=12, dec_layer=12, bit=[8]*5, causal=False): + super(Codec, self).__init__() + + self.nch = nch + self.sr = sr + self.win = int(sr / 1000 * win) + self.stride = self.win // 2 + self.enc_dim = self.win // 2 + 1 + self.feature_dim = feature_dim + self.vae_dim = vae_dim + self.bit = bit + self.eps = torch.finfo(torch.float32).eps + + # 0-1k (50 hop), 1k-2k (100 hop), 2k-4k (250 hop), 4k-8k (500 hop), 8k-12k (1k hop), 12k-20k (2k hop), 20k-inf + # 55 bands + bandwidth_50 = int(np.floor(50 / (sr / 2.) * self.enc_dim)) + bandwidth_100 = int(np.floor(100 / (sr / 2.) * self.enc_dim)) + bandwidth_250 = int(np.floor(250 / (sr / 2.) * self.enc_dim)) + bandwidth_500 = int(np.floor(500 / (sr / 2.) * self.enc_dim)) + bandwidth_1k = int(np.floor(1000 / (sr / 2.) * self.enc_dim)) + bandwidth_2k = int(np.floor(2000 / (sr / 2.) * self.enc_dim)) + self.band_width = [bandwidth_50]*20 + self.band_width += [bandwidth_100]*10 + self.band_width += [bandwidth_250]*8 + self.band_width += [bandwidth_500]*8 + self.band_width += [bandwidth_1k]*4 + self.band_width += [bandwidth_2k]*4 + self.band_width.append(self.enc_dim - np.sum(self.band_width)) + self.nband = len(self.band_width) + print(self.band_width, self.nband) + + self.VAE_BN = nn.ModuleList([]) + for i in range(self.nband): + self.VAE_BN.append(nn.Sequential(RMVN((self.band_width[i]*2+1)*self.nch), + nn.Conv1d(((self.band_width[i]*2+1)*self.nch), self.feature_dim, 1)) + ) + + self.VAE_encoder = [] + for _ in range(enc_layer): + self.VAE_encoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) + self.VAE_encoder = nn.Sequential(*self.VAE_encoder) + + self.vae_FC = nn.Sequential(RMVN(self.nband*self.feature_dim, groups=self.nband), + nn.Conv1d(self.nband*self.feature_dim, self.nband*self.vae_dim*2, 1, groups=self.nband) + ) + self.codebook = RVQEmbedding(self.nband*self.vae_dim*2, bit=bit) + self.vae_reshape = nn.Conv1d(self.nband*self.vae_dim, self.nband*self.feature_dim, 1, groups=self.nband) + + self.VAE_decoder = [] + for _ in range(dec_layer): + self.VAE_decoder.append(BSNet(self.feature_dim, kernel=7, causal=causal)) + self.VAE_decoder = nn.Sequential(*self.VAE_decoder) + + self.VAE_output = nn.ModuleList([]) + for i in range(self.nband): + self.VAE_output.append(nn.Sequential(RMVN(self.feature_dim), + nn.Conv1d(self.feature_dim, self.band_width[i]*4*self.nch, 1), + nn.GLU(dim=1)) + ) + + def spec_band_split(self, input): + + B, nch, nsample = input.shape + + spec = torch.stft(input.view(B*nch, nsample).float(), n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(input.device), return_complex=True) + + subband_spec = [] + subband_spec_norm = [] + subband_power = [] + band_idx = 0 + for i in range(self.nband): + this_spec = spec[:,band_idx:band_idx+self.band_width[i]] + subband_spec.append(this_spec) # B, BW, T + subband_power.append((this_spec.abs().pow(2).sum(1) + self.eps).sqrt().unsqueeze(1)) # B, 1, T + subband_spec_norm.append([this_spec.real / subband_power[-1], this_spec.imag / subband_power[-1]]) # B, BW, T + band_idx += self.band_width[i] + subband_power = torch.cat(subband_power, 1) # B, nband, T + + return subband_spec, subband_spec_norm, subband_power + + def feature_extractor(self, input): + + _, subband_spec_norm, subband_power = self.spec_band_split(input) + + # normalization and bottleneck + subband_feature = [] + for i in range(self.nband): + concat_spec = torch.cat([subband_spec_norm[i][0], subband_spec_norm[i][1], torch.log(subband_power[:,i].unsqueeze(1))], 1) + concat_spec = concat_spec.view(-1, (self.band_width[i]*2+1)*self.nch, concat_spec.shape[-1]) + subband_feature.append(self.VAE_BN[i](concat_spec.type(input.type()))) + subband_feature = torch.stack(subband_feature, 1) # B, nband, N, T + + return subband_feature + + def vae_sample(self, input): + + B, nch, _ = input.shape + + subband_feature = self.feature_extractor(input) + + # encode + enc_output = checkpoint_sequential(self.VAE_encoder, len(self.VAE_encoder), subband_feature) + enc_output = self.vae_FC(enc_output.view(B, self.nband*self.feature_dim, -1)).view(B, self.nband, 2, self.vae_dim, -1) + mu = enc_output[:,:,0].contiguous() + logvar = enc_output[:,:,1].contiguous() + + # vae + reparam_feature = mu + torch.randn_like(logvar) * torch.exp(0.5 * logvar) + + return reparam_feature.view(B, nch, self.nband, self.vae_dim, -1) + + def vae_decode(self, vae_feature): + B = vae_feature.shape[0] + dec_input = self.vae_reshape(vae_feature.contiguous().view(B, self.nband*self.vae_dim, -1)) + output = checkpoint_sequential(self.VAE_decoder, len(self.VAE_decoder), dec_input.view(B, self.nband, self.feature_dim, -1)) + + est_spec = [] + for i in range(self.nband): + this_RI = self.VAE_output[i](output[:,i]).view(B*self.nch, 2, self.band_width[i], -1) + est_spec.append(torch.complex(this_RI[:,0].float(), this_RI[:,1].float())) + est_spec = torch.cat(est_spec, 1) + + output = torch.istft(est_spec, n_fft=self.win, hop_length=self.stride, + window=torch.hann_window(self.win).to(vae_feature.device)).view(B, self.nch, -1) + + return output.type(vae_feature.type()) + + def forward(self, input): + + B, nch, nsample = input.shape + assert nch == self.nch + + vae_feature = self.vae_sample(input) + output = self.vae_decode(vae_feature).view(B, nch, -1) + if(output.shape[-1] > nsample): + output = output[:,:,0:nsample] + elif(output.shape[-1] < nsample): + output = torch.cat([output, torch.zeros(B, nch, nsample - output.shape[-1], device= output.device, dtype=output.dtype)],-1) + + return output + + def encode(self, input, do_sample=True): + assert do_sample, do_sample + B, nch, nsample = input.shape + assert nch == self.nch + + vae_feature = self.vae_sample(input) + return vae_feature + +def get_bsrnnvae(ckpt): + nch = 1 + model = Codec(nch = nch, \ + win = 100, \ + feature_dim = 128, \ + vae_dim = 2, \ + bit = [14]*5, \ + causal = True) + weight = torch.load(ckpt, map_location='cpu') + model.load_state_dict(weight) + return model.eval() + +if __name__ == '__main__': + model = Codec(causal=True) + x = torch.empty(1, 1, 44100).uniform_(-1, 1) + + s = 0 + for param in model.parameters(): + s += np.product(param.size()) + print('# of parameters: '+str(s/1e6)+" M") + + output = model(x) + print(output.shape) + + macs, params = profile(model, inputs=(x,)) + macs, params = clever_format([macs, params], "%.3f") + print(macs, params) + + import torchaudio + model = get_bsrnnvae() + inp, fs = torchaudio.load('769000.mp3') + inp = inp[[0],:] + if(fs!=44100): + inp = torchaudio.functional.resample(inp, fs, 44100) + fs = 44100 + inp = inp[:,0:30*44100] + out = model(inp[None,:,:]).detach() + torchaudio.save('out.flac', out[0], fs) \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_melvaehifigan48k.py b/codeclm/tokenizer/Flow1dVAE/tools/get_melvaehifigan48k.py new file mode 100644 index 0000000000000000000000000000000000000000..9e94c9bb83bb0a85c3ec9c3d895b1c67ec02544c --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_melvaehifigan48k.py @@ -0,0 +1,1578 @@ + +import soundfile as sf +import os +from librosa.filters import mel as librosa_mel_fn +import sys +import tools.torch_tools as torch_tools +import torch.nn as nn +import torch +import numpy as np +from einops import rearrange +from scipy.signal import get_window +from librosa.util import pad_center, tiny +import librosa.util as librosa_util + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + +LRELU_SLOPE = 0.1 + +class ResBlock(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + torch.nn.utils.weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + torch.nn.utils.weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + torch.nn.utils.weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + torch.nn.utils.weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + torch.nn.utils.weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + torch.nn.utils.weight_norm( + nn.Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + torch.nn.utils.remove_weight_norm(l) + for l in self.convs2: + torch.nn.utils.remove_weight_norm(l) + + +class Generator_old(torch.nn.Module): + def __init__(self, h): + super(Generator_old, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = torch.nn.utils.weight_norm( + nn.Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + torch.nn.utils.weight_norm( + nn.ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = torch.nn.utils.weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = torch.nn.functional.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print("Removing weight norm...") + for l in self.ups: + torch.nn.utils.remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + torch.nn.utils.remove_weight_norm(self.conv_pre) + torch.nn.utils.remove_weight_norm(self.conv_post) + + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class DownsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) + return x + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class UpsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=1, padding=2 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + raise ValueError(attn_type) + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + downsample_time_stride4_levels=[], + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.downsample_time_stride4_levels: + down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) + else: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + downsample_time_stride4_levels=[], + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # compute in_ch_mult, block_in and curr_res at lowest res + (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print( + # "Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape) + # ) + # ) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level - 1 in self.downsample_time_stride4_levels: + up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) + else: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + +def get_vocoder_config_48k(): + return { + "resblock": "1", + "num_gpus": 8, + "batch_size": 128, + "learning_rate": 0.0001, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + + "upsample_rates": [6,5,4,2,2], + "upsample_kernel_sizes": [12,10,8,4,4], + "upsample_initial_channel": 1536, + "resblock_kernel_sizes": [3,7,11,15], + "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5], [1,3,5]], + + "segment_size": 15360, + "num_mels": 256, + "n_fft": 2048, + "hop_size": 480, + "win_size": 2048, + + "sampling_rate": 48000, + + "fmin": 20, + "fmax": 24000, + "fmax_for_loss": None, + + "num_workers": 8, + + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:18273", + "world_size": 1 + } + } + +def get_vocoder(config, device, mel_bins): + name = "HiFi-GAN" + speaker = "" + if name == "MelGAN": + if speaker == "LJSpeech": + vocoder = torch.hub.load( + "descriptinc/melgan-neurips", "load_melgan", "linda_johnson" + ) + elif speaker == "universal": + vocoder = torch.hub.load( + "descriptinc/melgan-neurips", "load_melgan", "multi_speaker" + ) + vocoder.mel2wav.eval() + vocoder.mel2wav.to(device) + elif name == "HiFi-GAN": + if(mel_bins == 256): + config = get_vocoder_config_48k() + config = AttrDict(config) + vocoder = Generator_old(config) + # print("Load hifigan/g_01080000") + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_01080000")) + # ckpt = torch.load(os.path.join(ROOT, "hifigan/g_00660000")) + # ckpt = torch_version_orig_mod_remove(ckpt) + # vocoder.load_state_dict(ckpt["generator"]) + vocoder.eval() + vocoder.remove_weight_norm() + vocoder = vocoder.to(device) + # vocoder = vocoder.half() + else: + raise ValueError(mel_bins) + return vocoder + +def vocoder_infer(mels, vocoder, lengths=None): + with torch.no_grad(): + wavs = vocoder(mels).squeeze(1) + + #wavs = (wavs.cpu().numpy() * 32768).astype("int16") + wavs = (wavs.cpu().numpy()) + + if lengths is not None: + wavs = wavs[:, :lengths] + + # wavs = [wav for wav in wavs] + + # for i in range(len(mels)): + # if lengths is not None: + # wavs[i] = wavs[i][: lengths[i]] + + return wavs + +@torch.no_grad() +def vocoder_chunk_infer(mels, vocoder, lengths=None): + chunk_size = 256*4 + shift_size = 256*1 + ov_size = chunk_size-shift_size + # import pdb;pdb.set_trace() + + for cinx in range(0, mels.shape[2], shift_size): + if(cinx==0): + wavs = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1).float() + num_samples = int(wavs.shape[-1]/chunk_size)*chunk_size + wavs = wavs[:,0:num_samples] + ov_sample = int(float(wavs.shape[-1]) * ov_size / chunk_size) + ov_win = torch.linspace(0, 1, ov_sample, device="cuda").unsqueeze(0) + ov_win = torch.cat([ov_win,1-ov_win],-1) + if(cinx+chunk_size>=mels.shape[2]): + break + else: + cur_wav = vocoder(mels[:,:,cinx:cinx+chunk_size]).squeeze(1)[:,0:num_samples].float() + wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * ov_win[:,-ov_sample:] + cur_wav[:,0:ov_sample] * ov_win[:,0:ov_sample] + # wavs[:,-ov_sample:] = wavs[:,-ov_sample:] * 1.0 + cur_wav[:,0:ov_sample] * 0.0 + wavs = torch.cat([wavs, cur_wav[:,ov_sample:]],-1) + if(cinx+chunk_size>=mels.shape[2]): + break + # print(wavs.shape) + + wavs = (wavs.cpu().numpy()) + + if lengths is not None: + wavs = wavs[:, :lengths] + # print(wavs.shape) + return wavs + +def synth_one_sample(mel_input, mel_prediction, labels, vocoder): + if vocoder is not None: + + wav_reconstruction = vocoder_infer( + mel_input.permute(0, 2, 1), + vocoder, + ) + wav_prediction = vocoder_infer( + mel_prediction.permute(0, 2, 1), + vocoder, + ) + else: + wav_reconstruction = wav_prediction = None + + return wav_reconstruction, wav_prediction + + +class AutoencoderKL(nn.Module): + def __init__( + self, + ddconfig=None, + lossconfig=None, + batchsize=None, + embed_dim=None, + time_shuffle=1, + subband=1, + sampling_rate=16000, + ckpt_path=None, + reload_from_ckpt=None, + ignore_keys=[], + image_key="fbank", + colorize_nlabels=None, + monitor=None, + base_learning_rate=1e-5, + scale_factor=1 + ): + super().__init__() + self.automatic_optimization = False + assert ( + "mel_bins" in ddconfig.keys() + ), "mel_bins is not specified in the Autoencoder config" + num_mel = ddconfig["mel_bins"] + self.image_key = image_key + self.sampling_rate = sampling_rate + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + + self.loss = None + self.subband = int(subband) + + if self.subband > 1: + print("Use subband decomposition %s" % self.subband) + + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + if self.image_key == "fbank": + self.vocoder = get_vocoder(None, torch.device("cuda"), num_mel) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels) == int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.learning_rate = float(base_learning_rate) + # print("Initial learning rate %s" % self.learning_rate) + + self.time_shuffle = time_shuffle + self.reload_from_ckpt = reload_from_ckpt + self.reloaded = False + self.mean, self.std = None, None + + self.feature_cache = None + self.flag_first_run = True + self.train_step = 0 + + self.logger_save_dir = None + self.logger_exp_name = None + self.scale_factor = scale_factor + + print("Num parameters:") + print("Encoder : ", sum(p.numel() for p in self.encoder.parameters())) + print("Decoder : ", sum(p.numel() for p in self.decoder.parameters())) + print("Vocoder : ", sum(p.numel() for p in self.vocoder.parameters())) + + def get_log_dir(self): + if self.logger_save_dir is None and self.logger_exp_name is None: + return os.path.join(self.logger.save_dir, self.logger._project) + else: + return os.path.join(self.logger_save_dir, self.logger_exp_name) + + def set_log_dir(self, save_dir, exp_name): + self.logger_save_dir = save_dir + self.logger_exp_name = exp_name + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + # x = self.time_shuffle_operation(x) + # x = self.freq_split_subband(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + # bs, ch, shuffled_timesteps, fbins = dec.size() + # dec = self.time_unshuffle_operation(dec, bs, int(ch*shuffled_timesteps), fbins) + # dec = self.freq_merge_subband(dec) + return dec + + def decode_to_waveform(self, dec): + + if self.image_key == "fbank": + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = vocoder_chunk_infer(dec, self.vocoder) + elif self.image_key == "stft": + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = self.wave_decoder(dec) + return wav_reconstruction + + def mel_spectrogram_to_waveform( + self, mel, savepath=".", bs=None, name="outwav", save=True + ): + # Mel: [bs, 1, t-steps, fbins] + if len(mel.size()) == 4: + mel = mel.squeeze(1) + mel = mel.permute(0, 2, 1) + waveform = self.vocoder(mel) + waveform = waveform.cpu().detach().numpy() + #if save: + # self.save_waveform(waveform, savepath, name) + return waveform + + @torch.no_grad() + def encode_first_stage(self, x): + return self.encode(x) + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, "b h w c -> b c h w").contiguous() + + z = 1.0 / self.scale_factor * z + return self.decode(z) + + def decode_first_stage_withgrad(self, z): + z = 1.0 / self.scale_factor * z + return self.decode(z) + + def get_first_stage_encoding(self, encoder_posterior, use_mode=False): + if isinstance(encoder_posterior, DiagonalGaussianDistribution) and not use_mode: + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, DiagonalGaussianDistribution) and use_mode: + z = encoder_posterior.mode() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def visualize_latent(self, input): + import matplotlib.pyplot as plt + + # for i in range(10): + # zero_input = torch.zeros_like(input) - 11.59 + # zero_input[:,:,i * 16: i * 16 + 16,:16] += 13.59 + + # posterior = self.encode(zero_input) + # latent = posterior.sample() + # avg_latent = torch.mean(latent, dim=1)[0] + # plt.imshow(avg_latent.cpu().detach().numpy().T) + # plt.savefig("%s.png" % i) + # plt.close() + + np.save("input.npy", input.cpu().detach().numpy()) + # zero_input = torch.zeros_like(input) - 11.59 + time_input = input.clone() + time_input[:, :, :, :32] *= 0 + time_input[:, :, :, :32] -= 11.59 + + np.save("time_input.npy", time_input.cpu().detach().numpy()) + + posterior = self.encode(time_input) + latent = posterior.sample() + np.save("time_latent.npy", latent.cpu().detach().numpy()) + avg_latent = torch.mean(latent, dim=1) + for i in range(avg_latent.size(0)): + plt.imshow(avg_latent[i].cpu().detach().numpy().T) + plt.savefig("freq_%s.png" % i) + plt.close() + + freq_input = input.clone() + freq_input[:, :, :512, :] *= 0 + freq_input[:, :, :512, :] -= 11.59 + + np.save("freq_input.npy", freq_input.cpu().detach().numpy()) + + posterior = self.encode(freq_input) + latent = posterior.sample() + np.save("freq_latent.npy", latent.cpu().detach().numpy()) + avg_latent = torch.mean(latent, dim=1) + for i in range(avg_latent.size(0)): + plt.imshow(avg_latent[i].cpu().detach().numpy().T) + plt.savefig("time_%s.png" % i) + plt.close() + + def get_input(self, batch): + fname, text, label_indices, waveform, stft, fbank = ( + batch["fname"], + batch["text"], + batch["label_vector"], + batch["waveform"], + batch["stft"], + batch["log_mel_spec"], + ) + # if(self.time_shuffle != 1): + # if(fbank.size(1) % self.time_shuffle != 0): + # pad_len = self.time_shuffle - (fbank.size(1) % self.time_shuffle) + # fbank = torch.nn.functional.pad(fbank, (0,0,0,pad_len)) + + ret = {} + + ret["fbank"], ret["stft"], ret["fname"], ret["waveform"] = ( + fbank.unsqueeze(1), + stft.unsqueeze(1), + fname, + waveform.unsqueeze(1), + ) + + return ret + + def save_wave(self, batch_wav, fname, save_dir): + os.makedirs(save_dir, exist_ok=True) + + for wav, name in zip(batch_wav, fname): + name = os.path.basename(name) + + sf.write(os.path.join(save_dir, name), wav, samplerate=self.sampling_rate) + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, train=True, only_inputs=False, waveform=None, **kwargs): + log = dict() + x = batch.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + log["samples"] = self.decode(posterior.sample()) + log["reconstructions"] = xrec + + log["inputs"] = x + wavs = self._log_img(log, train=train, index=0, waveform=waveform) + return wavs + + def _log_img(self, log, train=True, index=0, waveform=None): + images_input = self.tensor2numpy(log["inputs"][index, 0]).T + images_reconstruct = self.tensor2numpy(log["reconstructions"][index, 0]).T + images_samples = self.tensor2numpy(log["samples"][index, 0]).T + + if train: + name = "train" + else: + name = "val" + + if self.logger is not None: + self.logger.log_image( + "img_%s" % name, + [images_input, images_reconstruct, images_samples], + caption=["input", "reconstruct", "samples"], + ) + + inputs, reconstructions, samples = ( + log["inputs"], + log["reconstructions"], + log["samples"], + ) + + if self.image_key == "fbank": + wav_original, wav_prediction = synth_one_sample( + inputs[index], + reconstructions[index], + labels="validation", + vocoder=self.vocoder, + ) + wav_original, wav_samples = synth_one_sample( + inputs[index], samples[index], labels="validation", vocoder=self.vocoder + ) + wav_original, wav_samples, wav_prediction = ( + wav_original[0], + wav_samples[0], + wav_prediction[0], + ) + elif self.image_key == "stft": + wav_prediction = ( + self.decode_to_waveform(reconstructions)[index, 0] + .cpu() + .detach() + .numpy() + ) + wav_samples = ( + self.decode_to_waveform(samples)[index, 0].cpu().detach().numpy() + ) + wav_original = waveform[index, 0].cpu().detach().numpy() + + if self.logger is not None: + self.logger.experiment.log( + { + "original_%s" + % name: wandb.Audio( + wav_original, caption="original", sample_rate=self.sampling_rate + ), + "reconstruct_%s" + % name: wandb.Audio( + wav_prediction, + caption="reconstruct", + sample_rate=self.sampling_rate, + ), + "samples_%s" + % name: wandb.Audio( + wav_samples, caption="samples", sample_rate=self.sampling_rate + ), + } + ) + + return wav_original, wav_prediction, wav_samples + + def tensor2numpy(self, tensor): + return tensor.cpu().detach().numpy() + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = torch.nn.functional.conv2d(x, weight=self.colorize) + x = 2.0 * (x - x.min()) / (x.max() - x.min()) - 1.0 + return x + + +class IdentityFirstStage(torch.nn.Module): + def __init__(self, *args, vq_interface=False, **kwargs): + self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff + super().__init__() + + def encode(self, x, *args, **kwargs): + return x + + def decode(self, x, *args, **kwargs): + return x + + def quantize(self, x, *args, **kwargs): + if self.vq_interface: + return x, None, [None, None, None] + return x + + def forward(self, x, *args, **kwargs): + return x + + +def window_sumsquare( + window, + n_frames, + hop_length, + win_length, + n_fft, + dtype=np.float32, + norm=None, +): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + +def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return normalize_fun(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, filter_length, hop_length, win_length, window="hann"): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) + + if window is not None: + assert filter_length >= win_length + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, size=filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + + def transform(self, input_data): + + device = self.forward_basis.device + input_data = input_data.to(device) + + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = torch.nn.functional.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode="reflect", + ) + input_data = input_data.squeeze(1) + + forward_transform = torch.nn.functional.conv1d( + input_data, + torch.autograd.Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + )#.cpu() + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + + device = self.forward_basis.device + magnitude, phase = magnitude.to(device), phase.to(device) + + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + + inverse_transform = torch.nn.functional.conv_transpose1d( + recombine_magnitude_phase, + torch.autograd.Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0] + ) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False + ) + window_sum = window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TacotronSTFT(torch.nn.Module): + def __init__( + self, + filter_length, + hop_length, + win_length, + n_mel_channels, + sampling_rate, + mel_fmin, + mel_fmax, + ): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + + def spectral_normalize(self, magnitudes, normalize_fun): + output = dynamic_range_compression(magnitudes, normalize_fun) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y, normalize_fun=torch.log): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert torch.min(y.data) >= -1, torch.min(y.data) + assert torch.max(y.data) <= 1, torch.max(y.data) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output, normalize_fun) + energy = torch.norm(magnitudes, dim=1) + + log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) + + return mel_output, log_magnitudes, energy + + +def build_pretrained_models(ckpt): + checkpoint = torch.load(ckpt, map_location="cpu") + scale_factor = checkpoint["state_dict"]["scale_factor"].item() + print("scale_factor: ", scale_factor) + + vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k} + + config = { + "preprocessing": { + "audio": { + "sampling_rate": 48000, + "max_wav_value": 32768, + "duration": 10.24 + }, + "stft": { + "filter_length": 2048, + "hop_length": 480, + "win_length": 2048 + }, + "mel": { + "n_mel_channels": 256, + "mel_fmin": 20, + "mel_fmax": 24000 + } + }, + "model": { + "params": { + "first_stage_config": { + "params": { + "sampling_rate": 48000, + "batchsize": 4, + "monitor": "val/rec_loss", + "image_key": "fbank", + "subband": 1, + "embed_dim": 16, + "time_shuffle": 1, + "lossconfig": { + "target": "audioldm2.latent_diffusion.modules.losses.LPIPSWithDiscriminator", + "params": { + "disc_start": 50001, + "kl_weight": 1000, + "disc_weight": 0.5, + "disc_in_channels": 1 + } + }, + "ddconfig": { + "double_z": True, + "mel_bins": 256, + "z_channels": 16, + "resolution": 256, + "downsample_time": False, + "in_channels": 1, + "out_ch": 1, + "ch": 128, + "ch_mult": [ + 1, + 2, + 4, + 8 + ], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0 + } + } + }, + } + } + } + vae_config = config["model"]["params"]["first_stage_config"]["params"] + vae_config["scale_factor"] = scale_factor + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(vae_state_dict) + + fn_STFT = TacotronSTFT( + config["preprocessing"]["stft"]["filter_length"], + config["preprocessing"]["stft"]["hop_length"], + config["preprocessing"]["stft"]["win_length"], + config["preprocessing"]["mel"]["n_mel_channels"], + config["preprocessing"]["audio"]["sampling_rate"], + config["preprocessing"]["mel"]["mel_fmin"], + config["preprocessing"]["mel"]["mel_fmax"], + ) + + vae.eval() + fn_STFT.eval() + return vae, fn_STFT + + +if __name__=="__main__": + vae, stft = build_pretrained_models() + vae, stft = vae.cuda(), stft.cuda() + + json_file="outputs/wav.scp" + out_path="outputs/Music_inverse" + + wavform = torch.randn(2,int(48000*10.24)) + mel, _, waveform = torch_tools.wav_to_fbank2(wavform, target_length=-1, fn_STFT=stft) + mel = mel.unsqueeze(1).cuda() + print(mel.shape) + # true_latent = torch.cat([vae.get_first_stage_encoding(vae.encode_first_stage(mel[[m]])) for m in range(mel.shape[0])],0) + # print(true_latent.shape) + true_latent = vae.get_first_stage_encoding(vae.encode_first_stage(mel)) + print(true_latent.shape) + true_latent = true_latent.reshape(true_latent.shape[0]//2, -1, true_latent.shape[2], true_latent.shape[3]).detach() + + true_latent = true_latent.reshape(true_latent.shape[0]*2,-1,true_latent.shape[2],true_latent.shape[3]) + print("111", true_latent.size()) + + mel = vae.decode_first_stage(true_latent) + print("222", mel.size()) + audio = vae.decode_to_waveform(mel) + print("333", audio.shape) + + # out_file = out_path + "/" + os.path.basename(fname.strip()) + # sf.write(out_file, audio[0], samplerate=48000) diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_mert_embedding.py b/codeclm/tokenizer/Flow1dVAE/tools/get_mert_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..a7fdd45820a6a1126b5c2d1559abdad61470747f --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_mert_embedding.py @@ -0,0 +1,120 @@ +import hydra +import librosa +import torch +import yaml +from prodict import Prodict +import torchaudio + +from musiclm_pytorch import AudioSpectrogramTransformerPretrained, TextTransformerPretrained, MuLaN, MuLaNEmbedder +from omegaconf import DictConfig +import os + +def get_pretrained_config(root, name): + if root is None: + return name + path = os.path.join(root, name) + #获取snapshots目录下的目录 + config_dir = os.path.join(path, 'snapshots') + config_files = os.listdir(config_dir) + assert len(config_files) == 1 + config_path = os.path.join(config_dir, config_files[0]) + return config_path + +def create_MuLaN_from_config(config: DictConfig): + """ + Create a MuLaN model from a configuration file. + """ + pretraind_root = config.model.pretraind_root + + audio_model_name = get_pretrained_config(pretraind_root, config.model.audio_model.name) + audio_transformer = AudioSpectrogramTransformerPretrained( + model_name = audio_model_name, + model_dim = config.model.audio_model.model_dim, + use_layer_idx = config.model.audio_model.use_layer_idx, + **config.model.audio_transformer + ) + text_model_name = get_pretrained_config(pretraind_root, config.model.text_model.name) + text_transformer = TextTransformerPretrained( + model_name = text_model_name, + **config.model.text_transformer + ) + + mulan = MuLaN( + audio_transformer = audio_transformer, + text_transformer = text_transformer, + **config.model.mulan + ) + + return mulan + + +def create_CLAP_model( model_kwargs = {}, ckpt_path = None ): + from musiclm_pytorch import SoftmaxContrastiveLearning + import laion_clap + + from torch import nn + import torch + from torchaudio.functional import resample + + import numpy as np + + from functools import partial + + # quantization + def int16_to_float32(x): + return (x / 32767.0).float() + + def float32_to_int16(x): + x = torch.clip(x, min=-1., max=1.) + return (x * 32767.).int() + + model = laion_clap.CLAP_Module(enable_fusion=False, **model_kwargs) + if ckpt_path is not None: + model.load_ckpt(ckpt_path) + else: + model.load_ckpt() + + class CLAP_Model(nn.Module): + def __init__(self, model, sr = 24000, decoupled_contrastive_learning = True): + super().__init__() + self.model = model + self.model.eval() + self.orig_sr = sr + + klass = partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning) + self.contrast = klass() + + + def forward(self, wavs, raw_texts): + with torch.no_grad(): + wavs = int16_to_float32(float32_to_int16(resample(wavs, self.orig_sr, 48000))) + audio_latents = self.model.get_audio_embedding_from_data(x = wavs, use_tensor=True).float() + text_latents = model.get_text_embedding(raw_texts, use_tensor=True) + cl_loss = self.contrast(audio_latents, text_latents) + return cl_loss + + clap = CLAP_Model(model) + return clap + +def get_mulan(config): + with open(config, "r") as stream: + mulan_config = yaml.safe_load(stream) + mulan_config = Prodict.from_dict(mulan_config) + ckpt_path = mulan_config.checkpoint_path + mulan = create_MuLaN_from_config(mulan_config) + mulan_embedder = MuLaNEmbedder(mulan, checkpoint_path = ckpt_path) + mulan_embedder.eval() + + return mulan_embedder + +def extract_mert_embeds(mulan_embd_extractor, layer_num, filename): + input_audios, fs = torchaudio.load(filename) + mulan_sr = 24000 + if(fs!=mulan_sr): + input_audios = torchaudio.functional.resample(input_audios, fs, mulan_sr) + fs = mulan_sr + # print(input_audios.shape) + inputs = mulan_embd_extractor.mulan.audio.processor(input_audios, sampling_rate=mulan_embd_extractor.mulan.audio.sr, return_tensors="pt") + input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) + prompt_embeds = mulan_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states[layer_num] # batch_size, Time steps, 1024 feature_dim + return prompt_embeds diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_mulan.py b/codeclm/tokenizer/Flow1dVAE/tools/get_mulan.py new file mode 100644 index 0000000000000000000000000000000000000000..d51f000dfcb94db35bc7061ee97a79d3bf0d3947 --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_mulan.py @@ -0,0 +1,108 @@ +from musiclm_pytorch import MuLaNEmbedder +import hydra +import librosa +import torch +import yaml +from prodict import Prodict + +from musiclm_pytorch import AudioSpectrogramTransformerPretrained, TextTransformerPretrained, MuLaN +from omegaconf import DictConfig +import os + +def get_pretrained_config(root, name): + if root is None: + return name + path = os.path.join(root, name) + #获取snapshots目录下的目录 + config_dir = os.path.join(path, 'snapshots') + config_files = os.listdir(config_dir) + assert len(config_files) == 1 + config_path = os.path.join(config_dir, config_files[0]) + return config_path + +def create_MuLaN_from_config(config: DictConfig): + """ + Create a MuLaN model from a configuration file. + """ + pretraind_root = config.model.pretraind_root + + audio_model_name = get_pretrained_config(pretraind_root, config.model.audio_model.name) + audio_transformer = AudioSpectrogramTransformerPretrained( + model_name = audio_model_name, + model_dim = config.model.audio_model.model_dim, + use_layer_idx = config.model.audio_model.use_layer_idx, + **config.model.audio_transformer + ) + text_model_name = get_pretrained_config(pretraind_root, config.model.text_model.name) + text_transformer = TextTransformerPretrained( + model_name = text_model_name, + **config.model.text_transformer + ) + + mulan = MuLaN( + audio_transformer = audio_transformer, + text_transformer = text_transformer, + **config.model.mulan + ) + + return mulan + + +def create_CLAP_model( model_kwargs = {}, ckpt_path = None ): + from musiclm_pytorch import SoftmaxContrastiveLearning + import laion_clap + + from torch import nn + import torch + from torchaudio.functional import resample + + import numpy as np + + from functools import partial + + # quantization + def int16_to_float32(x): + return (x / 32767.0).float() + + def float32_to_int16(x): + x = torch.clip(x, min=-1., max=1.) + return (x * 32767.).int() + + model = laion_clap.CLAP_Module(enable_fusion=False, **model_kwargs) + if ckpt_path is not None: + model.load_ckpt(ckpt_path) + else: + model.load_ckpt() + + class CLAP_Model(nn.Module): + def __init__(self, model, sr = 24000, decoupled_contrastive_learning = True): + super().__init__() + self.model = model + self.model.eval() + self.orig_sr = sr + + klass = partial(SoftmaxContrastiveLearning, decoupled_contrastive_learning = decoupled_contrastive_learning) + self.contrast = klass() + + + def forward(self, wavs, raw_texts): + with torch.no_grad(): + wavs = int16_to_float32(float32_to_int16(resample(wavs, self.orig_sr, 48000))) + audio_latents = self.model.get_audio_embedding_from_data(x = wavs, use_tensor=True).float() + text_latents = model.get_text_embedding(raw_texts, use_tensor=True) + cl_loss = self.contrast(audio_latents, text_latents) + return cl_loss + + clap = CLAP_Model(model) + return clap + +def get_mulan(config): + with open(config, "r") as stream: + mulan_config = yaml.safe_load(stream) + mulan_config = Prodict.from_dict(mulan_config) + ckpt_path = mulan_config.checkpoint_path + mulan = create_MuLaN_from_config(mulan_config) + mulan_embedder = MuLaNEmbedder(mulan, checkpoint_path = ckpt_path) + mulan_embedder.eval() + + return mulan_embedder diff --git a/codeclm/tokenizer/Flow1dVAE/tools/get_whisper_encoder.py b/codeclm/tokenizer/Flow1dVAE/tools/get_whisper_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7a113a269abe5dd7bfbfbb2c9ec3409f22ea6b7a --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/get_whisper_encoder.py @@ -0,0 +1,19 @@ +import torch +from transformers import WhisperProcessor, WhisperForConditionalGeneration + +def get_whisper_encoder(): + processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large-v3").model.encoder + return processor, model.eval() + +if __name__=="__main__": + import numpy as np + processor, model = get_whisper_encoder() + model = model.cuda() + + with torch.no_grad(): + input_features = processor(np.random.rand(16000*30,), sampling_rate=16000, return_tensors="pt").input_features.cuda() + print(input_features.shape) + out = model(input_features.repeat(10,1,1)) + import pdb;pdb.set_trace() + print(list(out.values())[0].shape) diff --git a/codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py b/codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py new file mode 100644 index 0000000000000000000000000000000000000000..4ae71d252f589e9756efd706aab8b69ec391a28b --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py @@ -0,0 +1,47 @@ +import json +import torch +from tqdm import tqdm +import torchaudio +import librosa +import os +import math +import numpy as np +from tools.get_bsrnnvae import get_bsrnnvae +import tools.torch_tools as torch_tools + +class Tango: + def __init__(self, \ + device="cuda:0"): + + self.sample_rate = 44100 + self.device = device + + self.vae = get_bsrnnvae() + self.vae = self.vae.eval().to(device) + + def sound2sound_generate_longterm(self, fname, batch_size=1, duration=15.36, steps=200, disable_progress=False): + """ Genrate audio without condition. """ + num_frames = math.ceil(duration * 100. / 8) + with torch.no_grad(): + orig_samples, fs = torchaudio.load(fname) + if(fs!=44100): + orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100) + fs = 44100 + if(orig_samples.shape[-1]1):init_audio = init_audio[0] + init_audio = torch.from_numpy(init_audio)[None,None,:].to(self.device) + init_audio = init_audio[:,:,int(0*self.sample_rate):int(10.24*3*self.sample_rate)] + if(init_audio.shape[-1]1):init_audio = init_audio[0] + init_audio = torch.from_numpy(init_audio)[None,None,:].to(self.device) + init_audio = init_audio[:,:,0:int(10.24*2*self.sample_rate)] + if(init_audio.shape[-1] 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return normalize_fun(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, filter_length, hop_length, win_length, window="hann"): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) + + if window is not None: + assert filter_length >= win_length + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, size=filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode="reflect", + ) + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + torch.autograd.Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + torch.autograd.Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0] + ) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False + ) + window_sum = window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TacotronSTFT(torch.nn.Module): + def __init__( + self, + filter_length=1024, + hop_length=160, + win_length=1024, + n_mel_channels=64, + sampling_rate=16000, + mel_fmin=0, + mel_fmax=8000., + ): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sr = sampling_rate, n_fft = filter_length, n_mels = n_mel_channels, fmin = mel_fmin, fmax = mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + + def spectral_normalize(self, magnitudes, normalize_fun): + output = dynamic_range_compression(magnitudes, normalize_fun) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y, normalize_fun=torch.log): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert torch.min(y.data) >= -1, torch.min(y.data) + assert torch.max(y.data) <= 1, torch.max(y.data) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output, normalize_fun) + energy = torch.norm(magnitudes, dim=1) + + log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) + + return mel_output, log_magnitudes, energy diff --git a/codeclm/tokenizer/Flow1dVAE/tools/torch_tools.py b/codeclm/tokenizer/Flow1dVAE/tools/torch_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..69a0155001084b57275c79fab24ed573d84d3d8f --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/torch_tools.py @@ -0,0 +1,143 @@ +import torch +import torchaudio +import random +import itertools +import numpy as np +from tools.mix import mix + + +def normalize_wav(waveform): + waveform = waveform - torch.mean(waveform) + waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) + return waveform * 0.5 + + +def pad_wav(waveform, segment_length): + waveform_length = len(waveform) + + if segment_length is None or waveform_length == segment_length: + return waveform + elif waveform_length > segment_length: + return waveform[:segment_length] + else: + pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device) + waveform = torch.cat([waveform, pad_wav]) + return waveform + + +def _pad_spec(fbank, target_length=1024): + batch, n_frames, channels = fbank.shape + p = target_length - n_frames + if p > 0: + pad = torch.zeros(batch, p, channels).to(fbank.device) + fbank = torch.cat([fbank, pad], 1) + elif p < 0: + fbank = fbank[:, :target_length, :] + + if channels % 2 != 0: + fbank = fbank[:, :, :-1] + + return fbank + + +def read_wav_file(filename, segment_length): + waveform, sr = torchaudio.load(filename) # Faster!!! + waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0] + try: + waveform = normalize_wav(waveform) + except: + print ("Exception normalizing:", filename) + waveform = torch.ones(160000) + waveform = pad_wav(waveform, segment_length).unsqueeze(0) + waveform = waveform / torch.max(torch.abs(waveform)) + waveform = 0.5 * waveform + return waveform + + +def get_mel_from_wav(audio, _stft): + audio = torch.nan_to_num(torch.clip(audio, -1, 1)) + audio = torch.autograd.Variable(audio, requires_grad=False) + melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) + return melspec, log_magnitudes_stft, energy + + +def wav_to_fbank(paths, target_length=1024, fn_STFT=None): + assert fn_STFT is not None + + waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160 + + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + fbank = fbank.transpose(1, 2) + log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + + return fbank, log_magnitudes_stft, waveform + +def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None): + assert fn_STFT is not None + + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + fbank = fbank.transpose(1, 2) + log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) + # print(fbank.shape, log_magnitudes_stft.shape) + + if(target_length>0): + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + + return fbank, log_magnitudes_stft, waveform + + +def uncapitalize(s): + if s: + return s[:1].lower() + s[1:] + else: + return "" + + +def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=1024): + sound1 = read_wav_file(path1, target_length * 160)[0].numpy() + sound2 = read_wav_file(path2, target_length * 160)[0].numpy() + mixed_sound = mix(sound1, sound2, 0.5, 16000).reshape(1, -1) + mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2)) + return mixed_sound, mixed_caption + + +def augment(paths, texts, num_items=4, target_length=1024): + mixed_sounds, mixed_captions = [], [] + combinations = list(itertools.combinations(list(range(len(texts))), 2)) + random.shuffle(combinations) + if len(combinations) < num_items: + selected_combinations = combinations + else: + selected_combinations = combinations[:num_items] + + for (i, j) in selected_combinations: + new_sound, new_caption = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length) + mixed_sounds.append(new_sound) + mixed_captions.append(new_caption) + + waveform = torch.tensor(np.concatenate(mixed_sounds, 0)) + waveform = waveform / torch.max(torch.abs(waveform)) + waveform = 0.5 * waveform + + return waveform, mixed_captions + + +def augment_wav_to_fbank(paths, texts, num_items=4, target_length=1024, fn_STFT=None): + assert fn_STFT is not None + + waveform, captions = augment(paths, texts) + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + fbank = fbank.transpose(1, 2) + log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + + return fbank, log_magnitudes_stft, waveform, captions \ No newline at end of file diff --git a/codeclm/tokenizer/Flow1dVAE/tools/transmodelnorm.py b/codeclm/tokenizer/Flow1dVAE/tools/transmodelnorm.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6f47d3042e5d37da4cd2e6a904ade2b61a06db --- /dev/null +++ b/codeclm/tokenizer/Flow1dVAE/tools/transmodelnorm.py @@ -0,0 +1,16 @@ +import torch + +if __name__=="__main__": + src_ckpt = 'saved/train_mulan_v3_48k_everything3/latest/pytorch_model_2.bin' + tgt_ckpt = 'saved/train_mulan_v3_48k_everything3_sepnorm/src_pytorch_model_2.bin' + # src_ckpt = 'saved/train_enhcodec2D_again/latest/pytorch_model_3.bin' + # tgt_ckpt = 'saved/train_enhcodec2D_again_sepnorm/pytorch_model_3.bin' + + ckpt = torch.load(src_ckpt, map_location='cpu') + + ckpt['normfeat.sum_x'] = torch.ones(16, 32, dtype=ckpt['normfeat.sum_x'].dtype) * ckpt['normfeat.sum_x'] / ckpt['normfeat.counts'] + ckpt['normfeat.sum_x2'] = torch.ones(16, 32, dtype=ckpt['normfeat.sum_x2'].dtype) * ckpt['normfeat.sum_x2'] / ckpt['normfeat.counts'] + ckpt['normfeat.sum_target_x2'] = torch.ones(16, 32, dtype=ckpt['normfeat.sum_target_x2'].dtype) * ckpt['normfeat.sum_target_x2'] / ckpt['normfeat.counts'] + ckpt['normfeat.counts'] = torch.ones_like(ckpt['normfeat.counts']) + torch.save(ckpt, tgt_ckpt) + \ No newline at end of file diff --git a/codeclm/tokenizer/__init__.py b/codeclm/tokenizer/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..78c5460d9c42432b0e5908995fc329193293c316 --- /dev/null +++ b/codeclm/tokenizer/__init__.py @@ -0,0 +1 @@ +# no need for training \ No newline at end of file diff --git a/codeclm/tokenizer/audio_tokenizer.py b/codeclm/tokenizer/audio_tokenizer.py new file mode 100755 index 0000000000000000000000000000000000000000..eb9c53714404e90d32782e8a9afb4ec0dcdb6440 --- /dev/null +++ b/codeclm/tokenizer/audio_tokenizer.py @@ -0,0 +1,907 @@ +""" +Tokenizer or wrapper around existing models. +Also defines the main interface that a model must follow to be usable as an audio tokenizer. +""" + +from abc import ABC, abstractmethod +import logging +import typing as tp +import torch +from torch import nn + + +logger = logging.getLogger() + + +class AudioTokenizer(ABC, nn.Module): + """Base API for all compression model that aim at being used as audio tokenizers + with a language model. + """ + + @abstractmethod + def forward(self, x: torch.Tensor) : + ... + + @abstractmethod + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + """See `EncodecModel.encode`.""" + ... + + @abstractmethod + def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None): + """See `EncodecModel.decode`.""" + ... + + @abstractmethod + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + ... + + @property + @abstractmethod + def channels(self) -> int: + ... + + @property + @abstractmethod + def frame_rate(self) -> float: + ... + + @property + @abstractmethod + def sample_rate(self) -> int: + ... + + @property + @abstractmethod + def cardinality(self) -> int: + ... + + @property + @abstractmethod + def num_codebooks(self) -> int: + ... + + @property + @abstractmethod + def total_codebooks(self) -> int: + ... + + @abstractmethod + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer.""" + ... + + @staticmethod + def get_pretrained( + name: str, + vae_config: str, + vae_model: str, + device: tp.Union[torch.device, str] = 'cpu', + mode='extract' + ) -> 'AudioTokenizer': + """Instantiate a AudioTokenizer model from a given pretrained model. + + Args: + name (Path or str): name of the pretrained model. See after. + device (torch.device or str): Device on which the model is loaded. + """ + + model: AudioTokenizer + if name.split('_')[0] == 'Flow1dVAESeparate': + model_type = name.split('_', 1)[1] + logger.info("Getting pretrained compression model from semantic model %s", model_type) + model = Flow1dVAESeparate(model_type, vae_config, vae_model) + elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereo': + model_type = name.split('_', 1)[1] + logger.info("Getting pretrained compression model from semantic model %s", model_type) + model = FlowVocalAndMusicDecoderStereo(model_type, mode=mode) + elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoLayer7': + model_type = name.split('_', 1)[1] + logger.info("Getting pretrained compression model from semantic model %s", model_type) + model = FlowVocalAndMusicDecoderStereoLayer7(model_type, mode=mode) + elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoLayer11': + model_type = name.split('_', 1)[1] + logger.info("Getting pretrained compression model from semantic model %s", model_type) + model = FlowVocalAndMusicDecoderStereoLayer11(model_type, mode=mode) + elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoASRTuneLayer7': + model_type = name.split('_', 1)[1] + logger.info("Getting pretrained compression model from semantic model %s", model_type) + model = FlowVocalAndMusicDecoderStereoASRTuneLayer7(model_type, mode=mode) + elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoASRTuneLayer7Code2': + model_type = name.split('_', 1)[1] + logger.info("Getting pretrained compression model from semantic model %s", model_type) + model = FlowVocalAndMusicDecoderStereoASRTuneLayer7Code2(model_type, mode=mode) + elif name.split('_')[0] == 'FlowVocalAndMusicDecoderStereoASRTuneLayer7Code1': + model_type = name.split('_', 1)[1] + logger.info("Getting pretrained compression model from semantic model %s", model_type) + model = FlowVocalAndMusicDecoderStereoASRTuneLayer7Code1(model_type, mode=mode) + elif name.split('_')[0] == 'Flow1dVAE2rvq': + model_type = name.split('_', 1)[1] + logger.info("Getting pretrained compression model from semantic model %s", model_type) + model = Flow1dVAE2rvq(model_type) + elif name.split('_')[0] == 'Flow1dVAE1rvq': + model_type = name.split('_', 1)[1] + logger.info("Getting pretrained compression model from semantic model %s", model_type) + model = Flow1dVAE1rvq(model_type, vae_config, vae_model) + elif name.split('_')[0] == 'Flow1dVAE4rvq': + model_type = name.split('_', 1)[1] + logger.info("Getting pretrained compression model from semantic model %s", model_type) + model = Flow1dVAE4rvq(model_type) + else: + raise NotImplementedError("{} is not implemented in models/audio_tokenizer.py".format( + name)) + return model.to(device).eval() + + +class FlowVocalAndMusicDecoderStereo(AudioTokenizer): + def __init__( + self, + model_type: str, + sample_rate=48000, + mode = 'extract', + ): + super().__init__() + + from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo import Tango + model_path = model_type + self.mode = mode + if mode == 'extract': + self.model = Tango(model_path=model_path, layer_num=3, load_main_model=False, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + elif mode == 'inference': + self.samplerate = sample_rate + self.model = Tango(model_path=model_path, layer_num=3, load_main_model=True, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + + self.n_quantizers = 1 + + def forward(self, x: torch.Tensor) : + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + if x.ndim == 2: + x = x.unsqueeze(1) + codes = self.model.sound2code(x) # [B T] -> [B N T] + return codes, None + + + @torch.no_grad() + def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9): + wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, + num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + return wav[None] + + + @torch.no_grad() + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + # import pdb; pdb.set_trace() + return self.model.quantizer.from_codes(codes.transpose(1,2))[0] + + @property + def channels(self) -> int: + return 2 + + @property + def frame_rate(self) -> float: + return 25 + + @property + def sample_rate(self) -> int: + return self.samplerate + + @property + def cardinality(self) -> int: + return 10000 + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + # return self.model.RVQ + return 1 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n + +class FlowVocalAndMusicDecoderStereoLayer7(AudioTokenizer): + def __init__( + self, + model_type: str = "pytorch_model_2.bin", + sample_rate=48000, + mode = 'extract', + ): + super().__init__() + + from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_layer7 import Tango + model_path = model_type + self.mode = mode + if mode == 'extract': + self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + elif mode == 'inference': + self.samplerate = sample_rate + self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + # print("Successfully loaded inference scheduler from {}".format(scheduler_name)) + + self.n_quantizers = 1 + + def forward(self, x: torch.Tensor) : + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + if x.ndim == 2: + x = x.unsqueeze(1) + codes = self.model.sound2code(x) # [B T] -> [B N T] + return codes, None + + + @torch.no_grad() + def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9): + wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, + num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + return wav[None] + + + @torch.no_grad() + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + # import pdb; pdb.set_trace() + return self.model.quantizer.from_codes(codes.transpose(1,2))[0] + + @property + def channels(self) -> int: + return 2 + + @property + def frame_rate(self) -> float: + return 25 + + @property + def sample_rate(self) -> int: + return self.samplerate + + @property + def cardinality(self) -> int: + return 10000 + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + # return self.model.RVQ + return 1 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n + +class FlowVocalAndMusicDecoderStereoASRTuneLayer7(AudioTokenizer): + def __init__( + self, + model_type: str = "model_layer7_1x4.safetensors", + sample_rate=48000, + mode = 'extract', + ): + super().__init__() + + from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_7_1x4 import Tango + model_path = model_type + self.mode = mode + if mode == 'extract': + self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + elif mode == 'inference': + self.samplerate = sample_rate + self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + # print("Successfully loaded inference scheduler from {}".format(scheduler_name)) + + self.n_quantizers = 1 + + def forward(self, x: torch.Tensor) : + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + if x.ndim == 2: + x = x.unsqueeze(1) + codes = self.model.sound2code(x) # [B T] -> [B N T] + return codes, None + + + @torch.no_grad() + def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9): + wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, + num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + return wav[None] + + + @torch.no_grad() + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + # import pdb; pdb.set_trace() + return self.model.quantizer.from_codes(codes.transpose(1,2))[0] + + @property + def channels(self) -> int: + return 2 + + @property + def frame_rate(self) -> float: + return 25 + + @property + def sample_rate(self) -> int: + return self.samplerate + + @property + def cardinality(self) -> int: + return 10000 + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + # return self.model.RVQ + return 1 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n +class FlowVocalAndMusicDecoderStereoASRTuneLayer7Code2(AudioTokenizer): + def __init__( + self, + model_type: str = "model_layer7_1x2.safetensors", + sample_rate=48000, + mode = 'extract', + ): + super().__init__() + + from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_7_1x2 import Tango + model_path = model_type + self.mode = mode + if mode == 'extract': + self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + elif mode == 'inference': + self.samplerate = sample_rate + self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + # print("Successfully loaded inference scheduler from {}".format(scheduler_name)) + + self.n_quantizers = 1 + + def forward(self, x: torch.Tensor) : + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + if x.ndim == 2: + x = x.unsqueeze(1) + codes = self.model.sound2code(x) # [B T] -> [B N T] + return codes, None + + + @torch.no_grad() + def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9): + wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, + num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + return wav[None] + + + @torch.no_grad() + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + # import pdb; pdb.set_trace() + return self.model.quantizer.from_codes(codes.transpose(1,2))[0] + + @property + def channels(self) -> int: + return 2 + + @property + def frame_rate(self) -> float: + return 25 + + @property + def sample_rate(self) -> int: + return self.samplerate + + @property + def cardinality(self) -> int: + return 10000 + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + # return self.model.RVQ + return 1 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n +class FlowVocalAndMusicDecoderStereoASRTuneLayer7Code1(AudioTokenizer): + def __init__( + self, + model_type: str = "model_layer7_1x1.safetensors", + sample_rate=48000, + mode = 'extract', + ): + super().__init__() + + from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_7_1x1 import Tango + model_path = model_type + self.mode = mode + if mode == 'extract': + self.model = Tango(model_path=model_path, layer_num=7, load_main_model=False, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + elif mode == 'inference': + self.samplerate = sample_rate + self.model = Tango(model_path=model_path, layer_num=7, load_main_model=True, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + # print("Successfully loaded inference scheduler from {}".format(scheduler_name)) + + self.n_quantizers = 1 + + def forward(self, x: torch.Tensor) : + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + if x.ndim == 2: + x = x.unsqueeze(1) + codes = self.model.sound2code(x) # [B T] -> [B N T] + return codes, None + + + @torch.no_grad() + def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9): + wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, + num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + return wav[None] + + + @torch.no_grad() + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + # import pdb; pdb.set_trace() + return self.model.quantizer.from_codes(codes.transpose(1,2))[0] + + @property + def channels(self) -> int: + return 2 + + @property + def frame_rate(self) -> float: + return 25 + + @property + def sample_rate(self) -> int: + return self.samplerate + + @property + def cardinality(self) -> int: + return 10000 + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + # return self.model.RVQ + return 1 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n +class Flow1dVAE2rvq(AudioTokenizer): + def __init__( + self, + model_type: str = "model_2.safetensors", + ): + super().__init__() + + from codeclm.tokenizer.Flow1dVAE.generate_2rvq import Tango + model_path = model_type + self.model = Tango(model_path=model_path, rvq_num=2, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + + + self.n_quantizers = 1 + + def forward(self, x: torch.Tensor) : + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + if x.ndim == 2: + x = x.unsqueeze(1) + codes = self.model.sound2code(x) # [B T] -> [B N T] + return codes, None + + + @torch.no_grad() + def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9): + wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5, + num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + return wav[None] + + + @torch.no_grad() + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + # import pdb; pdb.set_trace() + return self.model.quantizer.from_codes(codes.transpose(1,2))[0] + + @property + def channels(self) -> int: + return 2 + + @property + def frame_rate(self) -> float: + return 25 + + @property + def sample_rate(self) -> int: + return self.samplerate + + @property + def cardinality(self) -> int: + return 10000 + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + # return self.model.RVQ + return 1 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n +class Flow1dVAE1rvq(AudioTokenizer): + def __init__( + self, + model_type: str = "model_2_fixed.safetensors", + vae_config: str = "", + vae_model: str = "", + ): + super().__init__() + + from codeclm.tokenizer.Flow1dVAE.generate_1rvq import Tango + model_path = model_type + self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + + + self.n_quantizers = 1 + + def forward(self, x: torch.Tensor) : + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + if x.ndim == 2: + x = x.unsqueeze(1) + codes = self.model.sound2code(x) # [B T] -> [B N T] + return codes, None + + + @torch.no_grad() + def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9): + wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5, + num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + return wav[None] + + + @torch.no_grad() + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + # import pdb; pdb.set_trace() + return self.model.quantizer.from_codes(codes.transpose(1,2))[0] + + @property + def channels(self) -> int: + return 2 + + @property + def frame_rate(self) -> float: + return 25 + + @property + def sample_rate(self) -> int: + return self.samplerate + + @property + def cardinality(self) -> int: + return 10000 + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + # return self.model.RVQ + return 1 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n +class Flow1dVAE4rvq(AudioTokenizer): + def __init__( + self, + model_type: str = "model_2.safetensors", + ): + super().__init__() + + from codeclm.tokenizer.Flow1dVAE.generate_4rvq import Tango + model_path = model_type + self.model = Tango(model_path=model_path, rvq_num=4, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + + + self.n_quantizers = 1 + + def forward(self, x: torch.Tensor) : + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + if x.ndim == 2: + x = x.unsqueeze(1) + codes = self.model.sound2code(x) # [B T] -> [B N T] + return codes, None + + + @torch.no_grad() + def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9): + wav = self.model.code2sound(codes, prompt=prompt, guidance_scale=1.5, + num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + return wav[None] + + + @torch.no_grad() + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + # import pdb; pdb.set_trace() + return self.model.quantizer.from_codes(codes.transpose(1,2))[0] + + @property + def channels(self) -> int: + return 2 + + @property + def frame_rate(self) -> float: + return 25 + + @property + def sample_rate(self) -> int: + return self.samplerate + + @property + def cardinality(self) -> int: + return 10000 + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + # return self.model.RVQ + return 1 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n + + + +class Flow1dVAESeparate(AudioTokenizer): + def __init__( + self, + model_type: str = "model_2.safetensors", + vae_config: str = "", + vae_model: str = "", + ): + super().__init__() + + from codeclm.tokenizer.Flow1dVAE.generate_septoken import Tango + model_path = model_type + self.model = Tango(model_path=model_path, vae_config=vae_config, vae_model=vae_model, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + + + self.n_quantizers = 1 + + def forward(self, x: torch.Tensor) : + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + @torch.no_grad() + def encode(self, x_vocal: torch.Tensor, x_bgm: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + if x_vocal.ndim == 2: + x_vocal = x_vocal.unsqueeze(1) + if x_bgm.ndim == 2: + x_bgm = x_bgm.unsqueeze(1) + codes_vocal, codes_bgm = self.model.sound2code(x_vocal, x_bgm) + return codes_vocal, codes_bgm + + @torch.no_grad() + def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None): + wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5, + num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + return wav[None] + + + @torch.no_grad() + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + # import pdb; pdb.set_trace() + return self.model.quantizer.from_codes(codes.transpose(1,2))[0] + + @property + def channels(self) -> int: + return 2 + + @property + def frame_rate(self) -> float: + return 25 + + @property + def sample_rate(self) -> int: + return self.samplerate + + @property + def cardinality(self) -> int: + return 10000 + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + # return self.model.RVQ + return 1 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n + +class FlowVocalAndMusicDecoderStereoLayer11(AudioTokenizer): + def __init__( + self, + model_type: str = "layer11_ckpt.pth", + sample_rate=48000, + mode = 'extract', + ): + super().__init__() + + from codeclm.tokenizer.FlowVocalAndMusicDecoderStereoV014.generate_stereo_11 import Tango + model_path = model_type + self.mode = mode + if mode == 'extract': + self.model = Tango(model_path=model_path, layer_num=11, load_main_model=False, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + elif mode == 'inference': + self.samplerate = sample_rate + self.model = Tango(model_path=model_path, layer_num=11, load_main_model=True, device='cuda') + print ("Successfully loaded checkpoint from:", model_path) + # print("Successfully loaded inference scheduler from {}".format(scheduler_name)) + + self.n_quantizers = 1 + + def forward(self, x: torch.Tensor) : + # We don't support training with this. + raise NotImplementedError("Forward and training with DAC not supported.") + + @torch.no_grad() + def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: + if x.ndim == 2: + x = x.unsqueeze(1) + codes = self.model.sound2code(x) # [B T] -> [B N T] + return codes, None + + + @torch.no_grad() + def decode(self, codes: torch.Tensor, prompt = None, scale: tp.Optional[torch.Tensor] = None, ncodes=9): + wav = self.model.code2sound(codes, prompt=prompt, duration=40.96, guidance_scale=1.5, + num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + return wav[None] + + + @torch.no_grad() + def decode_latent(self, codes: torch.Tensor): + """Decode from the discrete codes to continuous latent space.""" + # import pdb; pdb.set_trace() + return self.model.quantizer.from_codes(codes.transpose(1,2))[0] + + @property + def channels(self) -> int: + return 2 + + @property + def frame_rate(self) -> float: + return 25 + + @property + def sample_rate(self) -> int: + return self.samplerate + + @property + def cardinality(self) -> int: + return 10000 + + @property + def num_codebooks(self) -> int: + return self.n_quantizers + + @property + def total_codebooks(self) -> int: + # return self.model.RVQ + return 1 + + def set_num_codebooks(self, n: int): + """Set the active number of codebooks used by the quantizer. + """ + assert n >= 1 + assert n <= self.total_codebooks + self.n_quantizers = n + + + \ No newline at end of file diff --git a/codeclm/trainer/codec_song_pl.py b/codeclm/trainer/codec_song_pl.py new file mode 100755 index 0000000000000000000000000000000000000000..917501f3fe348bf3c1335edbc2a9bdd296acb188 --- /dev/null +++ b/codeclm/trainer/codec_song_pl.py @@ -0,0 +1,685 @@ +""" +Main model for using CodecLM. This will combine all the required components +and provide easy access to the generation API. +""" + +import typing as tp +import warnings +import sys +import time +import torch +import torch.nn as nn +from torch.nn import functional as F +import torchaudio +import numpy as np +import lightning as pl +from torchmetrics.classification import MulticlassAccuracy +import pdb +from codeclm.models import builders +import math +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from peft import LoraConfig, get_peft_model +from datetime import datetime +import os +os.environ['TOKENIZERS_PARALLELISM'] = "false" + + +class CodecLM_PL(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + # 1) Build audio tokenizer (usually None during training) + self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg) + if self.audio_tokenizer is not None: + for param in self.audio_tokenizer.parameters(): + param.requires_grad = False + if "audio_tokenizer_checkpoint_sep" in self.cfg.keys(): + self.seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg) + for param in self.seperate_tokenizer.parameters(): + param.requires_grad = False + else: + self.seperate_tokenizer = None + + # 2) Build LM + self.audiolm = builders.get_lm_model(self.cfg) + print(self.audiolm) + # 输出参数量 + print('Number of parameters: ', sum(p.numel() for p in self.audiolm.parameters())) + # 3) Load pretrained checkpoint (if any) + if self.cfg.use_pretrained == 'deepspeed': + checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint, map_location='cpu') + missing, unexpected = self.load_state_dict(checkpoint, strict=False) + print(f'-------------Missing--------------\n{missing}') + print(f'-------------Unexpected--------------\n{unexpected}') + print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint)) + self.missing = missing + else: + self.missing = [] + # 如果cfg参数中有lora + if hasattr(self.cfg, 'lora'): + perf_config = LoraConfig( + r = self.cfg.lora.r, + lora_alpha = self.cfg.lora.lora_alpha, + target_modules = self.cfg.lora.target_modules, + lora_dropout = self.cfg.lora.lora_dropout, + bias = self.cfg.lora.bias, + task_type = self.cfg.lora.task_type, + ) + self.audiolm = get_peft_model(self.audiolm, perf_config) + + # 4) Build metrics + self.val_steps = [] + self.train_slide_acc = [] + self.train_steps = [] + self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy( + self.audiolm.code_size, + top_k=1, + average="micro", multidim_average="global", + ignore_index=self.cfg.lm.code_size, # ignore EOS token prediction + ) for _ in range(self.audiolm.code_depth)]) + self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy( + self.audiolm.code_size, + top_k=10, + average="micro", multidim_average="global", + ignore_index=self.cfg.lm.code_size, + ) for _ in range(self.audiolm.code_depth)]) + + self.epoch = 0 + print("++++++++++++++++ training +++++++++++++++++") + + # TODO: move this part to loader + def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384): + batch_size = sequence_lengths.size(0) + max_length = x.size(2) + + # pad one frame, if the maximum sequence length is equal to the input length + if max_length == sequence_lengths.max(): + x = F.pad(x, (0, 1), value=end_id) + max_length = x.size(2) + + if max_length <= sequence_lengths.max() + 1: + sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length) + + # Add end token to x according to the sequence length + x[torch.arange(batch_size), :, sequence_lengths] = end_id + sequence_lengths += 1 + + mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1) + mask = mask.to(x.device) + mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length) + x = torch.where(mask_3d, x, end_id+1) + return x, mask_3d + + @torch.no_grad() + def preprocess_batch(self, batch): # this function is usually called during training + # 处理 dataloader 返回的数据 + audio, text_lyric, time_stamp, structure_dur, prompt_audio, structure_labels = batch + + dur, valid_st, valid_et = zip(*time_stamp) + + if self.audio_tokenizer is not None: + # only used in inference + self.audio_tokenizer.eval() + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + audio_tokens, scale = self.audio_tokenizer.encode(audio) + audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:] + audio_tokens = audio_tokens.long() + else: + audio_tokens = audio.long() + + token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int() + audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur, + end_id=self.audiolm.eos_token_id) + condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric), + text=text_lyric, audio_qt_emb=prompt_audio) + + return condition_tensors, audio_tokens, audio_padding_mask + + def get_time(self): + # 获取当前的日期和时间 + now = datetime.now() + + # 使用strftime函数格式化日期和时间 + formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f") + return formatted_now + + def training_step(self, batch, batch_idx): + # 1) data processing + condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) + + # 2) compute model predictions (model forward) + model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors, + training_steps=self.global_step) # this input can be ignored + logits = model_output.logits.float() + mask = padding_mask & model_output.mask + + # 3) compute loss (float) + with torch.cuda.amp.autocast(enabled=False): + ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) + + total_loss = ce + if torch.isnan(total_loss): + print(self.trainer.global_rank, ce, padding_mask, batch[1]) + print('--------------------------------------------------------------') + return None + # torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000) + # import pdb; pdb.set_trace() + # 4) compute metrics and log + metrics = {} + self.log('ce', ce, prog_bar=True) + metrics['ppl'] = torch.exp(ce) + for k, ce_q in enumerate(ce_per_codebook): + metrics[f'ce_q{k + 1}'] = ce_q + metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) + + masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) + metrics['acc'] = [] + for k in range(self.audiolm.code_depth): + metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), + masked_labels[:, k]).item()) + metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item() + + self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']}) + self.log('train_acc', metrics['acc']+1e-8, prog_bar=True) + self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True) + self.log_dict(metrics) + + return total_loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + # 1) data processing + condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) + + # 2) compute model predictions + model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors) + logits = model_output.logits + mask = padding_mask & model_output.mask + + # 3) compute loss and metrics + ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) + metrics = {} + metrics['val_ce'] = ce + metrics['val_ppl'] = torch.exp(ce) + for k, ce_q in enumerate(ce_per_codebook): + metrics[f'val_ce_q{k + 1}'] = ce_q + metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q) + masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) + + for k in range(self.audiolm.code_depth): + self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) #* total_length + self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) + self.val_steps.append(metrics) + + metrics['acc'] = [] + metrics['acc_top10'] = [] + for k in range(self.audiolm.code_depth): + metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) + metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) + metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])) + metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10'])) + + return metrics['acc'] + + + def on_validation_epoch_end(self) -> None: + final_metrics = {} + for i in self.val_steps: + for k in i: + final_metrics[k] = final_metrics.get(k, []) + [i[k]] + final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())} + self.log_dict(final_metrics) + + q_acc = [] + q_acc10 = [] + for i in range(self.audiolm.code_depth): + q_acc.append(self.top1_acc_metric[i].compute()) + q_acc10.append(self.top10_acc_metric[i].compute()) + self.log(f"val_Top1Acc_{i}", q_acc[-1]) + self.log(f"val_Top10Acc_{i}", q_acc10[-1]) + self.top1_acc_metric[i].reset() + self.top10_acc_metric[i].reset() + + self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth) + self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth) + + return super().on_validation_epoch_end() + + + def on_validation_epoch_start(self) -> None: + self.val_steps = [] + for i in range(self.audiolm.code_depth): + self.top1_acc_metric[i].reset() + self.top10_acc_metric[i].reset() + + if len(self.train_steps) > 0: + train_metrics = {} + for i in self.train_steps: + for k in i: + train_metrics[k] = train_metrics.get(k, []) + [i[k]] + train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())} + self.log('train_summary_Top1Acc', train_metrics['acc']) + self.log('train_summary_ce', train_metrics['ce']) + self.train_steps = [] + + return super().on_validation_epoch_start() + + + # 定义优化器 + def configure_optimizers(self): + total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch + optim_dict = {} + + param_groups = [] + missing_params = [] + other_params = [] + cnt = 0 + # 去掉开头的‘audiolm.' + print('before missing len', len(self.missing)) + self.missing = [name.replace('audiolm.', '') for name in self.missing] + print('after missing len', len(self.missing)) + for name, param in self.audiolm.named_parameters(): + if name in self.missing: + cnt += 1 + print(name) + missing_params.append(param) + else: + other_params.append(param) + print(cnt) + assert cnt == len(self.missing) + param_groups.append({'params': other_params, 'lr': self.cfg.optim.old_lr}) + param_groups.append({ + 'params': missing_params, + 'lr': self.cfg.optim.new_lr # 为missing参数设置10倍的学习率,你可以调整这个倍数 + }) + + if self.cfg.optim.optimizer == "adamw": + optim_dict['optimizer'] = torch.optim.AdamW( + param_groups, # 使用参数分组替代原来的 self.audiolm.parameters() + betas=tuple(self.cfg.optim.adam.betas), + weight_decay=self.cfg.optim.adam.weight_decay, + eps=self.cfg.optim.adam.eps, + ) + else: + raise NotImplementedError + + if self.cfg.schedule is None: + pass + elif self.cfg.schedule.lr_scheduler == "cosine": + scheduler = CosineLRScheduler(optim_dict['optimizer'], + total_steps=total_updates, + warmup_steps=self.cfg.schedule.cosine.warmup, + lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio, + cycle_length=self.cfg.schedule.cosine.cycle_length, + ) + optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"} + else: + raise NotImplementedError + + return optim_dict + + + def _compute_cross_entropy( + self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: + """Compute cross entropy between multi-codebook targets and model's logits. + The cross entropy is computed per codebook to provide codebook-level cross entropy. + Valid timesteps for each of the codebook are pulled from the mask, where invalid + timesteps are set to 0. + + Args: + logits (torch.Tensor): Model's logits of shape [B, K, T, card]. + targets (torch.Tensor): Target codes, of shape [B, K, T]. + mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. + Returns: + ce (torch.Tensor): Cross entropy averaged over the codebooks + ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). + """ + # import pdb; pdb.set_trace() + B, K, T = targets.shape + assert logits.shape[:-1] == targets.shape + assert mask.shape == targets.shape + ce = torch.zeros([], device=targets.device) + ce_per_codebook: tp.List[torch.Tensor] = [] + for k in range(K): + logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] + targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] + mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] + ce_targets = targets_k[mask_k] + ce_logits = logits_k[mask_k] + q_ce = F.cross_entropy(ce_logits, ce_targets) + ce += q_ce + ce_per_codebook.append(q_ce.detach()) + # average cross entropy across codebooks + ce = ce / K + return ce, ce_per_codebook + + +class CodecLM_PL_FT(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + # 1) Build audio tokenizer (usually None during training) + self.audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg) + if self.audio_tokenizer is not None: + for param in self.audio_tokenizer.parameters(): + param.requires_grad = False + + # 2) Build LM + self.audiolm = builders.get_lm_model(self.cfg) + + # 3) Load pretrained checkpoint (if any) + if self.cfg.use_pretrained == 'deepspeed': + checkpoint = torch.load(self.cfg.pretrained.deepspeed_checkpoint, map_location='cpu') + missing, unexpected = self.load_state_dict(checkpoint, strict=False) + print(f'-------------Missing--------------\n{missing}') + print(f'-------------Unexpected--------------\n{unexpected}') + print("successfully load deepspeed pretrained model {}".format(self.cfg.pretrained.deepspeed_checkpoint)) + + # 4) Build metrics + self.val_steps = [] + self.train_slide_acc = [] + self.train_steps = [] + self.top1_acc_metric = nn.ModuleList([MulticlassAccuracy( + self.audiolm.code_size, + top_k=1, + average="micro", multidim_average="global", + ignore_index=self.cfg.lm.code_size, # ignore EOS token prediction + ) for _ in range(self.audiolm.code_depth)]) + self.top10_acc_metric = nn.ModuleList([MulticlassAccuracy( + self.audiolm.code_size, + top_k=10, + average="micro", multidim_average="global", + ignore_index=self.cfg.lm.code_size, + ) for _ in range(self.audiolm.code_depth)]) + + self.epoch = 0 + print("++++++++++++++++ training +++++++++++++++++") + + # TODO: move this part to loader + def generate_mask_and_end_token(self, x, sequence_lengths, end_id=16384): + batch_size = sequence_lengths.size(0) + max_length = x.size(2) + + # pad one frame, if the maximum sequence length is equal to the input length + if max_length == sequence_lengths.max(): + x = F.pad(x, (0, 1), value=end_id) + max_length = x.size(2) + + if max_length <= sequence_lengths.max() + 1: + sequence_lengths = sequence_lengths - (sequence_lengths.max()+1 - max_length) + + # Add end token to x according to the sequence length + x[torch.arange(batch_size), :, sequence_lengths] = end_id + sequence_lengths += 1 + + mask = torch.arange(max_length).expand(batch_size, max_length) < sequence_lengths.unsqueeze(1) + mask = mask.to(x.device) + mask_3d = mask.unsqueeze(1).expand(batch_size, x.size(1), max_length) + x = torch.where(mask_3d, x, end_id+1) + return x, mask_3d + + @torch.no_grad() + def preprocess_batch(self, batch): # this function is usually called during training + # 处理 dataloader 返回的数据 + audio, text_lyric, time_stamp, lang_type, prompt_audio = batch + dur, valid_st, valid_et = zip(*time_stamp) + + if self.audio_tokenizer is not None: + # only used in inference + self.audio_tokenizer.eval() + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + audio_tokens, scale = self.audio_tokenizer.encode(audio) + audio_tokens = audio_tokens[:,:self.cfg.lm.code_depth,:] + audio_tokens = audio_tokens.long() + else: + audio_tokens = audio.long() + + token_dur = (torch.Tensor(dur) * self.cfg.audio_tokenizer_frame_rate).int() + + audio_tokens, audio_padding_mask = self.generate_mask_and_end_token(audio_tokens, token_dur, + end_id=self.audiolm.eos_token_id) + condition_tensors = self.audiolm.prepare_condition_tensors(batch_size=len(text_lyric), + text=text_lyric, audio_qt_emb=prompt_audio) + + return condition_tensors, audio_tokens, audio_padding_mask + + def get_time(self): + # 获取当前的日期和时间 + now = datetime.now() + + # 使用strftime函数格式化日期和时间 + formatted_now = now.strftime("%Y-%m-%d %H:%M:%S.%f") + return formatted_now + + def training_step(self, batch, batch_idx): + # 1) data processing + condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) + + # 2) compute model predictions (model forward) + model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors, + training_steps=self.global_step) # this input can be ignored + logits = model_output.logits.float() + mask = padding_mask & model_output.mask + + # 3) compute loss (float) + with torch.cuda.amp.autocast(enabled=False): + ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) + + total_loss = ce + if torch.isnan(total_loss): + print(self.trainer.global_rank, ce, padding_mask, batch[1]) + # print('------------------------------------------------------------------------') + torchaudio.save("error_rank{}.wav".format(self.trainer.global_rank), batch[0][:,0].cpu(), 24000) + import pdb; pdb.set_trace() + return None + + # 4) compute metrics and log + metrics = {} + self.log('ce', ce, prog_bar=True) + metrics['ppl'] = torch.exp(ce) + for k, ce_q in enumerate(ce_per_codebook): + metrics[f'ce_q{k + 1}'] = ce_q + metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q) + + masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) + metrics['acc'] = [] + for k in range(self.audiolm.code_depth): + metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), + masked_labels[:, k]).item()) + metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])).item() + + self.train_steps.append({'ce': ce.detach().cpu().item(), 'acc': metrics['acc']}) + self.log('train_acc', metrics['acc']+1e-8, prog_bar=True) + self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], prog_bar=True) + self.log_dict(metrics) + + return total_loss + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + # 1) data processing + condition_tensors, audio_tokens, padding_mask = self.preprocess_batch(batch) + + # 2) compute model predictions + model_output = self.audiolm.compute_predictions(audio_tokens, condition_tensors) + logits = model_output.logits + mask = padding_mask & model_output.mask + + # 3) compute loss and metrics + ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask) + metrics = {} + metrics['val_ce'] = ce + metrics['val_ppl'] = torch.exp(ce) + for k, ce_q in enumerate(ce_per_codebook): + metrics[f'val_ce_q{k + 1}'] = ce_q + metrics[f'val_ppl_q{k + 1}'] = torch.exp(ce_q) + masked_labels = audio_tokens.masked_fill(~mask, value=self.cfg.lm.code_size) + + for k in range(self.audiolm.code_depth): + self.top1_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) #* total_length + self.top10_acc_metric[k].update(logits[:, k].transpose(1,2).detach(), masked_labels[:,k]) + self.val_steps.append(metrics) + metrics['acc'] = [] + metrics['acc_top10'] = [] + for k in range(self.audiolm.code_depth): + metrics['acc'].append(self.top1_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) + metrics['acc_top10'].append(self.top10_acc_metric[k](logits[:, k].transpose(1,2).detach(), masked_labels[:,k]).item()) + metrics['acc'] = torch.mean(torch.Tensor(metrics['acc'])) + metrics['acc_top10'] = torch.mean(torch.Tensor(metrics['acc_top10'])) + + return metrics['acc'] + + def on_validation_epoch_end(self) -> None: + final_metrics = {} + for i in self.val_steps: + for k in i: + final_metrics[k] = final_metrics.get(k, []) + [i[k]] + final_metrics = {k: sum(v) / len(v) for k,v in list(final_metrics.items())} + self.log_dict(final_metrics) + + q_acc = [] + q_acc10 = [] + for i in range(self.audiolm.code_depth): + q_acc.append(self.top1_acc_metric[i].compute()) + q_acc10.append(self.top10_acc_metric[i].compute()) + self.log(f"val_Top1Acc_{i}", q_acc[-1]) + self.log(f"val_Top10Acc_{i}", q_acc10[-1]) + self.top1_acc_metric[i].reset() + self.top10_acc_metric[i].reset() + + self.log('val_Top1Acc', sum(q_acc) / self.audiolm.code_depth) + self.log('val_Top10Acc', sum(q_acc10) / self.audiolm.code_depth) + + return super().on_validation_epoch_end() + + + def on_validation_epoch_start(self) -> None: + self.val_steps = [] + for i in range(self.audiolm.code_depth): + self.top1_acc_metric[i].reset() + self.top10_acc_metric[i].reset() + + if len(self.train_steps) > 0: + train_metrics = {} + for i in self.train_steps: + for k in i: + train_metrics[k] = train_metrics.get(k, []) + [i[k]] + train_metrics = {k: sum(v) / len(v) for k,v in list(train_metrics.items())} + self.log('train_summary_Top1Acc', train_metrics['acc']) + self.log('train_summary_ce', train_metrics['ce']) + self.train_steps = [] + + return super().on_validation_epoch_start() + + + # 定义优化器 + def configure_optimizers(self): + total_updates = self.cfg.optim.epochs * self.cfg.optim.updates_per_epoch + optim_dict = {} + + if self.cfg.optim.optimizer == "adamw": + optim_dict['optimizer'] = torch.optim.AdamW( + self.audiolm.parameters(), + lr=self.cfg.optim.lr, + betas=tuple(self.cfg.optim.adam.betas), + weight_decay=self.cfg.optim.adam.weight_decay, + eps=self.cfg.optim.adam.eps, + ) + else: + raise NotImplementedError + + if self.cfg.schedule is None: + pass + elif self.cfg.schedule.lr_scheduler == "cosine": + scheduler = CosineLRScheduler(optim_dict['optimizer'], + total_steps=total_updates, + warmup_steps=self.cfg.schedule.cosine.warmup, + lr_min_ratio=self.cfg.schedule.cosine.lr_min_ratio, + cycle_length=self.cfg.schedule.cosine.cycle_length, + ) + optim_dict['lr_scheduler'] = {"scheduler": scheduler, "interval": "step"} + else: + raise NotImplementedError + + return optim_dict + + + def _compute_cross_entropy( + self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor + ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]: + """Compute cross entropy between multi-codebook targets and model's logits. + The cross entropy is computed per codebook to provide codebook-level cross entropy. + Valid timesteps for each of the codebook are pulled from the mask, where invalid + timesteps are set to 0. + + Args: + logits (torch.Tensor): Model's logits of shape [B, K, T, card]. + targets (torch.Tensor): Target codes, of shape [B, K, T]. + mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T]. + Returns: + ce (torch.Tensor): Cross entropy averaged over the codebooks + ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached). + """ + # import pdb; pdb.set_trace() + B, K, T = targets.shape + assert logits.shape[:-1] == targets.shape + assert mask.shape == targets.shape + ce = torch.zeros([], device=targets.device) + ce_per_codebook: tp.List[torch.Tensor] = [] + for k in range(K): + logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card] + targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T] + mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T] + ce_targets = targets_k[mask_k] + ce_logits = logits_k[mask_k] + q_ce = F.cross_entropy(ce_logits, ce_targets) + ce += q_ce + ce_per_codebook.append(q_ce.detach()) + # average cross entropy across codebooks + ce = ce / K + return ce, ce_per_codebook + +class CosineLRScheduler(_LRScheduler):# + """Cosine LR scheduler. + + Args: + optimizer (Optimizer): Torch optimizer. + warmup_steps (int): Number of warmup steps. + total_steps (int): Total number of steps. + lr_min_ratio (float): Minimum learning rate. + cycle_length (float): Cycle length. + """ + def __init__(self, optimizer: Optimizer, total_steps: int, warmup_steps: int, + lr_min_ratio: float = 0.0, cycle_length: float = 1.0): + self.warmup_steps = warmup_steps + assert self.warmup_steps >= 0 + self.total_steps = total_steps + assert self.total_steps >= 0 + self.lr_min_ratio = lr_min_ratio + self.cycle_length = cycle_length + super().__init__(optimizer) + + def _get_sched_lr(self, lr: float, step: int): + if step < self.warmup_steps: + lr_ratio = step / self.warmup_steps + lr = lr_ratio * lr + elif step <= self.total_steps: + s = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps) + lr_ratio = self.lr_min_ratio + 0.5 * (1 - self.lr_min_ratio) * \ + (1. + math.cos(math.pi * s / self.cycle_length)) + lr = lr_ratio * lr + else: + lr_ratio = self.lr_min_ratio + lr = lr_ratio * lr + return lr + + def get_lr(self): + return [self._get_sched_lr(lr, self.last_epoch) for lr in self.base_lrs] diff --git a/codeclm/utils/autocast.py b/codeclm/utils/autocast.py new file mode 100644 index 0000000000000000000000000000000000000000..f3dbdc64331c09b94508f5ffa72a1046b5c4295f --- /dev/null +++ b/codeclm/utils/autocast.py @@ -0,0 +1,34 @@ +import torch + + +class TorchAutocast: + """TorchAutocast utility class. + Allows you to enable and disable autocast. This is specially useful + when dealing with different architectures and clusters with different + levels of support. + + Args: + enabled (bool): Whether to enable torch.autocast or not. + args: Additional args for torch.autocast. + kwargs: Additional kwargs for torch.autocast + """ + def __init__(self, enabled: bool, *args, **kwargs): + self.autocast = torch.autocast(*args, **kwargs) if enabled else None + + def __enter__(self): + if self.autocast is None: + return + try: + self.autocast.__enter__() + except RuntimeError: + device = self.autocast.device + dtype = self.autocast.fast_dtype + raise RuntimeError( + f"There was an error autocasting with dtype={dtype} device={device}\n" + "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" + ) + + def __exit__(self, *args, **kwargs): + if self.autocast is None: + return + self.autocast.__exit__(*args, **kwargs) diff --git a/codeclm/utils/utils.py b/codeclm/utils/utils.py new file mode 100755 index 0000000000000000000000000000000000000000..816c750fc3910fbd688ebb276f581c8b921f182e --- /dev/null +++ b/codeclm/utils/utils.py @@ -0,0 +1,196 @@ +from concurrent.futures import ProcessPoolExecutor +from contextlib import contextmanager +from functools import wraps, lru_cache +import hashlib +import json +import logging +from pathlib import Path +import typing as tp +import math +from torch import nn +import typing as tp +from functools import partial +import torch.nn.functional as F +import flashy +import flashy.distrib +import omegaconf +import torch +from torch.nn.utils.rnn import pad_sequence + +def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor: + """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences). + For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]] + + Args: + lengths (torch.Tensor): tensor with lengths + max_len (int): can set the max length manually. Defaults to None. + Returns: + torch.Tensor: mask with 0s where there is pad tokens else 1s + """ + assert len(lengths.shape) == 1, "Length shape should be 1 dimensional." + final_length = lengths.max().item() if not max_len else max_len + final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor + return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None] + + + +def dict_from_config(cfg: omegaconf.DictConfig) -> dict: + """Convenience function to map an omegaconf configuration to a dictionary. + + Args: + cfg (omegaconf.DictConfig): Original configuration to map to dict. + Returns: + dict: Config as dictionary object. + """ + dct = omegaconf.OmegaConf.to_container(cfg, resolve=True) + assert isinstance(dct, dict) + return dct + +def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: + """Create normalization module for transformer encoder layer. + + Args: + norm_type (str): Normalization method. + dim (int): Dimension of the normalized layer. + **kwargs (dict): Additional parameters for normalization layer. + Returns: + nn.Module: Normalization module. + """ + if norm_type == 'layer_norm': + return nn.LayerNorm(dim, eps=1e-5, **kwargs) + else: + raise ValueError(f"Unknown norm type: {norm_type}") + +def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): + """LM layer initialization. + Inspired from xlformers: https://github.com/fairinternal/xlformers + + Args: + method (str): Method name for init function. Valid options are: + 'gaussian', 'uniform'. + input_dim (int): Input dimension of the initialized module. + init_depth (int, optional): Optional init depth value used to rescale + the standard deviation if defined. + """ + # Compute std + std = 1 / math.sqrt(input_dim) + # Rescale with depth + if init_depth is not None: + std = std / math.sqrt(2 * init_depth) + + if method == 'gaussian': + return partial( + torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std + ) + elif method == 'uniform': + bound = math.sqrt(3) * std # ensure the standard deviation is `std` + return partial(torch.nn.init.uniform_, a=-bound, b=bound) + else: + raise ValueError("Unsupported layer initialization method") + +def init_layer(m: nn.Module, + method: str, + init_depth: tp.Optional[int] = None, + zero_bias_init: bool = False): + """Wrapper around ``get_init_fn`` for proper initialization of LM modules. + + Args: + m (nn.Module): Module to initialize. + method (str): Method name for the init function. + init_depth (int, optional): Optional init depth value used to rescale + the standard deviation if defined. + zero_bias_init (bool): Whether to initialize the bias to 0 or not. + """ + if isinstance(m, nn.Linear): + init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) + if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: + weight = m.weight.float() + init_fn(weight) + m.weight.data[:] = weight.half() + else: + init_fn(m.weight) + if zero_bias_init and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Embedding): + init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) + if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: + weight = m.weight.float() + init_fn(weight) + m.weight.data[:] = weight.half() + else: + init_fn(m.weight) + +def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """Get a list of tensors and collate them to a single tensor. according to the following logic: + - `dim` specifies the time dimension which will be stacked and padded. + - The output will contain 1 new dimension (dimension index 0) which will be the size of + of the original list. + + Args: + tensors (tp.List[torch.Tensor]): List of tensors to collate. + dim (int): Dimension which will be stacked and padded. + Returns: + tp.Tuple[torch.Tensor, torch.Tensor]: + torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension + (dimension index 0) which will be the size of the original list. + torch.Tensor: Tensor containing length of original tensor sizes (without padding). + """ + tensors = [x.transpose(0, dim) for x in tensors] + lens = torch.LongTensor([len(x) for x in tensors]) + padded_tensors = pad_sequence(tensors) + padded_tensors = padded_tensors.transpose(0, 1) + padded_tensors = padded_tensors.transpose(1, dim + 1) + return padded_tensors, lens + +def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor: + """Sample next token from top K values along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + k (int): The k in “top-k”. + Returns: + torch.Tensor: Sampled tokens. + """ + top_k_value, _ = torch.topk(probs, k, dim=-1) + min_value_top_k = top_k_value[..., [-1]] + probs *= (probs >= min_value_top_k).float() + probs.div_(probs.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs, num_samples=1) + return next_token + +def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor: + """Sample next token from top P probabilities along the last dimension of the input probs tensor. + + Args: + probs (torch.Tensor): Input probabilities with token candidates on the last dimension. + p (int): The p in “top-p”. + Returns: + torch.Tensor: Sampled tokens. + """ + probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) + probs_sum = torch.cumsum(probs_sort, dim=-1) + mask = probs_sum - probs_sort > p + probs_sort *= (~mask).float() + probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) + next_token = multinomial(probs_sort, num_samples=1) + next_token = torch.gather(probs_idx, -1, next_token) + return next_token + +def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None): + """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension. + + Args: + input (torch.Tensor): The input tensor containing probabilities. + num_samples (int): Number of samples to draw. + replacement (bool): Whether to draw with replacement or not. + Keywords args: + generator (torch.Generator): A pseudorandom number generator for sampling. + Returns: + torch.Tensor: Last dimension contains num_samples indices + sampled from the multinomial probability distribution + located in the last dimension of tensor input. + """ + input_ = input.reshape(-1, input.shape[-1]) + output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator) + output = output_.reshape(*list(input.shape[:-1]), -1) + return output \ No newline at end of file diff --git a/conf/infer.yaml b/conf/infer.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e1c89660fca830e337d8bf5ace9f93d4aae78973 --- /dev/null +++ b/conf/infer.yaml @@ -0,0 +1,152 @@ +# ================ Logging ====================== # +root_dir: exp/song/${get_fname:} + +# ================ Checkpoints ================== # +use_pretrained: deepspeed # ['ddp', 'continue', 'deepspeed'] +pretrained: + ddp_checkpoint: + deepspeed_checkpoint: ./ckpt/60000_alnew.pt + continue_checkpoint: + +# ================ Data & loader ================== # +prompt_select: random +train_jsonl_list: +- .jsonl +val_jsonl_list: +- .jsonl +train_scp_list: +- .scp +val_scp_list: +- .scp + +lyric_processor: +max_dur: 150 +min_dur: 30 +batch_size: 2 +prompt_len: 10 +pad_to_max: true + +# ================ Training ======================= # +accelerator: gpu +devices: 8 +num_nodes: 4 +val_check_interval: 2500 +accumulate_grad_batches: 1 +strategy: 'deepspeed_stage_2' # ['ddp', 'fsdp', 'deepspeed_stage_2', 'ddp_find_unused_parameters_true'] +precision: 'bf16-mixed' # ['16-mixed', 'bf16-mixed'] + +optim: + optimizer: adamw + updates_per_epoch: 1000 + epochs: 100 + old_lr: 0 # 1e-4 + new_lr: 1e-4 + max_norm: 0.5 + adam: + betas: + - 0.9 + - 0.95 + weight_decay: 0.00001 # 0.1 + eps: 1e-8 + +schedule: + lr_scheduler: cosine + cosine: + warmup: 4000 + lr_min_ratio: 0.0 + cycle_length: 1.0 + +# ================ Audio tokenzier ================ # +audio_tokenizer_checkpoint: Flow1dVAE1rvq_./ckpt/model_1rvq/model_2_fixed.safetensors +audio_tokenizer_frame_rate: 25 +audio_tokenizer_code_depth: 1 +sample_rate: 48000 + +audio_tokenizer_checkpoint_sep: Flow1dVAESeparate_./ckpt/model_septoken/model_2.safetensors +audio_tokenizer_frame_rate_sep: 25 +audio_tokenizer_code_depth_sep: 2 +sample_rate_sep: 48000 + +# ================ VAE ================ # +vae_config: ./ckpt/vae/stable_audio_1920_vae.json +vae_model: ./ckpt/vae/autoencoder_music_1320k.ckpt + +# ================== LM =========================== # +lm: + lm_type: Llama # [Llama] + dim: 1536 + intermediate_size: 8960 + num_heads: 12 + num_layers: 28 + code_depth: 3 + code_size: 16384 + dropout: 0.0 + activation: gelu + norm_first: true + bias_ff: false + bias_attn: false + bias_proj: false + causal: true + custom: false + memory_efficient: true + attention_as_float32: false + layer_scale: null + positional_embedding: sin + xpos: false + checkpointing: torch + weight_init: gaussian + depthwise_init: current + zero_bias_init: true + norm: layer_norm + cross_attention: false + qk_layer_norm: false + qk_layer_norm_cross: false + attention_dropout: null + kv_repeat: 1 + +codebooks_pattern: + modeling: delay + delay: + delays: [ 0, 250, 250 ] + flatten_first: 0 + empty_initial: 0 + +# ================ Conditioners ===================== # +classifier_free_guidance: + # drop all conditions simultaneously + training_dropout: 0.15 + inference_coef: 1.5 + +attribute_dropout: + # drop each condition separately + args: + active_on_eval: false + text: + description: 0.0 + type_info: 0.5 + audio: + prompt_audio: 0.0 + +use_text_training: True +fuser: + sum: [] + prepend: [ description, prompt_audio, type_info ] # this order is the SAME with the input concatenation order + +conditioners: + prompt_audio: + model: qt_embedding + qt_embedding: + code_size: 16384 + code_depth: 3 + max_len: ${eval:${prompt_len}*${audio_tokenizer_frame_rate}+2} # 25*10+2+1 + description: + model: QwTokenizer + QwTokenizer: + token_path: third_party/Qwen2-7B + max_len: 300 + add_token_list: ${load_yaml:conf/vocab.yaml} + type_info: + model: QwTextTokenizer + QwTextTokenizer: + token_path: third_party/Qwen2-7B + max_len: 50 diff --git a/conf/vocab.yaml b/conf/vocab.yaml new file mode 100755 index 0000000000000000000000000000000000000000..30fc94d8ba5b2e7d3355451a8fa995651539cfac --- /dev/null +++ b/conf/vocab.yaml @@ -0,0 +1,13 @@ +- '[verse]' +- '[chorus]' +- '[bridge]' +- '[intro-short]' +- '[intro-medium]' +- '[intro-long]' +- '[outro-short]' +- '[outro-medium]' +- '[outro-long]' +- '[inst-short]' +- '[inst-medium]' +- '[inst-long]' +- '[silence]' diff --git a/generate.py b/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..830a41690499bb7ffcc3556dd15f759e2038e7bd --- /dev/null +++ b/generate.py @@ -0,0 +1,145 @@ +import sys +import os + +import time +import json +import torch +import torchaudio +import numpy as np +from omegaconf import OmegaConf + +from codeclm.trainer.codec_song_pl import CodecLM_PL +from codeclm.models import CodecLM +from third_party.demucs.models.pretrained import get_model_from_yaml + + +class Separator: + def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None: + if torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): + self.device = torch.device(f"cuda:{gpu_id}") + else: + self.device = torch.device("cpu") + self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path) + + def init_demucs_model(self, model_path, config_path): + model = get_model_from_yaml(config_path, model_path) + model.to(self.device) + model.eval() + return model + + def load_audio(self, f): + a, fs = torchaudio.load(f) + if (fs != 48000): + a = torchaudio.functional.resample(a, fs, 48000) + if a.shape[-1] >= 48000*10: + a = a[..., :48000*10] + else: + a = torch.cat([a, a], -1) + return a[:, 0:48000*10] + + def run(self, audio_path, output_dir='tmp', ext=".flac"): + os.makedirs(output_dir, exist_ok=True) + name, _ = os.path.splitext(os.path.split(audio_path)[-1]) + output_paths = [] + + for stem in self.demucs_model.sources: + output_path = os.path.join(output_dir, f"{name}_{stem}{ext}") + if os.path.exists(output_path): + output_paths.append(output_path) + if len(output_paths) == 1: # 4 + vocal_path = output_paths[0] + else: + drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device) + for path in [drums_path, bass_path, other_path]: + os.remove(path) + full_audio = self.load_audio(audio_path) + vocal_audio = self.load_audio(vocal_path) + bgm_audio = full_audio - vocal_audio + return full_audio, vocal_audio, bgm_audio + + +def main_sep(): + torch.backends.cudnn.enabled = False #taiji的某些傻呗node会报奇奇怪怪的错 + OmegaConf.register_new_resolver("eval", lambda x: eval(x)) + OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) + OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0]) + OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x))) + cfg = OmegaConf.load(sys.argv[1]) + save_dir = sys.argv[2] + input_jsonl = sys.argv[3] + sidx = sys.argv[4] + cfg.mode = 'inference' + max_duration = cfg.max_dur + + # Define model or load pretrained model + model_light = CodecLM_PL(cfg) + + model_light = model_light.eval().cuda() + model_light.audiolm.cfg = cfg + model = CodecLM(name = "tmp", + lm = model_light.audiolm, + audiotokenizer = model_light.audio_tokenizer, + max_duration = max_duration, + seperate_tokenizer = model_light.seperate_tokenizer, + ) + separator = Separator() + + cfg_coef = 1.5 #25 + temp = 1.0 + top_k = 50 + top_p = 0.0 + record_tokens = True + record_window = 50 + + model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef, + top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window) + os.makedirs(save_dir + "/token", exist_ok=True) + os.makedirs(save_dir + "/audios", exist_ok=True) + os.makedirs(save_dir + "/jsonl", exist_ok=True) + + with open(input_jsonl, "r") as fp: + lines = fp.readlines() + + new_items = [] + for line in lines: + item = json.loads(line) + target_name = f"{save_dir}/token/{item['idx']}_s{sidx}.npy" + target_wav_name = f"{save_dir}/audios/{item['idx']}_s{sidx}.flac" + descriptions = item["descriptions"] + lyric = item["gt_lyric"] + + start_time = time.time() + pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path']) + generate_inp = { + 'lyrics': [lyric.replace(" ", " ")], + 'descriptions': [descriptions], + 'melody_wavs': pmt_wav, + 'vocal_wavs': vocal_wav, + 'bgm_wavs': bgm_wav, + } + + mid_time = time.time() + with torch.autocast(device_type="cuda", dtype=torch.float16): + tokens = model.generate(**generate_inp, return_tokens=True) + end_time = time.time() + if tokens.shape[-1] > 3000: + tokens = tokens[..., :3000] + + with torch.no_grad(): + wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav) + torchaudio.save(target_wav_name, wav_seperate[0].cpu().float(), cfg.sample_rate) + np.save(target_name, tokens.cpu().squeeze(0).numpy()) + print(f"process{item['idx']}, demucs cost {mid_time - start_time}s, lm cos {end_time - mid_time}") + + item["idx"] = f"{item['idx']}_s{sidx}" + item["tk_path"] = target_name + new_items.append(item) + + src_jsonl_name = os.path.split(input_jsonl)[-1] + with open(f"{save_dir}/jsonl/{src_jsonl_name}-s{sidx}.jsonl", "w", encoding='utf-8') as fw: + for item in new_items: + fw.writelines(json.dumps(item, ensure_ascii=False)+"\n") + + +if __name__ == "__main__": + main_sep() diff --git a/generate.sh b/generate.sh new file mode 100644 index 0000000000000000000000000000000000000000..5627435d548debdbc9a6d744e80d52289ea4a239 --- /dev/null +++ b/generate.sh @@ -0,0 +1,12 @@ +export USER=root +export PYTHONDONTWRITEBYTECODE=1 +export TRANSFORMERS_CACHE="$(pwd)/third_party/hub" +export NCCL_HOME=/usr/local/tccl +export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH + +CFG_FILE=conf/infer.yaml +JSONL=$1 +SAVE_DIR=$2 +SIDX=0 +DEVICE=0 +OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=$DEVICE python3 generate.py $CFG_FILE $SAVE_DIR $JSONL $SIDX diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..364e2ee48ed028108fe2a2dc80d0090a5f3c4247 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +fastapi +uvicorn[standard] diff --git "a/sample/19_2-\345\217\210\346\230\257\344\270\200\345\244\251\350\277\207\345\216\273\357\274\214\347\203\246\346\201\274\345\246\202\345\275\261\351\232\217\345\275\24210s.wav" "b/sample/19_2-\345\217\210\346\230\257\344\270\200\345\244\251\350\277\207\345\216\273\357\274\214\347\203\246\346\201\274\345\246\202\345\275\261\351\232\217\345\275\24210s.wav" new file mode 100644 index 0000000000000000000000000000000000000000..c4646c2c8a8651a21dfb2e54bd545ef7c385c1df --- /dev/null +++ "b/sample/19_2-\345\217\210\346\230\257\344\270\200\345\244\251\350\277\207\345\216\273\357\274\214\347\203\246\346\201\274\345\246\202\345\275\261\351\232\217\345\275\24210s.wav" @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2068592b00263f7c0b0f1d82a882d7738730ace3e04f2d889d06ff983ad6d618 +size 3845542 diff --git a/sample/example.mp3 b/sample/example.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..b08001d52a39ba3260138d5c5497957bec25a986 --- /dev/null +++ b/sample/example.mp3 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a74b76ae7aa94e5dc78b73074a52a3c7d2d5d2dfc764c53273c0b1454002459 +size 1769417 diff --git a/sample/lyric.jsonl b/sample/lyric.jsonl new file mode 100644 index 0000000000000000000000000000000000000000..6076116f9cac3b88cc1830fc1e7cdbce5e0d91be --- /dev/null +++ b/sample/lyric.jsonl @@ -0,0 +1 @@ +{"idx": "01_节奏蓝调", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 夜晚的街灯闪烁.我漫步在熟悉的角落.回忆像潮水般涌来.你的笑容如此清晰.在心头无法抹去.那些曾经的甜蜜.如今只剩我独自回忆 ; [bridge] 手机屏幕亮起.是你发来的消息.简单的几个字.却让我泪流满面.曾经的拥抱温暖.如今却变得遥远.我多想回到从前.重新拥有你的陪伴 ; [chorus] 回忆的温度还在.你却已不在.我的心被爱填满.却又被思念刺痛.R&B的节奏奏响.我的心却在流浪.没有你的日子.我该如何继续向前 ; [outro-short]", "prompt_audio_path": "sample/19_2-又是一天过去,烦恼如影随形10s.wav"}