Spaces:
BAAI
/
Running on L40S

akhaliq HF staff commited on
Commit
261b6ba
·
verified ·
1 Parent(s): 5d3ad1b

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ # pytype static type analyzer
135
+ .pytype/
136
+
137
+ # Cython debug symbols
138
+ cython_debug/
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,218 @@
1
- ---
2
- title: Emu3
3
- emoji: 🦀
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 4.44.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align='center'>
2
+ <h1>Emu3: Next-Token Prediction is All You Need</h1h1>
3
+ <h3></h3>
4
+
5
+ [Emu3 Team, BAAI](https://www.baai.ac.cn/english.html)
6
+
7
+ | [Project Page](https://emu.baai.ac.cn) | [Paper](https://baai-solution.ks3-cn-beijing.ksyuncs.com/emu3/Emu3-tech-report.pdf?KSSAccessKeyId=AKLTgew6Kdg6RsK92QSfB2KLA&Expires=2591406552&Signature=6BvwfLVqvfww26Bhwvk3mG0FrL8%3D) | [🤗HF Models](https://huggingface.co/collections/BAAI/emu3-66f4e64f70850ff358a2e60f) |
8
+
9
+
10
+ </div>
11
+
12
+ <div align='center'>
13
+ <img src="./assets/arch.png" class="interpolation-image" alt="arch." height="80%" width="70%" />
14
+ </div>
15
+
16
+ We introduce **Emu3**, a new suite of state-of-the-art multimodal models trained solely with **<i>next-token prediction</i>**! By tokenizing images, text, and videos into a discrete space, we train a single transformer from scratch on a mixture of multimodal sequences.
17
+
18
+ ### Emu3 excels in both generation and perception
19
+ **Emu3** outperforms several well-established task-specific models in both generation and perception tasks, surpassing flagship open models such as SDXL, LLaVA-1.6 and OpenSora-1.2, while eliminating the need for diffusion or compositional architectures.
20
+
21
+ <div align='center'>
22
+ <img src="./assets/comparison.png" class="interpolation-image" alt="comparison." height="80%" width="80%" />
23
+ </div>
24
+
25
+ ### Highlights
26
+
27
+ - **Emu3** is capable of generating high-quality images following the text input, by simply predicting the next vision token. The model naturally supports flexible resolutions and styles.
28
+ - **Emu3** shows strong vision-language understanding capabilities to see the physical world and provides coherent text responses. Notably, this capability is achieved without depending on a CLIP and a pretrained LLM.
29
+ - **Emu3** simply generates a video causally by predicting the next token in a video sequence, unlike the video diffusion model as in Sora. With a video in context, Emu3 can also naturally extend the video and predict what will happen next.
30
+
31
+
32
+ ### TODO
33
+
34
+ - [X] Release model weights of tokenizer, Emu3-Chat and Emu3-Gen
35
+ - [X] Release the inference code.
36
+ - [ ] Release the evaluation code.
37
+ - [ ] Release training scripts for pretrain, sft and dpo.
38
+
39
+
40
+ ### Setup
41
+
42
+ Clone this repository and install required packages:
43
+
44
+ ```shell
45
+ git clone https://github.com/baaivision/Emu3
46
+ cd Emu3
47
+
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ ### Model Weights
52
+
53
+ | Model name | HF Weight |
54
+ | ------------------ | ------------------------------------------------------- |
55
+ | **Emu3-Chat** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-Chat) |
56
+ | **Emu3-Gen** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-Gen) |
57
+ | **Emu3-VisionTokenizer** | [🤗 HF link](https://huggingface.co/BAAI/Emu3-VisionTokenizer) |
58
+
59
+ ### Quickstart
60
+
61
+ #### Use 🤗Transformers to run Emu3-Gen for image generation
62
+ ```python
63
+ from PIL import Image
64
+ from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
65
+ from transformers.generation.configuration_utils import GenerationConfig
66
+ from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor
67
+ import torch
68
+
69
+ from emu3.mllm.processing_emu3 import Emu3Processor
70
+
71
+
72
+ # model path
73
+ EMU_HUB = "BAAI/Emu3-Gen"
74
+ VQ_HUB = "BAAI/Emu3-VisionTokenizer"
75
+
76
+ # prepare model and processor
77
+ model = AutoModelForCausalLM.from_pretrained(
78
+ EMU_HUB,
79
+ device_map="cuda:0",
80
+ torch_dtype=torch.bfloat16,
81
+ attn_implementation="flash_attention_2",
82
+ trust_remote_code=True,
83
+ )
84
+
85
+ tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True)
86
+ image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
87
+ image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
88
+ processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
89
+
90
+ # prepare input
91
+ POSITIVE_PROMPT = " masterpiece, film grained, best quality."
92
+ NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
93
+
94
+ classifier_free_guidance = 3.0
95
+ prompt = "a portrait of young girl."
96
+ prompt += POSITIVE_PROMPT
97
+
98
+ kwargs = dict(
99
+ mode='G',
100
+ ratio="1:1",
101
+ image_area=model.config.image_area,
102
+ return_tensors="pt",
103
+ )
104
+ pos_inputs = processor(text=prompt, **kwargs)
105
+ neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs)
106
+
107
+ # prepare hyper parameters
108
+ GENERATION_CONFIG = GenerationConfig(
109
+ use_cache=True,
110
+ eos_token_id=model.config.eos_token_id,
111
+ pad_token_id=model.config.pad_token_id,
112
+ max_new_tokens=40960,
113
+ do_sample=True,
114
+ top_k=2048,
115
+ )
116
+
117
+ h, w = pos_inputs.image_size[0]
118
+ constrained_fn = processor.build_prefix_constrained_fn(h, w)
119
+ logits_processor = LogitsProcessorList([
120
+ UnbatchedClassifierFreeGuidanceLogitsProcessor(
121
+ classifier_free_guidance,
122
+ model,
123
+ unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
124
+ ),
125
+ PrefixConstrainedLogitsProcessor(
126
+ constrained_fn ,
127
+ num_beams=1,
128
+ ),
129
+ ])
130
+
131
+ # generate
132
+ outputs = model.generate(
133
+ pos_inputs.input_ids.to("cuda:0"),
134
+ GENERATION_CONFIG,
135
+ logits_processor=logits_processor
136
+ )
137
+
138
+ mm_list = processor.decode(outputs[0])
139
+ for idx, im in enumerate(mm_list):
140
+ if not isinstance(im, Image.Image):
141
+ continue
142
+ im.save(f"result_{idx}.png")
143
+ ```
144
+
145
+ #### Use 🤗Transformers to run Emu3-Chat for vision-language understanding
146
+
147
+ ```python
148
+ from PIL import Image
149
+ from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
150
+ from transformers.generation.configuration_utils import GenerationConfig
151
+ import torch
152
+
153
+ from emu3.mllm.processing_emu3 import Emu3Processor
154
+
155
+
156
+ # model path
157
+ EMU_HUB = "BAAI/Emu3-Chat"
158
+ VQ_HUB = "BAAI/Emu3-VisionTokenier"
159
+
160
+ # prepare model and processor
161
+ model = AutoModelForCausalLM.from_pretrained(
162
+ EMU_HUB,
163
+ device_map="cuda:0",
164
+ torch_dtype=torch.bfloat16,
165
+ attn_implementation="flash_attention_2",
166
+ trust_remote_code=True,
167
+ )
168
+
169
+ tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True)
170
+ image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
171
+ image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
172
+ processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
173
+
174
+ # prepare input
175
+ text = "Please describe the image"
176
+ image = Image.open("assets/demo.png")
177
+
178
+ inputs = processor(
179
+ text=text,
180
+ image=image,
181
+ mode='U',
182
+ padding_side="left",
183
+ padding="longest",
184
+ return_tensors="pt",
185
+ )
186
+
187
+ # prepare hyper parameters
188
+ GENERATION_CONFIG = GenerationConfig(pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
189
+
190
+ # generate
191
+ outputs = model.generate(
192
+ inputs.input_ids.to("cuda:0"),
193
+ GENERATION_CONFIG,
194
+ max_new_tokens=320,
195
+ )
196
+
197
+ outputs = outputs[:, inputs.input_ids.shape[-1]:]
198
+ print(processor.batch_decode(outputs, skip_special_tokens=True)[0])
199
+ ```
200
+
201
+ ## Acknowledgement
202
+
203
+ We thank the great work from [Emu Series](https://github.com/baaivision/Emu), [QWen2-VL](https://github.com/QwenLM/Qwen2-VL) and [MoVQGAN](https://github.com/ai-forever/MoVQGAN)
204
+
205
+ <!--
206
+ ## Citation
207
+
208
+ If you find Emu3 useful for your research and applications, please consider starring this repository and citing:
209
+
210
+ ```
211
+ @article{Emu2,
212
+ title={Generative Multimodal Models are In-Context Learners},
213
+ author={Quan Sun and Yufeng Cui and Xiaosong Zhang and Fan Zhang and Qiying Yu and Zhengxiong Luo and Yueze Wang and Yongming Rao and Jingjing Liu and Tiejun Huang and Xinlong Wang},
214
+ publisher={arXiv preprint arXiv:2312.13286},
215
+ year={2023},
216
+ }
217
+ ```
218
+ -->
assets/arch.png ADDED
assets/comparison.png ADDED
assets/demo.png ADDED
emu3/mllm/__init__.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 BAAI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import (
17
+ OptionalDependencyNotAvailable,
18
+ _LazyModule,
19
+ is_torch_available,
20
+ )
21
+
22
+
23
+ _import_structure = {
24
+ "configuration_emu3": ["Emu3Config"],
25
+ "tokenization_emu3": ["Emu3Tokenizer"],
26
+ "processing_emu3": ["Emu3Processor"],
27
+ }
28
+
29
+ try:
30
+ if not is_torch_available():
31
+ raise OptionalDependencyNotAvailable()
32
+ except OptionalDependencyNotAvailable:
33
+ pass
34
+ else:
35
+ _import_structure["modeling_emu3"] = [
36
+ "Emu3Model",
37
+ "Emu3PretrainedModel",
38
+ "Emu3ForCausalLM",
39
+ ]
40
+
41
+ if TYPE_CHECKING:
42
+ from .configuration_emu3 import Emu3Config
43
+ from .tokenization_emu3 import Emu3Tokenizer
44
+ from .processing_emu3 import Emu3Processor
45
+
46
+ try:
47
+ if not is_torch_available():
48
+ raise OptionalDependencyNotAvailable()
49
+ except OptionalDependencyNotAvailable:
50
+ pass
51
+ else:
52
+ from .modeling_emu3 import (
53
+ Emu3Model,
54
+ Emu3PretrainedModel,
55
+ Emu3ForCausalLM,
56
+ )
57
+
58
+ else:
59
+ import sys
60
+
61
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
emu3/mllm/configuration_emu3.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ Emu3 model configuration"""
21
+
22
+ from typing import Optional
23
+
24
+ from transformers.configuration_utils import PretrainedConfig
25
+ from transformers.utils import logging
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+ EMU3_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
31
+
32
+
33
+ class Emu3Config(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`Emu3Model`]. It is used to instantiate an Emu3
36
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
37
+ defaults will yield a similar configuration to that of the Emu3-8B.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+
43
+ Args:
44
+ vocab_size (`int`, *optional*, defaults to 184622):
45
+ Vocabulary size of the Emu3 model. Defines the number of different tokens that can be represented by the
46
+ `inputs_ids` passed when calling [`Emu3Model`]
47
+ hidden_size (`int`, *optional*, defaults to 4096):
48
+ Dimension of the hidden representations.
49
+ intermediate_size (`int`, *optional*, defaults to 14336):
50
+ Dimension of the MLP representations.
51
+ num_hidden_layers (`int`, *optional*, defaults to 32):
52
+ Number of hidden layers in the Transformer decoder.
53
+ num_attention_heads (`int`, *optional*, defaults to 32):
54
+ Number of attention heads for each attention layer in the Transformer decoder.
55
+ num_key_value_heads (`int`, *optional*, defaults to 8):
56
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
57
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
58
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
59
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
60
+ by meanpooling all the original heads within that group. For more details checkout [this
61
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
62
+ `num_attention_heads`.
63
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
64
+ The non-linear activation function (function or string) in the decoder.
65
+ max_position_embeddings (`int`, *optional*, defaults to 9216):
66
+ The maximum sequence length that this model might ever be used with. Emu supports up to 9216 tokens,
67
+ initializer_range (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
70
+ The epsilon used by the rms normalization layers.
71
+ use_cache (`bool`, *optional*, defaults to `True`):
72
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
73
+ relevant if `config.is_decoder=True`.
74
+ pad_token_id (`int`, *optional*, 151643):
75
+ Padding token id.
76
+ bos_token_id (`int`, *optional*, defaults to 151849):
77
+ Beginning of stream token id.
78
+ eos_token_id (`int`, *optional*, defaults to 151850):
79
+ End of stream token id.
80
+ img_token_id (`int`, *optional*, defaults to 151851):
81
+ image token id.
82
+ boi_token_id (`int`, *optional*, defaults to 151852):
83
+ Beginning of image token id.
84
+ eoi_token_id (`int`, *optional*, defaults to 151853):
85
+ End of image token id.
86
+ eol_token_id (`int`, *optional*, defaults to 151846):
87
+ End of line token id.
88
+ eof_token_id (`int`, *optional*, defaults to 151847):
89
+ End of line token id.
90
+ image_area (`int`, *optional*, defaults to 720 * 720)
91
+ generated image area (image area used in training)
92
+ pretraining_tp (`int`, *optional*, defaults to 1):
93
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
94
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
95
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
96
+ issue](https://github.com/pytorch/pytorch/issues/76232).
97
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
98
+ Whether to tie weight embeddings
99
+ rope_theta (`float`, *optional*, defaults to 1_000_000.0):
100
+ The base period of the RoPE embeddings.
101
+ rope_scaling (`Dict`, *optional*):
102
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
103
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
104
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
105
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
106
+ these scaling strategies behave:
107
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
108
+ experimental feature, subject to breaking API changes in future versions.
109
+ attention_dropout (`float`, *optional*, defaults to 0.1):
110
+ The dropout ratio for the attention probabilities.
111
+
112
+ ```python
113
+ >>> from transformers import Emu3Model, Emu3Config
114
+
115
+ >>> # Initializing a Emu3-8b style configuration
116
+ >>> configuration = Emu3Config()
117
+
118
+ >>> # Initializing a model from the Emu3-8b style configuration
119
+ >>> model = Emu3Model(configuration)
120
+
121
+ >>> # Accessing the model configuration
122
+ >>> configuration = model.config
123
+ ```"""
124
+
125
+ model_type = "Emu3"
126
+ keys_to_ignore_at_inference = ["past_key_values"]
127
+
128
+ def __init__(
129
+ self,
130
+ vocab_size: int = 184622,
131
+ hidden_size: int = 4096,
132
+ intermediate_size: int = 14336,
133
+ num_hidden_layers: int = 32,
134
+ num_attention_heads: int = 32,
135
+ num_key_value_heads: Optional[int] = 8,
136
+ hidden_act: str = "silu",
137
+ max_position_embeddings: int = 9216,
138
+ initializer_range: float = 0.02,
139
+ rms_norm_eps: float = 1e-5,
140
+ use_cache: bool = True,
141
+ pad_token_id: int = 151643,
142
+ bos_token_id: int = 151849,
143
+ eos_token_id: int = 151850,
144
+ img_token_id: int = 151851,
145
+ boi_token_id: int = 151852,
146
+ eoi_token_id: int = 151853,
147
+ eol_token_id: int = 151846,
148
+ eof_token_id: int = 151847,
149
+ image_area: int = 720 * 720,
150
+ pretraining_tp: int = 1,
151
+ tie_word_embeddings: bool = False,
152
+ rope_theta: float = 1000000.0,
153
+ rope_scaling: Optional = None,
154
+ attention_dropout: float = 0.1,
155
+ **kwargs,
156
+ ):
157
+ self.vocab_size = vocab_size
158
+ self.max_position_embeddings = max_position_embeddings
159
+ self.hidden_size = hidden_size
160
+ self.intermediate_size = intermediate_size
161
+ self.num_hidden_layers = num_hidden_layers
162
+ self.num_attention_heads = num_attention_heads
163
+
164
+ # for backward compatibility
165
+ if num_key_value_heads is None:
166
+ num_key_value_heads = num_attention_heads
167
+
168
+ self.num_key_value_heads = num_key_value_heads
169
+ self.hidden_act = hidden_act
170
+ self.initializer_range = initializer_range
171
+ self.rms_norm_eps = rms_norm_eps
172
+ self.pretraining_tp = pretraining_tp
173
+ self.use_cache = use_cache
174
+ self.rope_theta = rope_theta
175
+ self.rope_scaling = rope_scaling
176
+ self._rope_scaling_validation()
177
+ self.attention_dropout = attention_dropout
178
+
179
+ self.img_token_id = img_token_id
180
+ self.boi_token_id = boi_token_id
181
+ self.eoi_token_id = eoi_token_id
182
+ self.eol_token_id = eol_token_id
183
+ self.eof_token_id = eof_token_id
184
+ self.image_area = image_area
185
+
186
+ super().__init__(
187
+ pad_token_id=pad_token_id,
188
+ bos_token_id=bos_token_id,
189
+ eos_token_id=eos_token_id,
190
+ tie_word_embeddings=tie_word_embeddings,
191
+ **kwargs,
192
+ )
193
+
194
+ def _rope_scaling_validation(self):
195
+ """
196
+ Validate the `rope_scaling` configuration.
197
+ """
198
+ if self.rope_scaling is None:
199
+ return
200
+
201
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
202
+ raise ValueError(
203
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
204
+ f"got {self.rope_scaling}"
205
+ )
206
+ rope_scaling_type = self.rope_scaling.get("type", None)
207
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
208
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
209
+ raise ValueError(
210
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
211
+ )
212
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
213
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
emu3/mllm/modeling_emu3.py ADDED
@@ -0,0 +1,1343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ #
21
+ # Adapted from https://github.com/huggingface/transformers/blob/52daf4ec768fb9ffe84a0c373834172a7c54aecc/src/transformers/models/llama/modeling_llama.py
22
+ #
23
+ """ PyTorch Emu3 model."""
24
+ import math
25
+ import warnings
26
+ from typing import List, Optional, Tuple, Union
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import torch.utils.checkpoint
31
+ from torch import nn
32
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
33
+
34
+ from transformers.activations import ACT2FN
35
+ from transformers.cache_utils import Cache, DynamicCache
36
+ from transformers.modeling_attn_mask_utils import (
37
+ AttentionMaskConverter,
38
+ _prepare_4d_attention_mask,
39
+ _prepare_4d_causal_attention_mask,
40
+ _prepare_4d_causal_attention_mask_for_sdpa,
41
+ )
42
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
43
+ from transformers.modeling_utils import PreTrainedModel
44
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
45
+ from transformers.utils import (
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ is_flash_attn_2_available,
49
+ is_flash_attn_greater_or_equal_2_10,
50
+ logging,
51
+ replace_return_docstrings,
52
+ )
53
+ from transformers.utils.import_utils import is_torch_fx_available
54
+ from .configuration_emu3 import Emu3Config
55
+
56
+
57
+ if is_flash_attn_2_available():
58
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
59
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
60
+
61
+
62
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
63
+ # It means that the function will not be traced through and simply appear as a node in the graph.
64
+ if is_torch_fx_available():
65
+ if not is_torch_greater_or_equal_than_1_13:
66
+ import torch.fx
67
+
68
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
69
+
70
+
71
+ logger = logging.get_logger(__name__)
72
+
73
+ _CONFIG_FOR_DOC = "Emu3Config"
74
+
75
+
76
+ def _get_unpad_data(attention_mask):
77
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
78
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
79
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
80
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
81
+ return (
82
+ indices,
83
+ cu_seqlens,
84
+ max_seqlen_in_batch,
85
+ )
86
+
87
+
88
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
89
+ warnings.warn(
90
+ "Calling `transformers.models.emu3.modeling_emu3._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
91
+ )
92
+ return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
93
+
94
+
95
+ def _make_causal_mask(
96
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
97
+ ):
98
+ warnings.warn(
99
+ "Calling `transformers.models.emu3.modeling_emu3._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.emu3.modeling_emu3.AttentionMaskConverter._make_causal_mask"
100
+ )
101
+ return AttentionMaskConverter._make_causal_mask(
102
+ input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
103
+ )
104
+
105
+
106
+ class Emu3RMSNorm(nn.Module):
107
+ def __init__(self, hidden_size, eps=1e-6):
108
+ """
109
+ Emu3RMSNorm is equivalent to T5LayerNorm
110
+ """
111
+ super().__init__()
112
+ self.weight = nn.Parameter(torch.ones(hidden_size))
113
+ self.variance_epsilon = eps
114
+
115
+ def forward(self, hidden_states):
116
+ input_dtype = hidden_states.dtype
117
+ hidden_states = hidden_states.to(torch.float32)
118
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
119
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
120
+ return self.weight * hidden_states.to(input_dtype)
121
+
122
+
123
+ ALL_LAYERNORM_LAYERS.append(Emu3RMSNorm)
124
+
125
+
126
+ class Emu3RotaryEmbedding(nn.Module):
127
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
128
+ super().__init__()
129
+
130
+ self.dim = dim
131
+ self.max_position_embeddings = max_position_embeddings
132
+ self.base = base
133
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
134
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
135
+
136
+ # Build here to make `torch.jit.trace` work.
137
+ self._set_cos_sin_cache(
138
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
139
+ )
140
+
141
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
142
+ self.max_seq_len_cached = seq_len
143
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
144
+
145
+ freqs = torch.outer(t, self.inv_freq)
146
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
147
+ emb = torch.cat((freqs, freqs), dim=-1)
148
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
149
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
150
+
151
+ def forward(self, x, seq_len=None):
152
+ # x: [bs, num_attention_heads, seq_len, head_size]
153
+ if seq_len > self.max_seq_len_cached:
154
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
155
+
156
+ return (
157
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
158
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
159
+ )
160
+
161
+
162
+ class Emu3LinearScalingRotaryEmbedding(Emu3RotaryEmbedding):
163
+ """Emu3RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
164
+
165
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
166
+ self.scaling_factor = scaling_factor
167
+ super().__init__(dim, max_position_embeddings, base, device)
168
+
169
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
170
+ self.max_seq_len_cached = seq_len
171
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
172
+ t = t / self.scaling_factor
173
+
174
+ freqs = torch.outer(t, self.inv_freq)
175
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
176
+ emb = torch.cat((freqs, freqs), dim=-1)
177
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
178
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
179
+
180
+
181
+ class Emu3DynamicNTKScalingRotaryEmbedding(Emu3RotaryEmbedding):
182
+ """Emu3RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
183
+
184
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
185
+ self.scaling_factor = scaling_factor
186
+ super().__init__(dim, max_position_embeddings, base, device)
187
+
188
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
189
+ self.max_seq_len_cached = seq_len
190
+
191
+ if seq_len > self.max_position_embeddings:
192
+ base = self.base * (
193
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
194
+ ) ** (self.dim / (self.dim - 2))
195
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
196
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
197
+
198
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
199
+
200
+ freqs = torch.outer(t, self.inv_freq)
201
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
202
+ emb = torch.cat((freqs, freqs), dim=-1)
203
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
204
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
205
+
206
+
207
+ def rotate_half(x):
208
+ """Rotates half the hidden dims of the input."""
209
+ x1 = x[..., : x.shape[-1] // 2]
210
+ x2 = x[..., x.shape[-1] // 2 :]
211
+ return torch.cat((-x2, x1), dim=-1)
212
+
213
+
214
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
215
+ """Applies Rotary Position Embedding to the query and key tensors.
216
+
217
+ Args:
218
+ q (`torch.Tensor`): The query tensor.
219
+ k (`torch.Tensor`): The key tensor.
220
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
221
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
222
+ position_ids (`torch.Tensor`):
223
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
224
+ used to pass offsetted position ids when working with a KV-cache.
225
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
226
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
227
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
228
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
229
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
230
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
231
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
232
+ Returns:
233
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
234
+ """
235
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
236
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
237
+ q_embed = (q * cos) + (rotate_half(q) * sin)
238
+ k_embed = (k * cos) + (rotate_half(k) * sin)
239
+ return q_embed, k_embed
240
+
241
+
242
+ class Emu3MLP(nn.Module):
243
+ def __init__(self, config):
244
+ super().__init__()
245
+ self.config = config
246
+ self.hidden_size = config.hidden_size
247
+ self.intermediate_size = config.intermediate_size
248
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
249
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
250
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
251
+ self.act_fn = ACT2FN[config.hidden_act]
252
+
253
+ def forward(self, x):
254
+ if self.config.pretraining_tp > 1:
255
+ slice = self.intermediate_size // self.config.pretraining_tp
256
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
257
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
258
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
259
+
260
+ gate_proj = torch.cat(
261
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
262
+ )
263
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
264
+
265
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
266
+ down_proj = [
267
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
268
+ ]
269
+ down_proj = sum(down_proj)
270
+ else:
271
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
272
+
273
+ return down_proj
274
+
275
+
276
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
277
+ """
278
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
279
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
280
+ """
281
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
282
+ if n_rep == 1:
283
+ return hidden_states
284
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
285
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
286
+
287
+
288
+ class Emu3Attention(nn.Module):
289
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
290
+
291
+ def __init__(self, config: Emu3Config, layer_idx: Optional[int] = None):
292
+ super().__init__()
293
+ self.config = config
294
+ self.layer_idx = layer_idx
295
+ if layer_idx is None:
296
+ logger.warning_once(
297
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
298
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
299
+ "when creating this class."
300
+ )
301
+
302
+ self.attention_dropout = config.attention_dropout
303
+ self.hidden_size = config.hidden_size
304
+ self.num_heads = config.num_attention_heads
305
+ self.head_dim = self.hidden_size // self.num_heads
306
+ self.num_key_value_heads = config.num_key_value_heads
307
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
308
+ self.max_position_embeddings = config.max_position_embeddings
309
+ self.rope_theta = config.rope_theta
310
+ self.is_causal = True
311
+
312
+ if (self.head_dim * self.num_heads) != self.hidden_size:
313
+ raise ValueError(
314
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
315
+ f" and `num_heads`: {self.num_heads})."
316
+ )
317
+
318
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
319
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
320
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
321
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
322
+ self._init_rope()
323
+
324
+ def _init_rope(self):
325
+ if self.config.rope_scaling is None:
326
+ self.rotary_emb = Emu3RotaryEmbedding(
327
+ self.head_dim,
328
+ max_position_embeddings=self.max_position_embeddings,
329
+ base=self.rope_theta,
330
+ )
331
+ else:
332
+ scaling_type = self.config.rope_scaling["type"]
333
+ scaling_factor = self.config.rope_scaling["factor"]
334
+ if scaling_type == "linear":
335
+ self.rotary_emb = Emu3LinearScalingRotaryEmbedding(
336
+ self.head_dim,
337
+ max_position_embeddings=self.max_position_embeddings,
338
+ scaling_factor=scaling_factor,
339
+ base=self.rope_theta,
340
+ )
341
+ elif scaling_type == "dynamic":
342
+ self.rotary_emb = Emu3DynamicNTKScalingRotaryEmbedding(
343
+ self.head_dim,
344
+ max_position_embeddings=self.max_position_embeddings,
345
+ scaling_factor=scaling_factor,
346
+ base=self.rope_theta,
347
+ )
348
+ else:
349
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
350
+
351
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
352
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
353
+
354
+ def forward(
355
+ self,
356
+ hidden_states: torch.Tensor,
357
+ attention_mask: Optional[torch.Tensor] = None,
358
+ position_ids: Optional[torch.LongTensor] = None,
359
+ past_key_value: Optional[Cache] = None,
360
+ output_attentions: bool = False,
361
+ use_cache: bool = False,
362
+ **kwargs,
363
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
364
+ if "padding_mask" in kwargs:
365
+ warnings.warn(
366
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
367
+ )
368
+
369
+ bsz, q_len, _ = hidden_states.size()
370
+
371
+ if self.config.pretraining_tp > 1:
372
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
373
+ query_slices = self.q_proj.weight.split(
374
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
375
+ )
376
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
377
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
378
+
379
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
380
+ query_states = torch.cat(query_states, dim=-1)
381
+
382
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
383
+ key_states = torch.cat(key_states, dim=-1)
384
+
385
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
386
+ value_states = torch.cat(value_states, dim=-1)
387
+
388
+ else:
389
+ query_states = self.q_proj(hidden_states)
390
+ key_states = self.k_proj(hidden_states)
391
+ value_states = self.v_proj(hidden_states)
392
+
393
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
394
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
395
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
396
+
397
+ kv_seq_len = key_states.shape[-2]
398
+ if past_key_value is not None:
399
+ if self.layer_idx is None:
400
+ raise ValueError(
401
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
402
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
403
+ "with a layer index."
404
+ )
405
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
406
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
407
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
408
+
409
+ if past_key_value is not None:
410
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
411
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
412
+
413
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
414
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
415
+
416
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
417
+
418
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
419
+ raise ValueError(
420
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
421
+ f" {attn_weights.size()}"
422
+ )
423
+
424
+ if attention_mask is not None:
425
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
426
+ raise ValueError(
427
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
428
+ )
429
+ attn_weights = attn_weights + attention_mask
430
+
431
+ # upcast attention to fp32
432
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
433
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
434
+ attn_output = torch.matmul(attn_weights, value_states)
435
+
436
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
437
+ raise ValueError(
438
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
439
+ f" {attn_output.size()}"
440
+ )
441
+
442
+ attn_output = attn_output.transpose(1, 2).contiguous()
443
+
444
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
445
+
446
+ if self.config.pretraining_tp > 1:
447
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
448
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
449
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
450
+ else:
451
+ attn_output = self.o_proj(attn_output)
452
+
453
+ if not output_attentions:
454
+ attn_weights = None
455
+
456
+ return attn_output, attn_weights, past_key_value
457
+
458
+
459
+ class Emu3FlashAttention2(Emu3Attention):
460
+ """
461
+ Emu3 flash attention module. This module inherits from `Emu3Attention` as the weights of the module stays
462
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
463
+ flash attention and deal with padding tokens in case the input contains any of them.
464
+ """
465
+
466
+ def __init__(self, *args, **kwargs):
467
+ super().__init__(*args, **kwargs)
468
+
469
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
470
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
471
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
472
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
473
+
474
+ def forward(
475
+ self,
476
+ hidden_states: torch.Tensor,
477
+ attention_mask: Optional[torch.LongTensor] = None,
478
+ position_ids: Optional[torch.LongTensor] = None,
479
+ past_key_value: Optional[Cache] = None,
480
+ output_attentions: bool = False,
481
+ use_cache: bool = False,
482
+ **kwargs,
483
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
484
+ # Emu3FlashAttention2 attention does not support output_attentions
485
+ if "padding_mask" in kwargs:
486
+ warnings.warn(
487
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
488
+ )
489
+
490
+ # overwrite attention_mask with padding_mask
491
+ attention_mask = kwargs.pop("padding_mask")
492
+
493
+ output_attentions = False
494
+
495
+ bsz, q_len, _ = hidden_states.size()
496
+
497
+ query_states = self.q_proj(hidden_states)
498
+ key_states = self.k_proj(hidden_states)
499
+ value_states = self.v_proj(hidden_states)
500
+
501
+ # Flash attention requires the input to have the shape
502
+ # batch_size x seq_length x head_dim x hidden_dim
503
+ # therefore we just need to keep the original shape
504
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
505
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
506
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
507
+
508
+ kv_seq_len = key_states.shape[-2]
509
+ if past_key_value is not None:
510
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
511
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
512
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
513
+
514
+ if past_key_value is not None:
515
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
516
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
517
+
518
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
519
+ # to be able to avoid many of these transpose/reshape/view.
520
+ query_states = query_states.transpose(1, 2)
521
+ key_states = key_states.transpose(1, 2)
522
+ value_states = value_states.transpose(1, 2)
523
+
524
+ dropout_rate = self.attention_dropout if self.training else 0.0
525
+
526
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
527
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
528
+ # cast them back in the correct dtype just to be sure everything works as expected.
529
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
530
+ # in fp32. (Emu3RMSNorm handles it correctly)
531
+
532
+ input_dtype = query_states.dtype
533
+ if input_dtype == torch.float32:
534
+ # Handle the case where the model is quantized
535
+ if hasattr(self.config, "_pre_quantization_dtype"):
536
+ target_dtype = self.config._pre_quantization_dtype
537
+ else:
538
+ target_dtype = self.q_proj.weight.dtype
539
+
540
+ logger.warning_once(
541
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
542
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
543
+ f" {target_dtype}."
544
+ )
545
+
546
+ query_states = query_states.to(target_dtype)
547
+ key_states = key_states.to(target_dtype)
548
+ value_states = value_states.to(target_dtype)
549
+
550
+ attn_output = self._flash_attention_forward(
551
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
552
+ )
553
+
554
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
555
+ attn_output = self.o_proj(attn_output)
556
+
557
+ if not output_attentions:
558
+ attn_weights = None
559
+
560
+ return attn_output, attn_weights, past_key_value
561
+
562
+ def _flash_attention_forward(
563
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
564
+ ):
565
+ """
566
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
567
+ first unpad the input, then computes the attention scores and pad the final attention scores.
568
+
569
+ Args:
570
+ query_states (`torch.Tensor`):
571
+ Input query states to be passed to Flash Attention API
572
+ key_states (`torch.Tensor`):
573
+ Input key states to be passed to Flash Attention API
574
+ value_states (`torch.Tensor`):
575
+ Input value states to be passed to Flash Attention API
576
+ attention_mask (`torch.Tensor`):
577
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
578
+ position of padding tokens and 1 for the position of non-padding tokens.
579
+ dropout (`int`, *optional*):
580
+ Attention dropout
581
+ softmax_scale (`float`, *optional*):
582
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
583
+ """
584
+ if not self._flash_attn_uses_top_left_mask:
585
+ causal = self.is_causal
586
+ else:
587
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in Emu3FlashAttention2 __init__.
588
+ causal = self.is_causal and query_length != 1
589
+
590
+ # Contains at least one padding token in the sequence
591
+ if attention_mask is not None:
592
+ batch_size = query_states.shape[0]
593
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
594
+ query_states, key_states, value_states, attention_mask, query_length
595
+ )
596
+
597
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
598
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
599
+
600
+ attn_output_unpad = flash_attn_varlen_func(
601
+ query_states,
602
+ key_states,
603
+ value_states,
604
+ cu_seqlens_q=cu_seqlens_q,
605
+ cu_seqlens_k=cu_seqlens_k,
606
+ max_seqlen_q=max_seqlen_in_batch_q,
607
+ max_seqlen_k=max_seqlen_in_batch_k,
608
+ dropout_p=dropout,
609
+ softmax_scale=softmax_scale,
610
+ causal=causal,
611
+ )
612
+
613
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
614
+ else:
615
+ attn_output = flash_attn_func(
616
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
617
+ )
618
+
619
+ return attn_output
620
+
621
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
622
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
623
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
624
+
625
+ key_layer = index_first_axis(
626
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
627
+ )
628
+ value_layer = index_first_axis(
629
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
630
+ )
631
+ if query_length == kv_seq_len:
632
+ query_layer = index_first_axis(
633
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
634
+ )
635
+ cu_seqlens_q = cu_seqlens_k
636
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
637
+ indices_q = indices_k
638
+ elif query_length == 1:
639
+ max_seqlen_in_batch_q = 1
640
+ cu_seqlens_q = torch.arange(
641
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
642
+ ) # There is a memcpy here, that is very bad.
643
+ indices_q = cu_seqlens_q[:-1]
644
+ query_layer = query_layer.squeeze(1)
645
+ else:
646
+ # The -q_len: slice assumes left padding.
647
+ attention_mask = attention_mask[:, -query_length:]
648
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
649
+
650
+ return (
651
+ query_layer,
652
+ key_layer,
653
+ value_layer,
654
+ indices_q,
655
+ (cu_seqlens_q, cu_seqlens_k),
656
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
657
+ )
658
+
659
+
660
+ class Emu3SdpaAttention(Emu3Attention):
661
+ """
662
+ Emu3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
663
+ `Emu3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
664
+ SDPA API.
665
+ """
666
+
667
+ # Adapted from Emu3Attention.forward
668
+ def forward(
669
+ self,
670
+ hidden_states: torch.Tensor,
671
+ attention_mask: Optional[torch.Tensor] = None,
672
+ position_ids: Optional[torch.LongTensor] = None,
673
+ past_key_value: Optional[Cache] = None,
674
+ output_attentions: bool = False,
675
+ use_cache: bool = False,
676
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
677
+ if output_attentions:
678
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
679
+ logger.warning_once(
680
+ "Emu3Model is using Emu3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
681
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
682
+ )
683
+ return super().forward(
684
+ hidden_states=hidden_states,
685
+ attention_mask=attention_mask,
686
+ position_ids=position_ids,
687
+ past_key_value=past_key_value,
688
+ output_attentions=output_attentions,
689
+ use_cache=use_cache,
690
+ )
691
+
692
+ bsz, q_len, _ = hidden_states.size()
693
+
694
+ query_states = self.q_proj(hidden_states)
695
+ key_states = self.k_proj(hidden_states)
696
+ value_states = self.v_proj(hidden_states)
697
+
698
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
699
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
700
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
701
+
702
+ kv_seq_len = key_states.shape[-2]
703
+ if past_key_value is not None:
704
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
705
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
706
+
707
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
708
+
709
+ if past_key_value is not None:
710
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
711
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
712
+
713
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
714
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
715
+
716
+ if attention_mask is not None:
717
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
718
+ raise ValueError(
719
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
720
+ )
721
+
722
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
723
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
724
+ if query_states.device.type == "cuda" and attention_mask is not None:
725
+ query_states = query_states.contiguous()
726
+ key_states = key_states.contiguous()
727
+ value_states = value_states.contiguous()
728
+
729
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
730
+ query_states,
731
+ key_states,
732
+ value_states,
733
+ attn_mask=attention_mask,
734
+ dropout_p=self.attention_dropout if self.training else 0.0,
735
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
736
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
737
+ )
738
+
739
+ attn_output = attn_output.transpose(1, 2).contiguous()
740
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
741
+
742
+ attn_output = self.o_proj(attn_output)
743
+
744
+ return attn_output, None, past_key_value
745
+
746
+
747
+ EMU3_ATTENTION_CLASSES = {
748
+ "eager": Emu3Attention,
749
+ "flash_attention_2": Emu3FlashAttention2,
750
+ "sdpa": Emu3SdpaAttention,
751
+ }
752
+
753
+
754
+ class Emu3DecoderLayer(nn.Module):
755
+ def __init__(self, config: Emu3Config, layer_idx: int):
756
+ super().__init__()
757
+ self.hidden_size = config.hidden_size
758
+ self.dropout = nn.Dropout(config.attention_dropout)
759
+ self.self_attn = EMU3_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
760
+
761
+ self.mlp = Emu3MLP(config)
762
+ self.input_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
763
+ self.post_attention_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
764
+
765
+ def forward(
766
+ self,
767
+ hidden_states: torch.Tensor,
768
+ attention_mask: Optional[torch.Tensor] = None,
769
+ position_ids: Optional[torch.LongTensor] = None,
770
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
771
+ output_attentions: Optional[bool] = False,
772
+ use_cache: Optional[bool] = False,
773
+ **kwargs,
774
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
775
+ """
776
+ Args:
777
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
778
+ attention_mask (`torch.FloatTensor`, *optional*):
779
+ attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
780
+ query_sequence_length, key_sequence_length)` if default attention is used.
781
+ output_attentions (`bool`, *optional*):
782
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
783
+ returned tensors for more detail.
784
+ use_cache (`bool`, *optional*):
785
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
786
+ (see `past_key_values`).
787
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
788
+ """
789
+ if "padding_mask" in kwargs:
790
+ warnings.warn(
791
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
792
+ )
793
+
794
+ residual = hidden_states
795
+
796
+ hidden_states = self.input_layernorm(hidden_states)
797
+
798
+ # Self Attention
799
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
800
+ hidden_states=hidden_states,
801
+ attention_mask=attention_mask,
802
+ position_ids=position_ids,
803
+ past_key_value=past_key_value,
804
+ output_attentions=output_attentions,
805
+ use_cache=use_cache,
806
+ **kwargs,
807
+ )
808
+ hidden_states = residual + self.dropout(hidden_states)
809
+
810
+ # Fully Connected
811
+ residual = hidden_states
812
+ hidden_states = self.post_attention_layernorm(hidden_states)
813
+ hidden_states = self.mlp(hidden_states)
814
+ hidden_states = residual + self.dropout(hidden_states)
815
+
816
+ outputs = (hidden_states,)
817
+
818
+ if output_attentions:
819
+ outputs += (self_attn_weights,)
820
+
821
+ if use_cache:
822
+ outputs += (present_key_value,)
823
+
824
+ return outputs
825
+
826
+
827
+ EMU3_START_DOCSTRING = r"""
828
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
829
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
830
+ etc.)
831
+
832
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
833
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
834
+ and behavior.
835
+
836
+ Parameters:
837
+ config ([`Emu3Config`]):
838
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
839
+ load the weights associated with the model, only the configuration. Check out the
840
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
841
+ """
842
+
843
+
844
+ @add_start_docstrings(
845
+ "The bare Emu3 Model outputting raw hidden-states without any specific head on top.",
846
+ EMU3_START_DOCSTRING,
847
+ )
848
+ class Emu3PreTrainedModel(PreTrainedModel):
849
+ config_class = Emu3Config
850
+ base_model_prefix = "model"
851
+ supports_gradient_checkpointing = True
852
+ _no_split_modules = ["Emu3DecoderLayer"]
853
+ _skip_keys_device_placement = "past_key_values"
854
+ _supports_flash_attn_2 = True
855
+ _supports_sdpa = True
856
+ _supports_cache_class = True
857
+
858
+ def _init_weights(self, module):
859
+ std = self.config.initializer_range
860
+ if isinstance(module, nn.Linear):
861
+ module.weight.data.normal_(mean=0.0, std=std)
862
+ if module.bias is not None:
863
+ module.bias.data.zero_()
864
+ elif isinstance(module, nn.Embedding):
865
+ module.weight.data.normal_(mean=0.0, std=std)
866
+ if module.padding_idx is not None:
867
+ module.weight.data[module.padding_idx].zero_()
868
+
869
+
870
+ EMU3_INPUTS_DOCSTRING = r"""
871
+ Args:
872
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
873
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
874
+ it.
875
+
876
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
877
+ [`PreTrainedTokenizer.__call__`] for details.
878
+
879
+ [What are input IDs?](../glossary#input-ids)
880
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
881
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
882
+
883
+ - 1 for tokens that are **not masked**,
884
+ - 0 for tokens that are **masked**.
885
+
886
+ [What are attention masks?](../glossary#attention-mask)
887
+
888
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
889
+ [`PreTrainedTokenizer.__call__`] for details.
890
+
891
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
892
+ `past_key_values`).
893
+
894
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
895
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
896
+ information on the default strategy.
897
+
898
+ - 1 indicates the head is **not masked**,
899
+ - 0 indicates the head is **masked**.
900
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
901
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
902
+ config.n_positions - 1]`.
903
+
904
+ [What are position IDs?](../glossary#position-ids)
905
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
906
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
907
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
908
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
909
+
910
+ Two formats are allowed:
911
+ - a [`~cache_utils.Cache`] instance;
912
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
913
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
914
+ cache format.
915
+
916
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
917
+ legacy cache format will be returned.
918
+
919
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
920
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
921
+ of shape `(batch_size, sequence_length)`.
922
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
923
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
924
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
925
+ model's internal embedding lookup matrix.
926
+ use_cache (`bool`, *optional*):
927
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
928
+ `past_key_values`).
929
+ output_attentions (`bool`, *optional*):
930
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
931
+ tensors for more detail.
932
+ output_hidden_states (`bool`, *optional*):
933
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
934
+ more detail.
935
+ return_dict (`bool`, *optional*):
936
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
937
+ """
938
+
939
+
940
+ @add_start_docstrings(
941
+ "The bare Emu3 Model outputting raw hidden-states without any specific head on top.",
942
+ EMU3_START_DOCSTRING,
943
+ )
944
+ class Emu3Model(Emu3PreTrainedModel):
945
+ """
946
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Emu3DecoderLayer`]
947
+
948
+ Args:
949
+ config: Emu3Config
950
+ """
951
+
952
+ def __init__(self, config: Emu3Config):
953
+ super().__init__(config)
954
+ self.padding_idx = config.pad_token_id
955
+ self.vocab_size = config.vocab_size
956
+
957
+ self.dropout = nn.Dropout(config.attention_dropout)
958
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
959
+ self.layers = nn.ModuleList(
960
+ [Emu3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
961
+ )
962
+ self._use_sdpa = config._attn_implementation == "sdpa"
963
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
964
+ self.norm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
965
+
966
+ self.gradient_checkpointing = False
967
+ # Initialize weights and apply final processing
968
+ self.post_init()
969
+
970
+ def get_input_embeddings(self):
971
+ return self.embed_tokens
972
+
973
+ def set_input_embeddings(self, value):
974
+ self.embed_tokens = value
975
+
976
+ @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
977
+ def forward(
978
+ self,
979
+ input_ids: torch.LongTensor = None,
980
+ attention_mask: Optional[torch.Tensor] = None,
981
+ position_ids: Optional[torch.LongTensor] = None,
982
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
983
+ inputs_embeds: Optional[torch.FloatTensor] = None,
984
+ use_cache: Optional[bool] = None,
985
+ output_attentions: Optional[bool] = None,
986
+ output_hidden_states: Optional[bool] = None,
987
+ return_dict: Optional[bool] = None,
988
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
989
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
990
+ output_hidden_states = (
991
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
992
+ )
993
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
994
+
995
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
996
+
997
+ # retrieve input_ids and inputs_embeds
998
+ if input_ids is not None and inputs_embeds is not None:
999
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1000
+ elif input_ids is not None:
1001
+ batch_size, seq_length = input_ids.shape[:2]
1002
+ elif inputs_embeds is not None:
1003
+ batch_size, seq_length = inputs_embeds.shape[:2]
1004
+ else:
1005
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1006
+
1007
+ if self.gradient_checkpointing and self.training:
1008
+ if use_cache:
1009
+ logger.warning_once(
1010
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1011
+ )
1012
+ use_cache = False
1013
+
1014
+ past_key_values_length = 0
1015
+ if use_cache:
1016
+ use_legacy_cache = not isinstance(past_key_values, Cache)
1017
+ if use_legacy_cache:
1018
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1019
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
1020
+
1021
+ if position_ids is None:
1022
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1023
+ position_ids = torch.arange(
1024
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
1025
+ )
1026
+ position_ids = position_ids.unsqueeze(0)
1027
+
1028
+ if inputs_embeds is None:
1029
+ inputs_embeds = self.embed_tokens(input_ids)
1030
+
1031
+ if self._use_flash_attention_2:
1032
+ # 2d mask is passed through the layers
1033
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
1034
+ elif self._use_sdpa and not output_attentions:
1035
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
1036
+ # the manual implementation that requires a 4D causal mask in all cases.
1037
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
1038
+ attention_mask,
1039
+ (batch_size, seq_length),
1040
+ inputs_embeds,
1041
+ past_key_values_length,
1042
+ )
1043
+ else:
1044
+ # 4d mask is passed through the layers
1045
+ attention_mask = _prepare_4d_causal_attention_mask(
1046
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
1047
+ )
1048
+
1049
+ # embed positions
1050
+ hidden_states = self.dropout(inputs_embeds)
1051
+
1052
+ # decoder layers
1053
+ all_hidden_states = () if output_hidden_states else None
1054
+ all_self_attns = () if output_attentions else None
1055
+ next_decoder_cache = None
1056
+
1057
+ for decoder_layer in self.layers:
1058
+ if output_hidden_states:
1059
+ all_hidden_states += (hidden_states,)
1060
+
1061
+ if self.gradient_checkpointing and self.training:
1062
+ layer_outputs = self._gradient_checkpointing_func(
1063
+ decoder_layer.__call__,
1064
+ hidden_states,
1065
+ attention_mask,
1066
+ position_ids,
1067
+ past_key_values,
1068
+ output_attentions,
1069
+ use_cache,
1070
+ )
1071
+ else:
1072
+ layer_outputs = decoder_layer(
1073
+ hidden_states,
1074
+ attention_mask=attention_mask,
1075
+ position_ids=position_ids,
1076
+ past_key_value=past_key_values,
1077
+ output_attentions=output_attentions,
1078
+ use_cache=use_cache,
1079
+ )
1080
+
1081
+ hidden_states = layer_outputs[0]
1082
+
1083
+ if use_cache:
1084
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1085
+
1086
+ if output_attentions:
1087
+ all_self_attns += (layer_outputs[1],)
1088
+
1089
+ hidden_states = self.norm(hidden_states)
1090
+
1091
+ # add hidden states from the last decoder layer
1092
+ if output_hidden_states:
1093
+ all_hidden_states += (hidden_states,)
1094
+
1095
+ next_cache = None
1096
+ if use_cache:
1097
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
1098
+ if not return_dict:
1099
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1100
+ return BaseModelOutputWithPast(
1101
+ last_hidden_state=hidden_states,
1102
+ past_key_values=next_cache,
1103
+ hidden_states=all_hidden_states,
1104
+ attentions=all_self_attns,
1105
+ )
1106
+
1107
+
1108
+ class Emu3ForCausalLM(Emu3PreTrainedModel):
1109
+ _tied_weights_keys = ["lm_head.weight"]
1110
+
1111
+ def __init__(self, config):
1112
+ super().__init__(config)
1113
+ self.model = Emu3Model(config)
1114
+ self.vocab_size = config.vocab_size
1115
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1116
+
1117
+ # Initialize weights and apply final processing
1118
+ self.post_init()
1119
+
1120
+ def get_input_embeddings(self):
1121
+ return self.model.embed_tokens
1122
+
1123
+ def set_input_embeddings(self, value):
1124
+ self.model.embed_tokens = value
1125
+
1126
+ def get_output_embeddings(self):
1127
+ return self.lm_head
1128
+
1129
+ def set_output_embeddings(self, new_embeddings):
1130
+ self.lm_head = new_embeddings
1131
+
1132
+ def set_decoder(self, decoder):
1133
+ self.model = decoder
1134
+
1135
+ def get_decoder(self):
1136
+ return self.model
1137
+
1138
+ @add_start_docstrings_to_model_forward(EMU3_INPUTS_DOCSTRING)
1139
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1140
+ def forward(
1141
+ self,
1142
+ input_ids: torch.LongTensor = None,
1143
+ attention_mask: Optional[torch.Tensor] = None,
1144
+ position_ids: Optional[torch.LongTensor] = None,
1145
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1146
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1147
+ labels: Optional[torch.LongTensor] = None,
1148
+ use_cache: Optional[bool] = None,
1149
+ output_attentions: Optional[bool] = None,
1150
+ output_hidden_states: Optional[bool] = None,
1151
+ return_dict: Optional[bool] = None,
1152
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1153
+ r"""
1154
+ Args:
1155
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1156
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1157
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1158
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1159
+
1160
+ Returns:
1161
+
1162
+ Example:
1163
+
1164
+ ```python
1165
+ >>> from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
1166
+ >>> from transformers.generation.configuration_utils import GenerationConfig
1167
+ >>> from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor
1168
+ >>> from transformers import Emu3Processor
1169
+ >>> from PIL import Image
1170
+
1171
+ >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_EMU3_WEIGHTS)
1172
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
1173
+ >>> image_processor = AutoImageProcessor.from_pretrained(PATH_TO_CONVERTED_IMAGE_PROCESSER)
1174
+ >>> image_tokenizer = AutoModel.from_pretrained(PATH_TO_CONVERTED_TOKENIZER_WEIGHTS).eval()
1175
+ >>> processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
1176
+
1177
+ >>> # Generation
1178
+ >>> prompt = "An Emu in cartoon style, it is wearing sunglasses."
1179
+
1180
+ >>> pos_inputs = processor(text=prompt, mode='G', ratio="4:3", image_area=model.config.image_area, return_tensors="pt")
1181
+ >>> neg_inputs = processor(text="", mode='G', ratio="4:3", image_area=model.config.image_area, return_tensors="pt")
1182
+
1183
+ >>> GENERATION_CONFIG = GenerationConfig(
1184
+ >>> use_cache=True,
1185
+ >>> eos_token_id=model.config.eos_token_id,
1186
+ >>> pad_token_id=model.config.pad_token_id,
1187
+ >>> max_new_tokens=40960,
1188
+ >>> do_sample=True,
1189
+ >>> top_k=2048,
1190
+ >>> )
1191
+
1192
+ >>> h, w = pos_inputs.image_size[0]
1193
+ >>> constrained_fn = processor.build_prefix_constrained_fn(h, w)
1194
+ >>> logits_processor = LogitsProcessorList([
1195
+ >>> UnbatchedClassifierFreeGuidanceLogitsProcessor(
1196
+ >>> classifier_free_guidance,
1197
+ >>> model,
1198
+ >>> unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
1199
+ >>> ),
1200
+ >>> PrefixConstrainedLogitsProcessor(
1201
+ >>> constrained_fn,
1202
+ >>> num_beams=1,
1203
+ >>> ),
1204
+ >>> ])
1205
+
1206
+ >>> outputs = model.generate(pos_inputs.input_ids.to("cuda:0"), GENERATION_CONFIG, logits_processor=logits_processor)
1207
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1208
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1209
+ >>> mm_list = processor.decode(outputs[0])
1210
+
1211
+ >>> # Understanding
1212
+ >>> prompt = "Provide a one-sentence caption for the provided image."
1213
+ >>> image = Image.open(TEST_IMAGE_PATH)
1214
+
1215
+ >>> inputs = processor(text=text, image=image, mode='U', padding_side="left", padding="longest", return_tensors="pt")
1216
+ >>> input_ids = inputs.input_ids.to("cuda:0")
1217
+ >>> GENERATION_CONFIG = GenerationConfig(
1218
+ >>> pad_token_id=tokenizer.pad_token_id,
1219
+ >>> bos_token_id=tokenizer.bos_token_id,
1220
+ >>> eos_token_id=tokenizer.eos_token_id,
1221
+ >>> )
1222
+
1223
+ >>> outputs = model.generate(input_ids, GENERATION_CONFIG, max_new_tokens=100)
1224
+ >>> outputs = outputs[:, input_ids.shape[-1]:]
1225
+ >>> answer = processor.batch_decode(outputs, skip_special_tokens=True)
1226
+ ```"""
1227
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1228
+ output_hidden_states = (
1229
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1230
+ )
1231
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1232
+
1233
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1234
+ outputs = self.model(
1235
+ input_ids=input_ids,
1236
+ attention_mask=attention_mask,
1237
+ position_ids=position_ids,
1238
+ past_key_values=past_key_values,
1239
+ inputs_embeds=inputs_embeds,
1240
+ use_cache=use_cache,
1241
+ output_attentions=output_attentions,
1242
+ output_hidden_states=output_hidden_states,
1243
+ return_dict=return_dict,
1244
+ )
1245
+
1246
+ hidden_states = outputs[0]
1247
+ if self.config.pretraining_tp > 1:
1248
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
1249
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
1250
+ logits = torch.cat(logits, dim=-1)
1251
+ else:
1252
+ logits = self.lm_head(hidden_states)
1253
+ logits = logits.float()
1254
+
1255
+ loss = None
1256
+ if labels is not None:
1257
+ # Shift so that tokens < n predict n
1258
+ shift_logits = logits[..., :-1, :].contiguous()
1259
+ shift_labels = labels[..., 1:].contiguous()
1260
+ # Flatten the tokens
1261
+ loss_fct = CrossEntropyLoss()
1262
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
1263
+ shift_labels = shift_labels.view(-1)
1264
+ # Enable model parallelism
1265
+ shift_labels = shift_labels.to(shift_logits.device)
1266
+ loss = loss_fct(shift_logits, shift_labels)
1267
+
1268
+ if not return_dict:
1269
+ output = (logits,) + outputs[1:]
1270
+ return (loss,) + output if loss is not None else output
1271
+
1272
+ return CausalLMOutputWithPast(
1273
+ loss=loss,
1274
+ logits=logits,
1275
+ past_key_values=outputs.past_key_values,
1276
+ hidden_states=outputs.hidden_states,
1277
+ attentions=outputs.attentions,
1278
+ )
1279
+
1280
+ def prepare_inputs_for_generation(
1281
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
1282
+ ):
1283
+ if past_key_values is not None:
1284
+ if isinstance(past_key_values, Cache):
1285
+ cache_length = past_key_values.get_seq_length()
1286
+ past_length = past_key_values.seen_tokens
1287
+ max_cache_length = past_key_values.get_max_length()
1288
+ else:
1289
+ cache_length = past_length = past_key_values[0][0].shape[2]
1290
+ max_cache_length = None
1291
+
1292
+ # Keep only the unprocessed tokens:
1293
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1294
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
1295
+ # input)
1296
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1297
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1298
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1299
+ # input_ids based on the past_length.
1300
+ elif past_length < input_ids.shape[1]:
1301
+ input_ids = input_ids[:, past_length:]
1302
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1303
+
1304
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
1305
+ if (
1306
+ max_cache_length is not None
1307
+ and attention_mask is not None
1308
+ and cache_length + input_ids.shape[1] > max_cache_length
1309
+ ):
1310
+ attention_mask = attention_mask[:, -max_cache_length:]
1311
+
1312
+ position_ids = kwargs.get("position_ids", None)
1313
+ if attention_mask is not None and position_ids is None:
1314
+ # create position_ids on the fly for batch generation
1315
+ position_ids = attention_mask.long().cumsum(-1) - 1
1316
+ position_ids.masked_fill_(attention_mask == 0, 1)
1317
+ if past_key_values:
1318
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1319
+
1320
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1321
+ if inputs_embeds is not None and past_key_values is None:
1322
+ model_inputs = {"inputs_embeds": inputs_embeds}
1323
+ else:
1324
+ model_inputs = {"input_ids": input_ids}
1325
+
1326
+ model_inputs.update(
1327
+ {
1328
+ "position_ids": position_ids,
1329
+ "past_key_values": past_key_values,
1330
+ "use_cache": kwargs.get("use_cache"),
1331
+ "attention_mask": attention_mask,
1332
+ }
1333
+ )
1334
+ return model_inputs
1335
+
1336
+ @staticmethod
1337
+ def _reorder_cache(past_key_values, beam_idx):
1338
+ reordered_past = ()
1339
+ for layer_past in past_key_values:
1340
+ reordered_past += (
1341
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1342
+ )
1343
+ return reordered_past
emu3/mllm/processing_emu3.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Processor class for Emu3. """
16
+
17
+ import re
18
+ from typing import List, Optional, Sequence, Union
19
+ from functools import partial
20
+
21
+ from PIL import Image
22
+ import torch
23
+ from transformers.feature_extraction_utils import BatchFeature
24
+ from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
25
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
26
+ from transformers.tokenization_utils_base import TextInput, PreTokenizedInput
27
+ from transformers.utils import logging
28
+
29
+ from .utils_emu3 import Emu3PrefixConstrainedLogitsHelper
30
+
31
+
32
+ logger = logging.get_logger(__name__)
33
+
34
+
35
+ class Emu3Processor(ProcessorMixin):
36
+ r"""
37
+ Constructs an Emu3 processor which wraps an Emu3 image processor and an Emu3 vision vq model and an Emu3 tokenizer into a single processor.
38
+
39
+ [`Emu3Processor`] offers all the functionalities of [`Emu3VisionVQModel`] and [`Emu3Tokenizer`]. See the
40
+ [`~Emu3Processor.__call__`], [`~Emu3Processor.decode`], [`~Emu3Processor.vision_encode`], [`~Emu3Processor.vision_decode`]
41
+ for more information.
42
+
43
+ Args:
44
+ image_processor ([`Emu3VisionVQImageProcessor`]):
45
+ The image processor is a required input.
46
+ vision_tokenizer ([`Emu3VisionVQModel`]):
47
+ The vision tokenizer is a required input.
48
+ tokenizer ([`Emu3Tokenizer`]):
49
+ The tokenizer is a required input.
50
+ prefix_template(`str`, *optional*):
51
+ The prefix template for image tokens
52
+ visual_template(`Tuple[str, ...]`, *optional*):
53
+ The visual token template for image tokens
54
+ """
55
+
56
+ attributes = ["image_processor", "tokenizer"]
57
+ valid_kwargs = ["vision_tokenizer", "prefix_template", "visual_template"]
58
+ image_processor_class = "AutoImageProcessor"
59
+ tokenizer_class = "AutoTokenizer"
60
+
61
+ def __init__(
62
+ self,
63
+ image_processor=None,
64
+ vision_tokenizer=None,
65
+ tokenizer=None,
66
+ chat_template="You are a helpful assistant. USER: {image_prompt}{text_prompt}. ASSISTANT:",
67
+ prefix_template="{H}*{W}",
68
+ visual_template=("<|visual token {token_id:0>6d}|>", r"<\|visual token (\d+)\|>"),
69
+ **kwargs,
70
+ ):
71
+ assert vision_tokenizer is not None, "image tokenizer can not be None"
72
+
73
+ self.vision_tokenizer = vision_tokenizer
74
+ self.prefix_template = prefix_template
75
+ self.visual_template = visual_template
76
+
77
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
78
+ self.const_helper = self.build_const_helper()
79
+
80
+ @torch.no_grad()
81
+ def __call__(
82
+ self,
83
+ text: Optional[TextInput | PreTokenizedInput] = None,
84
+ image: Optional[Image.Image | List[Image.Image]] = None,
85
+ *,
86
+ mode: str = "G",
87
+ ratio: str = "1:1",
88
+ image_area: int = 518400,
89
+ **kwargs,
90
+ ) -> BatchFeature:
91
+ """
92
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
93
+ and `kwargs` arguments to Emu3Tokenizer's [`~Emu3Tokenizer.__call__`] to encode the text.
94
+ To prepare the image(s), this method forwards the `image` argument to
95
+ Emu3VisionVQImageProcessor's [`~Emu3VisionVQImageProcessor.__call__`] and Emu3VisionVQModel's [`~EmuVideoVQModel.encode`]
96
+ if `image` is not `None`. Please refer to the doctsring of the above two methods for more information.
97
+
98
+ Args:
99
+ text (`str` or `List[str]`):
100
+ The sequence or a batch of sequence to be encoded. A sequence is a string.
101
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*):
102
+ The image or a batch of images to be prepared. An image is a PIL image.
103
+ mode (`str`, *optional*, in `G` or `U`):
104
+ task mode, `G` for generation and `U` for understanding
105
+ ratio (`str`, *optional*):
106
+ the image width-height ratio for generation
107
+ image_area (`int`, *optional*):
108
+ image area used to calcualte the generated image height and width
109
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
110
+ If set, will return tensors of a particular framework. Acceptable values are:
111
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
112
+ - `'np'`: Return NumPy `np.ndarray` objects.
113
+
114
+ Returns:
115
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
116
+
117
+ - **input_ids** -- List of token ids to be fed to a model.
118
+ - **image_size** -- List of image size of input images or generated images.
119
+ """
120
+ assert mode in ('G', 'U'), "mode must be 'G' or 'U'."
121
+ if isinstance(text, str):
122
+ text = [text]
123
+
124
+ if not isinstance(text[0], str):
125
+ raise ValueError("`text` must be string or list of string")
126
+
127
+ image_inputs = None
128
+ if mode == 'G':
129
+ if image is not None:
130
+ raise ValueError("You have to specify only `text` in generation mode")
131
+
132
+ if len(text) > 1:
133
+ raise ValueError("`text` can only be `str` in generation mode")
134
+ else:
135
+ if image is None:
136
+ raise ValueError("Invalid input image. Please provide exactly one PIL.Image.Image per text.")
137
+
138
+ if not isinstance(image, Sequence) and not isinstance(image, Image.Image):
139
+ raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].")
140
+
141
+ if isinstance(image, Sequence) and not isinstance(image[0], Image.Image):
142
+ raise ValueError("Invalid input image. Please provide PIL.Image.Image or List[PIL.Image.Image].")
143
+
144
+ image_inputs = self.image_processor(image, return_tensors="pt")["pixel_values"]
145
+ print(image_inputs.shape)
146
+ image_inputs = image_inputs.to(self.vision_tokenizer.device, self.vision_tokenizer.dtype)
147
+ image_tokens = self.vision_tokenizer.encode(image_inputs)
148
+
149
+ if len(text) != len(image_tokens):
150
+ raise ValueError("number of image must match number of text prompt")
151
+
152
+ prompt_list, size_list = [], []
153
+ for idx, text_prompt in enumerate(text):
154
+ prompt = self.tokenizer.bos_token
155
+ if mode == 'U':
156
+ h, w = image_tokens[idx].shape
157
+ imgstr = self.to_imgstr(image_tokens[idx])
158
+ image_prompt = (
159
+ self.tokenizer.boi_token +
160
+ self.prefix_template.format(H=h, W=w) +
161
+ self.tokenizer.img_token +
162
+ imgstr +
163
+ self.tokenizer.eol_token +
164
+ self.tokenizer.eof_token +
165
+ self.tokenizer.eoi_token
166
+ )
167
+ prompt += self.chat_template.format(image_prompt=image_prompt, text_prompt=text_prompt)
168
+ else:
169
+ h, w = self.calculate_generate_size(ratio, image_area, self.vision_tokenizer.spatial_scale_factor)
170
+ image_prompt = (
171
+ self.tokenizer.boi_token +
172
+ self.prefix_template.format(H=h, W=w) +
173
+ self.tokenizer.img_token
174
+ )
175
+ prompt += (text_prompt + image_prompt)
176
+
177
+ prompt_list.append(prompt)
178
+ size_list.append([h, w])
179
+
180
+ text_inputs = self.tokenizer(prompt_list, **kwargs)
181
+ return BatchFeature(data={**text_inputs, "image_size": size_list}, tensor_type=kwargs.get("return_tensors"))
182
+
183
+ @torch.no_grad()
184
+ def batch_decode(self, *args, **kwargs):
185
+ docs = self.tokenizer.batch_decode(*args, **kwargs)
186
+ return [self.multimodal_decode(d) for d in docs]
187
+
188
+ @torch.no_grad()
189
+ def decode(self, *args, **kwargs):
190
+ doc = self.tokenizer.decode(*args, **kwargs)
191
+ return self.multimodal_decode(doc)
192
+
193
+ @torch.no_grad()
194
+ def vision_encode(self, *args, **kwargs):
195
+ return self.vision_tokenizer.encode(*args, **kwargs)
196
+
197
+ @torch.no_grad()
198
+ def vision_decode(self, *args, **kwargs):
199
+ return self.vision_tokenizer.decode(*args, **kwargs)
200
+
201
+ @torch.no_grad()
202
+ def multimodal_decode(self, doc):
203
+ multimodal_output = []
204
+ pattern = rf'({re.escape(self.tokenizer.boi_token)}.*?{re.escape(self.tokenizer.eoi_token)})'
205
+ chunks = re.split(pattern, doc)
206
+ for c in chunks:
207
+ if len(c) == 0:
208
+ continue
209
+
210
+ if self.tokenizer.boi_token in c:
211
+ image = []
212
+ image_rows = re.split(re.escape(self.tokenizer.eol_token), c)
213
+ for r in image_rows:
214
+ token_ids = re.findall(self.visual_template[1], r)
215
+ if len(token_ids) > 0:
216
+ row_token = [int(m) for m in token_ids]
217
+ image.append(row_token)
218
+ image = torch.tensor(image, dtype=torch.long, device=self.vision_tokenizer.device)
219
+ image = self.vision_tokenizer.decode(image[None]).float()
220
+ image = self.image_processor.postprocess(image)["pixel_values"][0]
221
+ multimodal_output.append(image)
222
+ else:
223
+ multimodal_output.append(c)
224
+
225
+ return multimodal_output if len(multimodal_output) > 1 else multimodal_output[0]
226
+
227
+ @property
228
+ def model_input_names(self):
229
+ tokenizer_input_names = self.tokenizer.model_input_names
230
+ image_processor_input_names = self.image_processor.model_input_names
231
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
232
+
233
+ def to_imgstr(self, image_tokens):
234
+ image_tokens = image_tokens.cpu().numpy().tolist()
235
+ image_token_str = [
236
+ [
237
+ self.visual_template[0].format(token_id=token_id)
238
+ for token_id in token_row
239
+ ]
240
+ for token_row in image_tokens
241
+ ]
242
+ image_row_str = ["".join(token_row) for token_row in image_token_str]
243
+ imgstr = self.tokenizer.eol_token.join(image_row_str)
244
+ return imgstr
245
+
246
+ def calculate_generate_size(self, ratio, image_area, spatial_scale_factor):
247
+ w, h = map(int, ratio.split(":"))
248
+ current_area = h * w
249
+ target_ratio = (image_area / current_area) ** 0.5
250
+
251
+ th = int(round(h * target_ratio / spatial_scale_factor))
252
+ tw = int(round(w * target_ratio / spatial_scale_factor))
253
+ return th, tw
254
+
255
+ def build_const_helper(self):
256
+ (
257
+ img_token,
258
+ eoi_token,
259
+ eos_token,
260
+ eol_token,
261
+ eof_token,
262
+ pad_token,
263
+ vis_start,
264
+ vis_end,
265
+ ) = self.tokenizer.encode([
266
+ self.tokenizer.img_token,
267
+ self.tokenizer.eoi_token,
268
+ self.tokenizer.eos_token,
269
+ self.tokenizer.eol_token,
270
+ self.tokenizer.eof_token,
271
+ self.tokenizer.pad_token,
272
+ self.visual_template[0].format(token_id=0),
273
+ self.visual_template[0].format(token_id=self.vision_tokenizer.config.codebook_size - 1),
274
+ ])
275
+
276
+ const_helper = partial(
277
+ Emu3PrefixConstrainedLogitsHelper,
278
+ img_token=img_token,
279
+ eoi_token=eoi_token,
280
+ eos_token=eos_token,
281
+ eol_token=eol_token,
282
+ eof_token=eof_token,
283
+ pad_token=pad_token,
284
+ visual_tokens=list(range(vis_start, vis_end + 1)),
285
+ )
286
+ return const_helper
287
+
288
+ def build_prefix_constrained_fn(self, height, width):
289
+ helper = self.const_helper(height=height, width=width)
290
+ return helper
emu3/mllm/tokenization_emu3.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Emu3."""
16
+
17
+ import base64
18
+ import logging
19
+ import os
20
+ import unicodedata
21
+ from typing import Collection, Dict, List, Optional, Set, Tuple, Union
22
+
23
+ import tiktoken
24
+ from transformers import PreTrainedTokenizer, AddedToken
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ VOCAB_FILES_NAMES = {
30
+ "vocab_file": "emu3.tiktoken",
31
+ "special_tokens_file": "emu3_vision_tokens.txt",
32
+ }
33
+
34
+ PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
35
+ ENDOFTEXT = "<|endoftext|>"
36
+ IMSTART = "<|im_start|>"
37
+ IMEND = "<|im_end|>"
38
+ # as the default behavior is changed to allow special tokens in
39
+ # regular texts, the surface forms of special tokens need to be
40
+ # as different as possible to minimize the impact
41
+ EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
42
+ # changed to use actual index to avoid misconfiguration with vocabulary expansion
43
+ SPECIAL_START_ID = 151643
44
+
45
+
46
+ def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
47
+ with open(tiktoken_bpe_file, "rb") as f:
48
+ contents = f.read()
49
+ return {
50
+ base64.b64decode(token): int(rank)
51
+ for token, rank in (line.split() for line in contents.splitlines() if line)
52
+ }
53
+
54
+
55
+ class Emu3Tokenizer(PreTrainedTokenizer):
56
+ """Emu3 tokenizer."""
57
+
58
+ vocab_files_names = VOCAB_FILES_NAMES
59
+
60
+ def __init__(
61
+ self,
62
+ vocab_file,
63
+ special_tokens_file,
64
+ errors="replace",
65
+ bos_token = "<|extra_203|>",
66
+ eos_token = "<|extra_204|>",
67
+ pad_token = "<|endoftext|>",
68
+ img_token = "<|image token|>",
69
+ boi_token = "<|image start|>",
70
+ eoi_token = "<|image end|>",
71
+ eol_token = "<|extra_200|>",
72
+ eof_token = "<|extra_201|>",
73
+ **kwargs,
74
+ ):
75
+ super().__init__(**kwargs)
76
+
77
+ # how to handle errors in decoding UTF-8 byte sequences
78
+ # use ignore if you are in streaming inference
79
+ self.errors = errors
80
+
81
+ self.mergeable_ranks = _load_tiktoken_bpe(vocab_file)
82
+
83
+ vision_tokens = [t.strip() for t in open(special_tokens_file).readlines() if len(t.strip()) > 0]
84
+ SPECIAL_TOKENS = tuple(
85
+ enumerate(
86
+ (
87
+ (
88
+ ENDOFTEXT,
89
+ IMSTART,
90
+ IMEND,
91
+ )
92
+ + EXTRAS
93
+ + tuple(vision_tokens)
94
+ ),
95
+ start=SPECIAL_START_ID,
96
+ )
97
+ )
98
+ self.special_tokens = {token: index for index, token in SPECIAL_TOKENS}
99
+ self.special_tokens_set = set(t for _, t in SPECIAL_TOKENS)
100
+
101
+ enc = tiktoken.Encoding(
102
+ "Emu3",
103
+ pat_str=PAT_STR,
104
+ mergeable_ranks=self.mergeable_ranks,
105
+ special_tokens=self.special_tokens,
106
+ )
107
+
108
+ assert (
109
+ len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
110
+ ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
111
+
112
+ self.decoder = {
113
+ v: k for k, v in self.mergeable_ranks.items()
114
+ }
115
+ self.decoder.update({v: k for k, v in self.special_tokens.items()})
116
+
117
+ self.tokenizer = enc
118
+
119
+ self.eod_id = self.tokenizer.eot_token
120
+ self.bos_token = bos_token
121
+ self.eos_token = eos_token
122
+ self.pad_token = pad_token
123
+ self.img_token = img_token
124
+ self.boi_token = boi_token
125
+ self.eoi_token = eoi_token
126
+ self.eol_token = eol_token
127
+ self.eof_token = eof_token
128
+
129
+ def __getstate__(self):
130
+ # for pickle lovers
131
+ state = self.__dict__.copy()
132
+ del state["tokenizer"]
133
+ return state
134
+
135
+ def __setstate__(self, state):
136
+ # tokenizer is not python native; don't pass it; rebuild it
137
+ self.__dict__.update(state)
138
+ enc = tiktoken.Encoding(
139
+ "Emu3",
140
+ pat_str=PAT_STR,
141
+ mergeable_ranks=self.mergeable_ranks,
142
+ special_tokens=self.special_tokens,
143
+ )
144
+ self.tokenizer = enc
145
+
146
+ def __len__(self) -> int:
147
+ return self.tokenizer.n_vocab
148
+
149
+ def get_vocab(self) -> Dict[bytes, int]:
150
+ return self.mergeable_ranks
151
+
152
+ def convert_tokens_to_ids(
153
+ self, tokens: Union[bytes, str, List[Union[bytes, str]]]
154
+ ) -> List[int]:
155
+ if isinstance(tokens, (str, bytes)):
156
+ if tokens in self.special_tokens:
157
+ return self.special_tokens[tokens]
158
+ else:
159
+ return self.mergeable_ranks.get(tokens)
160
+
161
+ ids = []
162
+ for token in tokens:
163
+ if token in self.special_tokens:
164
+ ids.append(self.special_tokens[token])
165
+ else:
166
+ ids.append(self.mergeable_ranks.get(token))
167
+ return ids
168
+
169
+ def _add_tokens(
170
+ self,
171
+ new_tokens: Union[List[str], List[AddedToken]],
172
+ special_tokens: bool = False,
173
+ ) -> int:
174
+ if not special_tokens and new_tokens:
175
+ raise ValueError("Adding regular tokens is not supported")
176
+
177
+ for token in new_tokens:
178
+ surface_form = token.content if isinstance(token, AddedToken) else token
179
+ if surface_form not in self.special_tokens_set:
180
+ raise ValueError("Adding unknown special tokens is not supported")
181
+
182
+ return 0
183
+
184
+ def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
185
+ """
186
+ Save only the vocabulary of the tokenizer (vocabulary).
187
+
188
+ Returns:
189
+ `Tuple(str)`: Paths to the files saved.
190
+ """
191
+ regular_file_path = os.path.join(save_directory, self.vocab_files_names["vocab_file"])
192
+ with open(regular_file_path,'w', encoding="utf8") as w:
193
+ for k, v in self.mergeable_ranks.items():
194
+ line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
195
+ w.write(line)
196
+
197
+ excluded_special_tokens = set((ENDOFTEXT, IMSTART, IMEND,) + EXTRAS)
198
+ special_file_path = os.path.join(save_directory, self.vocab_files_names["special_tokens_file"])
199
+ with open(special_file_path, 'w', encoding="utf8") as w:
200
+ for k in self.special_tokens:
201
+ if k not in excluded_special_tokens:
202
+ print(k, file=w)
203
+
204
+ return (regular_file_path, special_file_path)
205
+
206
+ def tokenize(
207
+ self,
208
+ text: str,
209
+ allowed_special: Union[Set, str] = "all",
210
+ disallowed_special: Union[Collection, str] = (),
211
+ **kwargs,
212
+ ) -> List[Union[bytes, str]]:
213
+ """
214
+ Converts a string in a sequence of tokens.
215
+
216
+ Args:
217
+ text (`str`):
218
+ The sequence to be encoded.
219
+ allowed_special (`Literal["all"]` or `set`):
220
+ The surface forms of the tokens to be encoded as special tokens in regular texts.
221
+ Default to "all".
222
+ disallowed_special (`Literal["all"]` or `Collection`):
223
+ The surface forms of the tokens that should not be in regular texts and trigger errors.
224
+ Default to an empty tuple.
225
+
226
+ kwargs (additional keyword arguments, *optional*):
227
+ Will be passed to the underlying model specific encode method.
228
+
229
+ Returns:
230
+ `List[bytes|str]`: The list of tokens.
231
+ """
232
+ tokens = []
233
+ text = unicodedata.normalize("NFC", text)
234
+
235
+ # this implementation takes a detour: text -> token id -> token surface forms
236
+ for t in self.tokenizer.encode(
237
+ text, allowed_special=allowed_special, disallowed_special=disallowed_special
238
+ ):
239
+ tokens.append(self.decoder[t])
240
+
241
+ return tokens
242
+
243
+ def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
244
+ """
245
+ Converts a sequence of tokens in a single string.
246
+ """
247
+ text = ""
248
+ temp = b""
249
+ for t in tokens:
250
+ if isinstance(t, str):
251
+ if temp:
252
+ text += temp.decode("utf-8", errors=self.errors)
253
+ temp = b""
254
+ text += t
255
+ elif isinstance(t, bytes):
256
+ temp += t
257
+ else:
258
+ raise TypeError("token should only be of type types or str")
259
+ if temp:
260
+ text += temp.decode("utf-8", errors=self.errors)
261
+ return text
262
+
263
+ @property
264
+ def vocab_size(self):
265
+ return self.tokenizer.n_vocab
266
+
267
+ def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
268
+ """Converts an id to a token, special tokens included"""
269
+ if index in self.decoder:
270
+ return self.decoder[index]
271
+ raise ValueError("unknown ids")
272
+
273
+ def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
274
+ """Converts a token to an id using the vocab, special tokens included"""
275
+ if token in self.special_tokens:
276
+ return self.special_tokens[token]
277
+ if token in self.mergeable_ranks:
278
+ return self.mergeable_ranks[token]
279
+ raise ValueError("unknown token")
280
+
281
+ def _decode(
282
+ self,
283
+ token_ids: Union[int, List[int]],
284
+ skip_special_tokens: bool = False,
285
+ errors: Optional[str] = None,
286
+ **kwargs,
287
+ ) -> str:
288
+ if isinstance(token_ids, int):
289
+ token_ids = [token_ids]
290
+
291
+ if skip_special_tokens:
292
+ token_ids = [i for i in token_ids if i < self.eod_id]
293
+
294
+ return self.tokenizer.decode(token_ids, errors=errors or self.errors)
emu3/mllm/utils_emu3.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Logits Processor Helper class for Emu3. """
16
+
17
+ import torch
18
+
19
+ class Emu3PrefixConstrainedLogitsHelper:
20
+
21
+ def __init__(
22
+ self,
23
+ height,
24
+ width,
25
+ img_token,
26
+ eoi_token,
27
+ eos_token,
28
+ eol_token,
29
+ eof_token,
30
+ pad_token,
31
+ visual_tokens,
32
+ ):
33
+ self.height = height
34
+ self.width = width
35
+ self.img_token = img_token
36
+ self.eoi_token = eoi_token
37
+ self.eos_token = eos_token
38
+ self.eol_token = eol_token
39
+ self.eof_token = eof_token
40
+ self.pad_token = pad_token
41
+ self.visual_tokens = visual_tokens
42
+
43
+ self.offset_cache = {}
44
+
45
+ def __call__(self, batch_id, input_ids):
46
+ if batch_id not in self.offset_cache:
47
+ position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0]
48
+ self.offset_cache[batch_id] = position
49
+
50
+ offset = input_ids.shape[0] - self.offset_cache[batch_id]
51
+ if offset % (self.width + 1) == 0:
52
+ return (self.eol_token, )
53
+ elif offset == (self.width + 1) * self.height + 1:
54
+ return (self.eof_token, )
55
+ elif offset == (self.width + 1) * self.height + 2:
56
+ return (self.eoi_token, )
57
+ elif offset == (self.width + 1) * self.height + 3:
58
+ return (self.eos_token, )
59
+ elif offset > (self.width + 1) * self.height + 3:
60
+ return (self.pad_token, )
61
+ else:
62
+ return self.visual_tokens
emu3/tokenizer/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 BAAI and the HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import (
17
+ OptionalDependencyNotAvailable,
18
+ _LazyModule,
19
+ is_torch_available,
20
+ is_vision_available,
21
+ )
22
+
23
+
24
+ _import_structure = {"configuration_emu3visionvq": ["Emu3VisionVQConfig"]}
25
+
26
+ try:
27
+ if not is_torch_available():
28
+ raise OptionalDependencyNotAvailable()
29
+ except OptionalDependencyNotAvailable:
30
+ pass
31
+ else:
32
+ _import_structure["modeling_emu3visionvq"] = [
33
+ "Emu3VisionVQModel",
34
+ "Emu3VisionVQPretrainedModel",
35
+ ]
36
+
37
+ try:
38
+ if not is_vision_available():
39
+ raise OptionalDependencyNotAvailable()
40
+ except OptionalDependencyNotAvailable:
41
+ pass
42
+ else:
43
+ _import_structure["image_processing_emu3visionvq"] = ["Emu3VisionVQImageProcessor"]
44
+
45
+ if TYPE_CHECKING:
46
+ from .configuration_emu3visionvq import Emu3VisionVQConfig
47
+
48
+ try:
49
+ if not is_torch_available():
50
+ raise OptionalDependencyNotAvailable()
51
+ except OptionalDependencyNotAvailable:
52
+ pass
53
+ else:
54
+ from .modeling_emu3visionvq import (
55
+ Emu3VisionVQModel,
56
+ Emu3VisionVQPretrainedModel,
57
+ )
58
+
59
+ try:
60
+ if not is_vision_available():
61
+ raise OptionalDependencyNotAvailable()
62
+ except OptionalDependencyNotAvailable:
63
+ pass
64
+ else:
65
+ from .image_processing_emu3visionvq import Emu3VisionVQImageProcessor
66
+
67
+ else:
68
+ import sys
69
+
70
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure)
emu3/tokenizer/configuration_emu3visionvq.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Emu3VisionVQ model configuration """
16
+
17
+ from typing import List
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class Emu3VisionVQConfig(PretrainedConfig):
27
+ r"""
28
+ This is the configuration class to store the configuration of a [`Emu3VisionVQ`]. It is used to instantiate an video movq
29
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30
+ defaults will yield a configuration to the VQ model presented in Emu3 paper.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ codebook_size (`int`, *optional*, defaults to 32768):
38
+ Codebook size of the VQ model.
39
+ embed_dim (`int`, *optional*, defaults to 4):
40
+ Dimension of the quantized vector in codebook.
41
+ z_channels (`int`, *optional*, defaults to 4):
42
+ Dimension of the output channel of encoder and the input channel of decoder
43
+ double_z (`bool`, *optional*, defaults to False):
44
+ Whether double the output dim of the encoder.
45
+ in_channels (`int`, *optional*, defaults to 3):
46
+ Input channel of encoder.
47
+ out_channels (`int`, *optional*, defaults to 3):
48
+ Output channel of decoder.
49
+ temporal_downsample_factor (`int`, *optional*, defaults to 4):
50
+ Temporal downsample factor.
51
+ ch (`int`, *optional*, defaults to 256):
52
+ Basic channel number of the intermediate blocks.
53
+ ch_mult (`List[int]`, *optional*, defaults to `[1, 2, 2, 4]`):
54
+ Channel scaling factor of the intermediate blocks.
55
+ num_res_blocks (`int`, *optional*, defaults to 2):
56
+ Residual block number in each stage.
57
+ attn_resolutions (`List[int]`, *optional*, defaults to 3):
58
+ Stage indices to apply attention.
59
+ dropout (`float`, *optional*, defaults to 0.0):
60
+ Dropout probability.
61
+
62
+ ```python
63
+ >>> from transformers import Emu3VisionVQ, Emu3VisionVQConfig
64
+
65
+ >>> # Initializing a video VQ model of Emu3 configuration
66
+ >>> configuration = Emu3VisionVQConfig()
67
+
68
+ >>> # Initializing a model from the Emu3 VQ model style configuration
69
+ >>> model = Emu3VisionVQModel(configuration)
70
+
71
+ >>> # Accessing the model configuration
72
+ >>> configuration = model.config
73
+ ```"""
74
+
75
+ model_type = "Emu3VisionVQ"
76
+
77
+ def __init__(
78
+ self,
79
+ codebook_size: int = 32768,
80
+ embed_dim: int = 4,
81
+ z_channels: int = 4,
82
+ double_z: bool = False,
83
+ in_channels: int = 3,
84
+ out_channels: int = 3,
85
+ temporal_downsample_factor: int = 4,
86
+ ch: int = 256,
87
+ ch_mult: List[int] = [1, 2, 2, 4],
88
+ num_res_blocks: int = 2,
89
+ attn_resolutions: List[int] = [3],
90
+ dropout: float = 0.0,
91
+ **kwargs,
92
+ ):
93
+ super().__init__(**kwargs)
94
+
95
+ self.codebook_size = codebook_size
96
+ self.embed_dim = embed_dim
97
+ self.z_channels = z_channels
98
+ self.double_z = double_z
99
+ self.in_channels = in_channels
100
+ self.out_channels = out_channels
101
+ self.temporal_downsample_factor = temporal_downsample_factor
102
+ self.ch = ch
103
+ self.ch_mult = ch_mult
104
+ self.num_res_blocks = num_res_blocks
105
+ self.attn_resolutions = attn_resolutions
106
+ self.dropout = dropout
emu3/tokenizer/image_processing_emu3visionvq.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Image processor class for Emu3VisionVQ."""
16
+
17
+
18
+ import math
19
+ from typing import Dict, List, Optional, Union
20
+
21
+ import numpy as np
22
+
23
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
24
+ from transformers.image_transforms import (
25
+ convert_to_rgb,
26
+ resize,
27
+ to_channel_dimension_format,
28
+ )
29
+ from transformers.image_utils import (
30
+ IMAGENET_STANDARD_MEAN,
31
+ IMAGENET_STANDARD_STD,
32
+ ChannelDimension,
33
+ ImageInput,
34
+ PILImageResampling,
35
+ get_image_size,
36
+ infer_channel_dimension_format,
37
+ is_scaled_image,
38
+ make_list_of_images,
39
+ to_numpy_array,
40
+ valid_images,
41
+ validate_preprocess_arguments,
42
+ )
43
+ from transformers.utils import TensorType, is_vision_available, logging
44
+
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ if is_vision_available():
50
+ from PIL import Image
51
+
52
+
53
+ def smart_resize(
54
+ height: int, width: int, factor: int = 8, min_pixels: int = 512 * 512, max_pixels: int = 1024 * 1024
55
+ ):
56
+ """Rescales the image so that the following conditions are met:
57
+
58
+ 1. Both dimensions (height and width) are divisible by 'factor'.
59
+
60
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
61
+
62
+ 3. The aspect ratio of the image is maintained as closely as possible.
63
+
64
+ """
65
+ if height < factor or width < factor:
66
+ raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}")
67
+ elif max(height, width) / min(height, width) > 5:
68
+ raise ValueError(
69
+ f"absolute aspect ratio must be smaller than 5, got {max(height, width) / min(height, width)}"
70
+ )
71
+
72
+ h_bar = round(height / factor) * factor
73
+ w_bar = round(width / factor) * factor
74
+ if h_bar * w_bar > max_pixels:
75
+ beta = math.sqrt((height * width) / max_pixels)
76
+ h_bar = math.floor(height / beta / factor) * factor
77
+ w_bar = math.floor(width / beta / factor) * factor
78
+ elif h_bar * w_bar < min_pixels:
79
+ beta = math.sqrt(min_pixels / (height * width))
80
+ h_bar = math.ceil(height * beta / factor) * factor
81
+ w_bar = math.ceil(width * beta / factor) * factor
82
+
83
+ return h_bar, w_bar
84
+
85
+
86
+ class Emu3VisionVQImageProcessor(BaseImageProcessor):
87
+ r"""
88
+ Constructs a Emu3VisionVQ image processor that dynamically resizes images based on the original images.
89
+
90
+ Args:
91
+ do_resize (`bool`, *optional*, defaults to `True`):
92
+ Whether to resize the image's (height, width) dimensions.
93
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
94
+ Resampling filter to use when resizing the image.
95
+ do_rescale (`bool`, *optional*, defaults to `True`):
96
+ Whether to rescale the image by the specified scale `rescale_factor`.
97
+ rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
98
+ Scale factor to use if rescaling the image.
99
+ do_normalize (`bool`, *optional*, defaults to `True`):
100
+ Whether to normalize the image.
101
+ image_mean (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
102
+ Mean to use if normalizing the image. This is a float or list of floats for each channel in the image.
103
+ image_std (`float` or `List[float]`, *optional*, defaults to `[0.5, 0.5, 0.5]`):
104
+ Standard deviation to use if normalizing the image. This is a float or list of floats for each channel in the image.
105
+ do_convert_rgb (`bool`, *optional*, defaults to `True`):
106
+ Whether to convert the image to RGB.
107
+ min_pixels (`int`, *optional*, defaults to `512 * 512`):
108
+ The min pixels of the image to resize the image.
109
+ max_pixels (`int`, *optional*, defaults to `1024 * 1024`):
110
+ The max pixels of the image to resize the image.
111
+ spatial_factor (`int`, *optional*, defautls to 8):
112
+ The spatial downsample factor the image will be downsampled in feature extracting phase
113
+ """
114
+
115
+ model_input_names = ["pixel_values"]
116
+
117
+ def __init__(
118
+ self,
119
+ do_resize: bool = True,
120
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
121
+ do_rescale: bool = True,
122
+ rescale_factor: Union[int, float] = 1 / 255,
123
+ do_normalize: bool = True,
124
+ image_mean: Optional[Union[float, List[float]]] = None,
125
+ image_std: Optional[Union[float, List[float]]] = None,
126
+ do_convert_rgb: bool = True,
127
+ min_pixels: int = 512 * 512,
128
+ max_pixels: int = 1024 * 1024,
129
+ spatial_factor: int = 8,
130
+ **kwargs,
131
+ ) -> None:
132
+ super().__init__(**kwargs)
133
+ self.do_resize = do_resize
134
+ self.resample = resample
135
+ self.do_rescale = do_rescale
136
+ self.rescale_factor = rescale_factor
137
+ self.do_normalize = do_normalize
138
+ self.image_mean = image_mean if image_mean is not None else IMAGENET_STANDARD_MEAN
139
+ self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD
140
+ self.min_pixels = min_pixels
141
+ self.max_pixels = max_pixels
142
+ self.size = {"min_pixels": min_pixels, "max_pixels": max_pixels}
143
+ self.do_convert_rgb = do_convert_rgb
144
+ self.spatial_factor = spatial_factor
145
+
146
+ def _preprocess(
147
+ self,
148
+ images: ImageInput,
149
+ do_resize: Optional[bool] = None,
150
+ resample: PILImageResampling = None,
151
+ do_rescale: Optional[bool] = None,
152
+ rescale_factor: Optional[float] = None,
153
+ do_normalize: Optional[bool] = None,
154
+ image_mean: Optional[Union[float, List[float]]] = None,
155
+ image_std: Optional[Union[float, List[float]]] = None,
156
+ do_convert_rgb: Optional[bool] = None,
157
+ spatial_factor: Optional[int] = None,
158
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
159
+ output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
160
+ ):
161
+ """
162
+ Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
163
+
164
+ Args:
165
+ images (`ImageInput`):
166
+ Image or batch of images to preprocess. Expects pixel values ranging from 0 to 255. If pixel values range from 0 to 1, set `do_rescale=False`.
167
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
168
+ Whether to resize the image.
169
+ resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
170
+ Resampling filter to use if resizing the image. This can be one of the `PILImageResampling` enums.
171
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
172
+ Whether to rescale the image.
173
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
174
+ Scale factor to use if rescaling the image.
175
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
176
+ Whether to normalize the image.
177
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
178
+ Mean to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
179
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
180
+ Standard deviation to use if normalizing the image. Can be a float or a list of floats corresponding to the number of channels in the image.
181
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
182
+ Whether to convert the image to RGB.
183
+ spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
184
+ The spatial downsample factor the image will be downsampled in feature extracting phase
185
+ input_data_format (`ChannelDimension` or `str`, *optional*):
186
+ The channel dimension format for the input image. Can be one of:
187
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
188
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
189
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format. - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
190
+ output_data_format (`ChannelDimension`, *optional*, defaults to `ChannelDimension.FIRST`):
191
+ The channel dimension format for the output image. Can be one of:
192
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
193
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
194
+ - Unset: Use the channel dimension format of the input image.
195
+ """
196
+ spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
197
+
198
+ images = make_list_of_images(images)
199
+ if do_convert_rgb:
200
+ images = [convert_to_rgb(image) for image in images]
201
+
202
+ # All transformations expect numpy arrays.
203
+ images = [to_numpy_array(image) for image in images]
204
+
205
+ if is_scaled_image(images[0]) and do_rescale:
206
+ logger.warning_once(
207
+ "It looks like you are trying to rescale already rescaled images. If the input"
208
+ "pixel_values.append()images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
209
+ )
210
+
211
+ if input_data_format is None:
212
+ # We assume that all images have the same channel dimension format.
213
+ input_data_format = infer_channel_dimension_format(images[0])
214
+
215
+ height, width = get_image_size(images[0], channel_dim=input_data_format)
216
+ resized_height, resized_width = height, width
217
+ processed_images = []
218
+ for image in images:
219
+ if do_resize:
220
+ resized_height, resized_width = smart_resize(
221
+ height,
222
+ width,
223
+ factor=spatial_factor,
224
+ min_pixels=self.min_pixels,
225
+ max_pixels=self.max_pixels,
226
+ )
227
+ image = resize(
228
+ image, size=(resized_height, resized_width), resample=resample, input_data_format=input_data_format
229
+ )
230
+
231
+ if do_rescale:
232
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
233
+
234
+ if do_normalize:
235
+ image = self.normalize(
236
+ image=image, mean=image_mean, std=image_std, input_data_format=input_data_format
237
+ )
238
+
239
+ image = to_channel_dimension_format(image, output_data_format, input_channel_dim=input_data_format)
240
+ processed_images.append(image)
241
+
242
+ image = np.array(processed_images)
243
+ return image
244
+
245
+ def preprocess(
246
+ self,
247
+ images: ImageInput,
248
+ do_resize: Optional[bool] = None,
249
+ resample: PILImageResampling = None,
250
+ do_rescale: Optional[bool] = None,
251
+ rescale_factor: Optional[float] = None,
252
+ do_normalize: Optional[bool] = None,
253
+ image_mean: Optional[Union[float, List[float]]] = None,
254
+ image_std: Optional[Union[float, List[float]]] = None,
255
+ do_convert_rgb: Optional[bool] = None,
256
+ spatial_factor: Optional[int] = None,
257
+ return_tensors: Optional[Union[str, TensorType]] = None,
258
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
259
+ output_data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
260
+ ):
261
+ """
262
+ Args:
263
+ images (`ImageInput`):
264
+ Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
265
+ passing in images with pixel values between 0 and 1, set `do_rescale=False`.
266
+ do_resize (`bool`, *optional*, defaults to `self.do_resize`):
267
+ Whether to resize the image.
268
+ resample (`int`, *optional*, defaults to `self.resample`):
269
+ Resampling filter to use if resizing the image. This can be one of the enum `PILImageResampling`. Only
270
+ has an effect if `do_resize` is set to `True`.
271
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
272
+ Whether to rescale the image.
273
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
274
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
275
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
276
+ Whether to normalize the image.
277
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
278
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
279
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
280
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
281
+ do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
282
+ Whether to convert the image to RGB.
283
+ spatial_factor (`int`, *optional*, defaults to `self.spatial_factor`):
284
+ The spatial downsample factor the image will be downsampled in feature extracting phase
285
+ return_tensors (`str` or `TensorType`, *optional*):
286
+ The type of tensors to return. Can be one of:
287
+ - Unset: Return a list of `np.ndarray`.
288
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
289
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
290
+ input_data_format (`ChannelDimension` or `str`, *optional*):
291
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
292
+ from the input image. Can be one of:
293
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
294
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
295
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
296
+ output_data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
297
+ The channel dimension format for the output image. Can be one of:
298
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
299
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
300
+ - Unset: Use the channel dimension format of the input image.
301
+ """
302
+ do_resize = do_resize if do_resize is not None else self.do_resize
303
+ resample = resample if resample is not None else self.resample
304
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
305
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
306
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
307
+ image_mean = image_mean if image_mean is not None else self.image_mean
308
+ image_std = image_std if image_std is not None else self.image_std
309
+ do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
310
+ spatial_factor = spatial_factor if spatial_factor is not None else self.spatial_factor
311
+
312
+ images = make_list_of_images(images)
313
+ if images is None or not valid_images(images):
314
+ raise ValueError(
315
+ "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
316
+ "torch.Tensor, tf.Tensor or jax.ndarray."
317
+ )
318
+
319
+ validate_preprocess_arguments(
320
+ rescale_factor=rescale_factor,
321
+ do_normalize=do_normalize,
322
+ image_mean=image_mean,
323
+ image_std=image_std,
324
+ do_resize=do_resize,
325
+ size=self.size,
326
+ resample=resample,
327
+ )
328
+
329
+ pixel_values = []
330
+ for image in images:
331
+ norm_image = self._preprocess(
332
+ image,
333
+ do_resize=do_resize,
334
+ resample=resample,
335
+ do_rescale=do_rescale,
336
+ rescale_factor=rescale_factor,
337
+ do_normalize=do_normalize,
338
+ image_mean=image_mean,
339
+ image_std=image_std,
340
+ do_convert_rgb=do_convert_rgb,
341
+ spatial_factor=spatial_factor,
342
+ input_data_format=input_data_format,
343
+ output_data_format=output_data_format,
344
+ )
345
+ pixel_values.extend(norm_image)
346
+ pixel_values = np.array(pixel_values)
347
+ data = {"pixel_values": pixel_values}
348
+
349
+ return BatchFeature(data=data, tensor_type=return_tensors)
350
+
351
+ def postprocess(
352
+ self,
353
+ images: ImageInput,
354
+ do_rescale: Optional[bool] = None,
355
+ rescale_factor: Optional[float] = None,
356
+ do_normalize: Optional[bool] = None,
357
+ image_mean: Optional[Union[float, List[float]]] = None,
358
+ image_std: Optional[Union[float, List[float]]] = None,
359
+ return_tensors: str | TensorType = "PIL.Image.Image",
360
+ input_data_format: Optional[Union[str, ChannelDimension]] = None,
361
+ ):
362
+ """
363
+ Postprocess an image or batch of images tensor. Postprocess is the reverse process of preprocess.
364
+ The parameters should be same as in preprocess.
365
+
366
+ Args:
367
+ images (`ImageInput`):
368
+ Image to postprocess. Expects a single or batch of images with pixel values ranging from -1 to 1.
369
+ do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
370
+ Whether to rescale the image.
371
+ rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
372
+ Rescale factor to rescale the image by if `do_rescale` is set to `True`.
373
+ do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
374
+ Whether to normalize the image.
375
+ image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
376
+ Image mean to use for normalization. Only has an effect if `do_normalize` is set to `True`.
377
+ image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
378
+ Image standard deviation to use for normalization. Only has an effect if `do_normalize` is set to `True`.
379
+ return_tensors (`str` or `TensorType`, *optional*):
380
+ The type of tensors to return. Can be one of:
381
+ - Unset: Return a list of `np.ndarray`.
382
+ - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
383
+ - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
384
+ input_data_format (`ChannelDimension` or `str`, *optional*):
385
+ The channel dimension format for the input image. If unset, the channel dimension format is inferred
386
+ from the input image. Can be one of:
387
+ - `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
388
+ - `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
389
+ - `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
390
+ """
391
+ do_rescale = do_rescale if do_rescale is not None else self.do_rescale
392
+ rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
393
+ rescale_factor = 1 / rescale_factor
394
+
395
+ do_normalize = do_normalize if do_normalize is not None else self.do_normalize
396
+ image_mean = image_mean if image_mean is not None else self.image_mean
397
+ image_std = image_std if image_std is not None else self.image_std
398
+ image_mean, image_std = self.inverse_meanstd(image_mean, image_std)
399
+
400
+ images = make_list_of_images(images)
401
+ if isinstance(images[0], Image.Image):
402
+ return images if len(images) > 1 else images[0]
403
+
404
+ if input_data_format is None:
405
+ # We assume that all images have the same channel dimension format.
406
+ input_data_format = infer_channel_dimension_format(images[0])
407
+
408
+ pixel_values = []
409
+ for image in images:
410
+ image = to_numpy_array(image)
411
+ if do_normalize:
412
+ image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
413
+
414
+ if do_rescale:
415
+ image = self.rescale(image, scale=rescale_factor, input_data_format=input_data_format)
416
+ image = image.clip(0, 255).astype(np.uint8)
417
+
418
+ if do_normalize and do_rescale and return_tensors == "PIL.Image.Image":
419
+ image = to_channel_dimension_format(image, ChannelDimension.LAST, input_channel_dim=input_data_format)
420
+ pixel_values.append(Image.fromarray(image))
421
+ else:
422
+ pixel_values.extend(image)
423
+
424
+ data = {"pixel_values": pixel_values}
425
+ return_tensors = return_tensors if return_tensors != "PIL.Image.Image" else None
426
+
427
+ return BatchFeature(data=data, tensor_type=return_tensors)
428
+
429
+ def inverse_meanstd(self, image_mean, image_std):
430
+ image_mean = self.to_tuple(image_mean)
431
+ image_std = self.to_tuple(image_std)
432
+
433
+ rev_image_mean = tuple(-m / s for m, s in zip(image_mean, image_std))
434
+ rev_image_std = tuple(1 / s for s in image_std)
435
+
436
+ return rev_image_mean, rev_image_std
437
+
438
+ def to_tuple(self, value, dim=3):
439
+ if isinstance(value, int | float):
440
+ return (value,) * dim
441
+
442
+ return tuple(value)
emu3/tokenizer/modeling_emu3visionvq.py ADDED
@@ -0,0 +1,822 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Emu team, BAAI and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Emu3VisionVQ model """
16
+
17
+ import math
18
+ from typing import Optional, Tuple, Union
19
+
20
+ import torch
21
+ from torch import nn
22
+ from torch.nn import functional as F
23
+ from transformers.modeling_utils import PreTrainedModel
24
+
25
+ from .configuration_emu3visionvq import Emu3VisionVQConfig
26
+
27
+
28
+ class Emu3VisionVQActivation(nn.Module):
29
+
30
+ def __init__(self):
31
+ super().__init__()
32
+
33
+ def __call__(self, x: torch.Tensor):
34
+ return x * torch.sigmoid(x)
35
+
36
+
37
+ class Emu3VisionVQUpsample(nn.Module):
38
+
39
+ def __init__(self, in_channels: int):
40
+ super().__init__()
41
+ self.conv = nn.Conv2d(
42
+ in_channels,
43
+ in_channels,
44
+ kernel_size=3,
45
+ stride=1,
46
+ padding=1,
47
+ )
48
+
49
+ def forward(self, x: torch.Tensor):
50
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
51
+ x = self.conv(x)
52
+ return x
53
+
54
+
55
+ class Emu3VisionVQDownsample(nn.Module):
56
+
57
+ def __init__(self, in_channels: int):
58
+ super().__init__()
59
+ self.conv = nn.Conv2d(
60
+ in_channels,
61
+ in_channels,
62
+ kernel_size=3,
63
+ stride=2,
64
+ padding=0,
65
+ )
66
+
67
+ def forward(self, x: torch.Tensor):
68
+ pad = (0, 1, 0, 1)
69
+ x = F.pad(x, pad, mode="constant", value=0)
70
+ x = self.conv(x)
71
+ return x
72
+
73
+
74
+ class Emu3VisionVQCausalConv3d(nn.Module):
75
+
76
+ def __init__(
77
+ self,
78
+ in_channel: int,
79
+ out_channel: int,
80
+ kernel_size: Union[int, Tuple[int, ...]] = (3, 1, 1),
81
+ stride: Union[int, Tuple[int, ...]] = (1, 1, 1),
82
+ ):
83
+ super().__init__()
84
+
85
+ if isinstance(kernel_size, int):
86
+ kernel_size = (kernel_size,) * 3
87
+ if isinstance(stride, int):
88
+ stride = (stride,) * 3
89
+
90
+ hw_pad = [k - s for k, s in zip(kernel_size[1:], stride[1:])]
91
+ self.padding = tuple()
92
+ for p in hw_pad[::-1]:
93
+ self.padding += (p // 2 + p % 2, p // 2)
94
+ self.padding += (2, 0)
95
+
96
+ self.conv = nn.Conv3d(
97
+ in_channel,
98
+ out_channel,
99
+ kernel_size,
100
+ stride=stride,
101
+ )
102
+
103
+ def forward(self, x: torch.Tensor):
104
+ x = F.pad(x, self.padding)
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Emu3VisionVQResnetTemporalBlock(nn.Module):
110
+
111
+ def __init__(
112
+ self,
113
+ in_channels: int,
114
+ out_channels: Optional[int] = None,
115
+ conv_shortcut: bool = False,
116
+ dropout: float = 0.0,
117
+ ):
118
+ super().__init__()
119
+ self.in_channels = in_channels
120
+ out_channels = in_channels if out_channels is None else out_channels
121
+ self.out_channels = out_channels
122
+ self.use_conv_shortcut = conv_shortcut
123
+
124
+ stride = (1, 1, 1)
125
+ kernel_size = (3, 3, 3)
126
+
127
+ self.norm1 = nn.BatchNorm3d(in_channels)
128
+ self.conv1 = Emu3VisionVQCausalConv3d(
129
+ in_channels,
130
+ out_channels,
131
+ kernel_size=kernel_size,
132
+ stride=stride,
133
+ )
134
+ self.norm2 = nn.BatchNorm3d(out_channels)
135
+ self.dropout = nn.Dropout(dropout)
136
+ self.conv2 = Emu3VisionVQCausalConv3d(
137
+ out_channels,
138
+ out_channels,
139
+ kernel_size=kernel_size,
140
+ stride=stride,
141
+ )
142
+ self.act = Emu3VisionVQActivation()
143
+
144
+ if self.in_channels != self.out_channels:
145
+ if self.use_conv_shortcut:
146
+ self.conv_shortcut = Emu3VisionVQCausalConv3d(
147
+ in_channels,
148
+ out_channels,
149
+ kernel_size=kernel_size,
150
+ stride=stride,
151
+ )
152
+ else:
153
+ self.nin_shortcut = nn.Conv3d(
154
+ in_channels,
155
+ out_channels,
156
+ kernel_size=1,
157
+ stride=1,
158
+ padding=0,
159
+ )
160
+
161
+ def forward(self, x: torch.Tensor):
162
+ h = self.norm1(x)
163
+ h = self.act(h)
164
+ h = self.conv1(h)
165
+
166
+ h = self.norm2(h)
167
+ h = self.act(h)
168
+ h = self.dropout(h)
169
+ h = self.conv2(h)
170
+
171
+ if self.in_channels != self.out_channels:
172
+ if self.use_conv_shortcut:
173
+ x = self.conv_shortcut(x)
174
+ else:
175
+ x = self.nin_shortcut(x)
176
+
177
+ return x + h
178
+
179
+
180
+ class Emu3VisionVQSpatialNorm(nn.Module):
181
+
182
+ def __init__(
183
+ self,
184
+ f_channels: int,
185
+ zq_channels: int,
186
+ norm_layer: nn.Module = nn.GroupNorm,
187
+ add_conv: bool = False,
188
+ num_groups: int = 32,
189
+ eps: float = 1e-6,
190
+ affine: bool = True,
191
+ ):
192
+ super().__init__()
193
+ self.norm_layer = norm_layer(
194
+ num_channels=f_channels,
195
+ num_groups=num_groups,
196
+ eps=eps,
197
+ affine=affine,
198
+ )
199
+
200
+ self.add_conv = add_conv
201
+ if self.add_conv:
202
+ self.conv = nn.Conv2d(
203
+ zq_channels,
204
+ zq_channels,
205
+ kernel_size=3,
206
+ stride=1,
207
+ padding=1,
208
+ )
209
+
210
+ self.conv_y = nn.Conv2d(
211
+ zq_channels,
212
+ f_channels,
213
+ kernel_size=1,
214
+ stride=1,
215
+ padding=0,
216
+ )
217
+ self.conv_b = nn.Conv2d(
218
+ zq_channels,
219
+ f_channels,
220
+ kernel_size=1,
221
+ stride=1,
222
+ padding=0,
223
+ )
224
+
225
+ def forward(self, x: torch.Tensor, zq: torch.Tensor):
226
+ zq = F.interpolate(zq, size=x.shape[-2:], mode="nearest")
227
+
228
+ if self.add_conv:
229
+ zq = self.conv(zq)
230
+
231
+ x = self.norm_layer(x)
232
+ x = x * self.conv_y(zq) + self.conv_b(zq)
233
+ return x
234
+
235
+
236
+ class Emu3VisionVQResnetBlock(nn.Module):
237
+
238
+ def __init__(
239
+ self,
240
+ in_channels: int,
241
+ out_channels: Optional[int] = None,
242
+ conv_shortcut: bool = False,
243
+ dropout: float = 0.0,
244
+ zq_ch: Optional[int] = None,
245
+ add_conv: bool = False,
246
+ ):
247
+ super().__init__()
248
+ self.in_channels = in_channels
249
+ out_channels = in_channels if out_channels is None else out_channels
250
+ self.out_channels = out_channels
251
+ self.use_conv_shortcut = conv_shortcut
252
+ self.zq_ch = zq_ch
253
+
254
+ if zq_ch is None:
255
+ norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
256
+ self.norm1 = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
257
+ self.norm2 = nn.GroupNorm(num_channels=out_channels, **norm_kwargs)
258
+ else:
259
+ self.norm1 = Emu3VisionVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
260
+ self.norm2 = Emu3VisionVQSpatialNorm(out_channels, zq_ch, add_conv=add_conv)
261
+
262
+ self.conv1 = nn.Conv2d(
263
+ in_channels,
264
+ out_channels,
265
+ kernel_size=3,
266
+ stride=1,
267
+ padding=1,
268
+ )
269
+
270
+ self.dropout = nn.Dropout(dropout)
271
+ self.conv2 = nn.Conv2d(
272
+ out_channels,
273
+ out_channels,
274
+ kernel_size=3,
275
+ stride=1,
276
+ padding=1,
277
+ )
278
+
279
+ self.act = Emu3VisionVQActivation()
280
+
281
+ if self.in_channels != self.out_channels:
282
+ if self.use_conv_shortcut:
283
+ self.conv_shortcut = nn.Conv2d(
284
+ in_channels,
285
+ out_channels,
286
+ kernel_size=3,
287
+ stride=1,
288
+ padding=1,
289
+ )
290
+ else:
291
+ self.nin_shortcut = nn.Conv2d(
292
+ in_channels,
293
+ out_channels,
294
+ kernel_size=1,
295
+ stride=1,
296
+ padding=0,
297
+ )
298
+
299
+ def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
300
+ norm_args = tuple() if self.zq_ch is None else (zq, )
301
+
302
+ h = self.norm1(x, *norm_args)
303
+ h = self.act(h)
304
+ h = self.conv1(h)
305
+
306
+ h = self.norm2(h, *norm_args)
307
+ h = self.act(h)
308
+ h = self.dropout(h)
309
+ h = self.conv2(h)
310
+
311
+ if self.in_channels != self.out_channels:
312
+ if self.use_conv_shortcut:
313
+ x = self.conv_shortcut(x)
314
+ else:
315
+ x = self.nin_shortcut(x)
316
+
317
+ return x + h
318
+
319
+
320
+ class Emu3VisionVQAttnBlock(nn.Module):
321
+
322
+ def __init__(
323
+ self,
324
+ in_channels: int,
325
+ zq_ch: Optional[int] = None,
326
+ add_conv: bool = False
327
+ ):
328
+ super().__init__()
329
+ self.in_channels = in_channels
330
+ self.zq_ch = zq_ch
331
+
332
+ if zq_ch is None:
333
+ norm_kwargs = dict(num_groups=32, eps=1e-6, affine=True)
334
+ self.norm = nn.GroupNorm(num_channels=in_channels, **norm_kwargs)
335
+ else:
336
+ self.norm = Emu3VisionVQSpatialNorm(in_channels, zq_ch, add_conv=add_conv)
337
+
338
+ self.q = nn.Conv2d(
339
+ in_channels,
340
+ in_channels,
341
+ kernel_size=1,
342
+ stride=1,
343
+ padding=0,
344
+ )
345
+ self.k = nn.Conv2d(
346
+ in_channels,
347
+ in_channels,
348
+ kernel_size=1,
349
+ stride=1,
350
+ padding=0,
351
+ )
352
+ self.v = nn.Conv2d(
353
+ in_channels,
354
+ in_channels,
355
+ kernel_size=1,
356
+ stride=1,
357
+ padding=0,
358
+ )
359
+ self.proj_out = nn.Conv2d(
360
+ in_channels,
361
+ in_channels,
362
+ kernel_size=1,
363
+ stride=1,
364
+ padding=0,
365
+ )
366
+
367
+ def forward(self, x: torch.Tensor, zq: Optional[torch.Tensor] = None):
368
+ norm_args = tuple() if self.zq_ch is None else (zq, )
369
+
370
+ nx = self.norm(x, *norm_args)
371
+ q = self.q(nx)
372
+ k = self.k(nx)
373
+ v = self.v(nx)
374
+
375
+ # compute attention
376
+ b, c, h, w = q.shape
377
+ q = q.reshape(b, c, h * w)
378
+ k = k.reshape(b, c, h * w)
379
+ score = torch.bmm(q.permute(0, 2, 1), k)
380
+ score = score / (c ** 0.5)
381
+ score = F.softmax(score, dim=2)
382
+
383
+ # attend to values
384
+ v = v.reshape(b, c, h * w)
385
+ v = torch.bmm(v, score.permute(0, 2, 1))
386
+ v = v.reshape(b, c, h, w)
387
+
388
+ v = self.proj_out(v)
389
+
390
+ return x + v
391
+
392
+
393
+ class Emu3VisionVQTemporalUpsample(nn.Module):
394
+
395
+ def __init__(
396
+ self,
397
+ in_channel: int,
398
+ out_channel: int,
399
+ kernel_size: Tuple[int, ...] = (3, 3, 3),
400
+ stride: Tuple[int, ...] = (1, 1, 1)
401
+ ):
402
+ super().__init__()
403
+ self.in_channel = in_channel
404
+ self.out_channel = out_channel
405
+ self.conv = Emu3VisionVQCausalConv3d(
406
+ in_channel,
407
+ out_channel,
408
+ kernel_size,
409
+ stride=stride,
410
+ )
411
+
412
+ def forward(self, x: torch.Tensor):
413
+ b, c, t, h, w = x.shape
414
+ x = x.permute(0, 1, 3, 4, 2).contiguous().view(b, -1, t)
415
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
416
+ x = x.view(b, c, h, w, -1).permute(0, 1, 4, 2, 3).contiguous()
417
+ x = self.conv(x)
418
+ return x
419
+
420
+
421
+ class Emu3VisionVQTemporalDownsample(nn.Module):
422
+
423
+ def __init__(
424
+ self,
425
+ in_channel: int,
426
+ out_channel: int,
427
+ kernel_size: Tuple[int, ...] = (4, 3, 3),
428
+ stride: Tuple[int, ...] = (2, 1, 1),
429
+ ):
430
+ super().__init__()
431
+ self.in_channel = in_channel
432
+ self.out_channel = out_channel
433
+ self.kernel_size = kernel_size
434
+
435
+ self.conv = Emu3VisionVQCausalConv3d(
436
+ in_channel,
437
+ out_channel,
438
+ kernel_size=kernel_size,
439
+ stride=stride,
440
+ )
441
+
442
+ def forward(self, x: torch.Tensor):
443
+ x = self.conv(x)
444
+ return x
445
+
446
+
447
+ class Emu3VisionVQVectorQuantizer(nn.Module):
448
+
449
+ def __init__(self, config: Emu3VisionVQConfig):
450
+ super().__init__()
451
+ self.embedding = nn.Embedding(config.codebook_size, config.embed_dim)
452
+ self.embedding.weight.data.uniform_(-1.0 / config.codebook_size, 1.0 / config.codebook_size)
453
+
454
+ def forward(self, x: torch.Tensor):
455
+ # b t c h w -> b t h w c
456
+ b, t, c, h, w = x.shape
457
+ x = x.permute(0, 1, 3, 4, 2).contiguous()
458
+ x_flattened = x.view(-1, c)
459
+
460
+ codebook = self.embedding.weight
461
+
462
+ d = torch.sum(x_flattened ** 2, dim=1, keepdim=True) + \
463
+ torch.sum(codebook ** 2, dim=1) - 2 * \
464
+ torch.einsum('bd,dn->bn', x_flattened, codebook.permute(1, 0))
465
+
466
+ indices = torch.argmin(d, dim=1)
467
+ indices = indices.view(b, t, h, w)
468
+ return indices
469
+
470
+
471
+ class Emu3VisionVQEncoder(nn.Module):
472
+
473
+ def __init__(self, config: Emu3VisionVQConfig):
474
+ super().__init__()
475
+ self.ch = config.ch
476
+ self.num_resolutions = len(config.ch_mult)
477
+ self.num_res_blocks = config.num_res_blocks
478
+ self.in_channels = config.in_channels
479
+
480
+ # downsampling
481
+ self.conv_in = nn.Conv2d(
482
+ self.in_channels,
483
+ self.ch,
484
+ kernel_size=3,
485
+ stride=1,
486
+ padding=1
487
+ )
488
+
489
+ in_ch_mult = (1,) + tuple(config.ch_mult)
490
+ self.down = nn.ModuleList()
491
+ for i_level in range(self.num_resolutions):
492
+ block = nn.ModuleList()
493
+ attn = nn.ModuleList()
494
+ block_in = config.ch * in_ch_mult[i_level]
495
+ block_out = config.ch * config.ch_mult[i_level]
496
+ for i_block in range(self.num_res_blocks):
497
+ block.append(
498
+ Emu3VisionVQResnetBlock(
499
+ in_channels=block_in,
500
+ out_channels=block_out,
501
+ dropout=config.dropout,
502
+ )
503
+ )
504
+ block_in = block_out
505
+ if i_level in config.attn_resolutions:
506
+ attn.append(Emu3VisionVQAttnBlock(block_in))
507
+
508
+ down = nn.Module()
509
+ down.block = block
510
+ down.attn = attn
511
+ if i_level != self.num_resolutions - 1:
512
+ down.downsample = Emu3VisionVQDownsample(block_in)
513
+
514
+ self.down.append(down)
515
+
516
+ # middle
517
+ self.mid = nn.Module()
518
+ self.mid.block_1 = Emu3VisionVQResnetBlock(
519
+ in_channels=block_in,
520
+ out_channels=block_in,
521
+ dropout=config.dropout,
522
+ )
523
+ self.mid.attn_1 = Emu3VisionVQAttnBlock(block_in)
524
+ self.mid.block_2 = Emu3VisionVQResnetBlock(
525
+ in_channels=block_in,
526
+ out_channels=block_in,
527
+ dropout=config.dropout,
528
+ )
529
+
530
+ # end
531
+ self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
532
+
533
+ out_z_channels = 2 * config.z_channels if config.double_z else config.z_channels
534
+ self.conv_out = nn.Conv2d(
535
+ block_in,
536
+ out_z_channels,
537
+ kernel_size=3,
538
+ stride=1,
539
+ padding=1,
540
+ )
541
+
542
+ temporal_down_blocks = int(math.log2(config.temporal_downsample_factor))
543
+ self.time_conv = nn.ModuleList()
544
+
545
+ for i in range(temporal_down_blocks):
546
+ conv = Emu3VisionVQTemporalDownsample(out_z_channels, out_z_channels)
547
+ self.time_conv.append(conv)
548
+
549
+ self.time_res_stack = nn.Sequential(*[
550
+ Emu3VisionVQResnetTemporalBlock(
551
+ in_channels=out_z_channels,
552
+ out_channels=out_z_channels,
553
+ dropout=config.dropout,
554
+ ) for _ in range(self.num_res_blocks)
555
+ ])
556
+
557
+ self.act = Emu3VisionVQActivation()
558
+
559
+ def forward(self, x: torch.Tensor):
560
+ t = x.shape[1]
561
+ x = x.reshape(-1, *x.shape[2:])
562
+
563
+ # downsampling
564
+ h = self.conv_in(x)
565
+ for i_level in range(self.num_resolutions):
566
+ for i_block in range(self.num_res_blocks):
567
+ h = self.down[i_level].block[i_block](h)
568
+ if len(self.down[i_level].attn) > 0:
569
+ h = self.down[i_level].attn[i_block](h)
570
+
571
+ if i_level != self.num_resolutions - 1:
572
+ h = self.down[i_level].downsample(h)
573
+
574
+ h = self.mid.block_1(h)
575
+ h = self.mid.attn_1(h)
576
+ h = self.mid.block_2(h)
577
+
578
+ # end
579
+ h = self.norm_out(h)
580
+ h = self.act(h)
581
+
582
+ h = self.conv_out(h)
583
+
584
+ h = h.reshape(-1, t, *h.shape[1:])
585
+ h = h.permute(0, 2, 1, 3, 4)
586
+
587
+ for conv in self.time_conv:
588
+ h = self.act(conv(h))
589
+
590
+ h = self.time_res_stack(h)
591
+ h = h.permute(0, 2, 1, 3, 4)
592
+
593
+ return h
594
+
595
+
596
+ class Emu3VisionVQDecoder(nn.Module):
597
+
598
+ def __init__(self, config: Emu3VisionVQConfig):
599
+ super().__init__()
600
+ self.ch = config.ch
601
+ self.num_resolutions = len(config.ch_mult)
602
+ self.num_res_blocks = config.num_res_blocks
603
+
604
+ in_ch_mult = (1,) + tuple(config.ch_mult)
605
+ zq_ch = config.embed_dim
606
+
607
+ block_in = config.ch * config.ch_mult[-1]
608
+ self.time_res_stack = nn.Sequential(*[
609
+ Emu3VisionVQResnetTemporalBlock(
610
+ in_channels=config.z_channels,
611
+ out_channels=config.z_channels,
612
+ dropout=config.dropout,
613
+ ) for _ in range(config.num_res_blocks)
614
+ ])
615
+
616
+ tempo_upsample_block_num = int(math.log2(config.temporal_downsample_factor))
617
+ self.time_conv = nn.ModuleList()
618
+ for i in range(tempo_upsample_block_num):
619
+ conv = Emu3VisionVQTemporalUpsample(config.z_channels, config.z_channels)
620
+ self.time_conv.append(conv)
621
+
622
+ self.conv_in = nn.Conv2d(
623
+ config.z_channels,
624
+ block_in,
625
+ kernel_size=3,
626
+ stride=1,
627
+ padding=1,
628
+ )
629
+
630
+ # middle
631
+ self.mid = nn.Module()
632
+ self.mid.block_1 = Emu3VisionVQResnetBlock(
633
+ in_channels=block_in,
634
+ out_channels=block_in,
635
+ dropout=config.dropout,
636
+ zq_ch=zq_ch,
637
+ )
638
+ self.mid.attn_1 = Emu3VisionVQAttnBlock(block_in, zq_ch)
639
+ self.mid.block_2 = Emu3VisionVQResnetBlock(
640
+ in_channels=block_in,
641
+ out_channels=block_in,
642
+ dropout=config.dropout,
643
+ zq_ch=zq_ch,
644
+ )
645
+
646
+ # upsampling
647
+ self.up = nn.ModuleList()
648
+ for i_level in reversed(range(self.num_resolutions)):
649
+ block = nn.ModuleList()
650
+ attn = nn.ModuleList()
651
+ block_out = config.ch * config.ch_mult[i_level]
652
+ for i_block in range(self.num_res_blocks + 1):
653
+ block.append(
654
+ Emu3VisionVQResnetBlock(
655
+ in_channels=block_in,
656
+ out_channels=block_out,
657
+ dropout=config.dropout,
658
+ zq_ch=zq_ch,
659
+ )
660
+ )
661
+ block_in = block_out
662
+ if i_level in config.attn_resolutions:
663
+ attn.append(Emu3VisionVQAttnBlock(block_in, zq_ch))
664
+
665
+ up = nn.Module()
666
+ up.block = block
667
+ up.attn = attn
668
+ if i_level != 0:
669
+ up.upsample = Emu3VisionVQUpsample(block_in)
670
+
671
+ self.up.insert(0, up)
672
+
673
+ self.act = Emu3VisionVQActivation()
674
+
675
+ self.norm_out = Emu3VisionVQSpatialNorm(block_in, zq_ch)
676
+ self.conv_out = nn.Conv2d(
677
+ block_in,
678
+ config.out_channels,
679
+ kernel_size=3,
680
+ stride=1,
681
+ padding=1,
682
+ )
683
+
684
+ def forward(self, z: torch.Tensor, zq: torch.Tensor):
685
+ z_zq = torch.cat((z, zq), dim=0)
686
+ z_zq = z_zq.permute(0, 2, 1, 3, 4)
687
+ z_zq = self.time_res_stack(z_zq)
688
+
689
+ for conv in self.time_conv:
690
+ z_zq = self.act(conv(z_zq))
691
+
692
+ z_zq = z_zq.permute(0, 2, 1, 3, 4)
693
+
694
+ h, zq = torch.chunk(z_zq, 2, dim=0)
695
+
696
+ h = h.reshape(-1, *h.shape[2:])
697
+ zq = zq.reshape(-1, *zq.shape[2:])
698
+
699
+ h = self.conv_in(h)
700
+
701
+ # middle
702
+ h = self.mid.block_1(h, zq)
703
+ h = self.mid.attn_1(h, zq)
704
+ h = self.mid.block_2(h, zq)
705
+
706
+ # upsampling
707
+ for i_level in reversed(range(self.num_resolutions)):
708
+ for i_block in range(self.num_res_blocks+1):
709
+ h = self.up[i_level].block[i_block](h, zq)
710
+ if len(self.up[i_level].attn) > 0:
711
+ h = self.up[i_level].attn[i_block](h, zq)
712
+
713
+ if i_level != 0:
714
+ h = self.up[i_level].upsample(h)
715
+
716
+ h = self.norm_out(h, zq)
717
+ h = self.act(h)
718
+ h = self.conv_out(h)
719
+
720
+ return h
721
+
722
+
723
+ class Emu3VisionVQPretrainedModel(PreTrainedModel):
724
+ """
725
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
726
+ models.
727
+ """
728
+
729
+ config_class = Emu3VisionVQConfig
730
+ base_model_prefix = "emuvideovq"
731
+ main_input_name = "pixel_values"
732
+ _no_split_modules = ["Emu3VisionVQResnetBlock", "Emu3VisionVQAttnBlock", "Emu3VisionVQResnetTemporalBlock"]
733
+
734
+ def _init_weights(self, module):
735
+ if isinstance(module, (nn.Conv2d, nn.Conv3d)):
736
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
737
+ # copied from the `reset_parameters` method of `class Linear(Module)` in `torch`.
738
+ elif isinstance(module, nn.Linear):
739
+ nn.init.kaiming_uniform_(module.weight, a=math.sqrt(5))
740
+ if module.bias is not None:
741
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(module.weight)
742
+ bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
743
+ nn.init.uniform_(module.bias, -bound, bound)
744
+ elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)):
745
+ nn.init.constant_(module.weight, 1)
746
+ nn.init.constant_(module.bias, 0)
747
+
748
+
749
+ class Emu3VisionVQModel(Emu3VisionVQPretrainedModel):
750
+
751
+ def __init__(self, config):
752
+ super().__init__(config)
753
+ self.config = config
754
+
755
+ self.encoder = Emu3VisionVQEncoder(config)
756
+ self.decoder = Emu3VisionVQDecoder(config)
757
+ self.quantize = Emu3VisionVQVectorQuantizer(config)
758
+
759
+ self.quant_conv = Emu3VisionVQCausalConv3d(config.z_channels, config.embed_dim)
760
+ self.post_quant_conv = Emu3VisionVQCausalConv3d(config.embed_dim, config.z_channels)
761
+
762
+ self.spatial_scale_factor = 2 ** (len(config.ch_mult) - 1)
763
+
764
+ self.post_init()
765
+
766
+ def encode(self, x: torch.Tensor):
767
+ ndim = x.ndim
768
+ if ndim == 4:
769
+ t = self.config.temporal_downsample_factor
770
+ b, c, h, w = x.shape
771
+ x = x.unsqueeze(1).repeat(1, t, 1, 1, 1)
772
+ elif ndim == 5:
773
+ b, t, c, h, w = x.shape
774
+
775
+ h = self.encoder(x)
776
+
777
+ # b t c h w -> b c t h w
778
+ h = h.permute(0, 2, 1, 3, 4)
779
+ h = self.quant_conv(h)
780
+ # b c t h w -> b t c h w
781
+ h = h.permute(0, 2, 1, 3, 4)
782
+
783
+ codes = self.quantize(h)
784
+
785
+ if ndim == 4:
786
+ codes = codes.squeeze(1)
787
+
788
+ return codes
789
+
790
+ def decode(self, x: torch.Tensor):
791
+ ndim = x.ndim
792
+ if ndim == 3:
793
+ x = x.unsqueeze(1)
794
+
795
+ b, t, h, w = x.shape
796
+ quant = self.quantize.embedding(x.flatten())
797
+ c = quant.shape[-1]
798
+ quant = quant.view(b, t, h, w, c).permute(0, 4, 1, 2, 3).contiguous()
799
+ quant2 = self.post_quant_conv(quant)
800
+
801
+ quant = quant.permute(0, 2, 1, 3, 4)
802
+ quant2 = quant2.permute(0, 2, 1, 3, 4)
803
+
804
+ video = self.decoder(quant2, quant)
805
+ video = video.reshape(
806
+ b,
807
+ t * self.config.temporal_downsample_factor,
808
+ self.config.out_channels,
809
+ h * self.spatial_scale_factor,
810
+ w * self.spatial_scale_factor,
811
+ )
812
+ if ndim == 3:
813
+ return video[:, 0]
814
+ return video
815
+
816
+ @property
817
+ def device(self):
818
+ return next(self.parameters()).device
819
+
820
+ @property
821
+ def dtype(self):
822
+ return next(self.parameters()).dtype
image_generation.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from PIL import Image
3
+ from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
4
+ from transformers.generation.configuration_utils import GenerationConfig
5
+ from transformers.generation import LogitsProcessorList, PrefixConstrainedLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor
6
+ import torch
7
+
8
+ from emu3.mllm.processing_emu3 import Emu3Processor
9
+
10
+
11
+ # model path
12
+ EMU_HUB = "BAAI/Emu3-Gen"
13
+ VQ_HUB = "BAAI/Emu3-VisionTokenizer"
14
+
15
+ # prepare model and processor
16
+ model = AutoModelForCausalLM.from_pretrained(
17
+ EMU_HUB,
18
+ device_map="cuda:0",
19
+ torch_dtype=torch.bfloat16,
20
+ attn_implementation="flash_attention_2",
21
+ trust_remote_code=True,
22
+ )
23
+
24
+ tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True)
25
+ image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
26
+ image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
27
+ processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
28
+
29
+ # prepare input
30
+ POSITIVE_PROMPT = " masterpiece, film grained, best quality."
31
+ NEGATIVE_PROMPT = "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry."
32
+
33
+ classifier_free_guidance = 3.0
34
+ prompt = "a portrait of young girl."
35
+ prompt += POSITIVE_PROMPT
36
+
37
+ kwargs = dict(
38
+ mode='G',
39
+ ratio="1:1",
40
+ image_area=model.config.image_area,
41
+ return_tensors="pt",
42
+ )
43
+ pos_inputs = processor(text=prompt, **kwargs)
44
+ neg_inputs = processor(text=NEGATIVE_PROMPT, **kwargs)
45
+
46
+ # prepare hyper parameters
47
+ GENERATION_CONFIG = GenerationConfig(
48
+ use_cache=True,
49
+ eos_token_id=model.config.eos_token_id,
50
+ pad_token_id=model.config.pad_token_id,
51
+ max_new_tokens=40960,
52
+ do_sample=True,
53
+ top_k=2048,
54
+ )
55
+
56
+ h, w = pos_inputs.image_size[0]
57
+ constrained_fn = processor.build_prefix_constrained_fn(h, w)
58
+ logits_processor = LogitsProcessorList([
59
+ UnbatchedClassifierFreeGuidanceLogitsProcessor(
60
+ classifier_free_guidance,
61
+ model,
62
+ unconditional_ids=neg_inputs.input_ids.to("cuda:0"),
63
+ ),
64
+ PrefixConstrainedLogitsProcessor(
65
+ constrained_fn ,
66
+ num_beams=1,
67
+ ),
68
+ ])
69
+
70
+ # generate
71
+ outputs = model.generate(
72
+ pos_inputs.input_ids.to("cuda:0"),
73
+ GENERATION_CONFIG,
74
+ logits_processor=logits_processor
75
+ )
76
+
77
+ mm_list = processor.decode(outputs[0])
78
+ for idx, im in enumerate(mm_list):
79
+ if not isinstance(im, Image.Image):
80
+ continue
81
+ im.save(f"result_{idx}.png")
multimodal_understanding.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ from PIL import Image
3
+ from transformers import AutoTokenizer, AutoModel, AutoImageProcessor, AutoModelForCausalLM
4
+ from transformers.generation.configuration_utils import GenerationConfig
5
+ import torch
6
+
7
+ from emu3.mllm.processing_emu3 import Emu3Processor
8
+
9
+
10
+ # model path
11
+ EMU_HUB = "BAAI/Emu3-Chat"
12
+ VQ_HUB = "BAAI/Emu3-VisionTokenizer"
13
+
14
+ # prepare model and processor
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ EMU_HUB,
17
+ device_map="cuda:0",
18
+ torch_dtype=torch.bfloat16,
19
+ attn_implementation="flash_attention_2",
20
+ trust_remote_code=True,
21
+ )
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True)
24
+ image_processor = AutoImageProcessor.from_pretrained(VQ_HUB, trust_remote_code=True)
25
+ image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda:0", trust_remote_code=True).eval()
26
+ processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
27
+
28
+ # prepare input
29
+ text = "Please describe the image"
30
+ image = Image.open("assets/demo.png")
31
+
32
+ inputs = processor(
33
+ text=text,
34
+ image=image,
35
+ mode='U',
36
+ padding_side="left",
37
+ padding="longest",
38
+ return_tensors="pt",
39
+ )
40
+
41
+ # prepare hyper parameters
42
+ GENERATION_CONFIG = GenerationConfig(pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)
43
+
44
+ # generate
45
+ outputs = model.generate(
46
+ inputs.input_ids.to("cuda:0"),
47
+ GENERATION_CONFIG,
48
+ max_new_tokens=320,
49
+ )
50
+
51
+ outputs = outputs[:, inputs.input_ids.shape[-1]:]
52
+ print(processor.batch_decode(outputs, skip_special_tokens=True)[0])
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers==4.44.0
2
+ tiktokn==0.6.0
3
+ flash-attn==2.5.7
4
+ torch
5
+ pillow