Pijush2023 commited on
Commit
13d49c4
·
verified ·
1 Parent(s): 784ad1f

Upload 27 files

Browse files
.gitattributes CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  mini-omni-main[[:space:]](1)/mini-omni-main/data/demo_gradio.mov filter=lfs diff=lfs merge=lfs -text
37
  mini-omni-main[[:space:]](1)/mini-omni-main/data/demo_streamlit.mov filter=lfs diff=lfs merge=lfs -text
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  mini-omni-main[[:space:]](1)/mini-omni-main/data/demo_gradio.mov filter=lfs diff=lfs merge=lfs -text
37
  mini-omni-main[[:space:]](1)/mini-omni-main/data/demo_streamlit.mov filter=lfs diff=lfs merge=lfs -text
38
+ mini-omni-main[[:space:]](1)/data/demo_gradio.mov filter=lfs diff=lfs merge=lfs -text
39
+ mini-omni-main[[:space:]](1)/data/demo_streamlit.mov filter=lfs diff=lfs merge=lfs -text
mini-omni-main (1)/.gitignore ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ *pyc
3
+ *pth
4
+ checkpoint/
5
+ checkpoint_bak/
6
+ output/
7
+ .DS_Store
8
+
9
+ __pycache__/
10
+ *.py[cod]
11
+ *$py.class
12
+
13
+ # C extensions
14
+ *.so
15
+
16
+ # Distribution / packaging
17
+ .Python
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ wheels/
30
+ share/python-wheels/
31
+ *.egg-info/
32
+ .installed.cfg
33
+ *.egg
34
+ MANIFEST
35
+
36
+ # PyInstaller
37
+ # Usually these files are written by a python script from a template
38
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
39
+ *.manifest
40
+ *.spec
41
+
42
+ # Installer logs
43
+ pip-log.txt
44
+ pip-delete-this-directory.txt
45
+
46
+ # Unit test / coverage reports
47
+ htmlcov/
48
+ .tox/
49
+ .nox/
50
+ .coverage
51
+ .coverage.*
52
+ .cache
53
+ nosetests.xml
54
+ coverage.xml
55
+ *.cover
56
+ *.py,cover
57
+ .hypothesis/
58
+ .pytest_cache/
59
+ cover/
60
+
61
+ # Translations
62
+ *.mo
63
+ *.pot
64
+
65
+ # Django stuff:
66
+ *.log
67
+ local_settings.py
68
+ db.sqlite3
69
+ db.sqlite3-journal
70
+
71
+ # Flask stuff:
72
+ instance/
73
+ .webassets-cache
74
+
75
+ # Scrapy stuff:
76
+ .scrapy
77
+
78
+ # Sphinx documentation
79
+ docs/_build/
80
+
81
+ # PyBuilder
82
+ .pybuilder/
83
+ target/
84
+
85
+ # Jupyter Notebook
86
+ .ipynb_checkpoints
87
+
88
+ # IPython
89
+ profile_default/
90
+ ipython_config.py
91
+
92
+ # pyenv
93
+ # For a library or package, you might want to ignore these files since the code is
94
+ # intended to run in multiple environments; otherwise, check them in:
95
+ # .python-version
96
+
97
+ # pipenv
98
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
99
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
100
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
101
+ # install all needed dependencies.
102
+ #Pipfile.lock
103
+
104
+ # poetry
105
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
106
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
107
+ # commonly ignored for libraries.
108
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
109
+ #poetry.lock
110
+
111
+ # pdm
112
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
113
+ #pdm.lock
114
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
115
+ # in version control.
116
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
117
+ .pdm.toml
118
+ .pdm-python
119
+ .pdm-build/
120
+
121
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
122
+ __pypackages__/
123
+
124
+ # Celery stuff
125
+ celerybeat-schedule
126
+ celerybeat.pid
127
+
128
+ # SageMath parsed files
129
+ *.sage.py
130
+
131
+ # Environments
132
+ .env
133
+ .venv
134
+ env/
135
+ venv/
136
+ ENV/
137
+ env.bak/
138
+ venv.bak/
139
+
140
+ # Spyder project settings
141
+ .spyderproject
142
+ .spyproject
143
+
144
+ # Rope project settings
145
+ .ropeproject
146
+
147
+ # mkdocs documentation
148
+ /site
149
+
150
+ # mypy
151
+ .mypy_cache/
152
+ .dmypy.json
153
+ dmypy.json
154
+
155
+ # Pyre type checker
156
+ .pyre/
157
+
158
+ # pytype static type analyzer
159
+ .pytype/
160
+
161
+ # Cython debug symbols
162
+ cython_debug/
163
+
164
+ # PyCharm
165
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
166
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
167
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
168
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
169
+ #.idea/
mini-omni-main (1)/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 gpt-omni
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
mini-omni-main (1)/README.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Mini-Omni
3
+
4
+ <p align="center"><strong style="font-size: 18px;">
5
+ Mini-Omni: Language Models Can Hear, Talk While Thinking in Streaming
6
+ </strong>
7
+ </p>
8
+
9
+ <p align="center">
10
+ 🤗 <a href="https://huggingface.co/gpt-omni/mini-omni">Hugging Face</a> | 📖 <a href="https://github.com/gpt-omni/mini-omni">Github</a>
11
+ | 📑 <a href="https://arxiv.org/abs/2408.16725">Technical report</a> |
12
+ 🤗 <a href="https://huggingface.co/datasets/gpt-omni/VoiceAssistant-400K">Datasets</a>
13
+ </p>
14
+
15
+ Mini-Omni is an open-source multimodal large language model that can **hear, talk while thinking**. Featuring real-time end-to-end speech input and **streaming audio output** conversational capabilities.
16
+
17
+ <p align="center">
18
+ <img src="data/figures/frameworkv3.jpg" width="100%"/>
19
+ </p>
20
+
21
+
22
+ ## Updates
23
+
24
+ - **2024.10:** We released [Mini-Omni2](https://github.com/gpt-omni/mini-omni2) with vision and audio capabilities.
25
+ - **2024.09:** Amazing online [interactive gradio demo](https://huggingface.co/spaces/gradio/omni-mini) by 🤗 gradio team.
26
+ - **2024.09:** **VoiceAssistant-400K** is uploaded to [Hugging Face](https://huggingface.co/datasets/gpt-omni/VoiceAssistant-400K).
27
+
28
+ ## Features
29
+
30
+ ✅ **Real-time speech-to-speech** conversational capabilities. No extra ASR or TTS models required.
31
+
32
+ ✅ **Talking while thinking**, with the ability to generate text and audio at the same time.
33
+
34
+ ✅ **Streaming audio output** capabilities.
35
+
36
+ ✅ With "Audio-to-Text" and "Audio-to-Audio" **batch inference** to further boost the performance.
37
+
38
+ ## Demo
39
+
40
+ NOTE: need to unmute first.
41
+
42
+ https://github.com/user-attachments/assets/03bdde05-9514-4748-b527-003bea57f118
43
+
44
+
45
+ ## Install
46
+
47
+ Create a new conda environment and install the required packages:
48
+
49
+ ```sh
50
+ conda create -n omni python=3.10
51
+ conda activate omni
52
+
53
+ git clone https://github.com/gpt-omni/mini-omni.git
54
+ cd mini-omni
55
+ pip install -r requirements.txt
56
+ ```
57
+
58
+ ## Quick start
59
+
60
+ **Interactive demo**
61
+
62
+ - start server
63
+
64
+ NOTE: you need to start the server before running the streamlit or gradio demo with API_URL set to the server address.
65
+
66
+ ```sh
67
+ sudo apt-get install ffmpeg
68
+ conda activate omni
69
+ cd mini-omni
70
+ python3 server.py --ip '0.0.0.0' --port 60808
71
+ ```
72
+
73
+
74
+ - run streamlit demo
75
+
76
+ NOTE: you need to run streamlit **locally** with PyAudio installed. For error: `ModuleNotFoundError: No module named 'utils.vad'`, please run `export PYTHONPATH=./` first.
77
+
78
+ ```sh
79
+ pip install PyAudio==0.2.14
80
+ API_URL=http://0.0.0.0:60808/chat streamlit run webui/omni_streamlit.py
81
+ ```
82
+
83
+ - run gradio demo
84
+ ```sh
85
+ API_URL=http://0.0.0.0:60808/chat python3 webui/omni_gradio.py
86
+ ```
87
+
88
+ example:
89
+
90
+ NOTE: need to unmute first. Gradio seems can not play audio stream instantly, so the latency feels a bit longer.
91
+
92
+ https://github.com/user-attachments/assets/29187680-4c42-47ff-b352-f0ea333496d9
93
+
94
+
95
+ **Local test**
96
+
97
+ ```sh
98
+ conda activate omni
99
+ cd mini-omni
100
+ # test run the preset audio samples and questions
101
+ python inference.py
102
+ ```
103
+
104
+ ## FAQ
105
+
106
+ **1. Does the model support other languages?**
107
+
108
+ No, the model is only trained on English. However, as we use whisper as the audio encoder, the model can understand other languages which is supported by whisper (like chinese), but the output is only in English.
109
+
110
+ **2. What is `post_adapter` in the code? does the open-source version support tts-adapter?**
111
+
112
+ The `post_adapter` is `tts-adapter` in the model.py, but the open-source version does not support `tts-adapter`.
113
+
114
+ **3. Error: `ModuleNotFoundError: No module named 'utils.xxxx'`**
115
+
116
+ Run `export PYTHONPATH=./` first. No need to run `pip install utils`, or just try: `pip uninstall utils`
117
+
118
+ **4. Error: can not run streamlit in local browser, with remote streamlit server**, issue: https://github.com/gpt-omni/mini-omni/issues/37
119
+
120
+ You need start streamlit **locally** with PyAudio installed.
121
+
122
+
123
+ ## Acknowledgements
124
+
125
+ - [Qwen2](https://github.com/QwenLM/Qwen2/) as the LLM backbone.
126
+ - [litGPT](https://github.com/Lightning-AI/litgpt/) for training and inference.
127
+ - [whisper](https://github.com/openai/whisper/) for audio encoding.
128
+ - [snac](https://github.com/hubertsiuzdak/snac/) for audio decoding.
129
+ - [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) for generating synthetic speech.
130
+ - [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca) and [MOSS](https://github.com/OpenMOSS/MOSS/tree/main) for alignment.
131
+
132
+ ## Star History
133
+
134
+ [![Star History Chart](https://api.star-history.com/svg?repos=gpt-omni/mini-omni&type=Date)](https://star-history.com/#gpt-omni/mini-omni&Date)
mini-omni-main (1)/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
mini-omni-main (1)/data/demo_gradio.mov ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26c411e4de546971d269aa31ab6da6c20150b6837aecde69e0c38590f01175bc
3
+ size 5502353
mini-omni-main (1)/data/demo_streamlit.mov ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b35dfdd5d271e85f965f5b2dc8fbcb6144bf51d260e27b5076dcf74463e3b99
3
+ size 9310793
mini-omni-main (1)/data/figures/frameworkv3.jpg ADDED
mini-omni-main (1)/data/samples/output1.wav ADDED
Binary file (62.2 kB). View file
 
mini-omni-main (1)/data/samples/output2.wav ADDED
Binary file (105 kB). View file
 
mini-omni-main (1)/data/samples/output3.wav ADDED
Binary file (70.4 kB). View file
 
mini-omni-main (1)/data/samples/output4.wav ADDED
Binary file (67.6 kB). View file
 
mini-omni-main (1)/data/samples/output5.wav ADDED
Binary file (115 kB). View file
 
mini-omni-main (1)/inference.py ADDED
@@ -0,0 +1,666 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import lightning as L
3
+ import torch
4
+ import time
5
+ from snac import SNAC
6
+ from litgpt import Tokenizer
7
+ from litgpt.utils import (
8
+ num_parameters,
9
+ )
10
+ from litgpt.generate.base import (
11
+ generate_AA,
12
+ generate_ASR,
13
+ generate_TA,
14
+ generate_TT,
15
+ generate_AT,
16
+ generate_TA_BATCH,
17
+ next_token_batch
18
+ )
19
+ import soundfile as sf
20
+ from litgpt.model import GPT, Config
21
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
22
+ from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
23
+ from utils.snac_utils import get_snac, generate_audio_data
24
+ import whisper
25
+ from tqdm import tqdm
26
+ from huggingface_hub import snapshot_download
27
+
28
+
29
+ torch.set_printoptions(sci_mode=False)
30
+
31
+
32
+ # TODO
33
+ text_vocabsize = 151936
34
+ text_specialtokens = 64
35
+ audio_vocabsize = 4096
36
+ audio_specialtokens = 64
37
+
38
+ padded_text_vocabsize = text_vocabsize + text_specialtokens
39
+ padded_audio_vocabsize = audio_vocabsize + audio_specialtokens
40
+
41
+ _eot = text_vocabsize
42
+ _pad_t = text_vocabsize + 1
43
+ _input_t = text_vocabsize + 2
44
+ _answer_t = text_vocabsize + 3
45
+ _asr = text_vocabsize + 4
46
+
47
+ _eoa = audio_vocabsize
48
+ _pad_a = audio_vocabsize + 1
49
+ _input_a = audio_vocabsize + 2
50
+ _answer_a = audio_vocabsize + 3
51
+ _split = audio_vocabsize + 4
52
+
53
+
54
+ def get_input_ids_TA(text, text_tokenizer):
55
+ input_ids_item = [[] for _ in range(8)]
56
+ text_tokens = text_tokenizer.encode(text)
57
+ for i in range(7):
58
+ input_ids_item[i] = [layershift(_pad_a, i)] * (len(text_tokens) + 2) + [
59
+ layershift(_answer_a, i)
60
+ ]
61
+ input_ids_item[i] = torch.tensor(input_ids_item[i]).unsqueeze(0)
62
+ input_ids_item[-1] = [_input_t] + text_tokens.tolist() + [_eot] + [_answer_t]
63
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
64
+ return input_ids_item
65
+
66
+
67
+ def get_input_ids_TT(text, text_tokenizer):
68
+ input_ids_item = [[] for i in range(8)]
69
+ text_tokens = text_tokenizer.encode(text).tolist()
70
+
71
+ for i in range(7):
72
+ input_ids_item[i] = torch.tensor(
73
+ [layershift(_pad_a, i)] * (len(text_tokens) + 3)
74
+ ).unsqueeze(0)
75
+ input_ids_item[-1] = [_input_t] + text_tokens + [_eot] + [_answer_t]
76
+ input_ids_item[-1] = torch.tensor(input_ids_item[-1]).unsqueeze(0)
77
+
78
+ return input_ids_item
79
+
80
+
81
+ def get_input_ids_whisper(
82
+ mel, leng, whispermodel, device,
83
+ special_token_a=_answer_a, special_token_t=_answer_t,
84
+ ):
85
+
86
+ with torch.no_grad():
87
+ mel = mel.unsqueeze(0).to(device)
88
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
89
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
90
+
91
+ T = audio_feature.size(0)
92
+ input_ids = []
93
+ for i in range(7):
94
+ input_ids_item = []
95
+ input_ids_item.append(layershift(_input_a, i))
96
+ input_ids_item += [layershift(_pad_a, i)] * T
97
+ input_ids_item += [(layershift(_eoa, i)), layershift(special_token_a, i)]
98
+ input_ids.append(torch.tensor(input_ids_item).unsqueeze(0))
99
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, special_token_t])
100
+ input_ids.append(input_id_T.unsqueeze(0))
101
+ return audio_feature.unsqueeze(0), input_ids
102
+
103
+
104
+ def get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device):
105
+ with torch.no_grad():
106
+ mel = mel.unsqueeze(0).to(device)
107
+ # audio_feature = whisper.decode(whispermodel,mel, options).audio_features
108
+ audio_feature = whispermodel.embed_audio(mel)[0][:leng]
109
+ T = audio_feature.size(0)
110
+ input_ids_AA = []
111
+ for i in range(7):
112
+ input_ids_item = []
113
+ input_ids_item.append(layershift(_input_a, i))
114
+ input_ids_item += [layershift(_pad_a, i)] * T
115
+ input_ids_item += [(layershift(_eoa, i)), layershift(_answer_a, i)]
116
+ input_ids_AA.append(torch.tensor(input_ids_item))
117
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
118
+ input_ids_AA.append(input_id_T)
119
+
120
+ input_ids_AT = []
121
+ for i in range(7):
122
+ input_ids_item = []
123
+ input_ids_item.append(layershift(_input_a, i))
124
+ input_ids_item += [layershift(_pad_a, i)] * T
125
+ input_ids_item += [(layershift(_eoa, i)), layershift(_pad_a, i)]
126
+ input_ids_AT.append(torch.tensor(input_ids_item))
127
+ input_id_T = torch.tensor([_input_t] + [_pad_t] * T + [_eot, _answer_t])
128
+ input_ids_AT.append(input_id_T)
129
+
130
+ input_ids = [input_ids_AA, input_ids_AT]
131
+ stacked_inputids = [[] for _ in range(8)]
132
+ for i in range(2):
133
+ for j in range(8):
134
+ stacked_inputids[j].append(input_ids[i][j])
135
+ stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
136
+ return torch.stack([audio_feature, audio_feature]), stacked_inputids
137
+
138
+
139
+ def load_audio(path):
140
+ audio = whisper.load_audio(path)
141
+ duration_ms = (len(audio) / 16000) * 1000
142
+ audio = whisper.pad_or_trim(audio)
143
+ mel = whisper.log_mel_spectrogram(audio)
144
+ return mel, int(duration_ms / 20) + 1
145
+
146
+
147
+ def A1_A2_batch(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
148
+ snacmodel, out_dir=None):
149
+ with fabric.init_tensor():
150
+ model.set_kv_cache(batch_size=2)
151
+ tokenlist = generate_TA_BATCH(
152
+ model,
153
+ audio_feature,
154
+ input_ids,
155
+ [leng, leng],
156
+ ["A1A2", "A1T2"],
157
+ max_returned_tokens=2048,
158
+ temperature=0.9,
159
+ top_k=1,
160
+ eos_id_a=_eoa,
161
+ eos_id_t=_eot,
162
+ pad_id_t=_pad_t,
163
+ shift=padded_text_vocabsize,
164
+ include_prompt=True,
165
+ generate_text=True,
166
+ )
167
+ text_tokenlist = tokenlist[-1]
168
+ if text_vocabsize in text_tokenlist:
169
+ text_tokenlist = text_tokenlist[: text_tokenlist.index(text_vocabsize)]
170
+ text = text_tokenizer.decode(torch.tensor(text_tokenlist)).strip()
171
+
172
+ audio_tokenlist = tokenlist[:-1]
173
+ audiolist = reconscruct_snac(audio_tokenlist)
174
+ audio = reconstruct_tensors(audiolist)
175
+ if out_dir is None:
176
+ out_dir = "./output/default/A1-A2-batch"
177
+ else:
178
+ out_dir = out_dir + "/A1-A2-batch"
179
+ if not os.path.exists(out_dir):
180
+ os.makedirs(out_dir)
181
+ with torch.inference_mode():
182
+ audio_hat = snacmodel.decode(audio)
183
+ sf.write(
184
+ f"{out_dir}/{step:02d}.wav",
185
+ audio_hat.squeeze().cpu().numpy(),
186
+ 24000,
187
+ )
188
+ model.clear_kv_cache()
189
+ return text
190
+
191
+
192
+ def A1_T2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
193
+ with fabric.init_tensor():
194
+ model.set_kv_cache(batch_size=1)
195
+ tokenlist = generate_AT(
196
+ model,
197
+ audio_feature,
198
+ input_ids,
199
+ [leng],
200
+ ["AT"],
201
+ max_returned_tokens=2048,
202
+ temperature=0.9,
203
+ top_k=1,
204
+ eos_id_a=_eoa,
205
+ eos_id_t=_eot,
206
+ pad_id_t=_pad_t,
207
+ shift=padded_text_vocabsize,
208
+ include_prompt=True,
209
+ generate_text=True,
210
+ )
211
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
212
+
213
+
214
+ def A1_A2(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
215
+ snacmodel, out_dir=None):
216
+ with fabric.init_tensor():
217
+ model.set_kv_cache(batch_size=1)
218
+ tokenlist = generate_AA(
219
+ model,
220
+ audio_feature,
221
+ input_ids,
222
+ [leng],
223
+ ["A1T2"],
224
+ max_returned_tokens=2048,
225
+ temperature=0.9,
226
+ top_k=1,
227
+ eos_id_a=_eoa,
228
+ eos_id_t=_eot,
229
+ pad_id_t=_pad_t,
230
+ shift=padded_text_vocabsize,
231
+ include_prompt=True,
232
+ generate_text=True,
233
+ )
234
+ audiolist = reconscruct_snac(tokenlist)
235
+ tokenlist = tokenlist[-1]
236
+ if text_vocabsize in tokenlist:
237
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
238
+ if out_dir is None:
239
+ out_dir = "./output/default/A1-A2"
240
+ else:
241
+ out_dir = out_dir + "/A1-A2"
242
+ if not os.path.exists(out_dir):
243
+ os.makedirs(out_dir)
244
+
245
+ audio = reconstruct_tensors(audiolist)
246
+ with torch.inference_mode():
247
+ audio_hat = snacmodel.decode(audio)
248
+ sf.write(
249
+ f"{out_dir}/{step:02d}.wav",
250
+ audio_hat.squeeze().cpu().numpy(),
251
+ 24000,
252
+ )
253
+ model.clear_kv_cache()
254
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
255
+
256
+
257
+ def A1_T1(fabric, audio_feature, input_ids, leng, model, text_tokenizer, step):
258
+ with fabric.init_tensor():
259
+ model.set_kv_cache(batch_size=1)
260
+ tokenlist = generate_ASR(
261
+ model,
262
+ audio_feature,
263
+ input_ids,
264
+ [leng],
265
+ ["A1T1"],
266
+ max_returned_tokens=2048,
267
+ temperature=0.9,
268
+ top_k=1,
269
+ eos_id_a=_eoa,
270
+ eos_id_t=_eot,
271
+ pad_id_t=_pad_t,
272
+ shift=padded_text_vocabsize,
273
+ include_prompt=True,
274
+ generate_text=True,
275
+ )
276
+ model.clear_kv_cache()
277
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
278
+
279
+
280
+ def T1_A2(fabric, input_ids, model, text_tokenizer, step,
281
+ snacmodel, out_dir=None):
282
+ with fabric.init_tensor():
283
+ model.set_kv_cache(batch_size=1)
284
+ tokenlist = generate_TA(
285
+ model,
286
+ None,
287
+ input_ids,
288
+ None,
289
+ ["T1A2"],
290
+ max_returned_tokens=2048,
291
+ temperature=0.9,
292
+ top_k=1,
293
+ eos_id_a=_eoa,
294
+ eos_id_t=_eot,
295
+ pad_id_t=_pad_t,
296
+ shift=padded_text_vocabsize,
297
+ include_prompt=True,
298
+ generate_text=True,
299
+ )
300
+
301
+ audiolist = reconscruct_snac(tokenlist)
302
+ tokenlist = tokenlist[-1]
303
+
304
+ if text_vocabsize in tokenlist:
305
+ tokenlist = tokenlist[: tokenlist.index(text_vocabsize)]
306
+ audio = reconstruct_tensors(audiolist)
307
+ if out_dir is None:
308
+ out_dir = "./output/default/T1-A2"
309
+ else:
310
+ out_dir = out_dir + "/T1-A2"
311
+ if not os.path.exists(out_dir):
312
+ os.makedirs(out_dir)
313
+
314
+ with torch.inference_mode():
315
+ audio_hat = snacmodel.decode(audio)
316
+ sf.write(
317
+ f"{out_dir}/{step:02d}.wav",
318
+ audio_hat.squeeze().cpu().numpy(),
319
+ 24000,
320
+ )
321
+ model.clear_kv_cache()
322
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
323
+
324
+
325
+ def T1_T2(fabric, input_ids, model, text_tokenizer, step):
326
+
327
+ with fabric.init_tensor():
328
+ model.set_kv_cache(batch_size=1)
329
+ tokenlist = generate_TT(
330
+ model,
331
+ None,
332
+ input_ids,
333
+ None,
334
+ ["T1T2"],
335
+ max_returned_tokens=2048,
336
+ temperature=0.9,
337
+ top_k=1,
338
+ eos_id_a=_eoa,
339
+ eos_id_t=_eot,
340
+ pad_id_t=_pad_t,
341
+ shift=padded_text_vocabsize,
342
+ include_prompt=True,
343
+ generate_text=True,
344
+ )
345
+ model.clear_kv_cache()
346
+ return text_tokenizer.decode(torch.tensor(tokenlist)).strip()
347
+
348
+
349
+ def load_model(ckpt_dir, device):
350
+ snacmodel = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(device)
351
+ whispermodel = whisper.load_model("small").to(device)
352
+ text_tokenizer = Tokenizer(ckpt_dir)
353
+ fabric = L.Fabric(devices=1, strategy="auto")
354
+ config = Config.from_file(ckpt_dir + "/model_config.yaml")
355
+ config.post_adapter = False
356
+
357
+ with fabric.init_module(empty_init=False):
358
+ model = GPT(config)
359
+
360
+ model = fabric.setup(model)
361
+ state_dict = lazy_load(ckpt_dir + "/lit_model.pth")
362
+ model.load_state_dict(state_dict, strict=True)
363
+ model.to(device).eval()
364
+
365
+ return fabric, model, text_tokenizer, snacmodel, whispermodel
366
+
367
+
368
+ def download_model(ckpt_dir):
369
+ repo_id = "gpt-omni/mini-omni"
370
+ snapshot_download(repo_id, local_dir=ckpt_dir, revision="main")
371
+
372
+
373
+ class OmniInference:
374
+
375
+ def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
376
+ self.device = device
377
+ if not os.path.exists(ckpt_dir):
378
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
379
+ download_model(ckpt_dir)
380
+ self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
381
+
382
+ def warm_up(self, sample='./data/samples/output1.wav'):
383
+ for _ in self.run_AT_batch_stream(sample):
384
+ pass
385
+
386
+ @torch.inference_mode()
387
+ def run_AT_batch_stream(self,
388
+ audio_path,
389
+ stream_stride=4,
390
+ max_returned_tokens=2048,
391
+ temperature=0.9,
392
+ top_k=1,
393
+ top_p=1.0,
394
+ eos_id_a=_eoa,
395
+ eos_id_t=_eot,
396
+ ):
397
+
398
+ assert os.path.exists(audio_path), f"audio file {audio_path} not found"
399
+ model = self.model
400
+
401
+ with self.fabric.init_tensor():
402
+ model.set_kv_cache(batch_size=2,device=self.device)
403
+
404
+ mel, leng = load_audio(audio_path)
405
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, self.whispermodel, self.device)
406
+ T = input_ids[0].size(1)
407
+ device = input_ids[0].device
408
+
409
+ assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
410
+
411
+ if model.max_seq_length < max_returned_tokens - 1:
412
+ raise NotImplementedError(
413
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
414
+ )
415
+
416
+ input_pos = torch.tensor([T], device=device)
417
+ list_output = [[] for i in range(8)]
418
+ tokens_A, token_T = next_token_batch(
419
+ model,
420
+ audio_feature.to(torch.float32).to(model.device),
421
+ input_ids,
422
+ [T - 3, T - 3],
423
+ ["A1T2", "A1T2"],
424
+ input_pos=torch.arange(0, T, device=device),
425
+ temperature=temperature,
426
+ top_k=top_k,
427
+ top_p=top_p,
428
+ )
429
+
430
+ for i in range(7):
431
+ list_output[i].append(tokens_A[i].tolist()[0])
432
+ list_output[7].append(token_T.tolist()[0])
433
+
434
+ model_input_ids = [[] for i in range(8)]
435
+ for i in range(7):
436
+ tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize + i * padded_audio_vocabsize
437
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
438
+ model_input_ids[i].append(torch.tensor([layershift(4097, i)], device=device))
439
+ model_input_ids[i] = torch.stack(model_input_ids[i])
440
+
441
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
442
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
443
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
444
+
445
+ text_end = False
446
+ index = 1
447
+ nums_generate = stream_stride
448
+ begin_generate = False
449
+ current_index = 0
450
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
451
+ tokens_A, token_T = next_token_batch(
452
+ model,
453
+ None,
454
+ model_input_ids,
455
+ None,
456
+ None,
457
+ input_pos=input_pos,
458
+ temperature=temperature,
459
+ top_k=top_k,
460
+ top_p=top_p,
461
+ )
462
+
463
+ if text_end:
464
+ token_T = torch.tensor([_pad_t], device=device)
465
+
466
+ if tokens_A[-1] == eos_id_a:
467
+ break
468
+
469
+ if token_T == eos_id_t:
470
+ text_end = True
471
+
472
+ for i in range(7):
473
+ list_output[i].append(tokens_A[i].tolist()[0])
474
+ list_output[7].append(token_T.tolist()[0])
475
+
476
+ model_input_ids = [[] for i in range(8)]
477
+ for i in range(7):
478
+ tokens_A[i] = tokens_A[i].clone() +padded_text_vocabsize + i * padded_audio_vocabsize
479
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
480
+ model_input_ids[i].append(
481
+ torch.tensor([layershift(4097, i)], device=device)
482
+ )
483
+ model_input_ids[i] = torch.stack(model_input_ids[i])
484
+
485
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
486
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
487
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
488
+
489
+ if index == 7:
490
+ begin_generate = True
491
+
492
+ if begin_generate:
493
+ current_index += 1
494
+ if current_index == nums_generate:
495
+ current_index = 0
496
+ snac = get_snac(list_output, index, nums_generate)
497
+ audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
498
+ yield audio_stream
499
+
500
+ input_pos = input_pos.add_(1)
501
+ index += 1
502
+ text = self.text_tokenizer.decode(torch.tensor(list_output[-1]))
503
+ print(f"text output: {text}")
504
+ model.clear_kv_cache()
505
+ return list_output
506
+
507
+
508
+ def test_infer():
509
+ device = "cuda:0"
510
+ out_dir = f"./output/{get_time_str()}"
511
+ ckpt_dir = f"./checkpoint"
512
+ if not os.path.exists(ckpt_dir):
513
+ print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
514
+ download_model(ckpt_dir)
515
+
516
+ fabric, model, text_tokenizer, snacmodel, whispermodel = load_model(ckpt_dir, device)
517
+
518
+ task = ['A1A2', 'asr', "T1A2", "AA-BATCH", 'T1T2', 'AT']
519
+
520
+ # prepare test data
521
+ # TODO
522
+ test_audio_list = sorted(os.listdir('./data/samples'))
523
+ test_audio_list = [os.path.join('./data/samples', path) for path in test_audio_list]
524
+ test_audio_transcripts = [
525
+ "What is your name?",
526
+ "what are your hobbies?",
527
+ "Do you like beijing",
528
+ "How are you feeling today?",
529
+ "what is the weather like today?",
530
+ ]
531
+ test_text_list = [
532
+ "What is your name?",
533
+ "How are you feeling today?",
534
+ "Can you describe your surroundings?",
535
+ "What did you do yesterday?",
536
+ "What is your favorite book and why?",
537
+ "How do you make a cup of tea?",
538
+ "What is the weather like today?",
539
+ "Can you explain the concept of time?",
540
+ "Can you tell me a joke?",
541
+ ]
542
+
543
+ # LOAD MODEL
544
+ with torch.no_grad():
545
+ if "A1A2" in task:
546
+ print("===============================================================")
547
+ print(" testing A1A2")
548
+ print("===============================================================")
549
+ step = 0
550
+ for path in test_audio_list:
551
+ try:
552
+ mel, leng = load_audio(path)
553
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device)
554
+ text = A1_A2(
555
+ fabric,
556
+ audio_feature,
557
+ input_ids,
558
+ leng,
559
+ model,
560
+ text_tokenizer,
561
+ step,
562
+ snacmodel,
563
+ out_dir=out_dir,
564
+ )
565
+ print(f"input: {test_audio_transcripts[step]}")
566
+ print(f"output: {text}")
567
+ step += 1
568
+ print(
569
+ "+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++"
570
+ )
571
+ except:
572
+ print(f"[error] failed to process {path}")
573
+ print("===============================================================")
574
+
575
+ if 'asr' in task:
576
+ print("===============================================================")
577
+ print(" testing asr")
578
+ print("===============================================================")
579
+
580
+ index = 0
581
+ step = 0
582
+ for path in test_audio_list:
583
+ mel, leng = load_audio(path)
584
+ audio_feature, input_ids = get_input_ids_whisper(mel, leng, whispermodel, device, special_token_a=_pad_a, special_token_t=_asr)
585
+ output = A1_T1(fabric, audio_feature, input_ids ,leng, model, text_tokenizer, index).lower().replace(',','').replace('.','').replace('?','')
586
+ print(f"audio_path: {path}")
587
+ print(f"audio transcript: {test_audio_transcripts[index]}")
588
+ print(f"asr output: {output}")
589
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
590
+ index += 1
591
+
592
+ if "T1A2" in task:
593
+ step = 0
594
+ print("\n")
595
+ print("===============================================================")
596
+ print(" testing T1A2")
597
+ print("===============================================================")
598
+ for text in test_text_list:
599
+ input_ids = get_input_ids_TA(text, text_tokenizer)
600
+ text_output = T1_A2(fabric, input_ids, model, text_tokenizer, step,
601
+ snacmodel, out_dir=out_dir)
602
+ print(f"input: {text}")
603
+ print(f"output: {text_output}")
604
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
605
+ step += 1
606
+ print("===============================================================")
607
+
608
+ if "T1T2" in task:
609
+ step = 0
610
+ print("\n")
611
+ print("===============================================================")
612
+ print(" testing T1T2")
613
+ print("===============================================================")
614
+
615
+ for text in test_text_list:
616
+ input_ids = get_input_ids_TT(text, text_tokenizer)
617
+ text_output = T1_T2(fabric, input_ids, model, text_tokenizer, step)
618
+ print(f" Input: {text}")
619
+ print(f"Output: {text_output}")
620
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
621
+ print("===============================================================")
622
+
623
+ if "AT" in task:
624
+ print("===============================================================")
625
+ print(" testing A1T2")
626
+ print("===============================================================")
627
+ step = 0
628
+ for path in test_audio_list:
629
+ mel, leng = load_audio(path)
630
+ audio_feature, input_ids = get_input_ids_whisper(
631
+ mel, leng, whispermodel, device,
632
+ special_token_a=_pad_a, special_token_t=_answer_t
633
+ )
634
+ text = A1_T2(
635
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step
636
+ )
637
+ print(f"input: {test_audio_transcripts[step]}")
638
+ print(f"output: {text}")
639
+ step += 1
640
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
641
+ print("===============================================================")
642
+
643
+ if "AA-BATCH" in task:
644
+ print("===============================================================")
645
+ print(" testing A1A2-BATCH")
646
+ print("===============================================================")
647
+ step = 0
648
+ for path in test_audio_list:
649
+ mel, leng = load_audio(path)
650
+ audio_feature, input_ids = get_input_ids_whisper_ATBatch(mel, leng, whispermodel, device)
651
+ text = A1_A2_batch(
652
+ fabric, audio_feature, input_ids, leng, model, text_tokenizer, step,
653
+ snacmodel, out_dir=out_dir
654
+ )
655
+ print(f"input: {test_audio_transcripts[step]}")
656
+ print(f"output: {text}")
657
+ step += 1
658
+ print("+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
659
+ print("===============================================================")
660
+
661
+ print("*********************** test end *****************************")
662
+
663
+
664
+
665
+ if __name__ == "__main__":
666
+ test_infer()
mini-omni-main (1)/litgpt/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import logging
4
+ import re
5
+ from litgpt.model import GPT # needs to be imported before config
6
+ from litgpt.config import Config
7
+ from litgpt.tokenizer import Tokenizer
8
+
9
+ # Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632
10
+ pattern = re.compile(".*Profiler function .* will be ignored")
11
+ logging.getLogger("torch._dynamo.variables.torch").addFilter(
12
+ lambda record: not pattern.search(record.getMessage())
13
+ )
14
+
15
+ # Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint
16
+ logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True
17
+ logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True
18
+
19
+ __all__ = ["GPT", "Config", "Tokenizer"]
mini-omni-main (1)/litgpt/config.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from copy import deepcopy
4
+ from dataclasses import dataclass, field
5
+ from pathlib import Path
6
+ from typing import Any, Literal, Optional, Type, Union
7
+
8
+ import torch
9
+ import yaml
10
+ from typing_extensions import Self
11
+
12
+ import litgpt.model
13
+ from litgpt.utils import find_multiple
14
+
15
+
16
+ @dataclass
17
+ class Config:
18
+ name: str = ""
19
+ hf_config: dict = field(default_factory=dict)
20
+ scale_embeddings: bool = False
21
+ block_size: int = 4096
22
+ vocab_size: int = 50254
23
+ padding_multiple: int = 512
24
+ padded_vocab_size: Optional[int] = None
25
+ n_layer: int = 16
26
+ n_head: int = 32
27
+ head_size: Optional[int] = None
28
+ n_embd: int = 4096
29
+ rotary_percentage: float = 0.25
30
+ parallel_residual: bool = True
31
+ bias: bool = True
32
+ lm_head_bias: bool = False
33
+ # to use multi-head attention (MHA), set this to `n_head` (default)
34
+ # to use multi-query attention (MQA), set this to 1
35
+ # to use grouped-query attention (GQA), set this to a value in between
36
+ # Example with `n_head=4`
37
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
38
+ # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │
39
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
40
+ # │ │ │ │ │ │ │
41
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐
42
+ # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │
43
+ # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘
44
+ # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐
45
+ # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐
46
+ # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │
47
+ # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘
48
+ # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶
49
+ # MHA GQA MQA
50
+ # n_query_groups=4 n_query_groups=2 n_query_groups=1
51
+ #
52
+ # credit https://arxiv.org/pdf/2305.13245.pdf
53
+ n_query_groups: Optional[int] = None
54
+ shared_attention_norm: bool = False
55
+ norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm"
56
+ norm_eps: float = 1e-5
57
+ mlp_class_name: Literal["GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE"] = (
58
+ "GptNeoxMLP"
59
+ )
60
+ gelu_approximate: str = "none"
61
+ intermediate_size: Optional[int] = None
62
+ rope_condense_ratio: int = 1
63
+ rope_base: int = 10000
64
+ n_expert: int = 0
65
+ n_expert_per_token: int = 0
66
+
67
+ add_qkv_bias: Optional[bool] = None
68
+ prompt_vocab_size: Optional[int] = None
69
+ attn_dropout: float = 0.0
70
+ pos_type: str = "rope"
71
+ force_align: bool = False
72
+ use_pretrain_phoneme_emb: bool = False
73
+ tie_word_embeddings: bool = False
74
+
75
+ # setting for mini-omni
76
+ text_vocab_size:int = 152000
77
+ cat_audio_vocab_size: int = 29120
78
+ audio_vocab_size: int = 4160
79
+ whisper_adapter_dim: int = 768
80
+
81
+ post_adapter: bool = False
82
+ post_adapter_layers: int = 6
83
+ asr_adapter: str = "llamamlp"
84
+
85
+ def __post_init__(self):
86
+ if not self.name:
87
+ self.name = self.hf_config.get("name", self.name)
88
+
89
+ if self.head_size is None:
90
+ assert self.n_embd % self.n_head == 0
91
+ self.head_size = self.n_embd // self.n_head
92
+
93
+ # vocab size should be a power of 2 to be optimal on hardware. compute the closest value
94
+ if self.padded_vocab_size is None:
95
+ self.padded_vocab_size = find_multiple(
96
+ self.vocab_size, self.padding_multiple
97
+ )
98
+ else:
99
+ # vocab size shouldn't be larger than padded vocab size
100
+ self.vocab_size = min(self.vocab_size, self.padded_vocab_size)
101
+
102
+ # compute the number of query groups
103
+ if self.n_query_groups is not None:
104
+ assert self.n_head % self.n_query_groups == 0
105
+ else:
106
+ self.n_query_groups = self.n_head
107
+
108
+ # compute the intermediate size for MLP if not set
109
+ if self.intermediate_size is None:
110
+ if self.mlp_class_name == "LLaMAMLP":
111
+ raise ValueError(
112
+ f"The config {self.name!r}, needs to set the `intermediate_size`"
113
+ )
114
+ self.intermediate_size = 4 * self.n_embd
115
+
116
+ self.rope_n_elem = int(self.rotary_percentage * self.head_size)
117
+
118
+ if self.add_qkv_bias is None:
119
+ self.add_qkv_bias = self.bias
120
+
121
+ @classmethod
122
+ def from_name(cls, name: str, **kwargs: Any) -> Optional[Self]:
123
+ if name not in name_to_config:
124
+ # search through all `config['hf_config']['name']`
125
+ try:
126
+ conf_dict = next(
127
+ config
128
+ for config in configs
129
+ if name == config["hf_config"]["name"]
130
+ or config["hf_config"]["org"] + "/" + config["hf_config"]["name"]
131
+ == name
132
+ )
133
+ except StopIteration:
134
+ raise ValueError(f"{name!r} is not a supported config name")
135
+ else:
136
+ conf_dict = name_to_config[name]
137
+
138
+ conf_dict = conf_dict.copy()
139
+ conf_dict.update(kwargs)
140
+ return cls(**conf_dict)
141
+
142
+ @classmethod
143
+ def from_file(cls, path: Union[str, Path], **kwargs: Any) -> Self:
144
+ with open(path, encoding="utf-8") as fp:
145
+ file_kwargs = yaml.safe_load(fp)
146
+ if file_kwargs is None:
147
+ raise ValueError(f"{path} is empty which is likely unexpected.")
148
+ file_kwargs.update(kwargs)
149
+ return cls(**file_kwargs)
150
+
151
+ @classmethod
152
+ def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self:
153
+ """Automatically load `model_config.yaml` and if it doesn't exist - a matching config from `litgpt/config.py`."""
154
+ if (config_path := path / "model_config.yaml").is_file():
155
+ return cls.from_file(config_path, **kwargs)
156
+ if (model_name := path.name) in name_to_config:
157
+ return cls.from_name(model_name, **kwargs)
158
+ raise FileNotFoundError(
159
+ f"For {str(path)!r} neither 'model_config.yaml' nor matching config exists."
160
+ )
161
+
162
+ @property
163
+ def mlp_class(self) -> Type:
164
+ # `self.mlp_class_name` cannot be the type to keep the config serializable
165
+ return getattr(litgpt.model, self.mlp_class_name)
166
+
167
+ @property
168
+ def norm_class(self) -> Type:
169
+ # `self.norm_class_name` cannot be the type to keep the config serializable
170
+ if self.norm_class_name == "RMSNorm":
171
+ from functools import partial
172
+
173
+ from litgpt.model import RMSNorm
174
+
175
+ return partial(RMSNorm, add_unit_offset="Gemma" in self.name)
176
+ return getattr(torch.nn, self.norm_class_name)
177
+
178
+
179
+ configs = []
180
+ name_to_config = {config["name"]: config for config in configs}
mini-omni-main (1)/litgpt/generate/__init__.py ADDED
File without changes
mini-omni-main (1)/litgpt/generate/base.py ADDED
@@ -0,0 +1,795 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ from typing import Any, Literal, Optional
4
+
5
+ import torch
6
+ # import torch._dynamo.config
7
+ # import torch._inductor.config
8
+
9
+ from litgpt.model import GPT
10
+ from utils.snac_utils import layershift, snac_config
11
+ from tqdm import tqdm
12
+
13
+
14
+ def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor:
15
+ if torch._dynamo.is_compiling():
16
+ # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly
17
+ distribution = torch.empty_like(probs).exponential_(1)
18
+ return torch.argmax(probs / distribution, dim=-1, keepdim=True)
19
+ return torch.multinomial(probs, num_samples=1)
20
+
21
+
22
+ def sample_top_p(logits: torch.Tensor, top_p: float) -> torch.Tensor:
23
+ sorted_logits, sorted_indices = torch.sort(logits, descending=False)
24
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
25
+ # Example:
26
+ # sorted_probs=[0.1, 0.15, 0.2, 0.25, 0.3] -> sorted_cumprobs=[0.1, 0.25, 0.45, 0.7, 1.0]
27
+ # sorted_indices_to_remove = [1, 1, 0, 0, 0] if top_p=0.7
28
+ sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
29
+ # Keep at least 1 token always to prevent the case where no token is selected
30
+ # In this case the most probable one is always kept
31
+ sorted_indices_to_remove[-1:] = 0
32
+ indices_to_remove = sorted_indices_to_remove.scatter(
33
+ 0, sorted_indices, sorted_indices_to_remove
34
+ )
35
+ logits = logits.masked_fill(indices_to_remove, float("-inf"))
36
+ return logits
37
+
38
+
39
+ def sample(
40
+ logits: torch.Tensor,
41
+ temperature: float = 1.0,
42
+ top_k: Optional[int] = None,
43
+ top_p: float = 1.0,
44
+ ) -> torch.Tensor:
45
+ if top_p < 0.0 or top_p > 1.0:
46
+ raise ValueError(f"top_p must be in [0, 1], got {top_p}")
47
+ logits = logits[0, -1]
48
+ # optionally crop the logits to only the top k options
49
+ if top_k is not None:
50
+ v, i = torch.topk(logits, min(top_k, logits.size(-1)))
51
+ # do not use `torch.where` as in nanogpt because it will repeat top-k collisions
52
+ logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v)
53
+ # optionally scale the logits and sample from a probability distribution
54
+ if temperature > 0.0 or top_p > 0.0:
55
+ if temperature > 0.0:
56
+ logits = logits / temperature
57
+ # optionally crop the logits to smallest set of logits with a cumulative probability above top_p
58
+ if top_p < 1.0:
59
+ logits = sample_top_p(logits, top_p)
60
+ probs = torch.nn.functional.softmax(logits, dim=-1)
61
+ return multinomial_num_samples_1(probs)
62
+ return torch.argmax(logits, dim=-1, keepdim=True)
63
+
64
+
65
+ def next_token(
66
+ model: GPT, input_pos: torch.Tensor, x: list, **kwargs: Any
67
+ ) -> torch.Tensor:
68
+ input_pos = input_pos.to(model.device)
69
+ logits_a, logit_t = model(x, input_pos)
70
+
71
+ next_audio_tokens = []
72
+ for logit_a in logits_a:
73
+ next_a = sample(logit_a, **kwargs).to(dtype=x[0].dtype)
74
+ next_audio_tokens.append(next_a)
75
+ next_t = sample(logit_t, **kwargs).to(dtype=x[0].dtype)
76
+ return next_audio_tokens, next_t
77
+
78
+
79
+ def next_token_asr(
80
+ model: GPT,
81
+ input_pos: torch.Tensor,
82
+ audio_features: torch.tensor,
83
+ lens: int,
84
+ input_ids: list,
85
+ **kwargs: Any,
86
+ ) -> torch.Tensor:
87
+ input_pos = input_pos.to(model.device)
88
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
89
+ logits_a, logit_t = model(audio_features, input_ids, input_pos, whisper_lens=lens)
90
+
91
+ next_audio_tokens = []
92
+ for logit_a in logits_a:
93
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
94
+ next_audio_tokens.append(next_a)
95
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
96
+ return next_audio_tokens, next_t
97
+
98
+
99
+ def next_token_A1T2(
100
+ model: GPT,
101
+ audio_features: torch.tensor,
102
+ input_ids: list,
103
+ whisper_lens: int,
104
+ task: list,
105
+ input_pos: torch.Tensor,
106
+ **kwargs: Any,
107
+ ) -> torch.Tensor:
108
+ input_pos = input_pos.to(model.device)
109
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
110
+ logits_a, logit_t = model(
111
+ audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
112
+ )
113
+
114
+ next_audio_tokens = []
115
+ for logit_a in logits_a:
116
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
117
+ next_audio_tokens.append(next_a)
118
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
119
+ return next_audio_tokens, next_t
120
+
121
+
122
+ def next_token_A1T1(
123
+ model: GPT,
124
+ audio_features: torch.tensor,
125
+ input_ids: list,
126
+ whisper_lens: int,
127
+ task: list,
128
+ input_pos: torch.Tensor,
129
+ **kwargs: Any,
130
+ ) -> torch.Tensor:
131
+ input_pos = input_pos.to(model.device)
132
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
133
+ logits_a, logit_t = model(
134
+ audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
135
+ )
136
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
137
+ return next_t
138
+
139
+
140
+ def next_token_batch(
141
+ model: GPT,
142
+ audio_features: torch.tensor,
143
+ input_ids: list,
144
+ whisper_lens: int,
145
+ task: list,
146
+ input_pos: torch.Tensor,
147
+ **kwargs: Any,
148
+ ) -> torch.Tensor:
149
+ input_pos = input_pos.to(model.device)
150
+ input_ids = [input_id.to(model.device) for input_id in input_ids]
151
+ logits_a, logit_t = model(
152
+ audio_features, input_ids, input_pos, whisper_lens=whisper_lens, task=task
153
+ )
154
+
155
+ for i in range(7):
156
+ logits_a[i] = logits_a[i][0].unsqueeze(0)
157
+ logit_t = logit_t[1].unsqueeze(0)
158
+
159
+ next_audio_tokens = []
160
+ for logit_a in logits_a:
161
+ next_a = sample(logit_a, **kwargs).to(dtype=input_ids[0].dtype)
162
+ next_audio_tokens.append(next_a)
163
+ next_t = sample(logit_t, **kwargs).to(dtype=input_ids[0].dtype)
164
+ return next_audio_tokens, next_t
165
+
166
+
167
+ # torch._dynamo.config.automatic_dynamic_shapes = True
168
+ # torch._inductor.config.triton.unique_kernel_names = True
169
+ # torch._inductor.config.coordinate_descent_tuning = True
170
+ # next_token = torch.compile(next_token, mode="reduce-overhead")
171
+
172
+
173
+ @torch.inference_mode()
174
+ def generate(
175
+ model: GPT,
176
+ input_ids: list,
177
+ max_returned_tokens: int,
178
+ *,
179
+ temperature: float = 1.0,
180
+ top_k: Optional[int] = None,
181
+ top_p: float = 1.0,
182
+ eos_id_a: Optional[int] = None,
183
+ eos_id_t: Optional[int] = None,
184
+ pad_id: Optional[int] = None,
185
+ shift: Optional[int] = None,
186
+ include_prompt: bool = True,
187
+ generate_text=False,
188
+ ) -> torch.Tensor:
189
+ # print("eos_id_a:", eos_id_a)
190
+ # print("eos_id_t:", eos_id_t)
191
+ # print("pad_id:", pad_id)
192
+ """
193
+ Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
194
+ The implementation of this function is modified from A. Karpathy's nanoGPT.
195
+
196
+ Args:
197
+ model: The model to use.
198
+ prompt: Tensor of shape (T) with indices of the prompt sequence.
199
+ max_returned_tokens: The maximum number of tokens to return (given plus generated).
200
+ temperature: Scales the predicted logits by 1 / temperature.
201
+ top_k: If specified, only sample among the tokens with the k highest probabilities.
202
+ top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
203
+ In top-p sampling, the next token is sampled from the highest probability tokens
204
+ whose cumulative probability exceeds the threshold `top_p`. When specified,
205
+ it must be `0 <= top_p <= 1`. Here, `top_p=0` is equivalent
206
+ to sampling the most probable token, while `top_p=1` samples from the whole distribution.
207
+ It can be used in conjunction with `top_k` and `temperature` with the following order
208
+ of application:
209
+
210
+ 1. `top_k` sampling
211
+ 2. `temperature` scaling
212
+ 3. `top_p` sampling
213
+
214
+ For more details, see https://arxiv.org/abs/1904.09751
215
+ or https://huyenchip.com/2024/01/16/sampling.html#top_p
216
+ eos_id: If specified, stop generating any more token once the <eos> token is triggered.
217
+ include_prompt: If true (default) prepends the prompt (after applying the prompt style) to the output.
218
+ """
219
+ T = input_ids[0].size(0)
220
+ device = input_ids[0].device
221
+ assert max_returned_tokens > T
222
+ if model.max_seq_length < max_returned_tokens - 1:
223
+ # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a
224
+ # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do
225
+ # not support it to avoid negatively impacting the overall speed
226
+ raise NotImplementedError(
227
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
228
+ )
229
+
230
+ for input_id in input_ids:
231
+ input_id = [input_id]
232
+ (
233
+ tokens_A1,
234
+ tokens_A2,
235
+ tokens_A3,
236
+ tokens_A4,
237
+ tokens_A5,
238
+ tokens_A6,
239
+ tokens_A7,
240
+ tokens_T,
241
+ ) = input_ids
242
+
243
+ tokens_A1_output = [tokens_A1]
244
+ tokens_A2_output = [tokens_A2]
245
+ tokens_A3_output = [tokens_A3]
246
+ tokens_A4_output = [tokens_A4]
247
+ tokens_A5_output = [tokens_A5]
248
+ tokens_A6_output = [tokens_A6]
249
+ tokens_A7_output = [tokens_A7]
250
+ tokens_T_output = [tokens_T]
251
+
252
+ list_output = [
253
+ tokens_A1_output,
254
+ tokens_A2_output,
255
+ tokens_A3_output,
256
+ tokens_A4_output,
257
+ tokens_A5_output,
258
+ tokens_A6_output,
259
+ tokens_A7_output,
260
+ tokens_T_output,
261
+ ]
262
+
263
+ input_pos = torch.tensor([T], device=device)
264
+ model_input_ids = [
265
+ tokens_A1.view(1, -1),
266
+ tokens_A2.view(1, -1),
267
+ tokens_A3.view(1, -1),
268
+ tokens_A4.view(1, -1),
269
+ tokens_A5.view(1, -1),
270
+ tokens_A6.view(1, -1),
271
+ tokens_A7.view(1, -1),
272
+ tokens_T.view(1, -1),
273
+ ]
274
+
275
+ tokens_A, token_T = next_token(
276
+ model,
277
+ torch.arange(0, T, device=device),
278
+ model_input_ids,
279
+ temperature=temperature,
280
+ top_k=top_k,
281
+ top_p=top_p,
282
+ )
283
+ for i in range(7):
284
+ list_output[i].append(tokens_A[i].clone())
285
+ list_output[7].append(token_T.clone())
286
+
287
+ # prepare the input for the next iteration
288
+ for i in range(7):
289
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
290
+ token_T = token_T.clone()
291
+
292
+ text_end = False
293
+ max_returned_tokens = 1000
294
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
295
+ model_input_ids = [
296
+ token_a.view(1, -1).to(torch.int32) for token_a in tokens_A
297
+ ] + [token_T.view(1, -1).to(torch.int32)]
298
+ tokens_A, token_T = next_token(
299
+ model,
300
+ input_pos,
301
+ model_input_ids,
302
+ temperature=temperature,
303
+ top_k=top_k,
304
+ top_p=top_p,
305
+ )
306
+ if text_end:
307
+ token_T = torch.tensor([pad_id], device=device)
308
+
309
+ for i in range(7):
310
+ list_output[i].append(tokens_A[i].clone())
311
+ list_output[7].append(token_T.clone())
312
+
313
+ if tokens_A[-1] == eos_id_a:
314
+ break
315
+ if token_T == eos_id_t:
316
+ if generate_text:
317
+ break
318
+ text_end = True
319
+
320
+ for i in range(7):
321
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
322
+ token_T = token_T.clone()
323
+ input_pos = input_pos.add_(1)
324
+
325
+ for i in range(len(list_output)):
326
+ list_output[i] = torch.cat(list_output[i])
327
+ return list_output
328
+
329
+
330
+ @torch.inference_mode()
331
+ def generate_TA_BATCH(
332
+ model: GPT,
333
+ audio_features: torch.Tensor,
334
+ input_ids: list,
335
+ leng,
336
+ task,
337
+ max_returned_tokens: int = 1000,
338
+ *,
339
+ temperature: float = 1.0,
340
+ top_k: Optional[int] = None,
341
+ top_p: float = 1.0,
342
+ eos_id_a: Optional[int] = None,
343
+ eos_id_t: Optional[int] = None,
344
+ pad_id_t: Optional[int] = None,
345
+ shift: Optional[int] = None,
346
+ include_prompt: bool = True,
347
+ generate_text=False,
348
+ ) -> torch.Tensor:
349
+
350
+ T = input_ids[0].size(1)
351
+ device = input_ids[0].device
352
+ assert max_returned_tokens > T
353
+ if model.max_seq_length < max_returned_tokens - 1:
354
+ raise NotImplementedError(
355
+ f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
356
+ )
357
+
358
+ input_pos = torch.tensor([T], device=device)
359
+ model_input_ids = input_ids
360
+
361
+ list_output = [[] for i in range(8)]
362
+
363
+ tokens_A, token_T = next_token_batch(
364
+ model,
365
+ audio_features.to(torch.float32).to(model.device),
366
+ input_ids,
367
+ [T - 3, T - 3],
368
+ ["A1T2", "A1T2"],
369
+ input_pos=torch.arange(0, T, device=device),
370
+ temperature=temperature,
371
+ top_k=top_k,
372
+ top_p=top_p,
373
+ )
374
+
375
+ for i in range(7):
376
+ list_output[i].append(tokens_A[i].tolist()[0])
377
+ list_output[7].append(token_T.tolist()[0])
378
+
379
+ model_input_ids = [[] for i in range(8)]
380
+ for i in range(7):
381
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
382
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
383
+ model_input_ids[i].append(torch.tensor([layershift(snac_config.end_of_audio, i)], device=device))
384
+ model_input_ids[i] = torch.stack(model_input_ids[i])
385
+
386
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
387
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
388
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
389
+
390
+ text_end = False
391
+
392
+ for _ in range(2, max_returned_tokens - T + 1):
393
+ tokens_A, token_T = next_token_batch(
394
+ model,
395
+ None,
396
+ model_input_ids,
397
+ None,
398
+ None,
399
+ input_pos=input_pos,
400
+ temperature=temperature,
401
+ top_k=top_k,
402
+ top_p=top_p,
403
+ )
404
+
405
+ if text_end:
406
+ token_T = torch.tensor([pad_id_t], device=device)
407
+
408
+ if tokens_A[-1] == eos_id_a:
409
+ break
410
+ if token_T == eos_id_t:
411
+ text_end = True
412
+
413
+ for i in range(7):
414
+ list_output[i].append(tokens_A[i].tolist()[0])
415
+ list_output[7].append(token_T.tolist()[0])
416
+
417
+ model_input_ids = [[] for i in range(8)]
418
+ for i in range(7):
419
+ tokens_A[i] = tokens_A[i].clone() + shift + i * snac_config.padded_vocab_size
420
+ model_input_ids[i].append(tokens_A[i].clone().to(device).to(torch.int32))
421
+ model_input_ids[i].append(
422
+ torch.tensor([layershift(snac_config.end_of_audio, i)], device=device)
423
+ )
424
+ model_input_ids[i] = torch.stack(model_input_ids[i])
425
+
426
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
427
+ model_input_ids[-1].append(token_T.clone().to(torch.int32))
428
+ model_input_ids[-1] = torch.stack(model_input_ids[-1])
429
+
430
+ input_pos = input_pos.add_(1)
431
+
432
+ return list_output
433
+
434
+
435
+ @torch.inference_mode()
436
+ def generate_TT(
437
+ model: GPT,
438
+ audio_features: torch.Tensor,
439
+ input_ids: list,
440
+ leng,
441
+ task,
442
+ max_returned_tokens: int = 2048,
443
+ *,
444
+ temperature: float = 1.0,
445
+ top_k: Optional[int] = None,
446
+ top_p: float = 1.0,
447
+ eos_id_a: Optional[int] = None,
448
+ eos_id_t: Optional[int] = None,
449
+ pad_id_t: Optional[int] = None,
450
+ shift: Optional[int] = None,
451
+ include_prompt: bool = True,
452
+ generate_text=False,
453
+ ) -> torch.Tensor:
454
+
455
+ T = input_ids[0].size(1)
456
+ device = input_ids[0].device
457
+
458
+ output = []
459
+ token_T = next_token_A1T1(
460
+ model,
461
+ None,
462
+ input_ids,
463
+ None,
464
+ None,
465
+ input_pos=torch.arange(0, T, device=device),
466
+ temperature=temperature,
467
+ top_k=top_k,
468
+ top_p=top_p,
469
+ )
470
+
471
+ output.append(token_T.clone().tolist()[0])
472
+ input_pos = torch.tensor([T], device=device)
473
+
474
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
475
+ model_input_ids = []
476
+ for i in range(7):
477
+ model_input_ids.append(
478
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
479
+ .view(1, -1)
480
+ .to(torch.int32)
481
+ .to(device)
482
+ )
483
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
484
+ token_T = next_token_A1T1(
485
+ model,
486
+ None,
487
+ model_input_ids,
488
+ None,
489
+ None,
490
+ input_pos=input_pos,
491
+ temperature=temperature,
492
+ top_k=top_k,
493
+ top_p=top_p,
494
+ )
495
+ if token_T == eos_id_t:
496
+ break
497
+ output.append(token_T.clone().tolist()[0])
498
+ input_pos = input_pos.add_(1)
499
+ return output
500
+
501
+
502
+ @torch.inference_mode()
503
+ def generate_AT(
504
+ model: GPT,
505
+ audio_features: torch.Tensor,
506
+ input_ids: list,
507
+ leng,
508
+ task,
509
+ max_returned_tokens: int = 2048,
510
+ *,
511
+ temperature: float = 1.0,
512
+ top_k: Optional[int] = None,
513
+ top_p: float = 1.0,
514
+ eos_id_a: Optional[int] = None,
515
+ eos_id_t: Optional[int] = None,
516
+ pad_id_t: Optional[int] = None,
517
+ shift: Optional[int] = None,
518
+ include_prompt: bool = True,
519
+ generate_text=False,
520
+ ) -> torch.Tensor:
521
+
522
+ T = input_ids[0].size(1)
523
+ device = input_ids[0].device
524
+
525
+ output = []
526
+ token_T = next_token_A1T1(
527
+ model,
528
+ audio_features.to(torch.float32).to(model.device),
529
+ input_ids,
530
+ [T - 3],
531
+ ["AT"],
532
+ input_pos=torch.arange(0, T, device=device),
533
+ temperature=temperature,
534
+ top_k=top_k,
535
+ top_p=top_p,
536
+ )
537
+ output.append(token_T.clone().tolist()[0])
538
+ input_pos = torch.tensor([T], device=device)
539
+ text_end = False
540
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
541
+ model_input_ids = []
542
+ for i in range(7):
543
+ model_input_ids.append(
544
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
545
+ .view(1, -1)
546
+ .to(torch.int32)
547
+ .to(device)
548
+ )
549
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
550
+ token_T = next_token_A1T1(
551
+ model,
552
+ None,
553
+ model_input_ids,
554
+ None,
555
+ None,
556
+ input_pos=input_pos,
557
+ temperature=temperature,
558
+ top_k=top_k,
559
+ top_p=top_p,
560
+ )
561
+ if token_T == eos_id_t:
562
+ break
563
+ output.append(token_T.clone().tolist()[0])
564
+ input_pos = input_pos.add_(1)
565
+ return output
566
+
567
+
568
+ @torch.inference_mode()
569
+ def generate_TA(
570
+ model: GPT,
571
+ audio_features: torch.Tensor,
572
+ input_ids: list,
573
+ leng,
574
+ task,
575
+ max_returned_tokens: int = 2048,
576
+ *,
577
+ temperature: float = 1.0,
578
+ top_k: Optional[int] = None,
579
+ top_p: float = 1.0,
580
+ eos_id_a: Optional[int] = None,
581
+ eos_id_t: Optional[int] = None,
582
+ pad_id_t: Optional[int] = None,
583
+ shift: Optional[int] = None,
584
+ include_prompt: bool = True,
585
+ generate_text=False,
586
+ ) -> torch.Tensor:
587
+
588
+ T = input_ids[0].size(1)
589
+ device = input_ids[0].device
590
+
591
+ output = [[] for _ in range(8)]
592
+ tokens_A, token_T = next_token_A1T2(
593
+ model,
594
+ None,
595
+ input_ids,
596
+ None,
597
+ None,
598
+ input_pos=torch.arange(0, T, device=device),
599
+ temperature=temperature,
600
+ top_k=top_k,
601
+ top_p=top_p,
602
+ )
603
+ for i in range(7):
604
+ output[i].append(tokens_A[i].clone().tolist()[0])
605
+ output[7].append(token_T.clone().tolist()[0])
606
+
607
+ input_pos = torch.tensor([T], device=device)
608
+ text_end = False
609
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
610
+
611
+ model_input_ids = []
612
+ for i in range(7):
613
+ model_input_ids.append(
614
+ layershift(tokens_A[i].clone(), i)
615
+ .view(1, -1)
616
+ .to(torch.int32)
617
+ .to(device)
618
+ )
619
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
620
+
621
+ tokens_A, token_T = next_token_A1T2(
622
+ model,
623
+ None,
624
+ model_input_ids,
625
+ None,
626
+ None,
627
+ input_pos=input_pos,
628
+ temperature=temperature,
629
+ top_k=top_k,
630
+ top_p=top_p,
631
+ )
632
+
633
+ if text_end:
634
+ token_T = torch.tensor([pad_id_t], device=device)
635
+
636
+ if tokens_A[-1] == eos_id_a:
637
+ break
638
+
639
+ if token_T == eos_id_t:
640
+ text_end = True
641
+
642
+ for i in range(7):
643
+ output[i].append(tokens_A[i].clone().tolist()[0])
644
+ output[7].append(token_T.clone().tolist()[0])
645
+ input_pos = input_pos.add_(1)
646
+
647
+ return output
648
+
649
+
650
+ @torch.inference_mode()
651
+ def generate_AA(
652
+ model: GPT,
653
+ audio_features: torch.Tensor,
654
+ input_ids: list,
655
+ leng,
656
+ task,
657
+ max_returned_tokens: int = 2048,
658
+ *,
659
+ temperature: float = 1.0,
660
+ top_k: Optional[int] = None,
661
+ top_p: float = 1.0,
662
+ eos_id_a: Optional[int] = None,
663
+ eos_id_t: Optional[int] = None,
664
+ pad_id_t: Optional[int] = None,
665
+ shift: Optional[int] = None,
666
+ include_prompt: bool = True,
667
+ generate_text=False,
668
+ ) -> torch.Tensor:
669
+
670
+ T = input_ids[0].size(1)
671
+ device = input_ids[0].device
672
+
673
+ output = [[] for _ in range(8)]
674
+ tokens_A, token_T = next_token_A1T2(
675
+ model,
676
+ audio_features.to(torch.float32).to(model.device),
677
+ input_ids,
678
+ [T - 3],
679
+ ["A1T2"],
680
+ input_pos=torch.arange(0, T, device=device),
681
+ temperature=temperature,
682
+ top_k=top_k,
683
+ top_p=top_p,
684
+ )
685
+ for i in range(7):
686
+ output[i].append(tokens_A[i].clone().tolist()[0])
687
+ output[7].append(token_T.clone().tolist()[0])
688
+
689
+ input_pos = torch.tensor([T], device=device)
690
+
691
+ text_end = False
692
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
693
+
694
+ model_input_ids = []
695
+ for i in range(7):
696
+ model_input_ids.append(
697
+ layershift(tokens_A[i].clone(), i)
698
+ .view(1, -1)
699
+ .to(torch.int32)
700
+ .to(device)
701
+ )
702
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
703
+
704
+ tokens_A, token_T = next_token_A1T2(
705
+ model,
706
+ None,
707
+ model_input_ids,
708
+ None,
709
+ None,
710
+ input_pos=input_pos,
711
+ temperature=temperature,
712
+ top_k=top_k,
713
+ top_p=top_p,
714
+ )
715
+
716
+ if text_end:
717
+ token_T = torch.tensor([pad_id_t], device=device)
718
+
719
+ if tokens_A[-1] == eos_id_a:
720
+ break
721
+ if token_T == eos_id_t:
722
+ # print("text_end")
723
+ text_end = True
724
+
725
+ for i in range(7):
726
+ output[i].append(tokens_A[i].clone().tolist()[0])
727
+ output[7].append(token_T.clone().tolist()[0])
728
+ input_pos = input_pos.add_(1)
729
+
730
+ return output
731
+
732
+
733
+ @torch.inference_mode()
734
+ def generate_ASR(
735
+ model: GPT,
736
+ audio_features: torch.Tensor,
737
+ input_ids: list,
738
+ leng,
739
+ task,
740
+ max_returned_tokens: int = 1200,
741
+ *,
742
+ temperature: float = 1.0,
743
+ top_k: Optional[int] = None,
744
+ top_p: float = 1.0,
745
+ eos_id_a: Optional[int] = None,
746
+ eos_id_t: Optional[int] = None,
747
+ pad_id_t: Optional[int] = None,
748
+ shift: Optional[int] = None,
749
+ include_prompt: bool = True,
750
+ generate_text=False,
751
+ ) -> torch.Tensor:
752
+
753
+ T = input_ids[0].size(1)
754
+ device = input_ids[0].device
755
+ output = []
756
+ token_T = next_token_A1T1(
757
+ model,
758
+ audio_features.to(torch.float32).to(model.device),
759
+ input_ids,
760
+ [T - 3],
761
+ ["asr"],
762
+ input_pos=torch.arange(0, T, device=device),
763
+ temperature=temperature,
764
+ top_k=top_k,
765
+ top_p=top_p,
766
+ )
767
+ output.append(token_T.clone().tolist()[0])
768
+ input_pos = torch.tensor([T], device=device)
769
+ text_end = False
770
+ for _ in tqdm(range(2, max_returned_tokens - T + 1)):
771
+ model_input_ids = []
772
+ for i in range(7):
773
+ model_input_ids.append(
774
+ torch.tensor([layershift(snac_config.end_of_audio, i)])
775
+ .view(1, -1)
776
+ .to(torch.int32)
777
+ .to(device)
778
+ )
779
+ model_input_ids.append(token_T.clone().view(1, -1).to(torch.int32).to(device))
780
+ token_T = next_token_A1T1(
781
+ model,
782
+ None,
783
+ model_input_ids,
784
+ None,
785
+ None,
786
+ input_pos=input_pos,
787
+ temperature=temperature,
788
+ top_k=top_k,
789
+ top_p=top_p,
790
+ )
791
+ if token_T == eos_id_t:
792
+ break
793
+ output.append(token_T.clone().tolist()[0])
794
+ input_pos = input_pos.add_(1)
795
+ return output
mini-omni-main (1)/litgpt/model.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Full definition of a decoder-only transformer-based language model, all of it in this single file.
4
+
5
+ Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and
6
+ https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model.
7
+ """
8
+
9
+ import math
10
+ from typing import Any, Optional, Tuple
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from typing_extensions import Self
15
+ from litgpt.config import Config
16
+
17
+
18
+ class GPT(nn.Module):
19
+ def __init__(self, config: Config) -> None:
20
+ super().__init__()
21
+ assert config.padded_vocab_size is not None
22
+ self.config = config
23
+ if self.config.asr_adapter == "mlp":
24
+ print("Using MLP adapter for ASR feature")
25
+ self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd)
26
+ elif self.config.asr_adapter == "llamamlp":
27
+ print("using LLAMA MLP adapter for ASR feature")
28
+ self.whisper_adapter = whisperMLP(config=config)
29
+ else:
30
+ raise ValueError("asr_adapter should be mlp or llamamlp")
31
+ self.lm_head = nn.Linear(
32
+ config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
33
+ )
34
+ if config.post_adapter:
35
+ self.transformer = nn.ModuleDict(
36
+ dict(
37
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
38
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
39
+ post_adapter=nn.ModuleList(
40
+ Block(config) for _ in range(config.post_adapter_layers)
41
+ ),
42
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
43
+ post_adapter_audio_ln=config.norm_class(
44
+ config.n_embd, eps=config.norm_eps
45
+ ),
46
+ post_adapter_audio_lm_head=nn.Linear(
47
+ config.n_embd, config.cat_audio_vocab_size, bias=config.lm_head_bias
48
+ ),
49
+ )
50
+ )
51
+ else:
52
+ self.transformer = nn.ModuleDict(
53
+ dict(
54
+ wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
55
+ h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
56
+ ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
57
+ )
58
+ )
59
+ self.max_seq_length = self.config.block_size
60
+ self.mask_cache: Optional[torch.Tensor] = None
61
+ if config.tie_word_embeddings:
62
+ self.lm_head.weight = self.transformer.wte.weight
63
+
64
+ @property
65
+ def max_seq_length(self) -> int:
66
+ return self._max_seq_length
67
+
68
+ @max_seq_length.setter
69
+ def max_seq_length(self, value: int) -> None:
70
+ """
71
+ When doing inference, the sequences used might be shorter than the model's context length.
72
+ This allows setting a smaller number to avoid allocating unused memory
73
+ """
74
+ if value > self.config.block_size:
75
+ raise ValueError(
76
+ f"Cannot attend to {value}, block size is only {self.config.block_size}"
77
+ )
78
+ self._max_seq_length = value
79
+ if not hasattr(self, "cos"):
80
+ # first call
81
+ cos, sin = self.rope_cache()
82
+ self.register_buffer("cos", cos, persistent=False)
83
+ self.register_buffer("sin", sin, persistent=False)
84
+ # override
85
+ elif value != self.cos.size(0):
86
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
87
+ # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know
88
+ # if the kv cache is expected
89
+
90
+ def reset_parameters(self) -> None:
91
+ # Trigger resetting the rope-cache
92
+ self.cos, self.sin = self.rope_cache(device=self.cos.device)
93
+
94
+ def _init_weights(self, module: nn.Module) -> None:
95
+ """Meant to be used with `gpt.apply(gpt._init_weights)`."""
96
+ if isinstance(module, nn.Linear):
97
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
98
+ if module.bias is not None:
99
+ torch.nn.init.zeros_(module.bias)
100
+ elif isinstance(module, nn.Embedding):
101
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
102
+
103
+ def concat_whisper_feat(self, audio_feature, input_ids, T, task):
104
+ for j in range(len(T)):
105
+ if task[j] != "T1T2" and task[j] != "T1A2":
106
+ for i in range(7):
107
+ input_ids[i][j, 1 : T[j] + 1, :] = audio_feature[j][: T[j]].clone()
108
+ else:
109
+ continue
110
+ return input_ids
111
+
112
+ def forward(
113
+ self,
114
+ audio_features: torch.Tensor,
115
+ input_ids: torch.Tensor,
116
+ input_pos: Optional[torch.Tensor] = None,
117
+ whisper_lens: Optional[list] = None,
118
+ task: Optional[str] = None,
119
+ ) -> torch.Tensor:
120
+
121
+ show = False
122
+ T = input_ids[0].size(1)
123
+ if self.max_seq_length < T:
124
+ raise ValueError(
125
+ f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
126
+ )
127
+
128
+ if input_pos is not None: # use the kv cache
129
+ cos = self.cos.index_select(0, input_pos)
130
+ sin = self.sin.index_select(0, input_pos)
131
+ if self.mask_cache is None:
132
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
133
+ mask = self.mask_cache.index_select(2, input_pos)
134
+ else:
135
+ cos = self.cos[:T]
136
+ sin = self.sin[:T]
137
+ mask = None
138
+
139
+ if audio_features is not None:
140
+ # get whisper feature
141
+ x_a = self.whisper_adapter(audio_features)
142
+ # get input_ids embedding
143
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
144
+
145
+ x0 = self.transformer.wte(x0)
146
+ x1 = self.transformer.wte(x1)
147
+ x2 = self.transformer.wte(x2)
148
+ x3 = self.transformer.wte(x3)
149
+ x4 = self.transformer.wte(x4)
150
+ x5 = self.transformer.wte(x5)
151
+ x6 = self.transformer.wte(x6)
152
+ x7 = self.transformer.wte(x7)
153
+
154
+ # concat whisper feature
155
+ input_emb = self.concat_whisper_feat(
156
+ x_a, [x0, x1, x2, x3, x4, x5, x6, x7], whisper_lens, task
157
+ )
158
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_emb
159
+
160
+ else:
161
+ x0, x1, x2, x3, x4, x5, x6, x7 = input_ids
162
+
163
+ x0 = self.transformer.wte(x0)
164
+ x1 = self.transformer.wte(x1)
165
+ x2 = self.transformer.wte(x2)
166
+ x3 = self.transformer.wte(x3)
167
+ x4 = self.transformer.wte(x4)
168
+ x5 = self.transformer.wte(x5)
169
+ x6 = self.transformer.wte(x6)
170
+ x7 = self.transformer.wte(x7)
171
+
172
+ x = (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8
173
+
174
+ if self.config.scale_embeddings:
175
+ x = x * (self.config.n_embd**0.5)
176
+
177
+ for block in self.transformer.h:
178
+ x = block(x, cos, sin, mask, input_pos)
179
+
180
+
181
+ text_vocab_size = self.config.text_vocab_size
182
+ audio_vocab_size = self.config.audio_vocab_size
183
+
184
+ x_ori = x
185
+ x_ori = self.transformer.ln_f(x_ori)
186
+ x_ori = self.lm_head(x_ori) # (b, t, vocab_size)
187
+ xt = x_ori[..., :text_vocab_size]
188
+
189
+ if self.config.post_adapter:
190
+ for block in self.transformer.post_adapter:
191
+ x = block(x, cos, sin, mask, input_pos)
192
+ x = self.transformer.post_adapter_audio_ln(x)
193
+ x = self.transformer.post_adapter_audio_lm_head(x) # (b, t, vocab_size)
194
+ xa = []
195
+ for i in range(7):
196
+ xa.append(x[..., audio_vocab_size * i : audio_vocab_size * (i + 1)])
197
+ else:
198
+ xa = []
199
+ for i in range(7):
200
+ xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)])
201
+
202
+ return xa, xt
203
+
204
+ @classmethod
205
+ def from_name(cls, name: str, **kwargs: Any) -> Self:
206
+ return cls(Config.from_name(name, **kwargs))
207
+
208
+ def rope_cache(
209
+ self, device: Optional[torch.device] = None
210
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
211
+ return build_rope_cache(
212
+ seq_len=self.max_seq_length,
213
+ n_elem=self.config.rope_n_elem,
214
+ device=device,
215
+ condense_ratio=self.config.rope_condense_ratio,
216
+ base=self.config.rope_base,
217
+ )
218
+
219
+ def set_kv_cache(
220
+ self,
221
+ batch_size: int,
222
+ rope_cache_length: Optional[int] = None,
223
+ device: Optional[torch.device] = None,
224
+ dtype: Optional[torch.dtype] = None,
225
+ ) -> None:
226
+ if rope_cache_length is None:
227
+ rope_cache_length = self.cos.size(-1)
228
+ max_seq_length = self.max_seq_length
229
+
230
+ # initialize the kv cache for all blocks
231
+ for block in self.transformer.h:
232
+ block.attn.kv_cache = block.attn.build_kv_cache(
233
+ batch_size, max_seq_length, rope_cache_length, device, dtype
234
+ )
235
+ if self.config.post_adapter:
236
+ for block in self.transformer.post_adapter:
237
+ block.attn.kv_cache = block.attn.build_kv_cache(
238
+ batch_size, max_seq_length, rope_cache_length, device, dtype
239
+ )
240
+
241
+ if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
242
+ # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask
243
+ # for the kv-cache support (only during inference), we only create it in that situation
244
+ self.mask_cache = build_mask_cache(max_seq_length, device)
245
+
246
+ def clear_kv_cache(self) -> None:
247
+ self.mask_cache = None
248
+ for block in self.transformer.h:
249
+ block.attn.kv_cache = None
250
+
251
+
252
+ class Block(nn.Module):
253
+
254
+ def __init__(self, config: Config) -> None:
255
+ super().__init__()
256
+ if not config.parallel_residual and config.shared_attention_norm:
257
+ raise NotImplementedError(
258
+ "No checkpoint amongst the ones we support uses this configuration"
259
+ " (non-parallel residual and shared attention norm)."
260
+ )
261
+
262
+ self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps)
263
+ self.attn = CausalSelfAttention(config)
264
+ self.norm_2 = (
265
+ None
266
+ if config.shared_attention_norm
267
+ else config.norm_class(config.n_embd, eps=config.norm_eps)
268
+ )
269
+ self.mlp = config.mlp_class(config)
270
+
271
+ self.config = config
272
+
273
+ def forward(
274
+ self,
275
+ x: torch.Tensor,
276
+ cos: torch.Tensor,
277
+ sin: torch.Tensor,
278
+ mask: Optional[torch.Tensor] = None,
279
+ input_pos: Optional[torch.Tensor] = None,
280
+ ) -> torch.Tensor:
281
+ """
282
+ Non-parallel residual Parallel residual
283
+ ┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True,
284
+ │ ↓ │ ↓ ↓ the output from `norm_1` is reused
285
+ │ norm_1 │ norm_1 ───► norm_2
286
+ │ ↓ │ ↓ ↓
287
+ │ attn │ attn mlp
288
+ │ ↓ │ ↓ │
289
+ ┌─ └► + └► + ◄───────────┘
290
+ │ norm_2
291
+ │ ↓
292
+ │ mlp
293
+ │ ↓
294
+ └───► +
295
+ """
296
+
297
+ x_normed = self.norm_1(x)
298
+ attention_output = self.attn(x_normed, cos, sin, mask, input_pos)
299
+
300
+ if self.config.parallel_residual:
301
+ x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x)
302
+ x = self.mlp(x_normed) + attention_output + x
303
+ else:
304
+ x = attention_output + x
305
+ x = self.mlp(self.norm_2(x)) + x
306
+ return x
307
+
308
+
309
+ class CausalSelfAttention(nn.Module):
310
+ def __init__(self, config: Config) -> None:
311
+ super().__init__()
312
+ shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
313
+ # key, query, value projections for all heads, but in a batch
314
+ self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias)
315
+ # output projection
316
+ # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head`
317
+ self.proj = nn.Linear(
318
+ config.head_size * config.n_head, config.n_embd, bias=config.bias
319
+ )
320
+ # disabled by default
321
+ self.kv_cache: Optional[KVCache] = None
322
+
323
+ self.config = config
324
+
325
+ def forward(
326
+ self,
327
+ x: torch.Tensor,
328
+ cos: torch.Tensor,
329
+ sin: torch.Tensor,
330
+ mask: Optional[torch.Tensor] = None,
331
+ input_pos: Optional[torch.Tensor] = None,
332
+ ) -> torch.Tensor:
333
+ B, T, C = (
334
+ x.size()
335
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
336
+
337
+ qkv = self.attn(x)
338
+
339
+ # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
340
+ q_per_kv = self.config.n_head // self.config.n_query_groups
341
+ total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
342
+ qkv = qkv.view(
343
+ B, T, self.config.n_query_groups, total_qkv, self.config.head_size
344
+ )
345
+ qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
346
+
347
+ # split batched computation into three
348
+ q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
349
+
350
+ # maybe repeat k and v if for the non multi-head attention cases
351
+ # training: flash attention requires it
352
+ # inference: multi-query would require a full kv cache so avoid it to limit its memory usage
353
+ if self.config.n_query_groups != self.config.n_head and (
354
+ input_pos is None or self.config.n_query_groups != 1
355
+ ):
356
+ k = k.expand(
357
+ B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
358
+ )
359
+ v = v.expand(
360
+ B, self.config.n_query_groups, q_per_kv, T, self.config.head_size
361
+ )
362
+
363
+ q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
364
+ k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
365
+ v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
366
+
367
+ q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
368
+ k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
369
+ q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
370
+ k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
371
+
372
+ if input_pos is not None:
373
+ if not isinstance(self.kv_cache, KVCache):
374
+ raise TypeError("You need to call `gpt.set_kv_cache()`")
375
+ k, v = self.kv_cache(input_pos, k, v)
376
+
377
+ y = self.scaled_dot_product_attention(q, k, v, mask)
378
+
379
+ y = y.reshape(
380
+ B, T, self.config.head_size * self.config.n_head
381
+ ) # re-assemble all head outputs side by side
382
+
383
+ # output projection
384
+ return self.proj(y)
385
+
386
+ def scaled_dot_product_attention(
387
+ self,
388
+ q: torch.Tensor,
389
+ k: torch.Tensor,
390
+ v: torch.Tensor,
391
+ mask: Optional[torch.Tensor] = None,
392
+ ) -> torch.Tensor:
393
+ scale = 1.0 / math.sqrt(self.config.head_size)
394
+ y = torch.nn.functional.scaled_dot_product_attention(
395
+ q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
396
+ )
397
+ return y.transpose(1, 2)
398
+
399
+ def build_kv_cache(
400
+ self,
401
+ batch_size: int,
402
+ max_seq_length: int,
403
+ rope_cache_length: Optional[int] = None,
404
+ device: Optional[torch.device] = None,
405
+ dtype: Optional[torch.dtype] = None,
406
+ ) -> "KVCache":
407
+ heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
408
+ v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
409
+ if rope_cache_length is None:
410
+ if self.config.rotary_percentage != 1.0:
411
+ raise TypeError(
412
+ "Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
413
+ )
414
+ k_shape = v_shape
415
+ else:
416
+ k_shape = (
417
+ batch_size,
418
+ heads,
419
+ max_seq_length,
420
+ rope_cache_length + self.config.head_size - self.config.rope_n_elem,
421
+ )
422
+ return KVCache(k_shape, v_shape, device=device, dtype=dtype)
423
+
424
+
425
+ class GptNeoxMLP(nn.Module):
426
+ def __init__(self, config: Config) -> None:
427
+ super().__init__()
428
+ self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
429
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
430
+
431
+ self.config = config
432
+
433
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
434
+ x = self.fc(x)
435
+ x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate)
436
+ return self.proj(x)
437
+
438
+
439
+ class LLaMAMLP(nn.Module):
440
+ def __init__(self, config: Config) -> None:
441
+ super().__init__()
442
+ self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
443
+ self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
444
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
445
+
446
+ self.config = config
447
+
448
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
449
+ x_fc_1 = self.fc_1(x)
450
+ x_fc_2 = self.fc_2(x)
451
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
452
+ return self.proj(x)
453
+
454
+
455
+ class whisperMLP(nn.Module):
456
+ def __init__(self, config: Config) -> None:
457
+ super().__init__()
458
+ self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
459
+ self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias)
460
+ self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
461
+
462
+ self.config = config
463
+
464
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
465
+ x_fc_1 = self.fc_1(x)
466
+ x_fc_2 = self.fc_2(x)
467
+ x = torch.nn.functional.silu(x_fc_1) * x_fc_2
468
+ return self.proj(x)
469
+
470
+
471
+ class GemmaMLP(LLaMAMLP):
472
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
473
+ x_fc_1 = self.fc_1(x)
474
+ x_fc_2 = self.fc_2(x)
475
+ x = (
476
+ torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate)
477
+ * x_fc_2
478
+ )
479
+ return self.proj(x)
480
+
481
+
482
+ class LLaMAMoE(nn.Module):
483
+ def __init__(self, config: Config) -> None:
484
+ super().__init__()
485
+ self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False)
486
+ self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert))
487
+
488
+ self.config = config
489
+
490
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
491
+ """
492
+ Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
493
+ See also figure 1 in https://arxiv.org/abs/2211.15841
494
+ """
495
+ B, T, C = (
496
+ x.size()
497
+ ) # batch size, sequence length, embedding dimensionality (n_embd)
498
+ x = x.view(-1, C) # (B*T, C)
499
+ router = self.gate(x) # (B*T, n_expert)
500
+ probs, indices = torch.topk(
501
+ router, self.config.n_expert_per_token
502
+ ) # (B*T, n_expert_per_token)
503
+ probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype)
504
+ masks = indices.unsqueeze(-1) == torch.arange(
505
+ self.config.n_expert, device=x.device
506
+ )
507
+ masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token)
508
+ y = torch.zeros_like(x) # (B*T, C)
509
+ for mask, expert in zip(masks, self.experts):
510
+ token_idx, expert_idx = torch.where(mask)
511
+ y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
512
+ return y.view(B, T, C)
513
+
514
+
515
+ def build_rope_cache(
516
+ seq_len: int,
517
+ n_elem: int,
518
+ device: Optional[torch.device] = None,
519
+ base: int = 10000,
520
+ condense_ratio: int = 1,
521
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
522
+ """Enhanced Transformer with Rotary Position Embedding.
523
+
524
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
525
+ transformers/rope/__init__.py. MIT License:
526
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
527
+ """
528
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
529
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
530
+
531
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
532
+ seq_idx = torch.arange(seq_len, device=device) / condense_ratio
533
+
534
+ # Calculate the product of position index and $\theta_i$
535
+ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
536
+
537
+ return torch.cos(idx_theta), torch.sin(idx_theta)
538
+
539
+
540
+ def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
541
+ head_size = x.size(-1)
542
+ x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
543
+ x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
544
+ rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
545
+ roped = (x * cos) + (rotated * sin)
546
+ return roped.to(dtype=x.dtype)
547
+
548
+
549
+ class KVCache(nn.Module):
550
+ def __init__(
551
+ self,
552
+ k_shape: Tuple[int, int, int, int],
553
+ v_shape: Tuple[int, int, int, int],
554
+ device: Optional[torch.device] = None,
555
+ dtype: Optional[torch.dtype] = None,
556
+ ) -> None:
557
+ super().__init__()
558
+ self.register_buffer(
559
+ "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
560
+ )
561
+ self.register_buffer(
562
+ "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
563
+ )
564
+
565
+ def forward(
566
+ self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
567
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
568
+ # move the buffer to the activation dtype for when AMP is used
569
+ self.k = self.k.to(k.dtype)
570
+ self.v = self.v.to(v.dtype)
571
+ # update the cache
572
+ k = self.k.index_copy_(2, input_pos, k)
573
+ v = self.v.index_copy_(2, input_pos, v)
574
+ return k, v
575
+
576
+ def reset_parameters(self) -> None:
577
+ torch.nn.init.zeros_(self.k)
578
+ torch.nn.init.zeros_(self.v)
579
+
580
+
581
+ def build_mask_cache(
582
+ max_seq_length: int, device: Optional[torch.device] = None
583
+ ) -> torch.Tensor:
584
+ ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
585
+ return torch.tril(ones).unsqueeze(0).unsqueeze(0)
586
+
587
+
588
+ class RMSNorm(torch.nn.Module):
589
+ """Root Mean Square Layer Normalization.
590
+
591
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
592
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
593
+ """
594
+
595
+ def __init__(
596
+ self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False
597
+ ) -> None:
598
+ super().__init__()
599
+ self.weight = torch.nn.Parameter(torch.ones(size))
600
+ self.eps = eps
601
+ self.dim = dim
602
+ self.add_unit_offset = add_unit_offset
603
+
604
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
605
+ dtype = x.dtype
606
+ x = x.float()
607
+ # NOTE: the original RMSNorm paper implementation is not equivalent
608
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
609
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
610
+ x_normed = x_normed.to(dtype=dtype)
611
+ if self.add_unit_offset:
612
+ # Gemma model requires a unit offset
613
+ # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176
614
+ return x_normed * (1 + self.weight)
615
+ return x_normed * self.weight
616
+
617
+ def reset_parameters(self) -> None:
618
+ torch.nn.init.ones_(self.weight)
mini-omni-main (1)/litgpt/tokenizer.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ import json
4
+ from pathlib import Path
5
+ from typing import Optional, Union
6
+
7
+ import torch
8
+
9
+
10
+ class Tokenizer:
11
+ def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
12
+ checkpoint_dir = Path(checkpoint_dir)
13
+ if not checkpoint_dir.exists():
14
+ raise NotADirectoryError(
15
+ f"The checkpoint directory does not exist: {str(checkpoint_dir)}"
16
+ )
17
+
18
+ self.use_bos = self.check_if_bos_token_used(checkpoint_dir)
19
+ self.bos_id = None
20
+ self.eos_id = None
21
+
22
+ # some checkpoints have both files, `.json` takes precedence
23
+ if (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file():
24
+ from tokenizers import Tokenizer as HFTokenizer
25
+
26
+ self.processor = HFTokenizer.from_file(str(vocabulary_path))
27
+ self.backend = "huggingface"
28
+
29
+ if (
30
+ special_tokens_path := checkpoint_dir / "tokenizer_config.json"
31
+ ).is_file():
32
+ with open(special_tokens_path, encoding="utf-8") as fp:
33
+ config = json.load(fp)
34
+ bos_token = config.get("bos_token")
35
+ eos_token = config.get("eos_token")
36
+ if bos_token is not None and isinstance(bos_token, dict):
37
+ bos_token = bos_token.get("content")
38
+ if eos_token is not None and isinstance(eos_token, dict):
39
+ eos_token = eos_token.get("content")
40
+ self.bos_id = (
41
+ self.token_to_id(bos_token) if bos_token is not None else None
42
+ )
43
+ self.eos_id = (
44
+ self.token_to_id(eos_token) if eos_token is not None else None
45
+ )
46
+ if (
47
+ special_tokens_path := checkpoint_dir / "generation_config.json"
48
+ ).is_file():
49
+ with open(special_tokens_path, encoding="utf-8") as fp:
50
+ config = json.load(fp)
51
+ if self.bos_id is None:
52
+ self.bos_id = config.get("bos_token_id")
53
+ if self.eos_id is None:
54
+ self.eos_id = config.get("eos_token_id")
55
+
56
+ elif (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
57
+ from sentencepiece import SentencePieceProcessor
58
+
59
+ self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
60
+ self.backend = "sentencepiece"
61
+ self.bos_id = self.processor.bos_id()
62
+ self.eos_id = self.processor.eos_id()
63
+ else:
64
+ raise NotImplementedError
65
+
66
+ @property
67
+ def vocab_size(self) -> int:
68
+ if self.backend == "huggingface":
69
+ return self.processor.get_vocab_size(with_added_tokens=False)
70
+ if self.backend == "sentencepiece":
71
+ return self.processor.vocab_size()
72
+ raise RuntimeError
73
+
74
+ def token_to_id(self, token: str) -> int:
75
+ if self.backend == "huggingface":
76
+ id_ = self.processor.token_to_id(token)
77
+ elif self.backend == "sentencepiece":
78
+ id_ = self.processor.piece_to_id(token)
79
+ else:
80
+ raise RuntimeError
81
+ if id_ is None:
82
+ raise ValueError(f"token {token!r} not found in the collection.")
83
+ return id_
84
+
85
+ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
86
+ if not (
87
+ tokenizer_config_path := checkpoint_dir / "tokenizer_config.json"
88
+ ).is_file():
89
+ return False
90
+ with open(tokenizer_config_path, encoding="utf-8") as fp:
91
+ config = json.load(fp)
92
+ if "add_bos_token" in config:
93
+ return config["add_bos_token"]
94
+ # if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.
95
+ # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2
96
+ return config.get("tokenizer_class") == "LlamaTokenizer"
97
+
98
+ def encode(
99
+ self,
100
+ string: str,
101
+ device: Optional[torch.device] = None,
102
+ bos: Optional[bool] = None,
103
+ eos: bool = False,
104
+ max_length: int = -1,
105
+ ) -> torch.Tensor:
106
+ if self.backend == "huggingface":
107
+ tokens = self.processor.encode(string).ids
108
+ elif self.backend == "sentencepiece":
109
+ tokens = self.processor.encode(string)
110
+ else:
111
+ raise RuntimeError
112
+ if bos or (bos is None and self.use_bos):
113
+ bos_id = self.bos_id
114
+ if bos_id is None:
115
+ raise NotImplementedError(
116
+ "This tokenizer does not have a defined a bos token"
117
+ )
118
+ if tokens[0] != bos_id:
119
+ tokens = [bos_id] + tokens
120
+ if tokens is None:
121
+ raise ValueError("`tokens` is None")
122
+
123
+ if eos and (not tokens or tokens[-1] != self.eos_id):
124
+ tokens = tokens + [self.eos_id]
125
+ if max_length > 0:
126
+ tokens = tokens[:max_length]
127
+ return torch.tensor(tokens, dtype=torch.int, device=device)
128
+
129
+ def decode(self, tensor: torch.Tensor) -> str:
130
+ tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
131
+ return self.processor.decode(tokens)
mini-omni-main (1)/litgpt/utils.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2
+
3
+ """Utility functions for training and inference."""
4
+ import inspect
5
+ import math
6
+ import os
7
+ import pickle
8
+ import shutil
9
+ import sys
10
+ from dataclasses import asdict, is_dataclass
11
+ from io import BytesIO
12
+ from pathlib import Path
13
+ from typing import (
14
+ TYPE_CHECKING,
15
+ Any,
16
+ Dict,
17
+ Iterable,
18
+ List,
19
+ Literal,
20
+ Mapping,
21
+ Optional,
22
+ TypeVar,
23
+ Union,
24
+ )
25
+
26
+ import lightning as L
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.utils._device
30
+ import yaml
31
+ from lightning.fabric.loggers import CSVLogger, TensorBoardLogger
32
+ from lightning.fabric.strategies import FSDPStrategy
33
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
34
+ from lightning.pytorch.loggers import WandbLogger
35
+ from lightning.pytorch.cli import instantiate_class
36
+ from torch.serialization import normalize_storage_type
37
+ from typing_extensions import Self
38
+
39
+ if TYPE_CHECKING:
40
+ from litgpt import GPT, Config
41
+
42
+
43
+ def init_out_dir(out_dir: Path) -> Path:
44
+ if not out_dir.is_absolute() and "LIGHTNING_ARTIFACTS_DIR" in os.environ:
45
+ return Path(os.getenv("LIGHTNING_ARTIFACTS_DIR")) / out_dir
46
+ return out_dir
47
+
48
+
49
+ def find_resume_path(
50
+ resume: Union[bool, Literal["auto"], Path], out_dir: Path
51
+ ) -> Optional[Path]:
52
+ if not resume or isinstance(resume, Path):
53
+ return resume
54
+
55
+ resume_path = max(
56
+ out_dir.rglob("step-*/*.pth"),
57
+ key=(lambda p: int(p.parent.name.split("-")[1])),
58
+ default=None,
59
+ )
60
+ if resume == "auto":
61
+ return resume_path
62
+ if resume is True and resume_path is None:
63
+ raise FileNotFoundError(
64
+ f"You passed `--resume=True`, but no checkpont file was found in `--out_dir={out_dir}`."
65
+ )
66
+ return resume_path
67
+
68
+
69
+ def find_multiple(n: int, k: int) -> int:
70
+ assert k > 0
71
+ if n % k == 0:
72
+ return n
73
+ return n + k - (n % k)
74
+
75
+
76
+ def num_parameters(module: nn.Module, requires_grad: Optional[bool] = None) -> int:
77
+ total = 0
78
+ for p in module.parameters():
79
+ if requires_grad is None or p.requires_grad == requires_grad:
80
+ if hasattr(p, "quant_state"):
81
+ # bitsandbytes 4bit layer support
82
+ total += math.prod(p.quant_state.shape)
83
+ else:
84
+ total += p.numel()
85
+ return total
86
+
87
+
88
+ def reset_parameters(module: nn.Module) -> None:
89
+ """Calls `reset_parameters` on the module and all its submodules."""
90
+ for mod in module.modules():
91
+ if callable(getattr(mod, "reset_parameters", None)):
92
+ mod.reset_parameters()
93
+
94
+
95
+ def check_valid_checkpoint_dir(
96
+ checkpoint_dir: Path,
97
+ model_filename: str = "lit_model.pth",
98
+ verbose: bool = True,
99
+ raise_error: bool = False,
100
+ ) -> None:
101
+ files = {
102
+ model_filename: (checkpoint_dir / model_filename).is_file(),
103
+ "model_config.yaml": (checkpoint_dir / "model_config.yaml").is_file(),
104
+ "tokenizer.json OR tokenizer.model": (
105
+ checkpoint_dir / "tokenizer.json"
106
+ ).is_file()
107
+ or (checkpoint_dir / "tokenizer.model").is_file(),
108
+ "tokenizer_config.json": (checkpoint_dir / "tokenizer_config.json").is_file(),
109
+ }
110
+ if checkpoint_dir.is_dir():
111
+ if all(files.values()):
112
+ # we're good
113
+ return
114
+ problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}"
115
+ else:
116
+ problem = " is not a checkpoint directory"
117
+
118
+ # list locally available checkpoints
119
+ available = list(Path("checkpoints").glob("*/*"))
120
+ if available:
121
+ options = "\n".join([""] + [repr(str(p.resolve())) for p in available])
122
+ extra = f"\nYou have downloaded locally:{options}\n"
123
+ else:
124
+ extra = ""
125
+
126
+ if verbose:
127
+ error_message = (
128
+ f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
129
+ "\nFind download instructions at https://github.com/Lightning-AI/litgpt/blob/main/tutorials\n"
130
+ f"{extra}\nSee all download options by running:\n litgpt download"
131
+ )
132
+ print(error_message, file=sys.stderr)
133
+
134
+ if raise_error:
135
+ raise FileNotFoundError(
136
+ f"checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}."
137
+ )
138
+ else:
139
+ raise SystemExit(1)
140
+
141
+
142
+ class SavingProxyForStorage:
143
+ def __init__(self, obj, saver, protocol_version=5):
144
+ self.protocol_version = protocol_version
145
+ self.saver = saver
146
+ if not (isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj)):
147
+ raise TypeError(f"expected storage, not {type(obj)}")
148
+
149
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
150
+ if isinstance(obj, torch.storage.TypedStorage):
151
+ # PT upstream wants to deprecate this eventually...
152
+ storage = obj._untyped_storage
153
+ storage_type_str = obj._pickle_storage_type()
154
+ storage_type = getattr(torch, storage_type_str)
155
+ storage_numel = obj._size()
156
+ else:
157
+ storage = obj
158
+ storage_type = normalize_storage_type(type(obj))
159
+ storage_numel = storage.nbytes()
160
+
161
+ storage_key = saver._write_storage_and_return_key(storage)
162
+ location = torch.serialization.location_tag(storage)
163
+
164
+ self.storage_info = (
165
+ "storage",
166
+ storage_type,
167
+ storage_key,
168
+ location,
169
+ storage_numel,
170
+ )
171
+
172
+ def __reduce_ex__(self, protocol_version):
173
+ assert False, "this should be handled with out of band"
174
+
175
+
176
+ class SavingProxyForTensor:
177
+ def __init__(self, tensor, saver, protocol_version=5):
178
+ self.protocol_version = protocol_version
179
+ self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__(protocol_version)
180
+ if reduce_args[0] == torch._utils._rebuild_tensor_v2:
181
+ # for Tensors with Python attributes
182
+ (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args
183
+ assert isinstance(
184
+ storage, torch.storage.TypedStorage
185
+ ), "Please check for updates"
186
+ storage_proxy = SavingProxyForStorage(
187
+ storage, saver, protocol_version=protocol_version
188
+ )
189
+ self.reduce_args = (a0, a1, (storage_proxy, *a2_other), *other_reduce_args)
190
+ else:
191
+ (storage, *other_reduce_args) = reduce_args
192
+ assert isinstance(
193
+ storage, torch.storage.TypedStorage
194
+ ), "Please check for updates"
195
+ storage_proxy = SavingProxyForStorage(
196
+ storage, saver, protocol_version=protocol_version
197
+ )
198
+ self.reduce_args = (storage_proxy, *other_reduce_args)
199
+
200
+ def __reduce_ex__(self, protocol_version):
201
+ if protocol_version != self.protocol_version:
202
+ raise RuntimeError(
203
+ f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}"
204
+ )
205
+ return self.reduce_ret_fn, self.reduce_args
206
+
207
+
208
+ class IncrementalPyTorchPickler(pickle.Pickler):
209
+ def __init__(self, saver, *args, **kwargs):
210
+ super().__init__(*args, **kwargs)
211
+ self.storage_dtypes = {}
212
+ self.saver = saver
213
+ self.id_map = {}
214
+
215
+ # this logic is taken from PyTorch 2.0+ torch/serialization.py
216
+ def persistent_id(self, obj):
217
+ # FIXME: the docs say that persistent_id should only return a string
218
+ # but torch store returns tuples. This works only in the binary protocol
219
+ # see
220
+ # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects
221
+ # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537
222
+ if isinstance(obj, SavingProxyForStorage):
223
+ return obj.storage_info
224
+
225
+ if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage(obj):
226
+ if isinstance(obj, torch.storage.TypedStorage):
227
+ # TODO: Once we decide to break serialization FC, this case
228
+ # can be deleted
229
+ storage = obj._untyped_storage
230
+ storage_dtype = obj.dtype
231
+ storage_type_str = obj._pickle_storage_type()
232
+ storage_type = getattr(torch, storage_type_str)
233
+ storage_numel = obj._size()
234
+
235
+ else:
236
+ storage = obj
237
+ storage_dtype = torch.uint8
238
+ storage_type = normalize_storage_type(type(obj))
239
+ storage_numel = storage.nbytes()
240
+
241
+ # If storage is allocated, ensure that any other saved storages
242
+ # pointing to the same data all have the same dtype. If storage is
243
+ # not allocated, don't perform this check
244
+ if storage.data_ptr() != 0:
245
+ if storage.data_ptr() in self.storage_dtypes:
246
+ if storage_dtype != self.storage_dtypes[storage.data_ptr()]:
247
+ raise RuntimeError(
248
+ "Cannot save multiple tensors or storages that view the same data as different types"
249
+ )
250
+ else:
251
+ self.storage_dtypes[storage.data_ptr()] = storage_dtype
252
+
253
+ storage_key = self.id_map.get(storage._cdata)
254
+ if storage_key is None:
255
+ storage_key = self.saver._write_storage_and_return_key(storage)
256
+ self.id_map[storage._cdata] = storage_key
257
+ location = torch.serialization.location_tag(storage)
258
+
259
+ return ("storage", storage_type, storage_key, location, storage_numel)
260
+
261
+ return None
262
+
263
+
264
+ class incremental_save:
265
+ def __init__(self, name):
266
+ self.name = name
267
+ self.zipfile = torch._C.PyTorchFileWriter(str(name))
268
+ self.has_saved = False
269
+ self.next_key = 0
270
+
271
+ def __enter__(self):
272
+ return self
273
+
274
+ def store_early(self, tensor):
275
+ if isinstance(tensor, torch.Tensor):
276
+ return SavingProxyForTensor(tensor, self)
277
+ raise TypeError(f"can only store tensors early, not {type(tensor)}")
278
+
279
+ def save(self, obj):
280
+ if self.has_saved:
281
+ raise RuntimeError("have already saved")
282
+ # Write the pickle data for `obj`
283
+ data_buf = BytesIO()
284
+ pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5)
285
+ pickler.dump(obj)
286
+ data_value = data_buf.getvalue()
287
+ self.zipfile.write_record("data.pkl", data_value, len(data_value))
288
+ self.has_saved = True
289
+
290
+ def _write_storage_and_return_key(self, storage):
291
+ if self.has_saved:
292
+ raise RuntimeError("have already saved")
293
+ key = self.next_key
294
+ self.next_key += 1
295
+ name = f"data/{key}"
296
+ if storage.device.type != "cpu":
297
+ storage = storage.cpu()
298
+ num_bytes = storage.nbytes()
299
+ self.zipfile.write_record(name, storage.data_ptr(), num_bytes)
300
+ return key
301
+
302
+ def __exit__(self, type, value, traceback):
303
+ self.zipfile.write_end_of_file()
304
+
305
+
306
+ T = TypeVar("T")
307
+
308
+
309
+ def chunked_cross_entropy(
310
+ logits: Union[torch.Tensor, List[torch.Tensor]],
311
+ targets: torch.Tensor,
312
+ chunk_size: int = 128,
313
+ ignore_index: int = -100,
314
+ ) -> torch.Tensor:
315
+ # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate
316
+ # the memory usage in fine-tuning settings with low number of parameters.
317
+ # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing
318
+ # the memory spike's magnitude
319
+
320
+ # lm_head was chunked (we are fine-tuning)
321
+ if isinstance(logits, list):
322
+ # don't want to chunk cross entropy
323
+ if chunk_size == 0:
324
+ logits = torch.cat(logits, dim=1)
325
+ logits = logits.reshape(-1, logits.size(-1))
326
+ targets = targets.reshape(-1)
327
+ return torch.nn.functional.cross_entropy(
328
+ logits, targets, ignore_index=ignore_index
329
+ )
330
+
331
+ # chunk cross entropy
332
+ logit_chunks = [
333
+ logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
334
+ ]
335
+ target_chunks = [
336
+ target_chunk.reshape(-1)
337
+ for target_chunk in targets.split(logits[0].size(1), dim=1)
338
+ ]
339
+ loss_chunks = [
340
+ torch.nn.functional.cross_entropy(
341
+ logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
342
+ )
343
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
344
+ ]
345
+ non_masked_elems = (targets != ignore_index).sum()
346
+ # See [non_masked_elems div note]
347
+ return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
348
+ torch.ones_like(non_masked_elems)
349
+ )
350
+
351
+ # no chunking at all
352
+ logits = logits.reshape(-1, logits.size(-1))
353
+ targets = targets.reshape(-1)
354
+ if chunk_size == 0:
355
+ return torch.nn.functional.cross_entropy(
356
+ logits, targets, ignore_index=ignore_index
357
+ )
358
+
359
+ # lm_head wasn't chunked, chunk cross entropy
360
+ logit_chunks = logits.split(chunk_size)
361
+ target_chunks = targets.split(chunk_size)
362
+ loss_chunks = [
363
+ torch.nn.functional.cross_entropy(
364
+ logit_chunk, target_chunk, ignore_index=ignore_index, reduction="none"
365
+ )
366
+ for logit_chunk, target_chunk in zip(logit_chunks, target_chunks)
367
+ ]
368
+ non_masked_elems = (targets != ignore_index).sum()
369
+ # [non_masked_elems div note]:
370
+ # max(1, non_masked_elems) would be more ergonomic to avoid a division by zero. However that
371
+ # results in a python int which is then passed back to torch division. By using the
372
+ # `x.maximum(torch.ones_like(x))` pattern we avoid a cudaStreamSynchronize.
373
+ return torch.cat(loss_chunks).sum() / non_masked_elems.maximum(
374
+ torch.ones_like(non_masked_elems)
375
+ )
376
+
377
+
378
+ def map_old_state_dict_weights(state_dict: Dict, mapping: Mapping, prefix: str) -> Dict:
379
+ for checkpoint_name, attribute_name in mapping.items():
380
+ full_checkpoint_name = prefix + checkpoint_name
381
+ if full_checkpoint_name in state_dict:
382
+ full_attribute_name = prefix + attribute_name
383
+ state_dict[full_attribute_name] = state_dict.pop(full_checkpoint_name)
384
+ return state_dict
385
+
386
+
387
+ def get_default_supported_precision(training: bool) -> str:
388
+ """Return default precision that is supported by the hardware: either `bf16` or `16`.
389
+
390
+ Args:
391
+ training: `-mixed` or `-true` version of the precision to use
392
+
393
+ Returns:
394
+ default precision that is suitable for the task and is supported by the hardware
395
+ """
396
+ from lightning.fabric.accelerators import MPSAccelerator
397
+
398
+ if MPSAccelerator.is_available() or (
399
+ torch.cuda.is_available() and not torch.cuda.is_bf16_supported()
400
+ ):
401
+ return "16-mixed" if training else "16-true"
402
+ return "bf16-mixed" if training else "bf16-true"
403
+
404
+
405
+ def load_checkpoint(
406
+ fabric: L.Fabric, model: nn.Module, checkpoint_path: Path, strict: bool = True
407
+ ) -> None:
408
+ if isinstance(fabric.strategy, FSDPStrategy):
409
+ fabric.load_raw(checkpoint_path, model, strict=strict)
410
+ else:
411
+ state_dict = lazy_load(checkpoint_path)
412
+ state_dict = state_dict.get("model", state_dict)
413
+ model.load_state_dict(state_dict, strict=strict)
414
+
415
+
416
+ def flops_per_param(
417
+ max_seq_length: int, n_layer: int, n_embd: int, n_params: int
418
+ ) -> int:
419
+ flops_per_token = (
420
+ 2 * n_params
421
+ ) # each parameter is used for a MAC (2 FLOPS) per network operation
422
+ # this assumes that all samples have a fixed length equal to the block size
423
+ # which is most likely false during finetuning
424
+ flops_per_seq = flops_per_token * max_seq_length
425
+ attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2))
426
+ return flops_per_seq + attn_flops_per_seq
427
+
428
+
429
+ def estimate_flops(model: "GPT", training: bool) -> int:
430
+ """Measures estimated FLOPs for MFU.
431
+
432
+ Refs:
433
+ * https://ar5iv.labs.arxiv.org/html/2205.05198#A1
434
+ * https://ar5iv.labs.arxiv.org/html/2204.02311#A2
435
+ """
436
+ # using all parameters for this is a naive over estimation because not all model parameters actually contribute to
437
+ # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage
438
+ # (~10%) compared to the measured FLOPs, making those lower but more realistic.
439
+ # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper.
440
+ n_trainable_params = num_parameters(model, requires_grad=True)
441
+ trainable_flops = flops_per_param(
442
+ model.max_seq_length,
443
+ model.config.n_layer,
444
+ model.config.n_embd,
445
+ n_trainable_params,
446
+ )
447
+ # forward + backward + gradients (assumes no gradient accumulation)
448
+ ops_per_step = 3 if training else 1
449
+ n_frozen_params = num_parameters(model, requires_grad=False)
450
+ frozen_flops = flops_per_param(
451
+ model.max_seq_length, model.config.n_layer, model.config.n_embd, n_frozen_params
452
+ )
453
+ # forward + backward
454
+ frozen_ops_per_step = 2 if training else 1
455
+ return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops
456
+
457
+
458
+ class CycleIterator:
459
+ """An iterator that cycles through an iterable indefinitely.
460
+
461
+ Example:
462
+ >>> iterator = CycleIterator([1, 2, 3])
463
+ >>> [next(iterator) for _ in range(5)]
464
+ [1, 2, 3, 1, 2]
465
+
466
+ Note:
467
+ Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable.
468
+ """
469
+
470
+ def __init__(self, iterable: Iterable) -> None:
471
+ self.iterable = iterable
472
+ self.epoch = 0
473
+ self._iterator = None
474
+
475
+ def __next__(self) -> Any:
476
+ if self._iterator is None:
477
+ self._iterator = iter(self.iterable)
478
+ try:
479
+ return next(self._iterator)
480
+ except StopIteration:
481
+ self._iterator = iter(self.iterable)
482
+ self.epoch += 1
483
+ return next(self._iterator)
484
+
485
+ def __iter__(self) -> Self:
486
+ return self
487
+
488
+
489
+ def copy_config_files(source_dir: Path, out_dir: Path) -> None:
490
+ """Copies the specified configuration and tokenizer files into the output directory."""
491
+
492
+ config_files = ["config.json", "generation_config.json", "model_config.yaml"]
493
+ tokenizer_files = ["tokenizer.json", "tokenizer.model", "tokenizer_config.json"]
494
+
495
+ for file_name in config_files + tokenizer_files:
496
+ src_path = source_dir / file_name
497
+ if src_path.exists():
498
+ shutil.copy(src_path, out_dir)
499
+
500
+
501
+ def CLI(*args: Any, **kwargs: Any) -> Any:
502
+ from jsonargparse import CLI, set_config_read_mode, set_docstring_parse_options
503
+
504
+ set_docstring_parse_options(attribute_docstrings=True)
505
+ set_config_read_mode(urls_enabled=True)
506
+
507
+ return CLI(*args, **kwargs)
508
+
509
+
510
+ def capture_hparams() -> Dict[str, Any]:
511
+ """Captures the local variables ('hyperparameters') from where this function gets called."""
512
+ caller_frame = inspect.currentframe().f_back
513
+ locals_of_caller = caller_frame.f_locals
514
+ hparams = {}
515
+ for name, value in locals_of_caller.items():
516
+ if value is None or isinstance(value, (int, float, str, bool, Path)):
517
+ hparams[name] = value
518
+ elif is_dataclass(value):
519
+ hparams[name] = asdict(value)
520
+ else:
521
+ hparams[name] = str(value)
522
+ return hparams
523
+
524
+
525
+ def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None:
526
+ """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint."""
527
+ from jsonargparse import capture_parser
528
+
529
+ # TODO: Make this more robust
530
+ # This hack strips away the subcommands from the top-level CLI
531
+ # to parse the file as if it was called as a script
532
+ known_commands = [
533
+ ("finetune_full",), # For subcommands, use `("finetune", "full")` etc
534
+ ("finetune_lora",),
535
+ ("finetune_adapter",),
536
+ ("finetune_adapter_v2",),
537
+ ("finetune",),
538
+ ("pretrain",),
539
+ ]
540
+ for known_command in known_commands:
541
+ unwanted = slice(1, 1 + len(known_command))
542
+ if tuple(sys.argv[unwanted]) == known_command:
543
+ sys.argv[unwanted] = []
544
+
545
+ parser = capture_parser(lambda: CLI(function))
546
+ config = parser.parse_args()
547
+ parser.save(config, checkpoint_dir / "hyperparameters.yaml", overwrite=True)
548
+
549
+
550
+ def save_config(config: "Config", checkpoint_dir: Path) -> None:
551
+ config_dict = asdict(config)
552
+ with open(checkpoint_dir / "model_config.yaml", "w", encoding="utf-8") as fp:
553
+ yaml.dump(config_dict, fp)
554
+
555
+
556
+ def parse_devices(devices: Union[str, int]) -> int:
557
+ if devices in (-1, "auto"):
558
+ return torch.cuda.device_count() or 1
559
+ if isinstance(devices, int) and devices > 0:
560
+ return devices
561
+ raise ValueError(f"Devices must be 'auto' or a positive integer, got: {devices!r}")
562
+
563
+
564
+ def choose_logger(
565
+ logger_name: Literal["csv", "tensorboard", "wandb"],
566
+ out_dir: Path,
567
+ name: str,
568
+ log_interval: int = 1,
569
+ resume: Optional[bool] = None,
570
+ **kwargs: Any,
571
+ ):
572
+ if logger_name == "csv":
573
+ return CSVLogger(
574
+ root_dir=(out_dir / "logs"),
575
+ name="csv",
576
+ flush_logs_every_n_steps=log_interval,
577
+ **kwargs,
578
+ )
579
+ if logger_name == "tensorboard":
580
+ return TensorBoardLogger(
581
+ root_dir=(out_dir / "logs"), name="tensorboard", **kwargs
582
+ )
583
+ if logger_name == "wandb":
584
+ return WandbLogger(project=name, resume=resume, **kwargs)
585
+ raise ValueError(
586
+ f"`--logger_name={logger_name}` is not a valid option. Choose from 'csv', 'tensorboard', 'wandb'."
587
+ )
588
+
589
+
590
+ def get_argument_names(cls):
591
+ sig = inspect.signature(cls.__init__)
592
+ return {
593
+ name
594
+ for name, param in sig.parameters.items()
595
+ if param.kind
596
+ in [inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY]
597
+ }
598
+
599
+
600
+ def instantiate_bnb_optimizer(optimizer, model_parameters):
601
+ if (isinstance(optimizer, str) and "AdamW" not in optimizer) or (
602
+ isinstance(optimizer, dict) and "AdamW" not in optimizer.get("class_path", "")
603
+ ):
604
+ raise ValueError(
605
+ "The chosen quantization format only supports the AdamW optimizer."
606
+ )
607
+
608
+ import bitsandbytes as bnb
609
+
610
+ if isinstance(optimizer, str):
611
+ optimizer = bnb.optim.PagedAdamW(model_parameters)
612
+ else:
613
+ optim_args = get_argument_names(bnb.optim.PagedAdamW)
614
+ allowed_kwargs = {
615
+ key: optimizer["init_args"][key]
616
+ for key in optim_args & optimizer["init_args"].keys()
617
+ }
618
+ optimizer = bnb.optim.PagedAdamW(model_parameters, **allowed_kwargs)
619
+ return optimizer
620
+
621
+
622
+ def instantiate_torch_optimizer(optimizer, model_parameters, **kwargs):
623
+ if isinstance(optimizer, str):
624
+ optimizer_cls = getattr(torch.optim, optimizer)
625
+ optimizer = optimizer_cls(model_parameters, **kwargs)
626
+ else:
627
+ optimizer = dict(optimizer) # copy
628
+ optimizer["init_args"].update(kwargs)
629
+ optimizer = instantiate_class(model_parameters, optimizer)
630
+ return optimizer
631
+
632
+
633
+ def extend_checkpoint_dir(checkpoint_dir: Path) -> Path:
634
+ new_checkpoint_dir = "checkpoints" / checkpoint_dir
635
+ should_return_new_dir = (
636
+ not checkpoint_dir.is_dir()
637
+ and checkpoint_dir.parts[0] != "checkpoints"
638
+ and not checkpoint_dir.is_absolute()
639
+ and new_checkpoint_dir.exists()
640
+ )
641
+ return new_checkpoint_dir if should_return_new_dir else checkpoint_dir
mini-omni-main (1)/requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.1
2
+ torchvision==0.18.1
3
+ torchaudio==2.3.1
4
+ litgpt==0.4.3
5
+ snac==1.2.0
6
+ soundfile==0.12.1
7
+ openai-whisper
8
+ tokenizers==0.19.1
9
+ streamlit==1.37.1
10
+ # PyAudio==0.2.14
11
+ pydub==0.25.1
12
+ onnxruntime==1.19.0
13
+ # numpy==1.26.3
14
+ gradio==4.42.0
15
+ librosa==0.10.2.post1
16
+ flask==3.0.3
17
+ fire
mini-omni-main (1)/server.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flask
2
+ import base64
3
+ import tempfile
4
+ import traceback
5
+ from flask import Flask, Response, stream_with_context
6
+ from inference import OmniInference
7
+
8
+
9
+ class OmniChatServer(object):
10
+ def __init__(self, ip='0.0.0.0', port=60808, run_app=True,
11
+ ckpt_dir='./checkpoint', device='cuda:0') -> None:
12
+ server = Flask(__name__)
13
+ # CORS(server, resources=r"/*")
14
+ # server.config["JSON_AS_ASCII"] = False
15
+
16
+ self.client = OmniInference(ckpt_dir, device)
17
+ self.client.warm_up()
18
+
19
+ server.route("/chat", methods=["POST"])(self.chat)
20
+
21
+ if run_app:
22
+ server.run(host=ip, port=port, threaded=False)
23
+ else:
24
+ self.server = server
25
+
26
+ def chat(self) -> Response:
27
+
28
+ req_data = flask.request.get_json()
29
+ try:
30
+ data_buf = req_data["audio"].encode("utf-8")
31
+ data_buf = base64.b64decode(data_buf)
32
+ stream_stride = req_data.get("stream_stride", 4)
33
+ max_tokens = req_data.get("max_tokens", 2048)
34
+
35
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
36
+ f.write(data_buf)
37
+ audio_generator = self.client.run_AT_batch_stream(f.name, stream_stride, max_tokens)
38
+ return Response(stream_with_context(audio_generator), mimetype="audio/wav")
39
+ except Exception as e:
40
+ print(traceback.format_exc())
41
+
42
+
43
+ # CUDA_VISIBLE_DEVICES=1 gunicorn -w 2 -b 0.0.0.0:60808 'server:create_app()'
44
+ def create_app():
45
+ server = OmniChatServer(run_app=False)
46
+ return server.server
47
+
48
+
49
+ def serve(ip='0.0.0.0', port=60808, device='cuda:0'):
50
+
51
+ OmniChatServer(ip, port=port,run_app=True, device=device)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ import fire
56
+ fire.Fire(serve)
57
+
mini-omni-main (1)/utils/assets/silero_vad.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:591f853590d11ddde2f2a54f9e7ccecb2533a8af7716330e8adfa6f3849787a9
3
+ size 1807524
mini-omni-main (1)/utils/snac_utils.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import numpy as np
4
+
5
+
6
+ class SnacConfig:
7
+ audio_vocab_size = 4096
8
+ padded_vocab_size = 4160
9
+ end_of_audio = 4097
10
+
11
+
12
+ snac_config = SnacConfig()
13
+
14
+
15
+ def get_time_str():
16
+ time_str = time.strftime("%Y%m%d_%H%M%S", time.localtime())
17
+ return time_str
18
+
19
+
20
+ def layershift(input_id, layer, stride=4160, shift=152000):
21
+ return input_id + shift + layer * stride
22
+
23
+
24
+ def generate_audio_data(snac_tokens, snacmodel, device=None):
25
+ audio = reconstruct_tensors(snac_tokens, device)
26
+ with torch.inference_mode():
27
+ audio_hat = snacmodel.decode(audio)
28
+ audio_data = audio_hat.cpu().numpy().astype(np.float64) * 32768.0
29
+ audio_data = audio_data.astype(np.int16)
30
+ audio_data = audio_data.tobytes()
31
+ return audio_data
32
+
33
+
34
+ def get_snac(list_output, index, nums_generate):
35
+
36
+ snac = []
37
+ start = index
38
+ for i in range(nums_generate):
39
+ snac.append("#")
40
+ for j in range(7):
41
+ snac.append(list_output[j][start - nums_generate - 5 + j + i])
42
+ return snac
43
+
44
+
45
+ def reconscruct_snac(output_list):
46
+ if len(output_list) == 8:
47
+ output_list = output_list[:-1]
48
+ output = []
49
+ for i in range(7):
50
+ output_list[i] = output_list[i][i + 1 :]
51
+ for i in range(len(output_list[-1])):
52
+ output.append("#")
53
+ for j in range(7):
54
+ output.append(output_list[j][i])
55
+ return output
56
+
57
+
58
+ def reconstruct_tensors(flattened_output, device=None):
59
+ """Reconstructs the list of tensors from the flattened output."""
60
+
61
+ if device is None:
62
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
63
+
64
+ def count_elements_between_hashes(lst):
65
+ try:
66
+ # Find the index of the first '#'
67
+ first_index = lst.index("#")
68
+ # Find the index of the second '#' after the first
69
+ second_index = lst.index("#", first_index + 1)
70
+ # Count the elements between the two indices
71
+ return second_index - first_index - 1
72
+ except ValueError:
73
+ # Handle the case where there aren't enough '#' symbols
74
+ return "List does not contain two '#' symbols"
75
+
76
+ def remove_elements_before_hash(flattened_list):
77
+ try:
78
+ # Find the index of the first '#'
79
+ first_hash_index = flattened_list.index("#")
80
+ # Return the list starting from the first '#'
81
+ return flattened_list[first_hash_index:]
82
+ except ValueError:
83
+ # Handle the case where there is no '#'
84
+ return "List does not contain the symbol '#'"
85
+
86
+ def list_to_torch_tensor(tensor1):
87
+ # Convert the list to a torch tensor
88
+ tensor = torch.tensor(tensor1)
89
+ # Reshape the tensor to have size (1, n)
90
+ tensor = tensor.unsqueeze(0)
91
+ return tensor
92
+
93
+ flattened_output = remove_elements_before_hash(flattened_output)
94
+ codes = []
95
+ tensor1 = []
96
+ tensor2 = []
97
+ tensor3 = []
98
+ tensor4 = []
99
+
100
+ n_tensors = count_elements_between_hashes(flattened_output)
101
+ if n_tensors == 7:
102
+ for i in range(0, len(flattened_output), 8):
103
+
104
+ tensor1.append(flattened_output[i + 1])
105
+ tensor2.append(flattened_output[i + 2])
106
+ tensor3.append(flattened_output[i + 3])
107
+ tensor3.append(flattened_output[i + 4])
108
+
109
+ tensor2.append(flattened_output[i + 5])
110
+ tensor3.append(flattened_output[i + 6])
111
+ tensor3.append(flattened_output[i + 7])
112
+ codes = [
113
+ list_to_torch_tensor(tensor1).to(device),
114
+ list_to_torch_tensor(tensor2).to(device),
115
+ list_to_torch_tensor(tensor3).to(device),
116
+ ]
117
+
118
+ if n_tensors == 15:
119
+ for i in range(0, len(flattened_output), 16):
120
+
121
+ tensor1.append(flattened_output[i + 1])
122
+ tensor2.append(flattened_output[i + 2])
123
+ tensor3.append(flattened_output[i + 3])
124
+ tensor4.append(flattened_output[i + 4])
125
+ tensor4.append(flattened_output[i + 5])
126
+ tensor3.append(flattened_output[i + 6])
127
+ tensor4.append(flattened_output[i + 7])
128
+ tensor4.append(flattened_output[i + 8])
129
+
130
+ tensor2.append(flattened_output[i + 9])
131
+ tensor3.append(flattened_output[i + 10])
132
+ tensor4.append(flattened_output[i + 11])
133
+ tensor4.append(flattened_output[i + 12])
134
+ tensor3.append(flattened_output[i + 13])
135
+ tensor4.append(flattened_output[i + 14])
136
+ tensor4.append(flattened_output[i + 15])
137
+
138
+ codes = [
139
+ list_to_torch_tensor(tensor1).to(device),
140
+ list_to_torch_tensor(tensor2).to(device),
141
+ list_to_torch_tensor(tensor3).to(device),
142
+ list_to_torch_tensor(tensor4).to(device),
143
+ ]
144
+
145
+ return codes
146
+
mini-omni-main (1)/utils/vad.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import functools
3
+ import os
4
+ import warnings
5
+
6
+ from typing import List, NamedTuple, Optional
7
+
8
+ import numpy as np
9
+
10
+
11
+ # The code below is adapted from https://github.com/snakers4/silero-vad.
12
+ class VadOptions(NamedTuple):
13
+ """VAD options.
14
+
15
+ Attributes:
16
+ threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
17
+ probabilities ABOVE this value are considered as SPEECH. It is better to tune this
18
+ parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
19
+ min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
20
+ max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
21
+ than max_speech_duration_s will be split at the timestamp of the last silence that
22
+ lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
23
+ split aggressively just before max_speech_duration_s.
24
+ min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
25
+ before separating it
26
+ window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
27
+ WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
28
+ Values other than these may affect model performance!!
29
+ speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
30
+ """
31
+
32
+ threshold: float = 0.5
33
+ min_speech_duration_ms: int = 250
34
+ max_speech_duration_s: float = float("inf")
35
+ min_silence_duration_ms: int = 2000
36
+ window_size_samples: int = 1024
37
+ speech_pad_ms: int = 400
38
+
39
+
40
+ def get_speech_timestamps(
41
+ audio: np.ndarray,
42
+ vad_options: Optional[VadOptions] = None,
43
+ **kwargs,
44
+ ) -> List[dict]:
45
+ """This method is used for splitting long audios into speech chunks using silero VAD.
46
+
47
+ Args:
48
+ audio: One dimensional float array.
49
+ vad_options: Options for VAD processing.
50
+ kwargs: VAD options passed as keyword arguments for backward compatibility.
51
+
52
+ Returns:
53
+ List of dicts containing begin and end samples of each speech chunk.
54
+ """
55
+ if vad_options is None:
56
+ vad_options = VadOptions(**kwargs)
57
+
58
+ threshold = vad_options.threshold
59
+ min_speech_duration_ms = vad_options.min_speech_duration_ms
60
+ max_speech_duration_s = vad_options.max_speech_duration_s
61
+ min_silence_duration_ms = vad_options.min_silence_duration_ms
62
+ window_size_samples = vad_options.window_size_samples
63
+ speech_pad_ms = vad_options.speech_pad_ms
64
+
65
+ if window_size_samples not in [512, 1024, 1536]:
66
+ warnings.warn(
67
+ "Unusual window_size_samples! Supported window_size_samples:\n"
68
+ " - [512, 1024, 1536] for 16000 sampling_rate"
69
+ )
70
+
71
+ sampling_rate = 16000
72
+ min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
73
+ speech_pad_samples = sampling_rate * speech_pad_ms / 1000
74
+ max_speech_samples = (
75
+ sampling_rate * max_speech_duration_s
76
+ - window_size_samples
77
+ - 2 * speech_pad_samples
78
+ )
79
+ min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
80
+ min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
81
+
82
+ audio_length_samples = len(audio)
83
+
84
+ model = get_vad_model()
85
+ state = model.get_initial_state(batch_size=1)
86
+
87
+ speech_probs = []
88
+ for current_start_sample in range(0, audio_length_samples, window_size_samples):
89
+ chunk = audio[current_start_sample : current_start_sample + window_size_samples]
90
+ if len(chunk) < window_size_samples:
91
+ chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
92
+ speech_prob, state = model(chunk, state, sampling_rate)
93
+ speech_probs.append(speech_prob)
94
+
95
+ triggered = False
96
+ speeches = []
97
+ current_speech = {}
98
+ neg_threshold = threshold - 0.15
99
+
100
+ # to save potential segment end (and tolerate some silence)
101
+ temp_end = 0
102
+ # to save potential segment limits in case of maximum segment size reached
103
+ prev_end = next_start = 0
104
+
105
+ for i, speech_prob in enumerate(speech_probs):
106
+ if (speech_prob >= threshold) and temp_end:
107
+ temp_end = 0
108
+ if next_start < prev_end:
109
+ next_start = window_size_samples * i
110
+
111
+ if (speech_prob >= threshold) and not triggered:
112
+ triggered = True
113
+ current_speech["start"] = window_size_samples * i
114
+ continue
115
+
116
+ if (
117
+ triggered
118
+ and (window_size_samples * i) - current_speech["start"] > max_speech_samples
119
+ ):
120
+ if prev_end:
121
+ current_speech["end"] = prev_end
122
+ speeches.append(current_speech)
123
+ current_speech = {}
124
+ # previously reached silence (< neg_thres) and is still not speech (< thres)
125
+ if next_start < prev_end:
126
+ triggered = False
127
+ else:
128
+ current_speech["start"] = next_start
129
+ prev_end = next_start = temp_end = 0
130
+ else:
131
+ current_speech["end"] = window_size_samples * i
132
+ speeches.append(current_speech)
133
+ current_speech = {}
134
+ prev_end = next_start = temp_end = 0
135
+ triggered = False
136
+ continue
137
+
138
+ if (speech_prob < neg_threshold) and triggered:
139
+ if not temp_end:
140
+ temp_end = window_size_samples * i
141
+ # condition to avoid cutting in very short silence
142
+ if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
143
+ prev_end = temp_end
144
+ if (window_size_samples * i) - temp_end < min_silence_samples:
145
+ continue
146
+ else:
147
+ current_speech["end"] = temp_end
148
+ if (
149
+ current_speech["end"] - current_speech["start"]
150
+ ) > min_speech_samples:
151
+ speeches.append(current_speech)
152
+ current_speech = {}
153
+ prev_end = next_start = temp_end = 0
154
+ triggered = False
155
+ continue
156
+
157
+ if (
158
+ current_speech
159
+ and (audio_length_samples - current_speech["start"]) > min_speech_samples
160
+ ):
161
+ current_speech["end"] = audio_length_samples
162
+ speeches.append(current_speech)
163
+
164
+ for i, speech in enumerate(speeches):
165
+ if i == 0:
166
+ speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
167
+ if i != len(speeches) - 1:
168
+ silence_duration = speeches[i + 1]["start"] - speech["end"]
169
+ if silence_duration < 2 * speech_pad_samples:
170
+ speech["end"] += int(silence_duration // 2)
171
+ speeches[i + 1]["start"] = int(
172
+ max(0, speeches[i + 1]["start"] - silence_duration // 2)
173
+ )
174
+ else:
175
+ speech["end"] = int(
176
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
177
+ )
178
+ speeches[i + 1]["start"] = int(
179
+ max(0, speeches[i + 1]["start"] - speech_pad_samples)
180
+ )
181
+ else:
182
+ speech["end"] = int(
183
+ min(audio_length_samples, speech["end"] + speech_pad_samples)
184
+ )
185
+
186
+ return speeches
187
+
188
+
189
+ def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
190
+ """Collects and concatenates audio chunks."""
191
+ if not chunks:
192
+ return np.array([], dtype=np.float32)
193
+
194
+ return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
195
+
196
+
197
+ class SpeechTimestampsMap:
198
+ """Helper class to restore original speech timestamps."""
199
+
200
+ def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
201
+ self.sampling_rate = sampling_rate
202
+ self.time_precision = time_precision
203
+ self.chunk_end_sample = []
204
+ self.total_silence_before = []
205
+
206
+ previous_end = 0
207
+ silent_samples = 0
208
+
209
+ for chunk in chunks:
210
+ silent_samples += chunk["start"] - previous_end
211
+ previous_end = chunk["end"]
212
+
213
+ self.chunk_end_sample.append(chunk["end"] - silent_samples)
214
+ self.total_silence_before.append(silent_samples / sampling_rate)
215
+
216
+ def get_original_time(
217
+ self,
218
+ time: float,
219
+ chunk_index: Optional[int] = None,
220
+ ) -> float:
221
+ if chunk_index is None:
222
+ chunk_index = self.get_chunk_index(time)
223
+
224
+ total_silence_before = self.total_silence_before[chunk_index]
225
+ return round(total_silence_before + time, self.time_precision)
226
+
227
+ def get_chunk_index(self, time: float) -> int:
228
+ sample = int(time * self.sampling_rate)
229
+ return min(
230
+ bisect.bisect(self.chunk_end_sample, sample),
231
+ len(self.chunk_end_sample) - 1,
232
+ )
233
+
234
+
235
+ @functools.lru_cache
236
+ def get_vad_model():
237
+ """Returns the VAD model instance."""
238
+ asset_dir = os.path.join(os.path.dirname(__file__), "assets")
239
+ path = os.path.join(asset_dir, "silero_vad.onnx")
240
+ return SileroVADModel(path)
241
+
242
+
243
+ class SileroVADModel:
244
+ def __init__(self, path):
245
+ try:
246
+ import onnxruntime
247
+ except ImportError as e:
248
+ raise RuntimeError(
249
+ "Applying the VAD filter requires the onnxruntime package"
250
+ ) from e
251
+
252
+ opts = onnxruntime.SessionOptions()
253
+ opts.inter_op_num_threads = 1
254
+ opts.intra_op_num_threads = 1
255
+ opts.log_severity_level = 4
256
+
257
+ self.session = onnxruntime.InferenceSession(
258
+ path,
259
+ providers=["CPUExecutionProvider"],
260
+ sess_options=opts,
261
+ )
262
+
263
+ def get_initial_state(self, batch_size: int):
264
+ h = np.zeros((2, batch_size, 64), dtype=np.float32)
265
+ c = np.zeros((2, batch_size, 64), dtype=np.float32)
266
+ return h, c
267
+
268
+ def __call__(self, x, state, sr: int):
269
+ if len(x.shape) == 1:
270
+ x = np.expand_dims(x, 0)
271
+ if len(x.shape) > 2:
272
+ raise ValueError(
273
+ f"Too many dimensions for input audio chunk {len(x.shape)}"
274
+ )
275
+ if sr / x.shape[1] > 31.25:
276
+ raise ValueError("Input audio chunk is too short")
277
+
278
+ h, c = state
279
+
280
+ ort_inputs = {
281
+ "input": x,
282
+ "h": h,
283
+ "c": c,
284
+ "sr": np.array(sr, dtype="int64"),
285
+ }
286
+
287
+ out, h, c = self.session.run(None, ort_inputs)
288
+ state = (h, c)
289
+
290
+ return out, state
mini-omni-main (1)/webui/omni_gradio.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A simple web interactive chat demo based on gradio."""
2
+
3
+ import os
4
+ import time
5
+ import gradio as gr
6
+ import base64
7
+ import numpy as np
8
+ import requests
9
+
10
+
11
+ API_URL = os.getenv("API_URL", None)
12
+ client = None
13
+
14
+ if API_URL is None:
15
+ from inference import OmniInference
16
+ omni_client = OmniInference('./checkpoint', 'cuda:0')
17
+ omni_client.warm_up()
18
+
19
+
20
+ OUT_CHUNK = 4096
21
+ OUT_RATE = 24000
22
+ OUT_CHANNELS = 1
23
+
24
+
25
+ def process_audio(audio):
26
+ filepath = audio
27
+ print(f"filepath: {filepath}")
28
+ if filepath is None:
29
+ return
30
+
31
+ cnt = 0
32
+ if API_URL is not None:
33
+ with open(filepath, "rb") as f:
34
+ data = f.read()
35
+ base64_encoded = str(base64.b64encode(data), encoding="utf-8")
36
+ files = {"audio": base64_encoded}
37
+ tik = time.time()
38
+ with requests.post(API_URL, json=files, stream=True) as response:
39
+ try:
40
+ for chunk in response.iter_content(chunk_size=OUT_CHUNK):
41
+ if chunk:
42
+ # Convert chunk to numpy array
43
+ if cnt == 0:
44
+ print(f"first chunk time cost: {time.time() - tik:.3f}")
45
+ cnt += 1
46
+ audio_data = np.frombuffer(chunk, dtype=np.int16)
47
+ audio_data = audio_data.reshape(-1, OUT_CHANNELS)
48
+ yield OUT_RATE, audio_data.astype(np.int16)
49
+
50
+ except Exception as e:
51
+ print(f"error: {e}")
52
+ else:
53
+ tik = time.time()
54
+ for chunk in omni_client.run_AT_batch_stream(filepath):
55
+ # Convert chunk to numpy array
56
+ if cnt == 0:
57
+ print(f"first chunk time cost: {time.time() - tik:.3f}")
58
+ cnt += 1
59
+ audio_data = np.frombuffer(chunk, dtype=np.int16)
60
+ audio_data = audio_data.reshape(-1, OUT_CHANNELS)
61
+ yield OUT_RATE, audio_data.astype(np.int16)
62
+
63
+
64
+ def main(port=None):
65
+
66
+ demo = gr.Interface(
67
+ process_audio,
68
+ inputs=gr.Audio(type="filepath", label="Microphone"),
69
+ outputs=[gr.Audio(label="Response", streaming=True, autoplay=True)],
70
+ title="Chat Mini-Omni Demo",
71
+ live=True,
72
+ )
73
+ if port is not None:
74
+ demo.queue().launch(share=False, server_name="0.0.0.0", server_port=port)
75
+ else:
76
+ demo.queue().launch()
77
+
78
+
79
+ if __name__ == "__main__":
80
+ import fire
81
+
82
+ fire.Fire(main)
mini-omni-main (1)/webui/omni_streamlit.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import wave
3
+
4
+ # from ASR import recognize
5
+ import requests
6
+ import pyaudio
7
+ import numpy as np
8
+ import base64
9
+ import io
10
+ import os
11
+ import time
12
+ import tempfile
13
+ import librosa
14
+ import traceback
15
+ from pydub import AudioSegment
16
+ from utils.vad import get_speech_timestamps, collect_chunks, VadOptions
17
+
18
+
19
+ API_URL = os.getenv("API_URL", "http://127.0.0.1:60808/chat")
20
+
21
+ # recording parameters
22
+ IN_FORMAT = pyaudio.paInt16
23
+ IN_CHANNELS = 1
24
+ IN_RATE = 24000
25
+ IN_CHUNK = 1024
26
+ IN_SAMPLE_WIDTH = 2
27
+ VAD_STRIDE = 0.5
28
+
29
+ # playing parameters
30
+ OUT_FORMAT = pyaudio.paInt16
31
+ OUT_CHANNELS = 1
32
+ OUT_RATE = 24000
33
+ OUT_SAMPLE_WIDTH = 2
34
+ OUT_CHUNK = 5760
35
+
36
+
37
+ # Initialize chat history
38
+ if "messages" not in st.session_state:
39
+ st.session_state.messages = []
40
+
41
+
42
+ def run_vad(ori_audio, sr):
43
+ _st = time.time()
44
+ try:
45
+ audio = np.frombuffer(ori_audio, dtype=np.int16)
46
+ audio = audio.astype(np.float32) / 32768.0
47
+ sampling_rate = 16000
48
+ if sr != sampling_rate:
49
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=sampling_rate)
50
+
51
+ vad_parameters = {}
52
+ vad_parameters = VadOptions(**vad_parameters)
53
+ speech_chunks = get_speech_timestamps(audio, vad_parameters)
54
+ audio = collect_chunks(audio, speech_chunks)
55
+ duration_after_vad = audio.shape[0] / sampling_rate
56
+
57
+ if sr != sampling_rate:
58
+ # resample to original sampling rate
59
+ vad_audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=sr)
60
+ else:
61
+ vad_audio = audio
62
+ vad_audio = np.round(vad_audio * 32768.0).astype(np.int16)
63
+ vad_audio_bytes = vad_audio.tobytes()
64
+
65
+ return duration_after_vad, vad_audio_bytes, round(time.time() - _st, 4)
66
+ except Exception as e:
67
+ msg = f"[asr vad error] audio_len: {len(ori_audio)/(sr*2):.3f} s, trace: {traceback.format_exc()}"
68
+ print(msg)
69
+ return -1, ori_audio, round(time.time() - _st, 4)
70
+
71
+
72
+ def warm_up():
73
+ frames = b"\x00\x00" * 1024 * 2 # 1024 frames of 2 bytes each
74
+ dur, frames, tcost = run_vad(frames, 16000)
75
+ print(f"warm up done, time_cost: {tcost:.3f} s")
76
+
77
+
78
+ def save_tmp_audio(audio_bytes):
79
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
80
+ file_name = tmpfile.name
81
+ audio = AudioSegment(
82
+ data=audio_bytes,
83
+ sample_width=OUT_SAMPLE_WIDTH,
84
+ frame_rate=OUT_RATE,
85
+ channels=OUT_CHANNELS,
86
+ )
87
+ audio.export(file_name, format="wav")
88
+ return file_name
89
+
90
+
91
+ def speaking(status):
92
+
93
+ # Initialize PyAudio
94
+ p = pyaudio.PyAudio()
95
+
96
+ # Open PyAudio stream
97
+ stream = p.open(
98
+ format=OUT_FORMAT, channels=OUT_CHANNELS, rate=OUT_RATE, output=True
99
+ )
100
+
101
+ audio_buffer = io.BytesIO()
102
+ wf = wave.open(audio_buffer, "wb")
103
+ wf.setnchannels(IN_CHANNELS)
104
+ wf.setsampwidth(IN_SAMPLE_WIDTH)
105
+ wf.setframerate(IN_RATE)
106
+ total_frames = b"".join(st.session_state.frames)
107
+ dur = len(total_frames) / (IN_RATE * IN_CHANNELS * IN_SAMPLE_WIDTH)
108
+ status.warning(f"Speaking... recorded audio duration: {dur:.3f} s")
109
+ wf.writeframes(total_frames)
110
+
111
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmpfile:
112
+ with open(tmpfile.name, "wb") as f:
113
+ f.write(audio_buffer.getvalue())
114
+ file_name = tmpfile.name
115
+ with st.chat_message("user"):
116
+ st.audio(file_name, format="audio/wav", loop=False, autoplay=False)
117
+ st.session_state.messages.append(
118
+ {"role": "assistant", "content": file_name, "type": "audio"}
119
+ )
120
+
121
+ st.session_state.frames = []
122
+
123
+ audio_bytes = audio_buffer.getvalue()
124
+ base64_encoded = str(base64.b64encode(audio_bytes), encoding="utf-8")
125
+ files = {"audio": base64_encoded}
126
+ output_audio_bytes = b""
127
+ with requests.post(API_URL, json=files, stream=True) as response:
128
+ try:
129
+ for chunk in response.iter_content(chunk_size=OUT_CHUNK):
130
+ if chunk:
131
+ # Convert chunk to numpy array
132
+ output_audio_bytes += chunk
133
+ audio_data = np.frombuffer(chunk, dtype=np.int8)
134
+ # Play audio
135
+ stream.write(audio_data)
136
+ except Exception as e:
137
+ st.error(f"Error during audio streaming: {e}")
138
+
139
+ out_file = save_tmp_audio(output_audio_bytes)
140
+ with st.chat_message("assistant"):
141
+ st.audio(out_file, format="audio/wav", loop=False, autoplay=False)
142
+ st.session_state.messages.append(
143
+ {"role": "assistant", "content": out_file, "type": "audio"}
144
+ )
145
+
146
+ wf.close()
147
+ # Close PyAudio stream and terminate PyAudio
148
+ stream.stop_stream()
149
+ stream.close()
150
+ p.terminate()
151
+ st.session_state.speaking = False
152
+ st.session_state.recording = True
153
+
154
+
155
+ def recording(status):
156
+ audio = pyaudio.PyAudio()
157
+
158
+ stream = audio.open(
159
+ format=IN_FORMAT,
160
+ channels=IN_CHANNELS,
161
+ rate=IN_RATE,
162
+ input=True,
163
+ frames_per_buffer=IN_CHUNK,
164
+ )
165
+
166
+ temp_audio = b""
167
+ vad_audio = b""
168
+
169
+ start_talking = False
170
+ last_temp_audio = None
171
+ st.session_state.frames = []
172
+
173
+ while st.session_state.recording:
174
+ status.success("Listening...")
175
+ audio_bytes = stream.read(IN_CHUNK)
176
+ temp_audio += audio_bytes
177
+
178
+ if len(temp_audio) > IN_SAMPLE_WIDTH * IN_RATE * IN_CHANNELS * VAD_STRIDE:
179
+ dur_vad, vad_audio_bytes, time_vad = run_vad(temp_audio, IN_RATE)
180
+
181
+ print(f"duration_after_vad: {dur_vad:.3f} s, time_vad: {time_vad:.3f} s")
182
+
183
+ if dur_vad > 0.2 and not start_talking:
184
+ if last_temp_audio is not None:
185
+ st.session_state.frames.append(last_temp_audio)
186
+ start_talking = True
187
+ if start_talking:
188
+ st.session_state.frames.append(temp_audio)
189
+ if dur_vad < 0.1 and start_talking:
190
+ st.session_state.recording = False
191
+ print(f"speech end detected. excit")
192
+ last_temp_audio = temp_audio
193
+ temp_audio = b""
194
+
195
+ stream.stop_stream()
196
+ stream.close()
197
+
198
+ audio.terminate()
199
+
200
+
201
+ def main():
202
+
203
+ st.title("Chat Mini-Omni Demo")
204
+ status = st.empty()
205
+
206
+ if "warm_up" not in st.session_state:
207
+ warm_up()
208
+ st.session_state.warm_up = True
209
+ if "start" not in st.session_state:
210
+ st.session_state.start = False
211
+ if "recording" not in st.session_state:
212
+ st.session_state.recording = False
213
+ if "speaking" not in st.session_state:
214
+ st.session_state.speaking = False
215
+ if "frames" not in st.session_state:
216
+ st.session_state.frames = []
217
+
218
+ if not st.session_state.start:
219
+ status.warning("Click Start to chat")
220
+
221
+ start_col, stop_col, _ = st.columns([0.2, 0.2, 0.6])
222
+ start_button = start_col.button("Start", key="start_button")
223
+ # stop_button = stop_col.button("Stop", key="stop_button")
224
+ if start_button:
225
+ time.sleep(1)
226
+ st.session_state.recording = True
227
+ st.session_state.start = True
228
+
229
+ for message in st.session_state.messages:
230
+ with st.chat_message(message["role"]):
231
+ if message["type"] == "msg":
232
+ st.markdown(message["content"])
233
+ elif message["type"] == "img":
234
+ st.image(message["content"], width=300)
235
+ elif message["type"] == "audio":
236
+ st.audio(
237
+ message["content"], format="audio/wav", loop=False, autoplay=False
238
+ )
239
+
240
+ while st.session_state.start:
241
+ if st.session_state.recording:
242
+ recording(status)
243
+
244
+ if not st.session_state.recording and st.session_state.start:
245
+ st.session_state.speaking = True
246
+ speaking(status)
247
+
248
+ # if stop_button:
249
+ # status.warning("Stopped, click Start to chat")
250
+ # st.session_state.start = False
251
+ # st.session_state.recording = False
252
+ # st.session_state.frames = []
253
+ # break
254
+
255
+
256
+ if __name__ == "__main__":
257
+ main()