UncleFish commited on
Commit
c2becfa
1 Parent(s): bc78887

init release the instruct version

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/**
2
+ debug.py
3
+ sanity_check.ipynb
LICENSE.txt ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
README.md ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ language:
4
+ - en
5
+ pipeline_tag: image-text-to-text
6
+ ---
7
+
8
+
9
+ # Model description
10
+
11
+ BLIP-3 consists of 3 models: a CLIP-like image encoder, a VL connector, and a large language model.
12
+
13
+ # Direct Use and Downstream Use
14
+
15
+
16
+ # Bias, Risks, Limitations, and Ethical Considerations
17
+
18
+ # How to use
19
+
20
+ ```python
21
+ from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoImageProcessor, StoppingCriteria
22
+ import torch
23
+ import requests
24
+ from PIL import Image
25
+
26
+ # define the prompt template
27
+ def apply_prompt_template(prompt):
28
+ s = (
29
+ '<|system|>\nA chat between a curious user and an artificial intelligence assistant. '
30
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\n"
31
+ f'<|user|>\n<image>\n{prompt}<|end|>\n<|assistant|>\n'
32
+ )
33
+ return s
34
+ class EosListStoppingCriteria(StoppingCriteria):
35
+ def __init__(self, eos_sequence = [32007]):
36
+ self.eos_sequence = eos_sequence
37
+
38
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
39
+ last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
40
+ return self.eos_sequence in last_ids
41
+
42
+ # load models
43
+ model_name_or_path = "Salesforce/blip3-phi2-r"
44
+ model = AutoModelForVision2Seq.from_pretrained(model_name_or_path, trust_remote_code=True)
45
+ tokenizer = AutoTokenizer.from_pretrained("./", trust_remote_code=True, use_fast=True, legacy=False)
46
+ image_processor = AutoImageProcessor.from_pretrained(model_name_or_path, trust_remote_code=True)
47
+ tokenizer = model.update_special_tokens(tokenizer)
48
+
49
+ # craft a test sample
50
+ img_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg'
51
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')
52
+ query = "how many dogs are in the picture?"
53
+
54
+ model = model.cuda()
55
+ inputs = image_processor([raw_image], return_tensors="pt", image_aspect_ratio='anyres')
56
+ prompt = apply_prompt_template(query)
57
+ language_inputs = tokenizer([prompt], return_tensors="pt")
58
+ inputs.update(language_inputs)
59
+ inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
60
+ generated_text = model.generate(**inputs, image_size=[img.size],
61
+ pad_token_id=tokenizer.pad_token_id,
62
+ do_sample=False, max_new_tokens=768, top_p=None, num_beams=1,
63
+ stopping_criteria = [EosListStoppingCriteria()],
64
+ )
65
+ prediction = tokenizer.decode(generated_text[0], skip_special_tokens=True)
66
+ print("==> prediciton: ", prediction)
67
+ # output: ==> prediciton: There is one dog in the picture.
68
+ ```
69
+
70
+ # License
71
+
72
+ Our code and weights are released under the Creative Commons Attribution Non Commercial 4.0 [LICENSE](LICENSE.txt).
73
+
74
+ # Troubleshoot
75
+
76
+ 1. If you missing any packages, please consider the followings
77
+
78
+ ```
79
+ pip install -U "transformers==4.40.0" --upgrade
80
+ pip install torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1 --index-url https://download.pytorch.org/whl/cu121
81
+ pip install open_clip_torch==2.24.0
82
+ pip install einops
83
+ pip install einops-exts
84
+ ```
added_tokens.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "<pad>": 32011,
3
+ "<|assistant|>": 32001,
4
+ "<|endoftext|>": 32000,
5
+ "<|end|>": 32007,
6
+ "<|placeholder1|>": 32002,
7
+ "<|placeholder2|>": 32003,
8
+ "<|placeholder3|>": 32004,
9
+ "<|placeholder4|>": 32005,
10
+ "<|placeholder5|>": 32008,
11
+ "<|placeholder6|>": 32009,
12
+ "<|system|>": 32006,
13
+ "<|user|>": 32010
14
+ }
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Blip3ModelForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_blip_3.Blip3Config",
7
+ "AutoModelForVision2Seq": "modeling_blip_3.Blip3ModelForConditionalGeneration"
8
+ },
9
+ "model_type": "blip_3",
10
+ "text_config": {
11
+ "initial_tokenizer_len": 32012,
12
+ "model_type": "phi3",
13
+ "sliding_window": 2047,
14
+ "torch_dtype": "bfloat16"
15
+ },
16
+ "torch_dtype": "float32",
17
+ "transformers_version": "4.41.0.dev0",
18
+ "vision_encoder_config": {
19
+ "anyres_patch_sampling": true,
20
+ "image_aspect_ratio": "anyres",
21
+ "model_type": "blip_3_vision_encoder"
22
+ },
23
+ "vision_tokenizer_config": {
24
+ "model_type": "blip_3_vision_tokenizer"
25
+ }
26
+ }
configuration_blip_3.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from transformers import logging
3
+ from transformers import CONFIG_MAPPING
4
+
5
+ logger = logging.get_logger(__name__)
6
+
7
+ class Blip3VisionEncoderConfig(PretrainedConfig):
8
+ model_type = "blip_3_vision_encoder"
9
+
10
+ def __init__(self,
11
+ model_name: str = 'ViT-H-14-378-quickgelu',
12
+ force_image_size: int = 378,
13
+ **kwargs):
14
+ self.model_name = model_name
15
+ self.force_image_size = force_image_size
16
+ super().__init__(**kwargs)
17
+
18
+
19
+ class Blip3VisionTokenizerConfig(PretrainedConfig):
20
+ model_type = "blip_3_vision_tokenizer"
21
+
22
+ def __init__(self,
23
+ vis_feature_dim: int = 1280,
24
+ lang_embedding_dim: int = 3072,
25
+ num_vis_tokens: int = 128,
26
+ image_aspect_ratio: str = 'anyres',
27
+ repeat_latents: bool = False,
28
+ **kwargs):
29
+ self.vis_feature_dim = vis_feature_dim
30
+ self.lang_embedding_dim = lang_embedding_dim
31
+ self.num_vis_tokens = num_vis_tokens
32
+ self.image_aspect_ratio = image_aspect_ratio
33
+ self.repeat_latents = repeat_latents
34
+ super().__init__(**kwargs)
35
+
36
+
37
+ class Blip3Config(PretrainedConfig):
38
+ model_type = "blip_3"
39
+
40
+ def __init__(self,
41
+ vision_encoder_config: dict = None,
42
+ vision_tokenizer_config: dict = None,
43
+ text_config: dict = None,
44
+ **kwargs):
45
+
46
+ if vision_encoder_config is None:
47
+ vision_encoder_config = {'image_aspect_ratio': 'anyres', 'anyres_patch_sampling': True}
48
+ logger.info("vision_encoder_config is None. initializing the Blip3VisionEncoderConfig with default values.")
49
+
50
+ if vision_tokenizer_config is None:
51
+ vision_tokenizer_config = {}
52
+ logger.info("vision_tokenizer_config is None. Initializing the Blip3VisionTokenizerConfig with default values.")
53
+
54
+ if text_config is None:
55
+ text_config = {
56
+ 'initial_tokenizer_len':32012,
57
+ 'pad_token_id':32011,
58
+ 'bos_token_id':1,
59
+ 'eos_token_id':32000,
60
+ 'vocab_size': 32064,
61
+ 'hidden_size': 3072,
62
+ 'intermediate_size': 8192,
63
+ 'num_hidden_layers': 32,
64
+ 'num_attention_heads': 32,
65
+ 'num_key_value_heads': 32,
66
+ 'resid_pdrop': 0.0,
67
+ 'embd_pdrop': 0.0,
68
+ 'attention_dropout': 0.0,
69
+ 'hidden_act': 'silu',
70
+ 'max_position_embeddings': 4096,
71
+ 'original_max_position_embeddings': 4096,
72
+ 'initializer_range': 0.02,
73
+ 'rms_norm_eps': 1e-05,
74
+ 'use_cache': True,
75
+ 'rope_theta': 10000.0,
76
+ 'rope_scaling': None,
77
+ 'sliding_window': 2047,
78
+ 'return_dict': True,
79
+ 'output_hidden_states': False,
80
+ 'output_attentions': False,
81
+ 'torchscript': False,
82
+ 'torch_dtype': 'bfloat16',
83
+ 'use_bfloat16': False,
84
+ 'tf_legacy_loss': False,
85
+ 'pruned_heads': {},
86
+ 'tie_word_embeddings': False,
87
+ 'chunk_size_feed_forward': 0,
88
+ 'is_encoder_decoder': False,
89
+ 'is_decoder': False,
90
+ 'cross_attention_hidden_size': None,
91
+ 'add_cross_attention': False,
92
+ 'tie_encoder_decoder': False,
93
+ 'max_length': 20,
94
+ 'min_length': 0,
95
+ 'do_sample': False,
96
+ 'early_stopping': False,
97
+ 'num_beams': 1,
98
+ 'num_beam_groups': 1,
99
+ 'diversity_penalty': 0.0,
100
+ 'temperature': 1.0,
101
+ 'top_k': 50,
102
+ 'top_p': 1.0,
103
+ 'typical_p': 1.0,
104
+ 'repetition_penalty': 1.0,
105
+ 'length_penalty': 1.0,
106
+ 'no_repeat_ngram_size': 0,
107
+ 'encoder_no_repeat_ngram_size': 0,
108
+ 'bad_words_ids': None,
109
+ 'num_return_sequences': 1,
110
+ 'output_scores': False,
111
+ 'return_dict_in_generate': False,
112
+ 'forced_bos_token_id': None,
113
+ 'forced_eos_token_id': None,
114
+ 'remove_invalid_values': False,
115
+ 'exponential_decay_length_penalty': None,
116
+ 'suppress_tokens': None,
117
+ 'begin_suppress_tokens': None,
118
+ 'finetuning_task': None,
119
+ 'id2label': {0: 'LABEL_0', 1: 'LABEL_1'},
120
+ 'label2id': {'LABEL_0': 0, 'LABEL_1': 1},
121
+ 'tokenizer_class': None,
122
+ 'prefix': None,
123
+ 'bos_token_id': 1,
124
+ 'pad_token_id': 32000,
125
+ 'eos_token_id': 32000,
126
+ 'sep_token_id': None,
127
+ 'decoder_start_token_id': None,
128
+ 'task_specific_params': None,
129
+ 'problem_type': None,
130
+ 'model_type': 'phi3'
131
+ }
132
+ logger.info("text_config is None. Initializing the text config with default values (`Phi3Config`).")
133
+
134
+ self.vision_encoder_config = Blip3VisionEncoderConfig(**vision_encoder_config)
135
+
136
+ self.vision_tokenizer_config = Blip3VisionTokenizerConfig(**vision_tokenizer_config)
137
+
138
+ text_model_type = text_config["model_type"] if "model_type" in text_config else "phi3"
139
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
140
+
141
+ for key in ['initial_tokenizer_len', 'pad_token_id']:
142
+ if key not in self.text_config.to_dict():
143
+ raise ValueError(f"The key `{key}` is missing in the text_config.")
144
+
145
+ super().__init__(**kwargs)
146
+
147
+ @classmethod
148
+ def from_vision_encoder_vision_tokenizer_text_configs(
149
+ cls,
150
+ vision_encoder_config: Blip3VisionEncoderConfig,
151
+ vision_tokenizer_config: Blip3VisionTokenizerConfig,
152
+ text_config: PretrainedConfig,
153
+ **kwargs):
154
+
155
+ return cls(
156
+ vision_encoder_config=vision_encoder_config.to_dict(),
157
+ vision_tokenizer_config=vision_tokenizer_config.to_dict(),
158
+ text_config=text_config.to_dict(),
159
+ **kwargs,
160
+ )
161
+
demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 32000,
5
+ "pad_token_id": 32000,
6
+ "transformers_version": "4.41.0.dev0"
7
+ }
image_processing_blip_3.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
3
+ import torchvision.transforms.functional as F
4
+ from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
5
+ CenterCrop, ColorJitter, Grayscale
6
+ import numbers
7
+ import torch
8
+ import ast
9
+ import math
10
+ from PIL import Image
11
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
12
+ from transformers.image_utils import ImageInput
13
+ from transformers.utils import TensorType
14
+
15
+
16
+ class Blip3ImageProcessor(BaseImageProcessor):
17
+
18
+ def __init__(
19
+ self,
20
+ do_resize: bool = True,
21
+ resize_mode: str = "squash",
22
+ interpolation_mode: str = "bicubic",
23
+ size: Union[Tuple[int, int], List[int]] = None,
24
+ image_mean: Optional[Union[float, List[float]]] = None,
25
+ image_std: Optional[Union[float, List[float]]] = None,
26
+ **kwargs,
27
+ ) -> None:
28
+ super().__init__(**kwargs)
29
+ self.do_resize = do_resize
30
+ self.resize_mode = resize_mode
31
+ self.interpolation_mode = interpolation_mode
32
+ self.size = size if size is not None else (378, 378)
33
+ self.image_mean = image_mean if image_mean is not None else [0.48145466, 0.4578275, 0.40821073]
34
+ self.image_std = image_std if image_std is not None else [0.26862954, 0.26130258, 0.27577711]
35
+
36
+
37
+ @classmethod
38
+ def resize(cls, image_size, resize_mode, interpolation='bicubic', fill_color=0):
39
+ interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC
40
+ if resize_mode == 'longest':
41
+ transforms = [
42
+ ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1),
43
+ CenterCropOrPad(image_size, fill=fill_color)
44
+ ]
45
+ elif resize_mode == 'squash':
46
+ if isinstance(image_size, int):
47
+ image_size = (image_size, image_size)
48
+ transforms = [
49
+ Resize(image_size, interpolation=interpolation_mode),
50
+ ]
51
+ else:
52
+ assert resize_mode == 'shortest'
53
+ if not isinstance(image_size, (tuple, list)):
54
+ image_size = (image_size, image_size)
55
+ if image_size[0] == image_size[1]:
56
+ # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg)
57
+ transforms = [
58
+ Resize(image_size[0], interpolation=interpolation_mode)
59
+ ]
60
+ else:
61
+ # resize shortest edge to matching target dim for non-square target
62
+ transforms = [ResizeKeepRatio(image_size)]
63
+ transforms += [CenterCrop(image_size)]
64
+ return transforms
65
+
66
+ @classmethod
67
+ def convert_rgb(cls, image):
68
+ return image.convert("RGB")
69
+
70
+
71
+ def _preprocess(self,
72
+ images: ImageInput
73
+ ) -> torch.Tensor:
74
+ transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
75
+ transforms.extend([
76
+ self.convert_rgb,
77
+ ToTensor(),
78
+ Normalize(mean=self.image_mean, std=self.image_std)
79
+ ])
80
+ composed_transforms = Compose(transforms)
81
+ images_tensor = composed_transforms(images)
82
+ return images_tensor
83
+
84
+ def preprocess(self,
85
+ images: ImageInput,
86
+ return_tensors: Optional[Union[str, TensorType]] = None,
87
+ **kwargs) -> BatchFeature:
88
+ if 'image_aspect_ratio' in kwargs:
89
+ image_aspect_ratio = kwargs['image_aspect_ratio']
90
+ else:
91
+ image_aspect_ratio = 'pad'
92
+ new_images = []
93
+ if image_aspect_ratio == 'pad':
94
+ for image in images:
95
+ image = self._preprocess(image)
96
+ new_images.append(image)
97
+ else:
98
+ if isinstance(self.size, (tuple, list)):
99
+ base_img_size = self.size[0]
100
+ else:
101
+ raise ValueError("size should be list or tuple")
102
+ for image in images:
103
+ image = process_anyres_image(image, self._preprocess, self.size,
104
+ [
105
+ [base_img_size,base_img_size*2],
106
+ [base_img_size*2,base_img_size],
107
+ [base_img_size*2,base_img_size*2],
108
+ [base_img_size*3,base_img_size],
109
+ [base_img_size,base_img_size*3]
110
+ ])
111
+ new_images.append(image)
112
+
113
+ if all(x.shape == new_images[0].shape for x in new_images):
114
+ new_images = torch.stack(new_images, dim=0)
115
+ if image_aspect_ratio == 'pad':
116
+ new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(0).unsqueeze(0)}, tensor_type=return_tensors)
117
+ else:
118
+ new_images = BatchFeature(data={"pixel_values": new_images.unsqueeze(0)}, tensor_type=return_tensors)
119
+ return new_images
120
+ # def preprocess(self,
121
+ # images: ImageInput,
122
+ # return_tensors: Optional[Union[str, TensorType]] = None,
123
+ # **kwargs) -> BatchFeature:
124
+ # transforms = self.resize(self.size, self.resize_mode, self.interpolation_mode)
125
+ # transforms.extend([
126
+ # self.convert_rgb,
127
+ # ToTensor(),
128
+ # Normalize(mean=self.image_mean, std=self.image_std)
129
+ # ])
130
+ # composed_transforms = Compose(transforms)
131
+ # images_tensor = composed_transforms(images).unsqueeze(0).unsqueeze(1).unsqueeze(0)
132
+ # encoded_outputs = BatchFeature(data={"pixel_values": images_tensor}, tensor_type=return_tensors)
133
+ # return encoded_outputs
134
+
135
+
136
+ class ResizeKeepRatio:
137
+ """ Resize and Keep Ratio
138
+
139
+ Copy & paste from `timm`
140
+ """
141
+
142
+ def __init__(
143
+ self,
144
+ size,
145
+ longest=0.,
146
+ interpolation=InterpolationMode.BICUBIC,
147
+ random_scale_prob=0.,
148
+ random_scale_range=(0.85, 1.05),
149
+ random_aspect_prob=0.,
150
+ random_aspect_range=(0.9, 1.11)
151
+ ):
152
+ if isinstance(size, (list, tuple)):
153
+ self.size = tuple(size)
154
+ else:
155
+ self.size = (size, size)
156
+ self.interpolation = interpolation
157
+ self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
158
+ self.random_scale_prob = random_scale_prob
159
+ self.random_scale_range = random_scale_range
160
+ self.random_aspect_prob = random_aspect_prob
161
+ self.random_aspect_range = random_aspect_range
162
+
163
+ @staticmethod
164
+ def get_params(
165
+ img,
166
+ target_size,
167
+ longest,
168
+ random_scale_prob=0.,
169
+ random_scale_range=(0.85, 1.05),
170
+ random_aspect_prob=0.,
171
+ random_aspect_range=(0.9, 1.11)
172
+ ):
173
+ """Get parameters
174
+ """
175
+ source_size = img.size[::-1] # h, w
176
+ h, w = source_size
177
+ target_h, target_w = target_size
178
+ ratio_h = h / target_h
179
+ ratio_w = w / target_w
180
+ ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
181
+ if random_scale_prob > 0 and random.random() < random_scale_prob:
182
+ ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
183
+ ratio_factor = (ratio_factor, ratio_factor)
184
+ else:
185
+ ratio_factor = (1., 1.)
186
+ if random_aspect_prob > 0 and random.random() < random_aspect_prob:
187
+ aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1])
188
+ ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor)
189
+ size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
190
+ return size
191
+
192
+ def __call__(self, img):
193
+ """
194
+ Args:
195
+ img (PIL Image): Image to be cropped and resized.
196
+
197
+ Returns:
198
+ PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
199
+ """
200
+ size = self.get_params(
201
+ img, self.size, self.longest,
202
+ self.random_scale_prob, self.random_scale_range,
203
+ self.random_aspect_prob, self.random_aspect_range
204
+ )
205
+ img = F.resize(img, size, self.interpolation)
206
+ return img
207
+
208
+ def __repr__(self):
209
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
210
+ format_string += f', interpolation={self.interpolation})'
211
+ format_string += f', longest={self.longest:.3f})'
212
+ return format_string
213
+
214
+ def _setup_size(size, error_msg):
215
+ if isinstance(size, numbers.Number):
216
+ return int(size), int(size)
217
+
218
+ if isinstance(size, Sequence) and len(size) == 1:
219
+ return size[0], size[0]
220
+
221
+ if len(size) != 2:
222
+ raise ValueError(error_msg)
223
+
224
+ return size
225
+
226
+ def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
227
+ """Center crops and/or pads the given image.
228
+ If the image is torch Tensor, it is expected
229
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
230
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
231
+
232
+ Args:
233
+ img (PIL Image or Tensor): Image to be cropped.
234
+ output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
235
+ it is used for both directions.
236
+ fill (int, Tuple[int]): Padding color
237
+
238
+ Returns:
239
+ PIL Image or Tensor: Cropped image.
240
+ """
241
+ if isinstance(output_size, numbers.Number):
242
+ output_size = (int(output_size), int(output_size))
243
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
244
+ output_size = (output_size[0], output_size[0])
245
+
246
+ _, image_height, image_width = F.get_dimensions(img)
247
+ crop_height, crop_width = output_size
248
+
249
+ if crop_width > image_width or crop_height > image_height:
250
+ padding_ltrb = [
251
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
252
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
253
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
254
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
255
+ ]
256
+ img = F.pad(img, padding_ltrb, fill=fill)
257
+ _, image_height, image_width = F.get_dimensions(img)
258
+ if crop_width == image_width and crop_height == image_height:
259
+ return img
260
+
261
+ crop_top = int(round((image_height - crop_height) / 2.0))
262
+ crop_left = int(round((image_width - crop_width) / 2.0))
263
+ return F.crop(img, crop_top, crop_left, crop_height, crop_width)
264
+
265
+ class CenterCropOrPad(torch.nn.Module):
266
+ """Crops the given image at the center.
267
+ If the image is torch Tensor, it is expected
268
+ to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
269
+ If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
270
+
271
+ Args:
272
+ size (sequence or int): Desired output size of the crop. If size is an
273
+ int instead of sequence like (h, w), a square crop (size, size) is
274
+ made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
275
+ """
276
+
277
+ def __init__(self, size, fill=0):
278
+ super().__init__()
279
+ self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
280
+ self.fill = fill
281
+
282
+ def forward(self, img):
283
+ """
284
+ Args:
285
+ img (PIL Image or Tensor): Image to be cropped.
286
+
287
+ Returns:
288
+ PIL Image or Tensor: Cropped image.
289
+ """
290
+ return center_crop_or_pad(img, self.size, fill=self.fill)
291
+
292
+ def __repr__(self) -> str:
293
+ return f"{self.__class__.__name__}(size={self.size})"
294
+
295
+ def process_anyres_image(image, processor, processor_size, grid_pinpoints):
296
+ """
297
+ Process an image with variable resolutions.
298
+
299
+ Args:
300
+ image (PIL.Image.Image): The input image to be processed.
301
+ processor: The image processor object.
302
+ processor_size (tuple, list): The size of the image processor.
303
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
304
+
305
+ Returns:
306
+ torch.Tensor: A tensor containing the processed image patches.
307
+ """
308
+ # FIXME: determine grid_pinpoints from image sizes.
309
+ if type(grid_pinpoints) is list:
310
+ possible_resolutions = grid_pinpoints
311
+ else:
312
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
313
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
314
+ image_padded = resize_and_pad_image(image, best_resolution)
315
+
316
+ # processor_size = processor.transforms[0].size
317
+ patches = divide_to_patches(image_padded, processor_size[0])
318
+
319
+ image_original_resize = image.resize((processor_size[0], processor_size[0]))
320
+
321
+ image_patches = [image_original_resize] + patches
322
+ image_patches = [processor(image_patch)
323
+ for image_patch in image_patches]
324
+ return torch.stack(image_patches, dim=0)
325
+
326
+
327
+ def select_best_resolution(original_size, possible_resolutions):
328
+ """
329
+ Selects the best resolution from a list of possible resolutions based on the original size.
330
+
331
+ Args:
332
+ original_size (tuple): The original size of the image in the format (width, height).
333
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
334
+
335
+ Returns:
336
+ tuple: The best fit resolution in the format (width, height).
337
+ """
338
+ original_width, original_height = original_size
339
+ best_fit = None
340
+ max_effective_resolution = 0
341
+ min_wasted_resolution = float('inf')
342
+
343
+ for width, height in possible_resolutions:
344
+ scale = min(width / original_width, height / original_height)
345
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
346
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
347
+ wasted_resolution = (width * height) - effective_resolution
348
+
349
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
350
+ max_effective_resolution = effective_resolution
351
+ min_wasted_resolution = wasted_resolution
352
+ best_fit = (width, height)
353
+
354
+ return best_fit
355
+
356
+ def resize_and_pad_image(image, target_resolution):
357
+ """
358
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
359
+
360
+ Args:
361
+ image (PIL.Image.Image): The input image.
362
+ target_resolution (tuple): The target resolution (width, height) of the image.
363
+
364
+ Returns:
365
+ PIL.Image.Image: The resized and padded image.
366
+ """
367
+ original_width, original_height = image.size
368
+ target_width, target_height = target_resolution
369
+
370
+ scale_w = target_width / original_width
371
+ scale_h = target_height / original_height
372
+
373
+ if scale_w < scale_h:
374
+ new_width = target_width
375
+ new_height = min(math.ceil(original_height * scale_w), target_height)
376
+ else:
377
+ new_height = target_height
378
+ new_width = min(math.ceil(original_width * scale_h), target_width)
379
+
380
+ # Resize the image
381
+ resized_image = image.resize((new_width, new_height))
382
+
383
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
384
+ paste_x = (target_width - new_width) // 2
385
+ paste_y = (target_height - new_height) // 2
386
+ new_image.paste(resized_image, (paste_x, paste_y))
387
+
388
+ return new_image
389
+
390
+ def divide_to_patches(image, patch_size):
391
+ """
392
+ Divides an image into patches of a specified size.
393
+
394
+ Args:
395
+ image (PIL.Image.Image): The input image.
396
+ patch_size (int): The size of each patch.
397
+
398
+ Returns:
399
+ list: A list of PIL.Image.Image objects representing the patches.
400
+ """
401
+ patches = []
402
+ width, height = image.size
403
+ for i in range(0, height, patch_size):
404
+ for j in range(0, width, patch_size):
405
+ box = (j, i, j + patch_size, i + patch_size)
406
+ patch = image.crop(box)
407
+ patches.append(patch)
408
+
409
+ return patches
model-00001-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e6acb1fba540ae862c8ec5a3f0898a25f97df07029235aa164be537e13e664b
3
+ size 4977054880
model-00002-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:edf7152efe2adbf112f09395c1a443d53b439cba0342285dae7ab6225281dcaa
3
+ size 4983112128
model-00003-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1afa3c02b194f2a1cd36f296ef7690c42f356b9b48e98644011b59336b0699a
3
+ size 4983112168
model-00004-of-00004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b69d6172f961180def49586fe73b71c2bd2e4ba968564f276486e86030a1da36
3
+ size 3414256548
model.safetensors.index.json ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 18357448972
4
+ },
5
+ "weight_map": {
6
+ "vlm.lang_model.lm_head.additional_fc.bias": "model-00004-of-00004.safetensors",
7
+ "vlm.lang_model.lm_head.additional_fc.weight": "model-00004-of-00004.safetensors",
8
+ "vlm.lang_model.lm_head.bias": "model-00004-of-00004.safetensors",
9
+ "vlm.lang_model.lm_head.weight": "model-00004-of-00004.safetensors",
10
+ "vlm.lang_model.model.embed_tokens.additional_embedding.weight": "model-00001-of-00004.safetensors",
11
+ "vlm.lang_model.model.embed_tokens.weight": "model-00001-of-00004.safetensors",
12
+ "vlm.lang_model.model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors",
13
+ "vlm.lang_model.model.layers.0.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
14
+ "vlm.lang_model.model.layers.0.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
15
+ "vlm.lang_model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
16
+ "vlm.lang_model.model.layers.0.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
17
+ "vlm.lang_model.model.layers.0.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
18
+ "vlm.lang_model.model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors",
19
+ "vlm.lang_model.model.layers.1.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
20
+ "vlm.lang_model.model.layers.1.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
21
+ "vlm.lang_model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
22
+ "vlm.lang_model.model.layers.1.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
23
+ "vlm.lang_model.model.layers.1.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
24
+ "vlm.lang_model.model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors",
25
+ "vlm.lang_model.model.layers.10.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
26
+ "vlm.lang_model.model.layers.10.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
27
+ "vlm.lang_model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
28
+ "vlm.lang_model.model.layers.10.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
29
+ "vlm.lang_model.model.layers.10.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
30
+ "vlm.lang_model.model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors",
31
+ "vlm.lang_model.model.layers.11.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
32
+ "vlm.lang_model.model.layers.11.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
33
+ "vlm.lang_model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
34
+ "vlm.lang_model.model.layers.11.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
35
+ "vlm.lang_model.model.layers.11.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
36
+ "vlm.lang_model.model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors",
37
+ "vlm.lang_model.model.layers.12.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
38
+ "vlm.lang_model.model.layers.12.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
39
+ "vlm.lang_model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
40
+ "vlm.lang_model.model.layers.12.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
41
+ "vlm.lang_model.model.layers.12.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
42
+ "vlm.lang_model.model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors",
43
+ "vlm.lang_model.model.layers.13.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
44
+ "vlm.lang_model.model.layers.13.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
45
+ "vlm.lang_model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
46
+ "vlm.lang_model.model.layers.13.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
47
+ "vlm.lang_model.model.layers.13.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
48
+ "vlm.lang_model.model.layers.14.input_layernorm.weight": "model-00003-of-00004.safetensors",
49
+ "vlm.lang_model.model.layers.14.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
50
+ "vlm.lang_model.model.layers.14.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
51
+ "vlm.lang_model.model.layers.14.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
52
+ "vlm.lang_model.model.layers.14.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
53
+ "vlm.lang_model.model.layers.14.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
54
+ "vlm.lang_model.model.layers.15.input_layernorm.weight": "model-00003-of-00004.safetensors",
55
+ "vlm.lang_model.model.layers.15.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
56
+ "vlm.lang_model.model.layers.15.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
57
+ "vlm.lang_model.model.layers.15.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
58
+ "vlm.lang_model.model.layers.15.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
59
+ "vlm.lang_model.model.layers.15.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
60
+ "vlm.lang_model.model.layers.16.input_layernorm.weight": "model-00003-of-00004.safetensors",
61
+ "vlm.lang_model.model.layers.16.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
62
+ "vlm.lang_model.model.layers.16.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
63
+ "vlm.lang_model.model.layers.16.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
64
+ "vlm.lang_model.model.layers.16.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
65
+ "vlm.lang_model.model.layers.16.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
66
+ "vlm.lang_model.model.layers.17.input_layernorm.weight": "model-00003-of-00004.safetensors",
67
+ "vlm.lang_model.model.layers.17.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
68
+ "vlm.lang_model.model.layers.17.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
69
+ "vlm.lang_model.model.layers.17.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
70
+ "vlm.lang_model.model.layers.17.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
71
+ "vlm.lang_model.model.layers.17.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
72
+ "vlm.lang_model.model.layers.18.input_layernorm.weight": "model-00003-of-00004.safetensors",
73
+ "vlm.lang_model.model.layers.18.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
74
+ "vlm.lang_model.model.layers.18.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
75
+ "vlm.lang_model.model.layers.18.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
76
+ "vlm.lang_model.model.layers.18.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
77
+ "vlm.lang_model.model.layers.18.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
78
+ "vlm.lang_model.model.layers.19.input_layernorm.weight": "model-00003-of-00004.safetensors",
79
+ "vlm.lang_model.model.layers.19.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
80
+ "vlm.lang_model.model.layers.19.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
81
+ "vlm.lang_model.model.layers.19.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
82
+ "vlm.lang_model.model.layers.19.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
83
+ "vlm.lang_model.model.layers.19.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
84
+ "vlm.lang_model.model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors",
85
+ "vlm.lang_model.model.layers.2.mlp.down_proj.weight": "model-00001-of-00004.safetensors",
86
+ "vlm.lang_model.model.layers.2.mlp.gate_up_proj.weight": "model-00001-of-00004.safetensors",
87
+ "vlm.lang_model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors",
88
+ "vlm.lang_model.model.layers.2.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
89
+ "vlm.lang_model.model.layers.2.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
90
+ "vlm.lang_model.model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors",
91
+ "vlm.lang_model.model.layers.20.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
92
+ "vlm.lang_model.model.layers.20.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
93
+ "vlm.lang_model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
94
+ "vlm.lang_model.model.layers.20.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
95
+ "vlm.lang_model.model.layers.20.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
96
+ "vlm.lang_model.model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors",
97
+ "vlm.lang_model.model.layers.21.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
98
+ "vlm.lang_model.model.layers.21.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
99
+ "vlm.lang_model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
100
+ "vlm.lang_model.model.layers.21.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
101
+ "vlm.lang_model.model.layers.21.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
102
+ "vlm.lang_model.model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors",
103
+ "vlm.lang_model.model.layers.22.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
104
+ "vlm.lang_model.model.layers.22.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
105
+ "vlm.lang_model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
106
+ "vlm.lang_model.model.layers.22.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
107
+ "vlm.lang_model.model.layers.22.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
108
+ "vlm.lang_model.model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors",
109
+ "vlm.lang_model.model.layers.23.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
110
+ "vlm.lang_model.model.layers.23.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
111
+ "vlm.lang_model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
112
+ "vlm.lang_model.model.layers.23.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
113
+ "vlm.lang_model.model.layers.23.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
114
+ "vlm.lang_model.model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors",
115
+ "vlm.lang_model.model.layers.24.mlp.down_proj.weight": "model-00003-of-00004.safetensors",
116
+ "vlm.lang_model.model.layers.24.mlp.gate_up_proj.weight": "model-00003-of-00004.safetensors",
117
+ "vlm.lang_model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors",
118
+ "vlm.lang_model.model.layers.24.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
119
+ "vlm.lang_model.model.layers.24.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
120
+ "vlm.lang_model.model.layers.25.input_layernorm.weight": "model-00004-of-00004.safetensors",
121
+ "vlm.lang_model.model.layers.25.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
122
+ "vlm.lang_model.model.layers.25.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
123
+ "vlm.lang_model.model.layers.25.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
124
+ "vlm.lang_model.model.layers.25.self_attn.o_proj.weight": "model-00003-of-00004.safetensors",
125
+ "vlm.lang_model.model.layers.25.self_attn.qkv_proj.weight": "model-00003-of-00004.safetensors",
126
+ "vlm.lang_model.model.layers.26.input_layernorm.weight": "model-00004-of-00004.safetensors",
127
+ "vlm.lang_model.model.layers.26.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
128
+ "vlm.lang_model.model.layers.26.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
129
+ "vlm.lang_model.model.layers.26.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
130
+ "vlm.lang_model.model.layers.26.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
131
+ "vlm.lang_model.model.layers.26.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
132
+ "vlm.lang_model.model.layers.27.input_layernorm.weight": "model-00004-of-00004.safetensors",
133
+ "vlm.lang_model.model.layers.27.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
134
+ "vlm.lang_model.model.layers.27.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
135
+ "vlm.lang_model.model.layers.27.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
136
+ "vlm.lang_model.model.layers.27.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
137
+ "vlm.lang_model.model.layers.27.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
138
+ "vlm.lang_model.model.layers.28.input_layernorm.weight": "model-00004-of-00004.safetensors",
139
+ "vlm.lang_model.model.layers.28.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
140
+ "vlm.lang_model.model.layers.28.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
141
+ "vlm.lang_model.model.layers.28.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
142
+ "vlm.lang_model.model.layers.28.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
143
+ "vlm.lang_model.model.layers.28.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
144
+ "vlm.lang_model.model.layers.29.input_layernorm.weight": "model-00004-of-00004.safetensors",
145
+ "vlm.lang_model.model.layers.29.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
146
+ "vlm.lang_model.model.layers.29.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
147
+ "vlm.lang_model.model.layers.29.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
148
+ "vlm.lang_model.model.layers.29.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
149
+ "vlm.lang_model.model.layers.29.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
150
+ "vlm.lang_model.model.layers.3.input_layernorm.weight": "model-00002-of-00004.safetensors",
151
+ "vlm.lang_model.model.layers.3.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
152
+ "vlm.lang_model.model.layers.3.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
153
+ "vlm.lang_model.model.layers.3.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
154
+ "vlm.lang_model.model.layers.3.self_attn.o_proj.weight": "model-00001-of-00004.safetensors",
155
+ "vlm.lang_model.model.layers.3.self_attn.qkv_proj.weight": "model-00001-of-00004.safetensors",
156
+ "vlm.lang_model.model.layers.30.input_layernorm.weight": "model-00004-of-00004.safetensors",
157
+ "vlm.lang_model.model.layers.30.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
158
+ "vlm.lang_model.model.layers.30.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
159
+ "vlm.lang_model.model.layers.30.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
160
+ "vlm.lang_model.model.layers.30.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
161
+ "vlm.lang_model.model.layers.30.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
162
+ "vlm.lang_model.model.layers.31.input_layernorm.weight": "model-00004-of-00004.safetensors",
163
+ "vlm.lang_model.model.layers.31.mlp.down_proj.weight": "model-00004-of-00004.safetensors",
164
+ "vlm.lang_model.model.layers.31.mlp.gate_up_proj.weight": "model-00004-of-00004.safetensors",
165
+ "vlm.lang_model.model.layers.31.post_attention_layernorm.weight": "model-00004-of-00004.safetensors",
166
+ "vlm.lang_model.model.layers.31.self_attn.o_proj.weight": "model-00004-of-00004.safetensors",
167
+ "vlm.lang_model.model.layers.31.self_attn.qkv_proj.weight": "model-00004-of-00004.safetensors",
168
+ "vlm.lang_model.model.layers.4.input_layernorm.weight": "model-00002-of-00004.safetensors",
169
+ "vlm.lang_model.model.layers.4.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
170
+ "vlm.lang_model.model.layers.4.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
171
+ "vlm.lang_model.model.layers.4.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
172
+ "vlm.lang_model.model.layers.4.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
173
+ "vlm.lang_model.model.layers.4.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
174
+ "vlm.lang_model.model.layers.5.input_layernorm.weight": "model-00002-of-00004.safetensors",
175
+ "vlm.lang_model.model.layers.5.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
176
+ "vlm.lang_model.model.layers.5.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
177
+ "vlm.lang_model.model.layers.5.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
178
+ "vlm.lang_model.model.layers.5.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
179
+ "vlm.lang_model.model.layers.5.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
180
+ "vlm.lang_model.model.layers.6.input_layernorm.weight": "model-00002-of-00004.safetensors",
181
+ "vlm.lang_model.model.layers.6.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
182
+ "vlm.lang_model.model.layers.6.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
183
+ "vlm.lang_model.model.layers.6.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
184
+ "vlm.lang_model.model.layers.6.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
185
+ "vlm.lang_model.model.layers.6.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
186
+ "vlm.lang_model.model.layers.7.input_layernorm.weight": "model-00002-of-00004.safetensors",
187
+ "vlm.lang_model.model.layers.7.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
188
+ "vlm.lang_model.model.layers.7.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
189
+ "vlm.lang_model.model.layers.7.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
190
+ "vlm.lang_model.model.layers.7.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
191
+ "vlm.lang_model.model.layers.7.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
192
+ "vlm.lang_model.model.layers.8.input_layernorm.weight": "model-00002-of-00004.safetensors",
193
+ "vlm.lang_model.model.layers.8.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
194
+ "vlm.lang_model.model.layers.8.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
195
+ "vlm.lang_model.model.layers.8.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
196
+ "vlm.lang_model.model.layers.8.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
197
+ "vlm.lang_model.model.layers.8.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
198
+ "vlm.lang_model.model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors",
199
+ "vlm.lang_model.model.layers.9.mlp.down_proj.weight": "model-00002-of-00004.safetensors",
200
+ "vlm.lang_model.model.layers.9.mlp.gate_up_proj.weight": "model-00002-of-00004.safetensors",
201
+ "vlm.lang_model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors",
202
+ "vlm.lang_model.model.layers.9.self_attn.o_proj.weight": "model-00002-of-00004.safetensors",
203
+ "vlm.lang_model.model.layers.9.self_attn.qkv_proj.weight": "model-00002-of-00004.safetensors",
204
+ "vlm.lang_model.model.norm.weight": "model-00004-of-00004.safetensors",
205
+ "vlm.vision_encoder.class_embedding": "model-00001-of-00004.safetensors",
206
+ "vlm.vision_encoder.conv1.weight": "model-00001-of-00004.safetensors",
207
+ "vlm.vision_encoder.ln_post.bias": "model-00001-of-00004.safetensors",
208
+ "vlm.vision_encoder.ln_post.weight": "model-00001-of-00004.safetensors",
209
+ "vlm.vision_encoder.ln_pre.bias": "model-00001-of-00004.safetensors",
210
+ "vlm.vision_encoder.ln_pre.weight": "model-00001-of-00004.safetensors",
211
+ "vlm.vision_encoder.positional_embedding": "model-00001-of-00004.safetensors",
212
+ "vlm.vision_encoder.proj": "model-00001-of-00004.safetensors",
213
+ "vlm.vision_encoder.transformer.resblocks.0.attn.in_proj_bias": "model-00001-of-00004.safetensors",
214
+ "vlm.vision_encoder.transformer.resblocks.0.attn.in_proj_weight": "model-00001-of-00004.safetensors",
215
+ "vlm.vision_encoder.transformer.resblocks.0.attn.out_proj.bias": "model-00001-of-00004.safetensors",
216
+ "vlm.vision_encoder.transformer.resblocks.0.attn.out_proj.weight": "model-00001-of-00004.safetensors",
217
+ "vlm.vision_encoder.transformer.resblocks.0.ln_1.bias": "model-00001-of-00004.safetensors",
218
+ "vlm.vision_encoder.transformer.resblocks.0.ln_1.weight": "model-00001-of-00004.safetensors",
219
+ "vlm.vision_encoder.transformer.resblocks.0.ln_2.bias": "model-00001-of-00004.safetensors",
220
+ "vlm.vision_encoder.transformer.resblocks.0.ln_2.weight": "model-00001-of-00004.safetensors",
221
+ "vlm.vision_encoder.transformer.resblocks.0.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
222
+ "vlm.vision_encoder.transformer.resblocks.0.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
223
+ "vlm.vision_encoder.transformer.resblocks.0.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
224
+ "vlm.vision_encoder.transformer.resblocks.0.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
225
+ "vlm.vision_encoder.transformer.resblocks.1.attn.in_proj_bias": "model-00001-of-00004.safetensors",
226
+ "vlm.vision_encoder.transformer.resblocks.1.attn.in_proj_weight": "model-00001-of-00004.safetensors",
227
+ "vlm.vision_encoder.transformer.resblocks.1.attn.out_proj.bias": "model-00001-of-00004.safetensors",
228
+ "vlm.vision_encoder.transformer.resblocks.1.attn.out_proj.weight": "model-00001-of-00004.safetensors",
229
+ "vlm.vision_encoder.transformer.resblocks.1.ln_1.bias": "model-00001-of-00004.safetensors",
230
+ "vlm.vision_encoder.transformer.resblocks.1.ln_1.weight": "model-00001-of-00004.safetensors",
231
+ "vlm.vision_encoder.transformer.resblocks.1.ln_2.bias": "model-00001-of-00004.safetensors",
232
+ "vlm.vision_encoder.transformer.resblocks.1.ln_2.weight": "model-00001-of-00004.safetensors",
233
+ "vlm.vision_encoder.transformer.resblocks.1.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
234
+ "vlm.vision_encoder.transformer.resblocks.1.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
235
+ "vlm.vision_encoder.transformer.resblocks.1.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
236
+ "vlm.vision_encoder.transformer.resblocks.1.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
237
+ "vlm.vision_encoder.transformer.resblocks.10.attn.in_proj_bias": "model-00001-of-00004.safetensors",
238
+ "vlm.vision_encoder.transformer.resblocks.10.attn.in_proj_weight": "model-00001-of-00004.safetensors",
239
+ "vlm.vision_encoder.transformer.resblocks.10.attn.out_proj.bias": "model-00001-of-00004.safetensors",
240
+ "vlm.vision_encoder.transformer.resblocks.10.attn.out_proj.weight": "model-00001-of-00004.safetensors",
241
+ "vlm.vision_encoder.transformer.resblocks.10.ln_1.bias": "model-00001-of-00004.safetensors",
242
+ "vlm.vision_encoder.transformer.resblocks.10.ln_1.weight": "model-00001-of-00004.safetensors",
243
+ "vlm.vision_encoder.transformer.resblocks.10.ln_2.bias": "model-00001-of-00004.safetensors",
244
+ "vlm.vision_encoder.transformer.resblocks.10.ln_2.weight": "model-00001-of-00004.safetensors",
245
+ "vlm.vision_encoder.transformer.resblocks.10.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
246
+ "vlm.vision_encoder.transformer.resblocks.10.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
247
+ "vlm.vision_encoder.transformer.resblocks.10.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
248
+ "vlm.vision_encoder.transformer.resblocks.10.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
249
+ "vlm.vision_encoder.transformer.resblocks.11.attn.in_proj_bias": "model-00001-of-00004.safetensors",
250
+ "vlm.vision_encoder.transformer.resblocks.11.attn.in_proj_weight": "model-00001-of-00004.safetensors",
251
+ "vlm.vision_encoder.transformer.resblocks.11.attn.out_proj.bias": "model-00001-of-00004.safetensors",
252
+ "vlm.vision_encoder.transformer.resblocks.11.attn.out_proj.weight": "model-00001-of-00004.safetensors",
253
+ "vlm.vision_encoder.transformer.resblocks.11.ln_1.bias": "model-00001-of-00004.safetensors",
254
+ "vlm.vision_encoder.transformer.resblocks.11.ln_1.weight": "model-00001-of-00004.safetensors",
255
+ "vlm.vision_encoder.transformer.resblocks.11.ln_2.bias": "model-00001-of-00004.safetensors",
256
+ "vlm.vision_encoder.transformer.resblocks.11.ln_2.weight": "model-00001-of-00004.safetensors",
257
+ "vlm.vision_encoder.transformer.resblocks.11.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
258
+ "vlm.vision_encoder.transformer.resblocks.11.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
259
+ "vlm.vision_encoder.transformer.resblocks.11.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
260
+ "vlm.vision_encoder.transformer.resblocks.11.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
261
+ "vlm.vision_encoder.transformer.resblocks.12.attn.in_proj_bias": "model-00001-of-00004.safetensors",
262
+ "vlm.vision_encoder.transformer.resblocks.12.attn.in_proj_weight": "model-00001-of-00004.safetensors",
263
+ "vlm.vision_encoder.transformer.resblocks.12.attn.out_proj.bias": "model-00001-of-00004.safetensors",
264
+ "vlm.vision_encoder.transformer.resblocks.12.attn.out_proj.weight": "model-00001-of-00004.safetensors",
265
+ "vlm.vision_encoder.transformer.resblocks.12.ln_1.bias": "model-00001-of-00004.safetensors",
266
+ "vlm.vision_encoder.transformer.resblocks.12.ln_1.weight": "model-00001-of-00004.safetensors",
267
+ "vlm.vision_encoder.transformer.resblocks.12.ln_2.bias": "model-00001-of-00004.safetensors",
268
+ "vlm.vision_encoder.transformer.resblocks.12.ln_2.weight": "model-00001-of-00004.safetensors",
269
+ "vlm.vision_encoder.transformer.resblocks.12.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
270
+ "vlm.vision_encoder.transformer.resblocks.12.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
271
+ "vlm.vision_encoder.transformer.resblocks.12.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
272
+ "vlm.vision_encoder.transformer.resblocks.12.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
273
+ "vlm.vision_encoder.transformer.resblocks.13.attn.in_proj_bias": "model-00001-of-00004.safetensors",
274
+ "vlm.vision_encoder.transformer.resblocks.13.attn.in_proj_weight": "model-00001-of-00004.safetensors",
275
+ "vlm.vision_encoder.transformer.resblocks.13.attn.out_proj.bias": "model-00001-of-00004.safetensors",
276
+ "vlm.vision_encoder.transformer.resblocks.13.attn.out_proj.weight": "model-00001-of-00004.safetensors",
277
+ "vlm.vision_encoder.transformer.resblocks.13.ln_1.bias": "model-00001-of-00004.safetensors",
278
+ "vlm.vision_encoder.transformer.resblocks.13.ln_1.weight": "model-00001-of-00004.safetensors",
279
+ "vlm.vision_encoder.transformer.resblocks.13.ln_2.bias": "model-00001-of-00004.safetensors",
280
+ "vlm.vision_encoder.transformer.resblocks.13.ln_2.weight": "model-00001-of-00004.safetensors",
281
+ "vlm.vision_encoder.transformer.resblocks.13.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
282
+ "vlm.vision_encoder.transformer.resblocks.13.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
283
+ "vlm.vision_encoder.transformer.resblocks.13.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
284
+ "vlm.vision_encoder.transformer.resblocks.13.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
285
+ "vlm.vision_encoder.transformer.resblocks.14.attn.in_proj_bias": "model-00001-of-00004.safetensors",
286
+ "vlm.vision_encoder.transformer.resblocks.14.attn.in_proj_weight": "model-00001-of-00004.safetensors",
287
+ "vlm.vision_encoder.transformer.resblocks.14.attn.out_proj.bias": "model-00001-of-00004.safetensors",
288
+ "vlm.vision_encoder.transformer.resblocks.14.attn.out_proj.weight": "model-00001-of-00004.safetensors",
289
+ "vlm.vision_encoder.transformer.resblocks.14.ln_1.bias": "model-00001-of-00004.safetensors",
290
+ "vlm.vision_encoder.transformer.resblocks.14.ln_1.weight": "model-00001-of-00004.safetensors",
291
+ "vlm.vision_encoder.transformer.resblocks.14.ln_2.bias": "model-00001-of-00004.safetensors",
292
+ "vlm.vision_encoder.transformer.resblocks.14.ln_2.weight": "model-00001-of-00004.safetensors",
293
+ "vlm.vision_encoder.transformer.resblocks.14.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
294
+ "vlm.vision_encoder.transformer.resblocks.14.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
295
+ "vlm.vision_encoder.transformer.resblocks.14.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
296
+ "vlm.vision_encoder.transformer.resblocks.14.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
297
+ "vlm.vision_encoder.transformer.resblocks.15.attn.in_proj_bias": "model-00001-of-00004.safetensors",
298
+ "vlm.vision_encoder.transformer.resblocks.15.attn.in_proj_weight": "model-00001-of-00004.safetensors",
299
+ "vlm.vision_encoder.transformer.resblocks.15.attn.out_proj.bias": "model-00001-of-00004.safetensors",
300
+ "vlm.vision_encoder.transformer.resblocks.15.attn.out_proj.weight": "model-00001-of-00004.safetensors",
301
+ "vlm.vision_encoder.transformer.resblocks.15.ln_1.bias": "model-00001-of-00004.safetensors",
302
+ "vlm.vision_encoder.transformer.resblocks.15.ln_1.weight": "model-00001-of-00004.safetensors",
303
+ "vlm.vision_encoder.transformer.resblocks.15.ln_2.bias": "model-00001-of-00004.safetensors",
304
+ "vlm.vision_encoder.transformer.resblocks.15.ln_2.weight": "model-00001-of-00004.safetensors",
305
+ "vlm.vision_encoder.transformer.resblocks.15.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
306
+ "vlm.vision_encoder.transformer.resblocks.15.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
307
+ "vlm.vision_encoder.transformer.resblocks.15.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
308
+ "vlm.vision_encoder.transformer.resblocks.15.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
309
+ "vlm.vision_encoder.transformer.resblocks.16.attn.in_proj_bias": "model-00001-of-00004.safetensors",
310
+ "vlm.vision_encoder.transformer.resblocks.16.attn.in_proj_weight": "model-00001-of-00004.safetensors",
311
+ "vlm.vision_encoder.transformer.resblocks.16.attn.out_proj.bias": "model-00001-of-00004.safetensors",
312
+ "vlm.vision_encoder.transformer.resblocks.16.attn.out_proj.weight": "model-00001-of-00004.safetensors",
313
+ "vlm.vision_encoder.transformer.resblocks.16.ln_1.bias": "model-00001-of-00004.safetensors",
314
+ "vlm.vision_encoder.transformer.resblocks.16.ln_1.weight": "model-00001-of-00004.safetensors",
315
+ "vlm.vision_encoder.transformer.resblocks.16.ln_2.bias": "model-00001-of-00004.safetensors",
316
+ "vlm.vision_encoder.transformer.resblocks.16.ln_2.weight": "model-00001-of-00004.safetensors",
317
+ "vlm.vision_encoder.transformer.resblocks.16.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
318
+ "vlm.vision_encoder.transformer.resblocks.16.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
319
+ "vlm.vision_encoder.transformer.resblocks.16.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
320
+ "vlm.vision_encoder.transformer.resblocks.16.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
321
+ "vlm.vision_encoder.transformer.resblocks.17.attn.in_proj_bias": "model-00001-of-00004.safetensors",
322
+ "vlm.vision_encoder.transformer.resblocks.17.attn.in_proj_weight": "model-00001-of-00004.safetensors",
323
+ "vlm.vision_encoder.transformer.resblocks.17.attn.out_proj.bias": "model-00001-of-00004.safetensors",
324
+ "vlm.vision_encoder.transformer.resblocks.17.attn.out_proj.weight": "model-00001-of-00004.safetensors",
325
+ "vlm.vision_encoder.transformer.resblocks.17.ln_1.bias": "model-00001-of-00004.safetensors",
326
+ "vlm.vision_encoder.transformer.resblocks.17.ln_1.weight": "model-00001-of-00004.safetensors",
327
+ "vlm.vision_encoder.transformer.resblocks.17.ln_2.bias": "model-00001-of-00004.safetensors",
328
+ "vlm.vision_encoder.transformer.resblocks.17.ln_2.weight": "model-00001-of-00004.safetensors",
329
+ "vlm.vision_encoder.transformer.resblocks.17.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
330
+ "vlm.vision_encoder.transformer.resblocks.17.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
331
+ "vlm.vision_encoder.transformer.resblocks.17.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
332
+ "vlm.vision_encoder.transformer.resblocks.17.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
333
+ "vlm.vision_encoder.transformer.resblocks.18.attn.in_proj_bias": "model-00001-of-00004.safetensors",
334
+ "vlm.vision_encoder.transformer.resblocks.18.attn.in_proj_weight": "model-00001-of-00004.safetensors",
335
+ "vlm.vision_encoder.transformer.resblocks.18.attn.out_proj.bias": "model-00001-of-00004.safetensors",
336
+ "vlm.vision_encoder.transformer.resblocks.18.attn.out_proj.weight": "model-00001-of-00004.safetensors",
337
+ "vlm.vision_encoder.transformer.resblocks.18.ln_1.bias": "model-00001-of-00004.safetensors",
338
+ "vlm.vision_encoder.transformer.resblocks.18.ln_1.weight": "model-00001-of-00004.safetensors",
339
+ "vlm.vision_encoder.transformer.resblocks.18.ln_2.bias": "model-00001-of-00004.safetensors",
340
+ "vlm.vision_encoder.transformer.resblocks.18.ln_2.weight": "model-00001-of-00004.safetensors",
341
+ "vlm.vision_encoder.transformer.resblocks.18.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
342
+ "vlm.vision_encoder.transformer.resblocks.18.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
343
+ "vlm.vision_encoder.transformer.resblocks.18.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
344
+ "vlm.vision_encoder.transformer.resblocks.18.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
345
+ "vlm.vision_encoder.transformer.resblocks.19.attn.in_proj_bias": "model-00001-of-00004.safetensors",
346
+ "vlm.vision_encoder.transformer.resblocks.19.attn.in_proj_weight": "model-00001-of-00004.safetensors",
347
+ "vlm.vision_encoder.transformer.resblocks.19.attn.out_proj.bias": "model-00001-of-00004.safetensors",
348
+ "vlm.vision_encoder.transformer.resblocks.19.attn.out_proj.weight": "model-00001-of-00004.safetensors",
349
+ "vlm.vision_encoder.transformer.resblocks.19.ln_1.bias": "model-00001-of-00004.safetensors",
350
+ "vlm.vision_encoder.transformer.resblocks.19.ln_1.weight": "model-00001-of-00004.safetensors",
351
+ "vlm.vision_encoder.transformer.resblocks.19.ln_2.bias": "model-00001-of-00004.safetensors",
352
+ "vlm.vision_encoder.transformer.resblocks.19.ln_2.weight": "model-00001-of-00004.safetensors",
353
+ "vlm.vision_encoder.transformer.resblocks.19.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
354
+ "vlm.vision_encoder.transformer.resblocks.19.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
355
+ "vlm.vision_encoder.transformer.resblocks.19.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
356
+ "vlm.vision_encoder.transformer.resblocks.19.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
357
+ "vlm.vision_encoder.transformer.resblocks.2.attn.in_proj_bias": "model-00001-of-00004.safetensors",
358
+ "vlm.vision_encoder.transformer.resblocks.2.attn.in_proj_weight": "model-00001-of-00004.safetensors",
359
+ "vlm.vision_encoder.transformer.resblocks.2.attn.out_proj.bias": "model-00001-of-00004.safetensors",
360
+ "vlm.vision_encoder.transformer.resblocks.2.attn.out_proj.weight": "model-00001-of-00004.safetensors",
361
+ "vlm.vision_encoder.transformer.resblocks.2.ln_1.bias": "model-00001-of-00004.safetensors",
362
+ "vlm.vision_encoder.transformer.resblocks.2.ln_1.weight": "model-00001-of-00004.safetensors",
363
+ "vlm.vision_encoder.transformer.resblocks.2.ln_2.bias": "model-00001-of-00004.safetensors",
364
+ "vlm.vision_encoder.transformer.resblocks.2.ln_2.weight": "model-00001-of-00004.safetensors",
365
+ "vlm.vision_encoder.transformer.resblocks.2.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
366
+ "vlm.vision_encoder.transformer.resblocks.2.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
367
+ "vlm.vision_encoder.transformer.resblocks.2.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
368
+ "vlm.vision_encoder.transformer.resblocks.2.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
369
+ "vlm.vision_encoder.transformer.resblocks.20.attn.in_proj_bias": "model-00001-of-00004.safetensors",
370
+ "vlm.vision_encoder.transformer.resblocks.20.attn.in_proj_weight": "model-00001-of-00004.safetensors",
371
+ "vlm.vision_encoder.transformer.resblocks.20.attn.out_proj.bias": "model-00001-of-00004.safetensors",
372
+ "vlm.vision_encoder.transformer.resblocks.20.attn.out_proj.weight": "model-00001-of-00004.safetensors",
373
+ "vlm.vision_encoder.transformer.resblocks.20.ln_1.bias": "model-00001-of-00004.safetensors",
374
+ "vlm.vision_encoder.transformer.resblocks.20.ln_1.weight": "model-00001-of-00004.safetensors",
375
+ "vlm.vision_encoder.transformer.resblocks.20.ln_2.bias": "model-00001-of-00004.safetensors",
376
+ "vlm.vision_encoder.transformer.resblocks.20.ln_2.weight": "model-00001-of-00004.safetensors",
377
+ "vlm.vision_encoder.transformer.resblocks.20.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
378
+ "vlm.vision_encoder.transformer.resblocks.20.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
379
+ "vlm.vision_encoder.transformer.resblocks.20.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
380
+ "vlm.vision_encoder.transformer.resblocks.20.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
381
+ "vlm.vision_encoder.transformer.resblocks.21.attn.in_proj_bias": "model-00001-of-00004.safetensors",
382
+ "vlm.vision_encoder.transformer.resblocks.21.attn.in_proj_weight": "model-00001-of-00004.safetensors",
383
+ "vlm.vision_encoder.transformer.resblocks.21.attn.out_proj.bias": "model-00001-of-00004.safetensors",
384
+ "vlm.vision_encoder.transformer.resblocks.21.attn.out_proj.weight": "model-00001-of-00004.safetensors",
385
+ "vlm.vision_encoder.transformer.resblocks.21.ln_1.bias": "model-00001-of-00004.safetensors",
386
+ "vlm.vision_encoder.transformer.resblocks.21.ln_1.weight": "model-00001-of-00004.safetensors",
387
+ "vlm.vision_encoder.transformer.resblocks.21.ln_2.bias": "model-00001-of-00004.safetensors",
388
+ "vlm.vision_encoder.transformer.resblocks.21.ln_2.weight": "model-00001-of-00004.safetensors",
389
+ "vlm.vision_encoder.transformer.resblocks.21.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
390
+ "vlm.vision_encoder.transformer.resblocks.21.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
391
+ "vlm.vision_encoder.transformer.resblocks.21.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
392
+ "vlm.vision_encoder.transformer.resblocks.21.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
393
+ "vlm.vision_encoder.transformer.resblocks.22.attn.in_proj_bias": "model-00001-of-00004.safetensors",
394
+ "vlm.vision_encoder.transformer.resblocks.22.attn.in_proj_weight": "model-00001-of-00004.safetensors",
395
+ "vlm.vision_encoder.transformer.resblocks.22.attn.out_proj.bias": "model-00001-of-00004.safetensors",
396
+ "vlm.vision_encoder.transformer.resblocks.22.attn.out_proj.weight": "model-00001-of-00004.safetensors",
397
+ "vlm.vision_encoder.transformer.resblocks.22.ln_1.bias": "model-00001-of-00004.safetensors",
398
+ "vlm.vision_encoder.transformer.resblocks.22.ln_1.weight": "model-00001-of-00004.safetensors",
399
+ "vlm.vision_encoder.transformer.resblocks.22.ln_2.bias": "model-00001-of-00004.safetensors",
400
+ "vlm.vision_encoder.transformer.resblocks.22.ln_2.weight": "model-00001-of-00004.safetensors",
401
+ "vlm.vision_encoder.transformer.resblocks.22.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
402
+ "vlm.vision_encoder.transformer.resblocks.22.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
403
+ "vlm.vision_encoder.transformer.resblocks.22.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
404
+ "vlm.vision_encoder.transformer.resblocks.22.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
405
+ "vlm.vision_encoder.transformer.resblocks.23.attn.in_proj_bias": "model-00001-of-00004.safetensors",
406
+ "vlm.vision_encoder.transformer.resblocks.23.attn.in_proj_weight": "model-00001-of-00004.safetensors",
407
+ "vlm.vision_encoder.transformer.resblocks.23.attn.out_proj.bias": "model-00001-of-00004.safetensors",
408
+ "vlm.vision_encoder.transformer.resblocks.23.attn.out_proj.weight": "model-00001-of-00004.safetensors",
409
+ "vlm.vision_encoder.transformer.resblocks.23.ln_1.bias": "model-00001-of-00004.safetensors",
410
+ "vlm.vision_encoder.transformer.resblocks.23.ln_1.weight": "model-00001-of-00004.safetensors",
411
+ "vlm.vision_encoder.transformer.resblocks.23.ln_2.bias": "model-00001-of-00004.safetensors",
412
+ "vlm.vision_encoder.transformer.resblocks.23.ln_2.weight": "model-00001-of-00004.safetensors",
413
+ "vlm.vision_encoder.transformer.resblocks.23.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
414
+ "vlm.vision_encoder.transformer.resblocks.23.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
415
+ "vlm.vision_encoder.transformer.resblocks.23.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
416
+ "vlm.vision_encoder.transformer.resblocks.23.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
417
+ "vlm.vision_encoder.transformer.resblocks.24.attn.in_proj_bias": "model-00001-of-00004.safetensors",
418
+ "vlm.vision_encoder.transformer.resblocks.24.attn.in_proj_weight": "model-00001-of-00004.safetensors",
419
+ "vlm.vision_encoder.transformer.resblocks.24.attn.out_proj.bias": "model-00001-of-00004.safetensors",
420
+ "vlm.vision_encoder.transformer.resblocks.24.attn.out_proj.weight": "model-00001-of-00004.safetensors",
421
+ "vlm.vision_encoder.transformer.resblocks.24.ln_1.bias": "model-00001-of-00004.safetensors",
422
+ "vlm.vision_encoder.transformer.resblocks.24.ln_1.weight": "model-00001-of-00004.safetensors",
423
+ "vlm.vision_encoder.transformer.resblocks.24.ln_2.bias": "model-00001-of-00004.safetensors",
424
+ "vlm.vision_encoder.transformer.resblocks.24.ln_2.weight": "model-00001-of-00004.safetensors",
425
+ "vlm.vision_encoder.transformer.resblocks.24.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
426
+ "vlm.vision_encoder.transformer.resblocks.24.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
427
+ "vlm.vision_encoder.transformer.resblocks.24.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
428
+ "vlm.vision_encoder.transformer.resblocks.24.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
429
+ "vlm.vision_encoder.transformer.resblocks.25.attn.in_proj_bias": "model-00001-of-00004.safetensors",
430
+ "vlm.vision_encoder.transformer.resblocks.25.attn.in_proj_weight": "model-00001-of-00004.safetensors",
431
+ "vlm.vision_encoder.transformer.resblocks.25.attn.out_proj.bias": "model-00001-of-00004.safetensors",
432
+ "vlm.vision_encoder.transformer.resblocks.25.attn.out_proj.weight": "model-00001-of-00004.safetensors",
433
+ "vlm.vision_encoder.transformer.resblocks.25.ln_1.bias": "model-00001-of-00004.safetensors",
434
+ "vlm.vision_encoder.transformer.resblocks.25.ln_1.weight": "model-00001-of-00004.safetensors",
435
+ "vlm.vision_encoder.transformer.resblocks.25.ln_2.bias": "model-00001-of-00004.safetensors",
436
+ "vlm.vision_encoder.transformer.resblocks.25.ln_2.weight": "model-00001-of-00004.safetensors",
437
+ "vlm.vision_encoder.transformer.resblocks.25.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
438
+ "vlm.vision_encoder.transformer.resblocks.25.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
439
+ "vlm.vision_encoder.transformer.resblocks.25.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
440
+ "vlm.vision_encoder.transformer.resblocks.25.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
441
+ "vlm.vision_encoder.transformer.resblocks.26.attn.in_proj_bias": "model-00001-of-00004.safetensors",
442
+ "vlm.vision_encoder.transformer.resblocks.26.attn.in_proj_weight": "model-00001-of-00004.safetensors",
443
+ "vlm.vision_encoder.transformer.resblocks.26.attn.out_proj.bias": "model-00001-of-00004.safetensors",
444
+ "vlm.vision_encoder.transformer.resblocks.26.attn.out_proj.weight": "model-00001-of-00004.safetensors",
445
+ "vlm.vision_encoder.transformer.resblocks.26.ln_1.bias": "model-00001-of-00004.safetensors",
446
+ "vlm.vision_encoder.transformer.resblocks.26.ln_1.weight": "model-00001-of-00004.safetensors",
447
+ "vlm.vision_encoder.transformer.resblocks.26.ln_2.bias": "model-00001-of-00004.safetensors",
448
+ "vlm.vision_encoder.transformer.resblocks.26.ln_2.weight": "model-00001-of-00004.safetensors",
449
+ "vlm.vision_encoder.transformer.resblocks.26.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
450
+ "vlm.vision_encoder.transformer.resblocks.26.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
451
+ "vlm.vision_encoder.transformer.resblocks.26.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
452
+ "vlm.vision_encoder.transformer.resblocks.26.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
453
+ "vlm.vision_encoder.transformer.resblocks.27.attn.in_proj_bias": "model-00001-of-00004.safetensors",
454
+ "vlm.vision_encoder.transformer.resblocks.27.attn.in_proj_weight": "model-00001-of-00004.safetensors",
455
+ "vlm.vision_encoder.transformer.resblocks.27.attn.out_proj.bias": "model-00001-of-00004.safetensors",
456
+ "vlm.vision_encoder.transformer.resblocks.27.attn.out_proj.weight": "model-00001-of-00004.safetensors",
457
+ "vlm.vision_encoder.transformer.resblocks.27.ln_1.bias": "model-00001-of-00004.safetensors",
458
+ "vlm.vision_encoder.transformer.resblocks.27.ln_1.weight": "model-00001-of-00004.safetensors",
459
+ "vlm.vision_encoder.transformer.resblocks.27.ln_2.bias": "model-00001-of-00004.safetensors",
460
+ "vlm.vision_encoder.transformer.resblocks.27.ln_2.weight": "model-00001-of-00004.safetensors",
461
+ "vlm.vision_encoder.transformer.resblocks.27.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
462
+ "vlm.vision_encoder.transformer.resblocks.27.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
463
+ "vlm.vision_encoder.transformer.resblocks.27.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
464
+ "vlm.vision_encoder.transformer.resblocks.27.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
465
+ "vlm.vision_encoder.transformer.resblocks.28.attn.in_proj_bias": "model-00001-of-00004.safetensors",
466
+ "vlm.vision_encoder.transformer.resblocks.28.attn.in_proj_weight": "model-00001-of-00004.safetensors",
467
+ "vlm.vision_encoder.transformer.resblocks.28.attn.out_proj.bias": "model-00001-of-00004.safetensors",
468
+ "vlm.vision_encoder.transformer.resblocks.28.attn.out_proj.weight": "model-00001-of-00004.safetensors",
469
+ "vlm.vision_encoder.transformer.resblocks.28.ln_1.bias": "model-00001-of-00004.safetensors",
470
+ "vlm.vision_encoder.transformer.resblocks.28.ln_1.weight": "model-00001-of-00004.safetensors",
471
+ "vlm.vision_encoder.transformer.resblocks.28.ln_2.bias": "model-00001-of-00004.safetensors",
472
+ "vlm.vision_encoder.transformer.resblocks.28.ln_2.weight": "model-00001-of-00004.safetensors",
473
+ "vlm.vision_encoder.transformer.resblocks.28.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
474
+ "vlm.vision_encoder.transformer.resblocks.28.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
475
+ "vlm.vision_encoder.transformer.resblocks.28.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
476
+ "vlm.vision_encoder.transformer.resblocks.28.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
477
+ "vlm.vision_encoder.transformer.resblocks.29.attn.in_proj_bias": "model-00001-of-00004.safetensors",
478
+ "vlm.vision_encoder.transformer.resblocks.29.attn.in_proj_weight": "model-00001-of-00004.safetensors",
479
+ "vlm.vision_encoder.transformer.resblocks.29.attn.out_proj.bias": "model-00001-of-00004.safetensors",
480
+ "vlm.vision_encoder.transformer.resblocks.29.attn.out_proj.weight": "model-00001-of-00004.safetensors",
481
+ "vlm.vision_encoder.transformer.resblocks.29.ln_1.bias": "model-00001-of-00004.safetensors",
482
+ "vlm.vision_encoder.transformer.resblocks.29.ln_1.weight": "model-00001-of-00004.safetensors",
483
+ "vlm.vision_encoder.transformer.resblocks.29.ln_2.bias": "model-00001-of-00004.safetensors",
484
+ "vlm.vision_encoder.transformer.resblocks.29.ln_2.weight": "model-00001-of-00004.safetensors",
485
+ "vlm.vision_encoder.transformer.resblocks.29.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
486
+ "vlm.vision_encoder.transformer.resblocks.29.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
487
+ "vlm.vision_encoder.transformer.resblocks.29.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
488
+ "vlm.vision_encoder.transformer.resblocks.29.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
489
+ "vlm.vision_encoder.transformer.resblocks.3.attn.in_proj_bias": "model-00001-of-00004.safetensors",
490
+ "vlm.vision_encoder.transformer.resblocks.3.attn.in_proj_weight": "model-00001-of-00004.safetensors",
491
+ "vlm.vision_encoder.transformer.resblocks.3.attn.out_proj.bias": "model-00001-of-00004.safetensors",
492
+ "vlm.vision_encoder.transformer.resblocks.3.attn.out_proj.weight": "model-00001-of-00004.safetensors",
493
+ "vlm.vision_encoder.transformer.resblocks.3.ln_1.bias": "model-00001-of-00004.safetensors",
494
+ "vlm.vision_encoder.transformer.resblocks.3.ln_1.weight": "model-00001-of-00004.safetensors",
495
+ "vlm.vision_encoder.transformer.resblocks.3.ln_2.bias": "model-00001-of-00004.safetensors",
496
+ "vlm.vision_encoder.transformer.resblocks.3.ln_2.weight": "model-00001-of-00004.safetensors",
497
+ "vlm.vision_encoder.transformer.resblocks.3.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
498
+ "vlm.vision_encoder.transformer.resblocks.3.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
499
+ "vlm.vision_encoder.transformer.resblocks.3.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
500
+ "vlm.vision_encoder.transformer.resblocks.3.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
501
+ "vlm.vision_encoder.transformer.resblocks.30.attn.in_proj_bias": "model-00001-of-00004.safetensors",
502
+ "vlm.vision_encoder.transformer.resblocks.30.attn.in_proj_weight": "model-00001-of-00004.safetensors",
503
+ "vlm.vision_encoder.transformer.resblocks.30.attn.out_proj.bias": "model-00001-of-00004.safetensors",
504
+ "vlm.vision_encoder.transformer.resblocks.30.attn.out_proj.weight": "model-00001-of-00004.safetensors",
505
+ "vlm.vision_encoder.transformer.resblocks.30.ln_1.bias": "model-00001-of-00004.safetensors",
506
+ "vlm.vision_encoder.transformer.resblocks.30.ln_1.weight": "model-00001-of-00004.safetensors",
507
+ "vlm.vision_encoder.transformer.resblocks.30.ln_2.bias": "model-00001-of-00004.safetensors",
508
+ "vlm.vision_encoder.transformer.resblocks.30.ln_2.weight": "model-00001-of-00004.safetensors",
509
+ "vlm.vision_encoder.transformer.resblocks.30.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
510
+ "vlm.vision_encoder.transformer.resblocks.30.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
511
+ "vlm.vision_encoder.transformer.resblocks.30.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
512
+ "vlm.vision_encoder.transformer.resblocks.30.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
513
+ "vlm.vision_encoder.transformer.resblocks.31.attn.in_proj_bias": "model-00001-of-00004.safetensors",
514
+ "vlm.vision_encoder.transformer.resblocks.31.attn.in_proj_weight": "model-00001-of-00004.safetensors",
515
+ "vlm.vision_encoder.transformer.resblocks.31.attn.out_proj.bias": "model-00001-of-00004.safetensors",
516
+ "vlm.vision_encoder.transformer.resblocks.31.attn.out_proj.weight": "model-00001-of-00004.safetensors",
517
+ "vlm.vision_encoder.transformer.resblocks.31.ln_1.bias": "model-00001-of-00004.safetensors",
518
+ "vlm.vision_encoder.transformer.resblocks.31.ln_1.weight": "model-00001-of-00004.safetensors",
519
+ "vlm.vision_encoder.transformer.resblocks.31.ln_2.bias": "model-00001-of-00004.safetensors",
520
+ "vlm.vision_encoder.transformer.resblocks.31.ln_2.weight": "model-00001-of-00004.safetensors",
521
+ "vlm.vision_encoder.transformer.resblocks.31.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
522
+ "vlm.vision_encoder.transformer.resblocks.31.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
523
+ "vlm.vision_encoder.transformer.resblocks.31.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
524
+ "vlm.vision_encoder.transformer.resblocks.31.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
525
+ "vlm.vision_encoder.transformer.resblocks.4.attn.in_proj_bias": "model-00001-of-00004.safetensors",
526
+ "vlm.vision_encoder.transformer.resblocks.4.attn.in_proj_weight": "model-00001-of-00004.safetensors",
527
+ "vlm.vision_encoder.transformer.resblocks.4.attn.out_proj.bias": "model-00001-of-00004.safetensors",
528
+ "vlm.vision_encoder.transformer.resblocks.4.attn.out_proj.weight": "model-00001-of-00004.safetensors",
529
+ "vlm.vision_encoder.transformer.resblocks.4.ln_1.bias": "model-00001-of-00004.safetensors",
530
+ "vlm.vision_encoder.transformer.resblocks.4.ln_1.weight": "model-00001-of-00004.safetensors",
531
+ "vlm.vision_encoder.transformer.resblocks.4.ln_2.bias": "model-00001-of-00004.safetensors",
532
+ "vlm.vision_encoder.transformer.resblocks.4.ln_2.weight": "model-00001-of-00004.safetensors",
533
+ "vlm.vision_encoder.transformer.resblocks.4.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
534
+ "vlm.vision_encoder.transformer.resblocks.4.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
535
+ "vlm.vision_encoder.transformer.resblocks.4.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
536
+ "vlm.vision_encoder.transformer.resblocks.4.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
537
+ "vlm.vision_encoder.transformer.resblocks.5.attn.in_proj_bias": "model-00001-of-00004.safetensors",
538
+ "vlm.vision_encoder.transformer.resblocks.5.attn.in_proj_weight": "model-00001-of-00004.safetensors",
539
+ "vlm.vision_encoder.transformer.resblocks.5.attn.out_proj.bias": "model-00001-of-00004.safetensors",
540
+ "vlm.vision_encoder.transformer.resblocks.5.attn.out_proj.weight": "model-00001-of-00004.safetensors",
541
+ "vlm.vision_encoder.transformer.resblocks.5.ln_1.bias": "model-00001-of-00004.safetensors",
542
+ "vlm.vision_encoder.transformer.resblocks.5.ln_1.weight": "model-00001-of-00004.safetensors",
543
+ "vlm.vision_encoder.transformer.resblocks.5.ln_2.bias": "model-00001-of-00004.safetensors",
544
+ "vlm.vision_encoder.transformer.resblocks.5.ln_2.weight": "model-00001-of-00004.safetensors",
545
+ "vlm.vision_encoder.transformer.resblocks.5.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
546
+ "vlm.vision_encoder.transformer.resblocks.5.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
547
+ "vlm.vision_encoder.transformer.resblocks.5.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
548
+ "vlm.vision_encoder.transformer.resblocks.5.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
549
+ "vlm.vision_encoder.transformer.resblocks.6.attn.in_proj_bias": "model-00001-of-00004.safetensors",
550
+ "vlm.vision_encoder.transformer.resblocks.6.attn.in_proj_weight": "model-00001-of-00004.safetensors",
551
+ "vlm.vision_encoder.transformer.resblocks.6.attn.out_proj.bias": "model-00001-of-00004.safetensors",
552
+ "vlm.vision_encoder.transformer.resblocks.6.attn.out_proj.weight": "model-00001-of-00004.safetensors",
553
+ "vlm.vision_encoder.transformer.resblocks.6.ln_1.bias": "model-00001-of-00004.safetensors",
554
+ "vlm.vision_encoder.transformer.resblocks.6.ln_1.weight": "model-00001-of-00004.safetensors",
555
+ "vlm.vision_encoder.transformer.resblocks.6.ln_2.bias": "model-00001-of-00004.safetensors",
556
+ "vlm.vision_encoder.transformer.resblocks.6.ln_2.weight": "model-00001-of-00004.safetensors",
557
+ "vlm.vision_encoder.transformer.resblocks.6.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
558
+ "vlm.vision_encoder.transformer.resblocks.6.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
559
+ "vlm.vision_encoder.transformer.resblocks.6.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
560
+ "vlm.vision_encoder.transformer.resblocks.6.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
561
+ "vlm.vision_encoder.transformer.resblocks.7.attn.in_proj_bias": "model-00001-of-00004.safetensors",
562
+ "vlm.vision_encoder.transformer.resblocks.7.attn.in_proj_weight": "model-00001-of-00004.safetensors",
563
+ "vlm.vision_encoder.transformer.resblocks.7.attn.out_proj.bias": "model-00001-of-00004.safetensors",
564
+ "vlm.vision_encoder.transformer.resblocks.7.attn.out_proj.weight": "model-00001-of-00004.safetensors",
565
+ "vlm.vision_encoder.transformer.resblocks.7.ln_1.bias": "model-00001-of-00004.safetensors",
566
+ "vlm.vision_encoder.transformer.resblocks.7.ln_1.weight": "model-00001-of-00004.safetensors",
567
+ "vlm.vision_encoder.transformer.resblocks.7.ln_2.bias": "model-00001-of-00004.safetensors",
568
+ "vlm.vision_encoder.transformer.resblocks.7.ln_2.weight": "model-00001-of-00004.safetensors",
569
+ "vlm.vision_encoder.transformer.resblocks.7.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
570
+ "vlm.vision_encoder.transformer.resblocks.7.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
571
+ "vlm.vision_encoder.transformer.resblocks.7.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
572
+ "vlm.vision_encoder.transformer.resblocks.7.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
573
+ "vlm.vision_encoder.transformer.resblocks.8.attn.in_proj_bias": "model-00001-of-00004.safetensors",
574
+ "vlm.vision_encoder.transformer.resblocks.8.attn.in_proj_weight": "model-00001-of-00004.safetensors",
575
+ "vlm.vision_encoder.transformer.resblocks.8.attn.out_proj.bias": "model-00001-of-00004.safetensors",
576
+ "vlm.vision_encoder.transformer.resblocks.8.attn.out_proj.weight": "model-00001-of-00004.safetensors",
577
+ "vlm.vision_encoder.transformer.resblocks.8.ln_1.bias": "model-00001-of-00004.safetensors",
578
+ "vlm.vision_encoder.transformer.resblocks.8.ln_1.weight": "model-00001-of-00004.safetensors",
579
+ "vlm.vision_encoder.transformer.resblocks.8.ln_2.bias": "model-00001-of-00004.safetensors",
580
+ "vlm.vision_encoder.transformer.resblocks.8.ln_2.weight": "model-00001-of-00004.safetensors",
581
+ "vlm.vision_encoder.transformer.resblocks.8.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
582
+ "vlm.vision_encoder.transformer.resblocks.8.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
583
+ "vlm.vision_encoder.transformer.resblocks.8.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
584
+ "vlm.vision_encoder.transformer.resblocks.8.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
585
+ "vlm.vision_encoder.transformer.resblocks.9.attn.in_proj_bias": "model-00001-of-00004.safetensors",
586
+ "vlm.vision_encoder.transformer.resblocks.9.attn.in_proj_weight": "model-00001-of-00004.safetensors",
587
+ "vlm.vision_encoder.transformer.resblocks.9.attn.out_proj.bias": "model-00001-of-00004.safetensors",
588
+ "vlm.vision_encoder.transformer.resblocks.9.attn.out_proj.weight": "model-00001-of-00004.safetensors",
589
+ "vlm.vision_encoder.transformer.resblocks.9.ln_1.bias": "model-00001-of-00004.safetensors",
590
+ "vlm.vision_encoder.transformer.resblocks.9.ln_1.weight": "model-00001-of-00004.safetensors",
591
+ "vlm.vision_encoder.transformer.resblocks.9.ln_2.bias": "model-00001-of-00004.safetensors",
592
+ "vlm.vision_encoder.transformer.resblocks.9.ln_2.weight": "model-00001-of-00004.safetensors",
593
+ "vlm.vision_encoder.transformer.resblocks.9.mlp.c_fc.bias": "model-00001-of-00004.safetensors",
594
+ "vlm.vision_encoder.transformer.resblocks.9.mlp.c_fc.weight": "model-00001-of-00004.safetensors",
595
+ "vlm.vision_encoder.transformer.resblocks.9.mlp.c_proj.bias": "model-00001-of-00004.safetensors",
596
+ "vlm.vision_encoder.transformer.resblocks.9.mlp.c_proj.weight": "model-00001-of-00004.safetensors",
597
+ "vlm.vision_tokenizer.latents": "model-00001-of-00004.safetensors",
598
+ "vlm.vision_tokenizer.layers.0.0.norm_latents.bias": "model-00001-of-00004.safetensors",
599
+ "vlm.vision_tokenizer.layers.0.0.norm_latents.weight": "model-00001-of-00004.safetensors",
600
+ "vlm.vision_tokenizer.layers.0.0.norm_media.bias": "model-00001-of-00004.safetensors",
601
+ "vlm.vision_tokenizer.layers.0.0.norm_media.weight": "model-00001-of-00004.safetensors",
602
+ "vlm.vision_tokenizer.layers.0.0.to_kv.weight": "model-00001-of-00004.safetensors",
603
+ "vlm.vision_tokenizer.layers.0.0.to_out.weight": "model-00001-of-00004.safetensors",
604
+ "vlm.vision_tokenizer.layers.0.0.to_q.weight": "model-00001-of-00004.safetensors",
605
+ "vlm.vision_tokenizer.layers.0.1.0.bias": "model-00001-of-00004.safetensors",
606
+ "vlm.vision_tokenizer.layers.0.1.0.weight": "model-00001-of-00004.safetensors",
607
+ "vlm.vision_tokenizer.layers.0.1.1.weight": "model-00001-of-00004.safetensors",
608
+ "vlm.vision_tokenizer.layers.0.1.3.weight": "model-00001-of-00004.safetensors",
609
+ "vlm.vision_tokenizer.layers.1.0.norm_latents.bias": "model-00001-of-00004.safetensors",
610
+ "vlm.vision_tokenizer.layers.1.0.norm_latents.weight": "model-00001-of-00004.safetensors",
611
+ "vlm.vision_tokenizer.layers.1.0.norm_media.bias": "model-00001-of-00004.safetensors",
612
+ "vlm.vision_tokenizer.layers.1.0.norm_media.weight": "model-00001-of-00004.safetensors",
613
+ "vlm.vision_tokenizer.layers.1.0.to_kv.weight": "model-00001-of-00004.safetensors",
614
+ "vlm.vision_tokenizer.layers.1.0.to_out.weight": "model-00001-of-00004.safetensors",
615
+ "vlm.vision_tokenizer.layers.1.0.to_q.weight": "model-00001-of-00004.safetensors",
616
+ "vlm.vision_tokenizer.layers.1.1.0.bias": "model-00001-of-00004.safetensors",
617
+ "vlm.vision_tokenizer.layers.1.1.0.weight": "model-00001-of-00004.safetensors",
618
+ "vlm.vision_tokenizer.layers.1.1.1.weight": "model-00001-of-00004.safetensors",
619
+ "vlm.vision_tokenizer.layers.1.1.3.weight": "model-00001-of-00004.safetensors",
620
+ "vlm.vision_tokenizer.layers.2.0.norm_latents.bias": "model-00001-of-00004.safetensors",
621
+ "vlm.vision_tokenizer.layers.2.0.norm_latents.weight": "model-00001-of-00004.safetensors",
622
+ "vlm.vision_tokenizer.layers.2.0.norm_media.bias": "model-00001-of-00004.safetensors",
623
+ "vlm.vision_tokenizer.layers.2.0.norm_media.weight": "model-00001-of-00004.safetensors",
624
+ "vlm.vision_tokenizer.layers.2.0.to_kv.weight": "model-00001-of-00004.safetensors",
625
+ "vlm.vision_tokenizer.layers.2.0.to_out.weight": "model-00001-of-00004.safetensors",
626
+ "vlm.vision_tokenizer.layers.2.0.to_q.weight": "model-00001-of-00004.safetensors",
627
+ "vlm.vision_tokenizer.layers.2.1.0.bias": "model-00001-of-00004.safetensors",
628
+ "vlm.vision_tokenizer.layers.2.1.0.weight": "model-00001-of-00004.safetensors",
629
+ "vlm.vision_tokenizer.layers.2.1.1.weight": "model-00001-of-00004.safetensors",
630
+ "vlm.vision_tokenizer.layers.2.1.3.weight": "model-00001-of-00004.safetensors",
631
+ "vlm.vision_tokenizer.layers.3.0.norm_latents.bias": "model-00001-of-00004.safetensors",
632
+ "vlm.vision_tokenizer.layers.3.0.norm_latents.weight": "model-00001-of-00004.safetensors",
633
+ "vlm.vision_tokenizer.layers.3.0.norm_media.bias": "model-00001-of-00004.safetensors",
634
+ "vlm.vision_tokenizer.layers.3.0.norm_media.weight": "model-00001-of-00004.safetensors",
635
+ "vlm.vision_tokenizer.layers.3.0.to_kv.weight": "model-00001-of-00004.safetensors",
636
+ "vlm.vision_tokenizer.layers.3.0.to_out.weight": "model-00001-of-00004.safetensors",
637
+ "vlm.vision_tokenizer.layers.3.0.to_q.weight": "model-00001-of-00004.safetensors",
638
+ "vlm.vision_tokenizer.layers.3.1.0.bias": "model-00001-of-00004.safetensors",
639
+ "vlm.vision_tokenizer.layers.3.1.0.weight": "model-00001-of-00004.safetensors",
640
+ "vlm.vision_tokenizer.layers.3.1.1.weight": "model-00001-of-00004.safetensors",
641
+ "vlm.vision_tokenizer.layers.3.1.3.weight": "model-00001-of-00004.safetensors",
642
+ "vlm.vision_tokenizer.layers.4.0.norm_latents.bias": "model-00001-of-00004.safetensors",
643
+ "vlm.vision_tokenizer.layers.4.0.norm_latents.weight": "model-00001-of-00004.safetensors",
644
+ "vlm.vision_tokenizer.layers.4.0.norm_media.bias": "model-00001-of-00004.safetensors",
645
+ "vlm.vision_tokenizer.layers.4.0.norm_media.weight": "model-00001-of-00004.safetensors",
646
+ "vlm.vision_tokenizer.layers.4.0.to_kv.weight": "model-00001-of-00004.safetensors",
647
+ "vlm.vision_tokenizer.layers.4.0.to_out.weight": "model-00001-of-00004.safetensors",
648
+ "vlm.vision_tokenizer.layers.4.0.to_q.weight": "model-00001-of-00004.safetensors",
649
+ "vlm.vision_tokenizer.layers.4.1.0.bias": "model-00001-of-00004.safetensors",
650
+ "vlm.vision_tokenizer.layers.4.1.0.weight": "model-00001-of-00004.safetensors",
651
+ "vlm.vision_tokenizer.layers.4.1.1.weight": "model-00001-of-00004.safetensors",
652
+ "vlm.vision_tokenizer.layers.4.1.3.weight": "model-00001-of-00004.safetensors",
653
+ "vlm.vision_tokenizer.layers.5.0.norm_latents.bias": "model-00001-of-00004.safetensors",
654
+ "vlm.vision_tokenizer.layers.5.0.norm_latents.weight": "model-00001-of-00004.safetensors",
655
+ "vlm.vision_tokenizer.layers.5.0.norm_media.bias": "model-00001-of-00004.safetensors",
656
+ "vlm.vision_tokenizer.layers.5.0.norm_media.weight": "model-00001-of-00004.safetensors",
657
+ "vlm.vision_tokenizer.layers.5.0.to_kv.weight": "model-00001-of-00004.safetensors",
658
+ "vlm.vision_tokenizer.layers.5.0.to_out.weight": "model-00001-of-00004.safetensors",
659
+ "vlm.vision_tokenizer.layers.5.0.to_q.weight": "model-00001-of-00004.safetensors",
660
+ "vlm.vision_tokenizer.layers.5.1.0.bias": "model-00001-of-00004.safetensors",
661
+ "vlm.vision_tokenizer.layers.5.1.0.weight": "model-00001-of-00004.safetensors",
662
+ "vlm.vision_tokenizer.layers.5.1.1.weight": "model-00001-of-00004.safetensors",
663
+ "vlm.vision_tokenizer.layers.5.1.3.weight": "model-00001-of-00004.safetensors",
664
+ "vlm.vision_tokenizer.norm.bias": "model-00001-of-00004.safetensors",
665
+ "vlm.vision_tokenizer.norm.weight": "model-00001-of-00004.safetensors",
666
+ "vlm.vision_tokenizer.projection.bias": "model-00001-of-00004.safetensors",
667
+ "vlm.vision_tokenizer.projection.weight": "model-00001-of-00004.safetensors",
668
+ "vlm.vision_tokenizer.text_projection.0.bias": "model-00001-of-00004.safetensors",
669
+ "vlm.vision_tokenizer.text_projection.0.weight": "model-00001-of-00004.safetensors",
670
+ "vlm.vision_tokenizer.text_projection.2.bias": "model-00001-of-00004.safetensors",
671
+ "vlm.vision_tokenizer.text_projection.2.weight": "model-00001-of-00004.safetensors"
672
+ }
673
+ }
modeling_blip_3.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, AutoModelForCausalLM
2
+ import torch
3
+ import open_clip
4
+ from typing import List, Optional, Tuple, Union
5
+ from .utils import check_embedding_fns
6
+ from .vlm import InstructPerceiverResampler, KosmosInstruct
7
+ from .configuration_blip_3 import Blip3VisionEncoderConfig, Blip3VisionTokenizerConfig, Blip3Config
8
+
9
+ class Blip3VisionEncoder(PreTrainedModel):
10
+ main_input_name = "pixel_values"
11
+ config_class = Blip3VisionEncoderConfig
12
+
13
+ def __init__(self, config: Blip3VisionEncoderConfig):
14
+ super().__init__(config)
15
+ if config.model_name != 'ViT-H-14-378-quickgelu':
16
+ raise ValueError(f"Unsupported model {config.model_name}. New vision models will be added soon.")
17
+ self.model, _, _ = open_clip.create_model_and_transforms(
18
+ model_name = config.model_name,
19
+ force_image_size=config.force_image_size
20
+ )
21
+
22
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
23
+ # assert pixel_values.ndim == 4, f"Expected 4D tensor (bs, c, h, w), got {pixel_values.ndim}"
24
+ return self.model.encode_image(pixel_values)
25
+
26
+
27
+ # vision tokenizer
28
+ class Blip3VisionTokenizer(PreTrainedModel):
29
+ config_class = Blip3VisionTokenizerConfig
30
+ def __init__(self, config: Blip3VisionTokenizerConfig):
31
+ super().__init__(config)
32
+ self.model = InstructPerceiverResampler(
33
+ dim_llm=config.lang_embedding_dim,
34
+ dim=config.vis_feature_dim,
35
+ dim_inner=config.lang_embedding_dim,
36
+ num_latents=config.num_vis_tokens,
37
+ repeat_latents=config.repeat_latents
38
+ )
39
+
40
+ def forward(self,
41
+ vision_features: torch.Tensor,
42
+ vision_attn_masks: torch.Tensor):
43
+ return self.model(vision_features, vision_attn_masks)
44
+
45
+ # Blip3 model
46
+ class Blip3ModelForConditionalGeneration(PreTrainedModel):
47
+ config_class = Blip3Config
48
+
49
+ def __init__(self, config: Blip3Config):
50
+ super().__init__(config)
51
+
52
+ # vision encoder initialization
53
+ vision_encoder = Blip3VisionEncoder(config.vision_encoder_config).model
54
+ vision_encoder.visual.output_tokens = True
55
+ vision_encoder = vision_encoder.visual
56
+
57
+ # language model initialization
58
+ language_model = AutoModelForCausalLM.from_config(config.text_config)
59
+ check_embedding_fns(language_model)
60
+ # Update _tied_weights_keys using the base model used.
61
+ if language_model._tied_weights_keys is not None:
62
+ self._tied_weights_keys = [f"language_model.{k}" for k in language_model._tied_weights_keys]
63
+
64
+ # vision tokenizer initialization
65
+ if config.vision_tokenizer_config.lang_embedding_dim != language_model.get_input_embeddings().weight.shape[1]:
66
+ overwrite = language_model.get_input_embeddings().weight.shape[1]
67
+ config.vision_tokenizer_config.lang_embedding_dim = overwrite
68
+ print(f"Warning: The language embedding dimension in the vision tokenizer config is different from the language model's embedding dimension. Overwriting the language embedding dimension in the vision tokenizer config to {overwrite}.")
69
+
70
+ vision_tokenizer = Blip3VisionTokenizer(config.vision_tokenizer_config).model
71
+
72
+ self.vlm = KosmosInstruct(
73
+ vision_encoder=vision_encoder,
74
+ vision_tokenizer=vision_tokenizer,
75
+ lang_model=language_model,
76
+ initial_tokenizer_len = config.text_config.initial_tokenizer_len,
77
+ pad_token_id = config.text_config.pad_token_id,
78
+ image_aspect_ratio = config.vision_encoder_config.image_aspect_ratio,
79
+ anyres_patch_sampling = config.vision_encoder_config.anyres_patch_sampling
80
+ )
81
+ # Initialize weights and apply final processing
82
+ self.post_init()
83
+
84
+ @torch.no_grad()
85
+ def generate(
86
+ self,
87
+ pixel_values: torch.FloatTensor,
88
+ input_ids: Optional[torch.LongTensor] = None,
89
+ attention_mask: Optional[torch.LongTensor] = None,
90
+ **generate_kwargs,
91
+ ) -> torch.LongTensor:
92
+ self.vlm = self.vlm.eval()
93
+ return self.vlm.generate(
94
+ vision_x = pixel_values,
95
+ lang_x = input_ids,
96
+ attention_mask = attention_mask,
97
+ **generate_kwargs)
98
+
99
+ def update_special_tokens(self, tokenizer):
100
+ tokenizer.add_special_tokens(
101
+ {"additional_special_tokens": list(self.vlm.special_tokens.values())}
102
+ )
103
+ self.vlm.lang_model.config.vocab_size = len(tokenizer)
104
+ self.vlm.set_special_token_ids(
105
+ {
106
+ v: tokenizer.convert_tokens_to_ids(v) for v in self.vlm.special_tokens.values()
107
+ }
108
+ )
109
+ return tokenizer
110
+
preprocessor_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_blip_3.Blip3ImageProcessor"
4
+ },
5
+ "do_resize": true,
6
+ "image_mean": [
7
+ 0.48145466,
8
+ 0.4578275,
9
+ 0.40821073
10
+ ],
11
+ "image_processor_type": "Blip3ImageProcessor",
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "interpolation_mode": "bicubic",
18
+ "resize_mode": "squash",
19
+ "size": [
20
+ 378,
21
+ 378
22
+ ]
23
+ }
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<pad>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
test_samples/images/112.jpg ADDED
test_samples/images/51.jpg ADDED
test_samples/images/76.jpg ADDED
test_samples/test.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "question": "Which room is bigger, the double garage or the living room?",
4
+ "answer": "double garage",
5
+ "category": "ocr,spat,math",
6
+ "image_path": "./test_samples/images/51.jpg"
7
+ },
8
+ {
9
+ "question": "what is the green logo on the car?",
10
+ "answer": "monster",
11
+ "category": "rec",
12
+ "image_path": "./test_samples/images/76.jpg"
13
+ },
14
+ {
15
+ "question": "Is there any reflection of zebra in water?",
16
+ "answer": "yes",
17
+ "category": "rec",
18
+ "image_path": "./test_samples/images/112.jpg"
19
+ }
20
+ ]
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
tokenizer_config.json ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "added_tokens_decoder": {
5
+ "0": {
6
+ "content": "<unk>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "1": {
14
+ "content": "<s>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "2": {
22
+ "content": "</s>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": true,
26
+ "single_word": false,
27
+ "special": false
28
+ },
29
+ "32000": {
30
+ "content": "<|endoftext|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "32001": {
38
+ "content": "<|assistant|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": true,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "32002": {
46
+ "content": "<|placeholder1|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": true,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "32003": {
54
+ "content": "<|placeholder2|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": true,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "32004": {
62
+ "content": "<|placeholder3|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": true,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "32005": {
70
+ "content": "<|placeholder4|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": true,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "32006": {
78
+ "content": "<|system|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": true,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "32007": {
86
+ "content": "<|end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": true,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "32008": {
94
+ "content": "<|placeholder5|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": true,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "32009": {
102
+ "content": "<|placeholder6|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": true,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "32010": {
110
+ "content": "<|user|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": true,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "32011": {
118
+ "content": "<pad>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": true
124
+ }
125
+ },
126
+ "bos_token": "<s>",
127
+ "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
128
+ "clean_up_tokenization_spaces": false,
129
+ "eos_token": "<|endoftext|>",
130
+ "model_max_length": 4096,
131
+ "pad_token": "<pad>",
132
+ "padding_side": "left",
133
+ "sp_model_kwargs": {},
134
+ "tokenizer_class": "LlamaTokenizer",
135
+ "unk_token": "<unk>",
136
+ "use_default_system_prompt": false
137
+ }
utils.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import ast
3
+ import math
4
+ from PIL import Image
5
+
6
+
7
+ def has_fn(model, fn_name):
8
+ """Check if model has a function fn_name"""
9
+ return callable(getattr(model, fn_name, None))
10
+
11
+ def exists(val):
12
+ return val is not None
13
+
14
+ def num_params(module, filter_to_trainable=False):
15
+ """Returns the number of parameters in the module, or optionally only the trainable parameters"""
16
+ if filter_to_trainable:
17
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
18
+ else:
19
+ return sum(p.numel() for p in module.parameters())
20
+
21
+ def hasattr_recursive(obj, att):
22
+ """
23
+ Check if obj has nested attribute
24
+ Example: hasattr_recursive(obj, 'a.b.c') is equivalent to hasattr(obj, 'a') and hasattr(obj.a, 'b') and hasattr(obj.a.b, 'c')
25
+ """
26
+ if att == "":
27
+ return True
28
+ i = att.find(".")
29
+ if i < 0:
30
+ return hasattr(obj, att)
31
+ else:
32
+ try:
33
+ return hasattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
34
+ except:
35
+ return False
36
+
37
+ def getattr_recursive(obj, att):
38
+ """
39
+ Return nested attribute of obj
40
+ Example: getattr_recursive(obj, 'a.b.c') is equivalent to obj.a.b.c
41
+ """
42
+ if att == "":
43
+ return obj
44
+ i = att.find(".")
45
+ if i < 0:
46
+ return getattr(obj, att)
47
+ else:
48
+ return getattr_recursive(getattr(obj, att[:i]), att[i + 1 :])
49
+
50
+
51
+ def setattr_recursive(obj, att, val):
52
+ """
53
+ Set nested attribute of obj
54
+ Example: setattr_recursive(obj, 'a.b.c', val) is equivalent to obj.a.b.c = val
55
+ """
56
+ if "." in att:
57
+ obj = getattr_recursive(obj, ".".join(att.split(".")[:-1]))
58
+ setattr(obj, att.split(".")[-1], val)
59
+
60
+
61
+ def stack_with_padding(list_of_tensors, padding_value=0, padding_side="right"):
62
+ """
63
+ Stack a list of tensors with padding on one side
64
+ Args:
65
+ list_of_tensors (list[torch.Tensor]): List of tensors to stack
66
+ padding_value (int, optional): Value to pad with. Defaults to 0.
67
+ padding_side (str, optional): Side to pad on. Defaults to "right".
68
+ Returns:
69
+ torch.Tensor: Stacked tensors
70
+ """
71
+ max_tokens = max(tensor.size(0) for tensor in list_of_tensors)
72
+ padded_tensors = []
73
+ for tensor in list_of_tensors:
74
+ num_tokens = tensor.size(0)
75
+ if len(tensor.size()) == 1:
76
+ padding = torch.full(
77
+ (max_tokens - num_tokens,),
78
+ padding_value,
79
+ dtype=tensor.dtype,
80
+ device=tensor.device,
81
+ )
82
+ else:
83
+ padding = torch.full(
84
+ (max_tokens - num_tokens, tensor.size(1)),
85
+ padding_value,
86
+ dtype=tensor.dtype,
87
+ device=tensor.device,
88
+ )
89
+ padded_tensor = (
90
+ torch.cat((tensor, padding), dim=0)
91
+ if padding_side == "right"
92
+ else torch.cat((padding, tensor), dim=0)
93
+ )
94
+ padded_tensors.append(padded_tensor)
95
+ return torch.stack(padded_tensors)
96
+
97
+
98
+ def check_embedding_fns(lang_model):
99
+ """Checks for and attempts to set {get/set}_{input/output}_embeddings functions to the model"""
100
+ if not has_fn(lang_model, "get_input_embeddings"):
101
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
102
+ lang_model.get_input_embeddings = lambda: lang_model.transformer.wte
103
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
104
+ lang_model.get_input_embeddings = lambda: lang_model.decoder.embed_tokens
105
+ else:
106
+ raise ValueError(
107
+ "We require the language encoder to have a get_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
108
+ )
109
+
110
+ if not has_fn(lang_model, "set_input_embeddings"):
111
+ if hasattr_recursive(lang_model, "transformer.wte"): # MPT
112
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
113
+ lang_model, "transformer.wte", x
114
+ )
115
+ elif hasattr_recursive(lang_model, "model.decoder.embed_tokens"): # OPT
116
+ lang_model.set_input_embeddings = lambda x: setattr_recursive(
117
+ lang_model, "model.decoder.embed_tokens", x
118
+ )
119
+ else:
120
+ raise ValueError(
121
+ "We require the language encoder to have a set_input_embeddings method but we couldn't determine the name of the input embeddings attribute. Please supply this manually in factory.py."
122
+ )
123
+
124
+ if not has_fn(lang_model, "get_output_embeddings"):
125
+ if hasattr_recursive(lang_model, "lm_head"):
126
+ lang_model.get_output_embeddings = lambda: lang_model.lm_head
127
+ else:
128
+ raise ValueError(
129
+ "We require the language encoder to have a get_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
130
+ )
131
+
132
+ if not has_fn(lang_model, "set_output_embeddings"):
133
+ if hasattr_recursive(lang_model, "lm_head"):
134
+ lang_model.set_output_embeddings = lambda x: setattr_recursive(
135
+ lang_model, "lm_head", x
136
+ )
137
+ else:
138
+ raise ValueError(
139
+ "We require the language encoder to have a set_output_embeddings method but we couldn't determine the name of the output embeddings attribute. Please supply this manually in factory.py."
140
+ )
141
+
142
+
143
+ def has_fn(model, fn_name):
144
+ """Check if model has a function fn_name"""
145
+ return callable(getattr(model, fn_name, None))
146
+
147
+
148
+ # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright:
149
+ #
150
+ # Licensed under the Apache License, Version 2.0 (the "License");
151
+ # you may not use this file except in compliance with the License.
152
+ # You may obtain a copy of the License at
153
+ #
154
+ # http://www.apache.org/licenses/LICENSE-2.0
155
+ #
156
+ # Unless required by applicable law or agreed to in writing, software
157
+ # distributed under the License is distributed on an "AS IS" BASIS,
158
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
159
+ # See the License for the specific language governing permissions and
160
+ # limitations under the License.
161
+
162
+ def unpad_image(tensor, original_size, keep_original_shape=False):
163
+ """
164
+ Unpads a PyTorch tensor of a padded and resized image.
165
+
166
+ Args:
167
+ tensor (torch.Tensor): The image tensor, assumed to be in CxHxW format.
168
+ original_size (tuple): The original size of the image (height, width).
169
+
170
+ Returns:
171
+ torch.Tensor: The unpadded image tensor.
172
+ """
173
+ original_width, original_height = original_size
174
+ current_height, current_width = tensor.shape[1:]
175
+
176
+ original_aspect_ratio = original_width / original_height
177
+ current_aspect_ratio = current_width / current_height
178
+
179
+ if original_aspect_ratio > current_aspect_ratio:
180
+ scale_factor = current_width / original_width
181
+ new_height = int(original_height * scale_factor)
182
+ padding = (current_height - new_height) // 2
183
+ if keep_original_shape:
184
+ attention_mask = torch.ones((current_height, current_width), device=tensor.device)
185
+ attention_mask[:padding, :] = 0
186
+ attention_mask[current_height - padding:, :] = 0
187
+ return tensor, attention_mask
188
+ else:
189
+ unpadded_tensor = tensor[:, padding:current_height - padding, :]
190
+ return unpadded_tensor, None
191
+ else:
192
+ scale_factor = current_height / original_height
193
+ new_width = int(original_width * scale_factor)
194
+ padding = (current_width - new_width) // 2
195
+ if keep_original_shape:
196
+ attention_mask = torch.ones((current_height, current_width), device=tensor.device)
197
+ attention_mask[:, :padding] = 0
198
+ attention_mask[:, current_width - padding:] = 0
199
+ return tensor, attention_mask
200
+ else:
201
+ unpadded_tensor = tensor[:, :, padding:current_width - padding]
202
+ return unpadded_tensor, None
203
+
204
+
205
+ def select_best_resolution(original_size, possible_resolutions):
206
+ """
207
+ Selects the best resolution from a list of possible resolutions based on the original size.
208
+
209
+ Args:
210
+ original_size (tuple): The original size of the image in the format (width, height).
211
+ possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
212
+
213
+ Returns:
214
+ tuple: The best fit resolution in the format (width, height).
215
+ """
216
+ original_width, original_height = original_size
217
+ best_fit = None
218
+ max_effective_resolution = 0
219
+ min_wasted_resolution = float('inf')
220
+
221
+ for width, height in possible_resolutions:
222
+ scale = min(width / original_width, height / original_height)
223
+ downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
224
+ effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
225
+ wasted_resolution = (width * height) - effective_resolution
226
+
227
+ if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
228
+ max_effective_resolution = effective_resolution
229
+ min_wasted_resolution = wasted_resolution
230
+ best_fit = (width, height)
231
+
232
+ return best_fit
233
+
234
+
235
+ def resize_and_pad_image(image, target_resolution):
236
+ """
237
+ Resize and pad an image to a target resolution while maintaining aspect ratio.
238
+
239
+ Args:
240
+ image (PIL.Image.Image): The input image.
241
+ target_resolution (tuple): The target resolution (width, height) of the image.
242
+
243
+ Returns:
244
+ PIL.Image.Image: The resized and padded image.
245
+ """
246
+ original_width, original_height = image.size
247
+ target_width, target_height = target_resolution
248
+
249
+ scale_w = target_width / original_width
250
+ scale_h = target_height / original_height
251
+
252
+ if scale_w < scale_h:
253
+ new_width = target_width
254
+ new_height = min(math.ceil(original_height * scale_w), target_height)
255
+ else:
256
+ new_height = target_height
257
+ new_width = min(math.ceil(original_width * scale_h), target_width)
258
+
259
+ # Resize the image
260
+ resized_image = image.resize((new_width, new_height))
261
+
262
+ new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
263
+ paste_x = (target_width - new_width) // 2
264
+ paste_y = (target_height - new_height) // 2
265
+ new_image.paste(resized_image, (paste_x, paste_y))
266
+
267
+ return new_image
268
+
269
+
270
+ def divide_to_patches(image, patch_size):
271
+ """
272
+ Divides an image into patches of a specified size.
273
+
274
+ Args:
275
+ image (PIL.Image.Image): The input image.
276
+ patch_size (int): The size of each patch.
277
+
278
+ Returns:
279
+ list: A list of PIL.Image.Image objects representing the patches.
280
+ """
281
+ patches = []
282
+ width, height = image.size
283
+ for i in range(0, height, patch_size):
284
+ for j in range(0, width, patch_size):
285
+ box = (j, i, j + patch_size, i + patch_size)
286
+ patch = image.crop(box)
287
+ patches.append(patch)
288
+
289
+ return patches
290
+
291
+
292
+ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
293
+ """
294
+ Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
295
+
296
+ Args:
297
+ image_size (tuple): The size of the input image in the format (width, height).
298
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
299
+ patch_size (int): The size of each image patch.
300
+
301
+ Returns:
302
+ tuple: The shape of the image patch grid in the format (width, height).
303
+ """
304
+ if type(grid_pinpoints) is list:
305
+ possible_resolutions = grid_pinpoints
306
+ else:
307
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
308
+ width, height = select_best_resolution(image_size, possible_resolutions)
309
+ return width // patch_size, height // patch_size
310
+
311
+
312
+ def process_anyres_image(image, processor, grid_pinpoints):
313
+ """
314
+ Process an image with variable resolutions.
315
+
316
+ Args:
317
+ image (PIL.Image.Image): The input image to be processed.
318
+ processor: The image processor object.
319
+ grid_pinpoints (str): A string representation of a list of possible resolutions.
320
+
321
+ Returns:
322
+ torch.Tensor: A tensor containing the processed image patches.
323
+ """
324
+ # FIXME: determine grid_pinpoints from image sizes.
325
+ if type(grid_pinpoints) is list:
326
+ possible_resolutions = grid_pinpoints
327
+ else:
328
+ possible_resolutions = ast.literal_eval(grid_pinpoints)
329
+ best_resolution = select_best_resolution(image.size, possible_resolutions)
330
+ image_padded = resize_and_pad_image(image, best_resolution)
331
+
332
+ processor_size = processor.transforms[0].size
333
+ patches = divide_to_patches(image_padded, processor_size[0])
334
+
335
+ image_original_resize = image.resize((processor_size[0], processor_size[0]))
336
+
337
+ image_patches = [image_original_resize] + patches
338
+ image_patches = [processor(image_patch)
339
+ for image_patch in image_patches]
340
+ return torch.stack(image_patches, dim=0)
341
+
342
+
343
+ def expand2square(pil_img, background_color):
344
+ width, height = pil_img.size
345
+ if width == height:
346
+ return pil_img
347
+ elif width > height:
348
+ result = Image.new(pil_img.mode, (width, width), background_color)
349
+ result.paste(pil_img, (0, (width - height) // 2))
350
+ return result
351
+ else:
352
+ result = Image.new(pil_img.mode, (height, height), background_color)
353
+ result.paste(pil_img, ((height - width) // 2, 0))
354
+ return result
355
+
356
+
357
+ def process_images(images, image_processor, model_cfg):
358
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
359
+ new_images = []
360
+ if image_aspect_ratio == 'pad':
361
+ for image in images:
362
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.transforms[-1].mean))
363
+ image = image_processor(image)
364
+ new_images.append(image)
365
+ elif image_aspect_ratio in ["anyres", "anyres-legacy"]:
366
+ base_img_size = image_processor.transforms[0].size[0]
367
+ for image in images:
368
+ image = process_anyres_image(image, image_processor, [[base_img_size,base_img_size*2],
369
+ [base_img_size*2,base_img_size],
370
+ [base_img_size*2,base_img_size*2],
371
+ [base_img_size*3,base_img_size],
372
+ [base_img_size,base_img_size*3]])
373
+
374
+ # Debug any res inference by only using 672x672.
375
+ # image = process_anyres_image(image, image_processor, [[base_img_size*2,base_img_size*2]])
376
+ new_images.append(image)
377
+ else:
378
+ return image_processor(images)
379
+ if all(x.shape == new_images[0].shape for x in new_images):
380
+ new_images = torch.stack(new_images, dim=0)
381
+ return new_images
382
+
383
+
vlm.py ADDED
@@ -0,0 +1,1531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import einsum, nn
4
+ from einops import rearrange, repeat
5
+ from einops_exts import rearrange_many
6
+ from einops import rearrange
7
+ from typing import List, Optional, Tuple, Union
8
+ import torch.nn.functional as F
9
+ from transformers.modeling_outputs import CausalLMOutputWithPast
10
+ from dataclasses import dataclass
11
+ from transformers import CLIPVisionModel
12
+ import transformers
13
+
14
+ from .utils import num_params, getattr_recursive, stack_with_padding, get_anyres_image_grid_shape, unpad_image
15
+
16
+
17
+ class VisionTokenizer(nn.Module):
18
+ def __init__(self, dim_media, num_tokens_per_media):
19
+ super().__init__()
20
+ self.dim_media = dim_media
21
+ self.num_tokens_per_media = num_tokens_per_media
22
+
23
+ class PerceiverAttention(nn.Module):
24
+ def __init__(self, *, dim, dim_head=64, heads=8):
25
+ super().__init__()
26
+ self.scale = dim_head**-0.5
27
+ self.heads = heads
28
+ inner_dim = dim_head * heads
29
+
30
+ self.norm_media = nn.LayerNorm(dim)
31
+ self.norm_latents = nn.LayerNorm(dim)
32
+
33
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
34
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
35
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
36
+
37
+ def forward(self, x, latents, vision_attn_masks=None):
38
+ """
39
+ Args:
40
+ x (torch.Tensor): image features
41
+ shape (b, T, n1, D)
42
+ latent (torch.Tensor): latent features
43
+ shape (b, T, n2, D)
44
+ """
45
+ x = self.norm_media(x)
46
+ latents = self.norm_latents(latents)
47
+
48
+ h = self.heads
49
+
50
+ q = self.to_q(latents)
51
+ kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.
52
+ if vision_attn_masks is not None:
53
+ vision_attn_masks = torch.cat((vision_attn_masks,
54
+ torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),
55
+ dim=-1)
56
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
57
+ q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)
58
+ q = q * self.scale
59
+
60
+ # attention
61
+ sim = einsum("... i d, ... j d -> ... i j", q, k)
62
+ # Apply vision attention mask here.
63
+ # Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
64
+ if vision_attn_masks is not None:
65
+ attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)
66
+ vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))
67
+ attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))
68
+ sim += attn_bias
69
+
70
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
71
+ attn = sim.softmax(dim=-1)
72
+
73
+
74
+ out = einsum("... i j, ... j d -> ... i d", attn, v)
75
+ out = rearrange(out, "b h t n d -> b t n (h d)", h=h)
76
+ return self.to_out(out)
77
+
78
+
79
+ def FeedForward(dim, mult=4):
80
+ inner_dim = int(dim * mult)
81
+ return nn.Sequential(
82
+ nn.LayerNorm(dim),
83
+ nn.Linear(dim, inner_dim, bias=False),
84
+ nn.GELU(),
85
+ nn.Linear(inner_dim, dim, bias=False),
86
+ )
87
+
88
+
89
+ class InstructPerceiverResampler(VisionTokenizer):
90
+ def __init__(
91
+ self,
92
+ *,
93
+ dim_llm,
94
+ dim,
95
+ dim_inner=None,
96
+ depth=6,
97
+ dim_head=96,
98
+ heads=16,
99
+ num_latents=64,
100
+ repeat_latents=False,
101
+ max_num_media=None,
102
+ max_num_frames=None,
103
+ ff_mult=4,
104
+ ):
105
+ """
106
+ Perceiver module which takes in image features and outputs image tokens.
107
+ Args:
108
+ dim (int): dimension of the incoming image features
109
+ dim_inner (int, optional): final dimension to project the incoming image features to;
110
+ also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
111
+ depth (int, optional): number of layers. Defaults to 6.
112
+ dim_head (int, optional): dimension of each head. Defaults to 64.
113
+ heads (int, optional): number of heads. Defaults to 8.
114
+ num_latents (int, optional): number of latent tokens to use in the Perceiver;
115
+ also corresponds to number of tokens per sequence to output. Defaults to 64.
116
+ max_num_media (int, optional): maximum number of media per sequence to input into the Perceiver
117
+ and keep positional embeddings for. If None, no positional embeddings are used.
118
+ max_num_frames (int, optional): maximum number of frames to input into the Perceiver
119
+ and keep positional embeddings for. If None, no positional embeddings are used.
120
+ ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4.
121
+ """
122
+ if dim_inner is not None:
123
+ projection = nn.Linear(dim, dim_inner)
124
+ else:
125
+ projection = None
126
+ dim_inner = dim
127
+ super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
128
+ self.projection = projection
129
+
130
+ # Text embedding projection.
131
+ # self.text_projection = nn.Linear(dim_llm, dim)
132
+ modules = [nn.Linear(dim_llm, dim)]
133
+ for _ in range(1, 2):
134
+ modules.append(nn.GELU())
135
+ modules.append(nn.Linear(dim, dim))
136
+ self.text_projection = nn.Sequential(*modules)
137
+
138
+ self.latents = nn.Parameter(torch.randn(num_latents, dim))
139
+ self.repeat_latents = repeat_latents
140
+ # positional embeddings
141
+ self.frame_embs = (
142
+ nn.Parameter(torch.randn(max_num_frames, dim))
143
+ if exists(max_num_frames)
144
+ else None
145
+ )
146
+ self.media_time_embs = (
147
+ nn.Parameter(torch.randn(max_num_media, 1, dim))
148
+ if exists(max_num_media)
149
+ else None
150
+ )
151
+
152
+ self.layers = nn.ModuleList([])
153
+ for _ in range(depth):
154
+ self.layers.append(
155
+ nn.ModuleList(
156
+ [
157
+ PerceiverAttention(
158
+ dim=dim, dim_head=dim_head, heads=heads
159
+ ),
160
+ FeedForward(dim=dim, mult=ff_mult),
161
+ ]
162
+ )
163
+ )
164
+
165
+ self.norm = nn.LayerNorm(dim)
166
+ # TODO: write a new forward function that takes in text input and append them to the query tokens.
167
+ def forward(self, x, text_embeds=None):
168
+ """
169
+ Args:
170
+ x (torch.Tensor): image features
171
+ shape (b, T, F, v, D)
172
+ Returns:
173
+ shape (b, T, n, D) where n is self.num_latents
174
+ """
175
+ b, T, F, v = x.shape[:4]
176
+
177
+ # frame and media time embeddings
178
+ if exists(self.frame_embs):
179
+ frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
180
+ x = x + frame_embs
181
+ x = rearrange(
182
+ x, "b T F v d -> b T (F v) d"
183
+ ) # flatten the frame and spatial dimensions
184
+ if exists(self.media_time_embs):
185
+ x = x + self.media_time_embs[:T]
186
+
187
+ # blocks
188
+ # FIXME: extending query tokens proportional to the vision sequence length. Hard-coded as dfn5b token_len=729.
189
+ if self.repeat_latents:
190
+ r = v // 729 # Repeat the query tokens for r times.
191
+ latents = repeat(self.latents, "n d -> (n repeat) d", repeat=r)
192
+ else:
193
+ latents = self.latents
194
+ latents = repeat(latents, "n d -> b T n d", b=b, T=T)
195
+ # latents = repeat(self.latents, "n d -> b T n d", b=b, T=T)
196
+ # Append text embedding.
197
+ if exists(text_embeds):
198
+ text_embeds = self.text_projection(text_embeds)
199
+ text_embeds = text_embeds[:, None, :, :]
200
+ latents = torch.cat((latents, text_embeds), dim=2) # FIXME: check latents shape.
201
+
202
+ for attn, ff in self.layers:
203
+ latents = attn(x, latents) + latents
204
+ latents = ff(latents) + latents
205
+
206
+ # Truncate latents to only keep query tokens.
207
+ if exists(text_embeds):
208
+ latents = latents[:, :, :self.latents.shape[0], :]
209
+
210
+ if exists(self.projection):
211
+ return self.projection(self.norm(latents))
212
+ else:
213
+ return self.norm(latents)
214
+
215
+ class DecoupledEmbedding(nn.Embedding):
216
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/sparse.html#Embedding
217
+ """
218
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the embeddings. In practise, the
219
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `num_additional_embeddings` > 0,
220
+ then it will create `num_additional_embeddings` additional parameters that are always trained. If
221
+ `num_additional_embeddings=0`, then the module defaults back to the regular behavior of `nn.Embedding`.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ max_original_id: int,
227
+ num_additional_embeddings: int = 0,
228
+ _weight: torch.Tensor = None,
229
+ num_original_embeddings: int = None,
230
+ embedding_dim: int = None,
231
+ partially_freeze=True,
232
+ device=None,
233
+ dtype=None,
234
+ pad_token_id=None,
235
+ ) -> None:
236
+ """
237
+ Args:
238
+ max_original_id (`int`):
239
+ The largest token id that should be embedded using the regular embedding (regular `weight`).
240
+ This is usually len(tokenizer) - 1 before additional tokens are added.
241
+ Note that this may not equal self.weight.shape[0]
242
+ num_additional_embeddings (`int`):
243
+ Number of additional tokens to initialize an Embedding matrix for (`additional_weight`).
244
+ _weight (`torch.Tensor`, *optional*, defaults to `None`): The regular weight tensor.
245
+ If provided, this sets the `num_original_embeddings` and `embedding_dim` parameters.
246
+ num_original_embeddings (`int`):
247
+ self.weight.shape[0]
248
+ embedding_dim (`int`):
249
+ The size of each embedding vector
250
+ partially_freeze: (`bool`, *optional*, defaults to `True`):
251
+ If `True`, the regular `weight` will be frozen. `additional_weight` is never frozen.
252
+ padding_idx (`int`, *optional*):
253
+ The padding index (needs to be less than num_embeddings)
254
+
255
+ Note: there are a lot of other parameters to initialize a standard `nn.Embedding` such as `padding_idx`,
256
+ `max_norm` or `norm_type`. We are not supporting these.
257
+ """
258
+ # validate args
259
+ if pad_token_id is not None and pad_token_id > max_original_id:
260
+ raise ValueError(
261
+ f"pad_token_id must be <= max_original_id. Got {pad_token_id} and {max_original_id}."
262
+ + "If the original tokenizer does not have a pad_token_id, use pad_token_id=None."
263
+ )
264
+ if _weight is not None:
265
+ assert (num_original_embeddings is None) or (
266
+ _weight.shape[0] == num_original_embeddings
267
+ ), f"num_original_embeddings={num_original_embeddings} but _weight.shape[0]={_weight.shape[0]}"
268
+ assert (embedding_dim is None) or (
269
+ _weight.shape[1] == embedding_dim
270
+ ), f"embedding_dim={embedding_dim} but _weight.shape[1]={_weight.shape[1]}"
271
+ num_original_embeddings = _weight.shape[0]
272
+ embedding_dim = _weight.shape[1]
273
+ else:
274
+ assert (
275
+ num_original_embeddings is not None
276
+ ), "num_original_embeddings must be provided if _weight is not provided"
277
+ assert (
278
+ embedding_dim is not None
279
+ ), "embedding_dim must be provided if _weight is not provided"
280
+
281
+ super().__init__(
282
+ num_embeddings=num_original_embeddings,
283
+ embedding_dim=embedding_dim,
284
+ device=device,
285
+ dtype=dtype,
286
+ padding_idx=pad_token_id,
287
+ _weight=_weight,
288
+ )
289
+ self.max_original_id = max_original_id
290
+ self.padding_idx = pad_token_id
291
+ self.num_additional_embeddings = num_additional_embeddings
292
+ if self.num_additional_embeddings > 0:
293
+ self.additional_embedding = nn.Embedding(
294
+ num_embeddings=self.num_additional_embeddings,
295
+ embedding_dim=embedding_dim,
296
+ device=device,
297
+ dtype=dtype,
298
+ )
299
+ self.set_requires_grad(
300
+ require_regular_grad=not partially_freeze, require_additional_grad=True
301
+ )
302
+
303
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
304
+ """
305
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
306
+ """
307
+ self.weight.requires_grad_(require_regular_grad)
308
+ self.additional_embedding.requires_grad_(require_additional_grad)
309
+
310
+ def forward(self, input_ids):
311
+ """
312
+ we have 2 embeddings, with different indices - one pretrained self.weight and another
313
+ self.additional_embedding.weight that is being trained.
314
+
315
+ in order to make a lookup of the input ids, we:
316
+ 1. find out the indices of the entries belonging to the 2nd embedding
317
+ 2. extract those values while subtracting the size of the first embedding (num_embeddings), since the 2nd
318
+ embedding starts from 0 and not num_embeddings
319
+ 3. perform the 2nd embedding lookup
320
+ 4. now we handle the 1st embedding, we overwrite indices belonging to the 2nd embedding with a padding index
321
+ 5. perform the 1st embedding lookup
322
+ 6. now we overwrite the values in the 1st embedding lookup with the values of the 2nd embedding lookup
323
+
324
+ note: for the 1st embedding lookup we could have looked up only the low indices and not do the padding, but
325
+ then we have to create a new tensor and populate it with 2 tensors that are spread out across various indices -
326
+ i.e. not a simple concat - I haven't benchmarked the complex case if it's any faster, given that seqlens are
327
+ usually relatively short it's probably not faster or if faster not by much - but might be a good idea to
328
+ measure.
329
+
330
+ """
331
+ if self.num_additional_embeddings == 0:
332
+ return F.embedding(input_ids, self.weight)
333
+
334
+ # Clone so that we don't modify the original input_ids later on
335
+ input_ids = input_ids.clone()
336
+ additional_vocab_indices = torch.where(input_ids > self.max_original_id)
337
+ input_ids_additional_vocab = input_ids[additional_vocab_indices]
338
+ additional_embeddings = self.additional_embedding(
339
+ input_ids_additional_vocab - self.max_original_id - 1
340
+ )
341
+
342
+ # for successful lookup replace input_ids with 0, the results of these will be discarded anyway
343
+ input_ids[additional_vocab_indices] = 0
344
+ full_vector = F.embedding(input_ids, self.weight)
345
+
346
+ # overwrite the records with high indices
347
+ full_vector[additional_vocab_indices] = additional_embeddings
348
+
349
+ return full_vector
350
+
351
+ def extra_repr(self) -> str:
352
+ return "num_original_embeddings={}, num_additional_embeddings={}, embedding_dim={}, partially_freeze={}".format(
353
+ self.max_original_id + 1,
354
+ self.num_additional_embeddings,
355
+ self.embedding_dim,
356
+ (not self.weight.requires_grad),
357
+ )
358
+
359
+
360
+ class DecoupledLinear(nn.Linear):
361
+ # Derived from https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear
362
+ """
363
+ Implements a decoupling of parameters to allow freezing (or not) a subset of the parameters. In practise, the
364
+ regular `weight` can be trained or frozen (i.e. `partially_freeze=True`), and if `additional_out_features` > 0,
365
+ then it will create `additional_out_features * in_features` additional parameters that are always trained. If
366
+ `additional_out_features=0`, then the module defaults back to the regular behavior of `nn.Linear`.
367
+ """
368
+
369
+ def __init__(
370
+ self,
371
+ max_original_id: int,
372
+ additional_out_features: int = 0,
373
+ _weight: torch.Tensor = None,
374
+ _bias: torch.Tensor = None,
375
+ in_features: int = None,
376
+ original_out_features: int = None,
377
+ bias: bool = True,
378
+ partially_freeze: bool = True,
379
+ device=None,
380
+ dtype=None,
381
+ ) -> None:
382
+ """
383
+ Args:
384
+ max_original_id (`int`): The largest token id that should be extracted from the regular weight.
385
+ This is usually len(tokenizer) - 1 before additional tokens are added.
386
+ Note that this may not equal original_out_features - 1
387
+ _weight: torch.Tensor, *optional*, defaults to `None`. The regular weight tensor.
388
+ If provided, this sets the `in_features` and `original_out_features` parameters.
389
+ _bias: torch.Tensor, *optional*, defaults to `None`. The regular bias tensor.
390
+ in_features: int. Input hidden size.
391
+ original_out_features: int. Original out_features of the language model's get_output_embeddings() function.
392
+ additional_out_features: int. Number of additional trainable dimensions.
393
+ bias: bool. Whether to include a bias term.
394
+ partially_freeze: bool, *optional*, defaults to `True`): If `True`, the regular `weight` will be frozen.
395
+ """
396
+ # argument validation
397
+ if _weight is not None:
398
+ assert (_weight.shape[0] == original_out_features) or (
399
+ original_out_features is None
400
+ ), f"original_out_features={original_out_features} but _weight.shape[0]={_weight.shape[0]}"
401
+ assert (_weight.shape[1] == in_features) or (
402
+ in_features is None
403
+ ), f"in_features={in_features} but _weight.shape[1]={_weight.shape[1]}"
404
+ in_features = _weight.shape[1]
405
+ original_out_features = _weight.shape[0]
406
+ else:
407
+ assert (
408
+ in_features is not None
409
+ ), "in_features must be provided if _weight is not provided"
410
+ assert (
411
+ original_out_features is not None
412
+ ), "original_out_features must be provided if _weight is not provided"
413
+
414
+ if _bias is not None:
415
+ assert bias is True, "bias must be True if _bias is provided"
416
+
417
+ # initialize original linear
418
+ super().__init__(
419
+ in_features,
420
+ original_out_features,
421
+ bias,
422
+ device,
423
+ dtype)
424
+
425
+ # set weight and bias manually
426
+ if _weight is not None:
427
+ self.weight = nn.Parameter(_weight)
428
+ if _bias is not None:
429
+ self.bias = nn.Parameter(_bias)
430
+
431
+ self.in_features = in_features
432
+ self.original_out_features = original_out_features
433
+ self.max_original_id = max_original_id
434
+
435
+ # initialize additional linear
436
+ self.additional_out_features = additional_out_features
437
+ self.has_bias = bias
438
+ if additional_out_features > 0:
439
+ self.additional_fc = nn.Linear(
440
+ in_features=in_features,
441
+ out_features=additional_out_features,
442
+ bias=self.has_bias,
443
+ device=device,
444
+ dtype=dtype,
445
+ )
446
+ self.set_requires_grad(
447
+ require_regular_grad=not partially_freeze, require_additional_grad=True
448
+ )
449
+
450
+ def set_requires_grad(self, require_regular_grad, require_additional_grad):
451
+ """
452
+ Helper function to separately set the requires_grad flag for the regular weight and the additional weight.
453
+ """
454
+ self.weight.requires_grad_(require_regular_grad)
455
+ if self.has_bias:
456
+ self.bias.requires_grad_(require_regular_grad)
457
+ self.additional_fc.requires_grad_(require_additional_grad)
458
+
459
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
460
+ output = F.linear(input, self.weight, self.bias)
461
+ output = output[..., : self.max_original_id + 1]
462
+
463
+ if self.additional_out_features > 0:
464
+ additional_features = F.linear(
465
+ input, self.additional_fc.weight, self.additional_fc.bias
466
+ )
467
+ output = torch.cat((output, additional_features), -1)
468
+ return output
469
+
470
+ def extra_repr(self) -> str:
471
+ """Overwriting `nn.Linear.extra_repr` to include new parameters."""
472
+ return "in_features={}, out_features={}, additional_out_features={}, bias={}, partially_freeze={}".format(
473
+ self.in_features,
474
+ self.max_original_id + 1,
475
+ self.additional_out_features,
476
+ self.bias is not None,
477
+ (not self.weight.requires_grad or not self.bias.requires_grad),
478
+ )
479
+
480
+ class VLM(nn.Module):
481
+ """
482
+ Generic vision-language model (VLM) class.
483
+ A VLM consists of four components:
484
+ 1. A vision encoder that extracts features from pixels, e.g. CLIP
485
+ input: (B, T_img, F, C, H, W)
486
+ output: (B, T_img, F, v, d)
487
+ 2. A vision tokenizer that converts these features to visual token-like embeddings, e.g. Perceiver, or a linear projection head
488
+ input: (B, T_img, F, v, d)
489
+ output: (B, T_img, n, d)
490
+ 3. A fusion method that allows the language model to attend to these tokens, e.g. cross-attention, or placing the tokens directly in the language model's input sequence
491
+ 4. A language model
492
+ """
493
+
494
+ def __init__(
495
+ self,
496
+ vision_encoder: nn.Module,
497
+ vision_tokenizer: nn.Module,
498
+ lang_model: nn.Module,
499
+ initial_tokenizer_len: int,
500
+ pad_token_id: int,
501
+ gradient_checkpointing: bool = False,
502
+ ):
503
+ """
504
+ Args:
505
+ vision_encoder (nn.Module): e.g. CLIP
506
+ vision_tokenizer (nn.Module): e.g. PerceiverResampler
507
+ lang_model (nn.Module): e.g. MPT
508
+ initial_tokenizer_len (int): size of the original tokenizer vocab
509
+ pad_token_id (int): id of the pad token
510
+ gradient_checkpointing (bool, optional): Whether to use gradient checkpointing. Defaults to False.
511
+ """
512
+ super().__init__()
513
+
514
+ # save dimension information
515
+ self.lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
516
+ if hasattr(lang_model.config, "d_model"):
517
+ self.lang_hidden_dim = lang_model.config.d_model # mpt uses d_model
518
+ else:
519
+ self.lang_hidden_dim = lang_model.config.hidden_size
520
+ self.vis_embedding_dim = vision_tokenizer.dim_media
521
+ self.num_tokens_per_vis = vision_tokenizer.num_tokens_per_media
522
+
523
+ # core components
524
+ self.vision_encoder = vision_encoder
525
+ self.vision_tokenizer = vision_tokenizer
526
+ self.lang_model = lang_model
527
+
528
+ # lm embeddings
529
+ self.pad_token_id = pad_token_id
530
+ self.initial_tokenizer_len = initial_tokenizer_len
531
+ input_embeds = DecoupledEmbedding(
532
+ max_original_id=initial_tokenizer_len - 1,
533
+ num_additional_embeddings=len(self.special_tokens),
534
+ _weight=self.lang_model.get_input_embeddings().weight,
535
+ pad_token_id=self.pad_token_id,
536
+ )
537
+ if hasattr(input_embeds, "additional_embedding"):
538
+ input_embeds.additional_embedding.weight.data.normal_(
539
+ mean=0.0,
540
+ std=self.lang_model.config.initializer_range
541
+ if hasattr(self.lang_model.config, "initializer_range")
542
+ else 0.02,
543
+ )
544
+ self.lang_model.set_input_embeddings(input_embeds)
545
+
546
+ out_embeds = DecoupledLinear(
547
+ max_original_id=initial_tokenizer_len - 1,
548
+ additional_out_features=len(self.special_tokens),
549
+ _weight=self.lang_model.get_output_embeddings().weight,
550
+ _bias=self.lang_model.get_output_embeddings().bias if hasattr(self.lang_model.get_output_embeddings(), "bias") else None,
551
+ )
552
+ if hasattr(out_embeds, "additional_fc"):
553
+ out_embeds.additional_fc.weight.data.normal_(
554
+ mean=0.0,
555
+ std=self.lang_model.config.initializer_range
556
+ if hasattr(self.lang_model.config, "initializer_range")
557
+ else 0.02,
558
+ )
559
+ self.lang_model.set_output_embeddings(out_embeds)
560
+
561
+ # gradient checkpointing
562
+ self.vision_tokenizer._use_gradient_checkpointing = gradient_checkpointing
563
+
564
+ def forward(
565
+ self,
566
+ vision_x: Optional[torch.Tensor],
567
+ lang_x: torch.Tensor,
568
+ attention_mask: Optional[torch.Tensor] = None,
569
+ labels: Optional[torch.Tensor] = None,
570
+ past_key_values: Optional[
571
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
572
+ ] = None,
573
+ past_media_locations: Optional[torch.Tensor] = None,
574
+ past_vision_tokens: Optional[torch.Tensor] = None,
575
+ use_cache: Optional[bool] = False,
576
+ **kwargs,
577
+ ):
578
+ """
579
+ Args:
580
+ vision_x: Vision input
581
+ shape (B, T_img, F, C, H, W) with F=1
582
+ only F = 1 is supported (single-frame videos)
583
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
584
+ only the first number of media tokens in lang_x are used
585
+ lang_x: Language input ids, with media tokens denoting where
586
+ visual media should be inserted.
587
+ shape (B, T_txt)
588
+ attention_mask: Attention mask. Defaults to None.
589
+ labels: Labels. Defaults to None.
590
+ shape (B, T_txt)
591
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
592
+ list of length = number of decoder layers in the LM
593
+ exact implementation depends on LM, see Hugging Face docs
594
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
595
+ shape (B, T_txt)
596
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
597
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
598
+ If True, includes key_values, media_locations, and vision_tokens in the output.
599
+ """
600
+ assert not (past_vision_tokens is None) ^ (
601
+ past_media_locations is None
602
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
603
+
604
+ # convert pixels to vision tokens
605
+ if vision_x is not None:
606
+ vision_features = self._encode_vision_x(vision_x=vision_x)
607
+ vision_tokens = self.vision_tokenizer(vision_features)
608
+ else:
609
+ vision_tokens = None
610
+
611
+ # fuse the vision and language tokens
612
+ new_inputs = self._prepare_inputs_for_forward(
613
+ vision_tokens=vision_tokens,
614
+ lang_x=lang_x,
615
+ attention_mask=attention_mask,
616
+ labels=labels,
617
+ past_key_values=past_key_values,
618
+ past_media_locations=past_media_locations,
619
+ padding_side="right",
620
+ past_vision_tokens=past_vision_tokens,
621
+ )
622
+ output = self.lang_model(
623
+ **new_inputs,
624
+ use_cache=use_cache,
625
+ past_key_values=past_key_values,
626
+ **kwargs,
627
+ )
628
+
629
+ # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
630
+ # or to add the past_vision_tokens and past_media_locations to the output
631
+ output = self._postprocess_outputs_from_forward(
632
+ output=output,
633
+ lang_x=lang_x,
634
+ vision_tokens=vision_tokens,
635
+ use_cache=use_cache,
636
+ past_vision_tokens=past_vision_tokens,
637
+ past_media_locations=past_media_locations,
638
+ )
639
+
640
+ # postforward hooks
641
+ self._post_forward_hook()
642
+ return output
643
+
644
+ def _encode_vision_x_anyres(self, samples, device):
645
+ image_raw = samples["image"] # list of patch list in of shape [1, N_patch, C, H, W]
646
+ image_sizes = samples["image_size"]
647
+
648
+ # concate list of patches into one big patch for any res encoding.
649
+ images = [x.squeeze(0) for x in image_raw] # [N_patch, C, H, W]
650
+ image = torch.cat(images, dim=0) # [\sum{B}{N_patch_i}, C, H, W]
651
+ image = image.to(device)
652
+
653
+ with torch.no_grad():
654
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
655
+ image_embeds = self.vision_encoder.trunk.forward_features(image)
656
+ elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
657
+ image_embeds = self.vision_encoder(image).last_hidden_state
658
+ else:
659
+ image_embeds = self.vision_encoder(image)[1] # OpenCLIP returns tuples
660
+
661
+ if isinstance(self.vision_encoder, CLIPVisionModel):
662
+ base_img_size = self.vision_encoder.config.image_size
663
+ else:
664
+ base_img_size = self.vision_encoder.image_size[0]
665
+
666
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
667
+ grid_size = self.vision_encoder.trunk.patch_embed.grid_size
668
+ elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
669
+ grid_size_base = self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size
670
+ grid_size = (grid_size_base, grid_size_base)
671
+ else:
672
+ grid_size = self.vision_encoder.grid_size
673
+ height, width = grid_size
674
+
675
+ if not image_embeds.shape[1] == height * width:
676
+ assert image_embeds.shape[1] == height * width + 1 # For vision encoders that has [CLS] token.
677
+ image_embeds = image_embeds[:, 1:, :] # Drop the cls token for each patch.
678
+ n_vis_token_per_patch = image_embeds.shape[1]
679
+
680
+ # Split encoded patches and merge patch features
681
+ # 1. Get the raw sizes from samples, and split the image embeds [\sum_{B}(N_patch_i), N_tok(16*16), C]
682
+ split_sizes = [image.shape[0] for image in images]
683
+ image_embeds = torch.split(image_embeds, split_sizes, dim=0)
684
+ # 2. For each image (consist of a list of patches), merge the patches spatially (of shape [C, n_patch_height, n_patch_width])
685
+ new_image_embeds = []
686
+ patch_attn_masks = []
687
+ max_n_img_token = -1
688
+ for idx, patch_embeds in enumerate(image_embeds):
689
+ if patch_embeds.shape[0] > 1:
690
+ # 3. Flatten the patch features and get [C, n_patch_height * (n_patch_width+1)]
691
+ base_patch_embeds = patch_embeds[0] # TODO: prepend the CLS token for th base patch embeds (of the resized entire image).
692
+ patch_embeds = patch_embeds[1:]
693
+
694
+ assert height * width == base_patch_embeds.shape[0]
695
+
696
+ num_patch_width, num_patch_height = get_anyres_image_grid_shape(image_sizes[idx],
697
+ [[base_img_size,base_img_size*2],
698
+ [base_img_size*2,base_img_size],
699
+ [base_img_size*2,base_img_size*2],
700
+ [base_img_size*3,base_img_size],
701
+ [base_img_size,base_img_size*3]],
702
+ base_img_size) # Hardcoded grid_pinpoints.
703
+ patch_embeds = patch_embeds.view(num_patch_height, num_patch_width, height, width, -1)
704
+
705
+ patch_embeds = patch_embeds.permute(4, 0, 2, 1, 3).contiguous()
706
+ patch_embeds = patch_embeds.flatten(1, 2).flatten(2, 3)
707
+ # TODO: add an option that return masked patch_embeds instead of trimmed.
708
+ patch_embeds, patch_attn_mask = unpad_image(patch_embeds, image_sizes[idx], self.anyres_patch_sampling)
709
+ if hasattr(self, 'image_newline'):
710
+ patch_embeds = torch.cat((
711
+ patch_embeds,
712
+ self.image_newline[:, None, None].expand(*patch_embeds.shape[:-1], 1)
713
+ ), dim=-1)
714
+ if self.anyres_patch_sampling:
715
+ patch_embeds = patch_embeds.view(-1, num_patch_height, num_patch_width, height*width)
716
+ patch_embeds = patch_embeds.flatten(1, 2).permute(1, 2, 0)
717
+ assert patch_attn_mask is not None
718
+ patch_attn_mask = patch_attn_mask.view(num_patch_height, num_patch_width, height*width)
719
+ patch_attn_mask = patch_attn_mask.flatten(0, 1)
720
+ patch_embeds = torch.cat((base_patch_embeds.unsqueeze(0), patch_embeds), dim=0)
721
+ patch_attn_mask = torch.cat((torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0), patch_attn_mask), dim=0)
722
+ else:
723
+ patch_embeds = patch_embeds.flatten(1, 2).transpose(0, 1)
724
+ patch_embeds = torch.cat((base_patch_embeds, patch_embeds), dim=0)
725
+ else:
726
+ patch_embeds = patch_embeds[0].unsqueeze(0) if self.anyres_patch_sampling else patch_embeds[0]
727
+ patch_attn_mask = torch.ones(n_vis_token_per_patch, device=patch_embeds.device).unsqueeze(0) if self.anyres_patch_sampling else None
728
+ if hasattr(self, 'image_newline'):
729
+ patch_embeds = torch.cat((
730
+ patch_embeds,
731
+ self.image_newline[None]
732
+ ), dim=0)
733
+ if not self.anyres_patch_sampling:
734
+ max_n_img_token = max(patch_embeds.shape[0], max_n_img_token)
735
+
736
+ new_image_embeds.append(patch_embeds)
737
+ patch_attn_masks.append(patch_attn_mask)
738
+
739
+ if self.anyres_patch_sampling:
740
+ # Return individual patches for independent token downsampling.
741
+ return new_image_embeds, patch_attn_masks
742
+
743
+ # 4. Pad and concat the list of image_embeds [N_tok_i, C] together into a batch. Also modify the query attention mask.
744
+ image_embeds = []
745
+ image_atts = []
746
+ for image_embed in new_image_embeds:
747
+ n_img_token = image_embed.shape[0]
748
+ img_attn = torch.ones((max_n_img_token), dtype=torch.long, device=image_embed.device)
749
+ if n_img_token < max_n_img_token:
750
+ padded_embed = torch.zeros((max_n_img_token, image_embed.shape[-1]), dtype=image_embed.dtype, device=image_embed.device)
751
+ padded_embed[:n_img_token, :] = image_embed
752
+ img_attn[n_img_token:] = 0 # Mask out the padded entries.
753
+ else:
754
+ padded_embed = image_embed
755
+ image_embeds.append(padded_embed)
756
+ image_atts.append(img_attn)
757
+ image_embeds = torch.stack(image_embeds, dim=0) # Shape [B, N_tok_longest, C_dim]
758
+ image_atts = torch.stack(image_atts, dim=0) # Shape [B, N_tok_longest, C_dim]
759
+ # TODO: reshape image_embeds and image_atts to "b T F v d"
760
+ image_embeds = image_embeds[:, None, None, :, :]
761
+ # image_atts = image_atts[:, None, None, :, :]
762
+
763
+ return image_embeds, image_atts
764
+
765
+ def _encode_vision_x(self, vision_x: torch.Tensor):
766
+ """
767
+ Compute media tokens from vision input by passing it through vision encoder and conditioning language model.
768
+ Args:
769
+ vision_x: Vision input
770
+ shape (B, T_img, F, C, H, W)
771
+ Images in the same chunk are collated along T_img, and frames are collated along F
772
+ Currently only F=1 is supported (single-frame videos)
773
+
774
+ rearrange code based on https://github.com/dhansmair/flamingo-mini
775
+ """
776
+ assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
777
+ b, T, F = vision_x.shape[:3]
778
+
779
+ vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
780
+ with torch.no_grad():
781
+ if self.vision_encoder.__class__.__name__ == "TimmModel":
782
+ vision_x = self.vision_encoder.trunk.forward_features(vision_x)
783
+ elif self.vision_encoder.__class__.__name__ == 'CLIPVisionModel':
784
+ vision_x = self.vision_encoder(vision_x).last_hidden_state
785
+ else:
786
+ vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
787
+ vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
788
+ return vision_x
789
+
790
+ def _concat_vision_cache(
791
+ self, lang_x, vision_tokens, past_vision_tokens, past_media_locations, use_cache
792
+ ):
793
+ """
794
+ Helper function to include the past vision tokens and past media locations in the output.
795
+ """
796
+ if use_cache:
797
+ if past_media_locations is not None and past_vision_tokens is not None:
798
+ if vision_tokens is not None:
799
+ updated_vision_tokens = torch.cat(
800
+ [
801
+ past_vision_tokens,
802
+ vision_tokens,
803
+ ],
804
+ dim=1,
805
+ )
806
+ else:
807
+ updated_vision_tokens = past_vision_tokens
808
+ updated_media_locations = torch.cat(
809
+ [
810
+ past_media_locations,
811
+ lang_x == self.media_token_id,
812
+ ],
813
+ dim=1,
814
+ )
815
+ else:
816
+ updated_vision_tokens = vision_tokens
817
+ updated_media_locations = lang_x == self.media_token_id
818
+
819
+ else:
820
+ updated_vision_tokens = None
821
+ updated_media_locations = None
822
+
823
+ return updated_vision_tokens, updated_media_locations
824
+
825
+ def generate(
826
+ self,
827
+ vision_x: torch.Tensor,
828
+ lang_x: torch.Tensor,
829
+ attention_mask: torch.Tensor = None,
830
+ past_key_values: Optional[
831
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
832
+ ] = None,
833
+ past_media_locations: Optional[torch.Tensor] = None,
834
+ past_vision_tokens: Optional[torch.Tensor] = None,
835
+ **kwargs,
836
+ ):
837
+ """
838
+ Generate text conditioned on vision and language inputs.
839
+ Args:
840
+ vision_x (torch.Tensor): Vision input
841
+ shape (B, T_img, F, C, H, W)
842
+ see documentation for forward
843
+ lang_x (torch.Tensor): Language input
844
+ shape (B, T_txt)
845
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
846
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
847
+ Returns:
848
+ torch.Tensor: lang_x with generated tokens appended to it
849
+ """
850
+ num_beams = kwargs.pop("num_beams", 1)
851
+
852
+ # convert pixels to vision tokens
853
+ if vision_x is not None:
854
+ vision_features = self._encode_vision_x(vision_x=vision_x)
855
+ vision_tokens = self.vision_tokenizer(vision_features)
856
+ else:
857
+ vision_tokens = None
858
+
859
+ # fuse the vision and language tokens
860
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
861
+ # the total batch size is B * num_beams
862
+ new_inputs = self._prepare_inputs_for_forward(
863
+ vision_tokens=vision_tokens,
864
+ lang_x=lang_x,
865
+ attention_mask=attention_mask,
866
+ past_key_values=past_key_values,
867
+ past_media_locations=past_media_locations,
868
+ past_vision_tokens=past_vision_tokens,
869
+ padding_side="left",
870
+ num_beams=num_beams,
871
+ )
872
+ output = self.lang_model.generate(
873
+ **new_inputs,
874
+ past_key_values=past_key_values,
875
+ num_beams=num_beams,
876
+ use_cache=True,
877
+ **kwargs,
878
+ )
879
+ self._post_forward_hook()
880
+ return output
881
+
882
+ @property
883
+ def num_trainable_params(self):
884
+ """Print the number of trainable parameters"""
885
+ return num_params(self, filter_to_trainable=True)
886
+
887
+ def set_trainable(self):
888
+ """
889
+ Freeze appropriate parameters in the model.
890
+ """
891
+ raise NotImplementedError
892
+
893
+ def group_params_by_weight_decay(self):
894
+ """
895
+ Return a tuple of (params to optimize w/ weight decay, params to optimize w/o weight decay)
896
+ """
897
+ params_with_wd, params_without_wd = [], []
898
+ for n, p in self.named_parameters():
899
+ if p.requires_grad:
900
+ if self._should_apply_weight_decay(n):
901
+ params_with_wd.append(p)
902
+ else:
903
+ params_without_wd.append(p)
904
+ return params_with_wd, params_without_wd
905
+
906
+ def _should_apply_weight_decay(self, parameter_name):
907
+ """
908
+ Return whether weight decay should be applied to a parameter.
909
+ """
910
+ raise NotImplementedError
911
+
912
+ @property
913
+ def special_tokens(self):
914
+ """
915
+ Returns a dict mapping from the attribute name of a special token to its string format,
916
+ e.g. "media_token": "<image>"
917
+ """
918
+ assert (
919
+ "media_token" in self._special_tokens
920
+ ), "VLMs need to request that the tokenizer add a media_token and call set_special_token_ids to set self.media_token_id"
921
+ return self._special_tokens
922
+
923
+ @property
924
+ def special_token_ids(self):
925
+ """
926
+ Returns a list of the special token ids
927
+ """
928
+ return [getattr(self, f"{att_name}_id") for att_name in self.special_tokens]
929
+
930
+ def set_special_token_ids(self, string_to_ids):
931
+ """
932
+ Args:
933
+ string_to_ids (dict): mapping from token string to id
934
+ """
935
+ assert set(self.special_tokens.values()).issubset(set(string_to_ids.keys()))
936
+ for att_name, token_str in self.special_tokens.items():
937
+ token_id = string_to_ids[token_str]
938
+ setattr(self, f"{att_name}_id", token_id)
939
+ setattr(self.lang_model, f"{att_name}_id", token_id)
940
+
941
+ def init_gradient_checkpointing(self):
942
+ from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
943
+ checkpoint_wrapper,
944
+ CheckpointWrapper,
945
+ CheckpointImpl,
946
+ apply_activation_checkpointing,
947
+ )
948
+ from functools import partial
949
+
950
+ non_reentrant_wrapper = partial(
951
+ checkpoint_wrapper,
952
+ checkpoint_impl=CheckpointImpl.NO_REENTRANT,
953
+ )
954
+ apply_activation_checkpointing(
955
+ self,
956
+ checkpoint_wrapper_fn=non_reentrant_wrapper,
957
+ check_fn=lambda m: getattr(m, "_use_gradient_checkpointing", False)
958
+ and not isinstance(m, CheckpointWrapper),
959
+ )
960
+
961
+ @dataclass
962
+ class VLMOutputWithPast(CausalLMOutputWithPast):
963
+ """
964
+ VLMOutputWithPast is a wrapper around CausalLMOutputWithPast that adds the following attributes:
965
+ past_media_locations: Optional[torch.Tensor] = None,
966
+ past_vision_tokens: Optional[torch.Tensor] = None,
967
+ """
968
+
969
+ past_media_locations: Optional[torch.Tensor] = None
970
+ past_vision_tokens: Optional[torch.Tensor] = None
971
+
972
+
973
+ def exists(val):
974
+ return val is not None
975
+
976
+
977
+ def FeedForward(dim, mult=4):
978
+ inner_dim = int(dim * mult)
979
+ return nn.Sequential(
980
+ nn.LayerNorm(dim),
981
+ nn.Linear(dim, inner_dim, bias=False),
982
+ nn.GELU(),
983
+ nn.Linear(inner_dim, dim, bias=False),
984
+ )
985
+
986
+ class VLMWithLanguageStream(VLM):
987
+ """
988
+ VLM that fuses modalities by inserting vision tokens directly into the language stream.
989
+ """
990
+
991
+ def __init__(
992
+ self,
993
+ vision_encoder: nn.Module,
994
+ vision_tokenizer: nn.Module,
995
+ lang_model: nn.Module,
996
+ initial_tokenizer_len: int,
997
+ pad_token_id: int,
998
+ decoder_layers_attr_name: str = None,
999
+ gradient_checkpointing: bool = False,
1000
+ ):
1001
+ super().__init__(
1002
+ vision_encoder=vision_encoder,
1003
+ vision_tokenizer=vision_tokenizer,
1004
+ lang_model=lang_model,
1005
+ initial_tokenizer_len=initial_tokenizer_len,
1006
+ pad_token_id=pad_token_id,
1007
+ gradient_checkpointing=gradient_checkpointing,
1008
+ )
1009
+ self.decoder_layers_attr_name = decoder_layers_attr_name
1010
+ if decoder_layers_attr_name is not None:
1011
+ for block in getattr_recursive(self.lang_model, self.decoder_layers_attr_name):
1012
+ block._use_gradient_checkpointing = gradient_checkpointing
1013
+
1014
+ def _prepare_inputs_for_forward(
1015
+ self,
1016
+ vision_tokens: torch.Tensor,
1017
+ lang_x: torch.Tensor,
1018
+ attention_mask: torch.Tensor,
1019
+ labels: torch.Tensor = None,
1020
+ past_key_values=None,
1021
+ vision_attention_mask: Optional[torch.Tensor] = None,
1022
+ past_media_locations: torch.Tensor = None,
1023
+ past_vision_tokens: torch.Tensor = None,
1024
+ padding_side: str = "left",
1025
+ num_beams: int = 1,
1026
+ ):
1027
+ """
1028
+ Insert the vision tokens directly into the language stream/
1029
+ This requires us to modify the input_ids, attention_mask, and labels.
1030
+ """
1031
+ if past_key_values is not None:
1032
+ past_len = past_key_values[0][0].shape[2]
1033
+ assert attention_mask.shape[1] == past_len + lang_x.shape[1], (
1034
+ "Attention_mask must be as long as the entire past len (including image tokens) and current input IDs. "
1035
+ + "Check that you've expanded the attention mask to account for past image tokens."
1036
+ )
1037
+
1038
+ if vision_tokens is None:
1039
+ return {
1040
+ "input_ids": lang_x,
1041
+ "attention_mask": attention_mask,
1042
+ "labels": labels,
1043
+ }
1044
+
1045
+ # get the language embeddings
1046
+ lang_embeds = self.lang_model.get_input_embeddings()(lang_x)
1047
+
1048
+ # build up the multimodal embeddings
1049
+ B = lang_x.shape[0]
1050
+ has_labels = labels is not None
1051
+ multimodal_embeds = []
1052
+ multimodal_attention_mask = []
1053
+ multimodal_labels = [] if has_labels else None
1054
+ for i in range(B):
1055
+ # get index of <image> tokens in lang_x[i]
1056
+ image_token_idxs = torch.where(lang_x[i] == self.media_token_id)[0]
1057
+
1058
+ if len(image_token_idxs) == 0:
1059
+ multimodal_embeds.append(lang_embeds[i].clone())
1060
+ multimodal_attention_mask.append(attention_mask[i].clone())
1061
+ if has_labels:
1062
+ multimodal_labels.append(labels[i].clone())
1063
+ continue
1064
+
1065
+ # since an image is represented by self.num_tokens_per_vis tokens, we need to offset the image_token_idxs
1066
+ for j, img_idx in enumerate(image_token_idxs):
1067
+ image_token_idxs[j] += (self.num_tokens_per_vis - 1) * j # FIXME: different offset for any resolution encoding when has multiple images.
1068
+
1069
+ # loop through the image_token_idxs and insert the vision tokens
1070
+ new_embed = lang_embeds[i].clone()
1071
+ new_attention_mask = (
1072
+ attention_mask[i].clone() if attention_mask is not None else None
1073
+ )
1074
+ if has_labels:
1075
+ new_label = labels[i].clone()
1076
+
1077
+ for img_num, img_idx in enumerate(image_token_idxs):
1078
+ if img_num > 0:
1079
+ # FIXME: hardcoded as such to avoid assertion error, but this only works for single image samples.
1080
+ break
1081
+ # Get vision token attention mask for padded llava-style any resolution image tokens.
1082
+ if self.image_aspect_ratio =='anyres':
1083
+ num_vis_tokens = vision_tokens[i][img_num].shape[0]
1084
+ if vision_attention_mask is not None:
1085
+ vis_attention_mask = vision_attention_mask[i]
1086
+ else:
1087
+ vis_attention_mask = torch.ones(
1088
+ num_vis_tokens, dtype=torch.long
1089
+ ).to(attention_mask.device)
1090
+ else:
1091
+ assert (
1092
+ vision_tokens[i][img_num].shape[0] == self.num_tokens_per_vis
1093
+ ), f"vision token number mismatch: image embedding ({vision_tokens[i][img_num].shape[0]}) \
1094
+ vs. model.num_tokens_per_vis ({self.num_tokens_per_vis})"
1095
+ # By default, vision tokens are not padded.
1096
+ num_vis_tokens = self.num_tokens_per_vis
1097
+ vis_attention_mask = torch.ones(
1098
+ num_vis_tokens, dtype=torch.long
1099
+ ).to(attention_mask.device)
1100
+
1101
+
1102
+ new_embed = torch.cat(
1103
+ (
1104
+ new_embed[:img_idx],
1105
+ vision_tokens[i][img_num],
1106
+ new_embed[img_idx + 1 :],
1107
+ ),
1108
+ dim=0,
1109
+ )
1110
+ new_attention_mask = torch.cat(
1111
+ (
1112
+ new_attention_mask[:img_idx],
1113
+ vis_attention_mask,
1114
+ new_attention_mask[img_idx + 1 :],
1115
+ ),
1116
+ dim=0,
1117
+ )
1118
+ if has_labels:
1119
+ new_label = torch.cat(
1120
+ (
1121
+ new_label[:img_idx],
1122
+ torch.ones(num_vis_tokens, dtype=torch.long).to(
1123
+ labels.device
1124
+ )
1125
+ * -100,
1126
+ new_label[img_idx + 1 :],
1127
+ ),
1128
+ dim=0,
1129
+ )
1130
+ multimodal_embeds.append(new_embed)
1131
+ multimodal_attention_mask.append(new_attention_mask)
1132
+ if has_labels:
1133
+ multimodal_labels.append(new_label)
1134
+
1135
+ # stack
1136
+ multimodal_embeds = stack_with_padding(
1137
+ multimodal_embeds,
1138
+ padding_value=self.pad_token_id,
1139
+ padding_side=padding_side,
1140
+ )
1141
+ multimodal_attention_mask = stack_with_padding(
1142
+ multimodal_attention_mask,
1143
+ padding_value=0,
1144
+ padding_side=padding_side,
1145
+ )
1146
+ if has_labels:
1147
+ multimodal_labels = stack_with_padding(
1148
+ multimodal_labels,
1149
+ padding_value=-100,
1150
+ padding_side=padding_side,
1151
+ )
1152
+
1153
+ return {
1154
+ "inputs_embeds": multimodal_embeds,
1155
+ "attention_mask": multimodal_attention_mask,
1156
+ "labels": multimodal_labels,
1157
+ }
1158
+
1159
+ def _postprocess_outputs_from_forward(
1160
+ self,
1161
+ output: CausalLMOutputWithPast,
1162
+ lang_x: torch.Tensor,
1163
+ vision_tokens: torch.Tensor,
1164
+ past_vision_tokens: torch.Tensor,
1165
+ past_media_locations: torch.Tensor,
1166
+ use_cache: bool = False,
1167
+ ):
1168
+ # Include the past vision tokens and past media locations in the output
1169
+ updated_vision_tokens, updated_media_locations = self._concat_vision_cache(
1170
+ lang_x=lang_x,
1171
+ vision_tokens=vision_tokens,
1172
+ past_vision_tokens=past_vision_tokens,
1173
+ past_media_locations=past_media_locations,
1174
+ use_cache=use_cache,
1175
+ )
1176
+
1177
+ # return logits that are the same shape as the original input_ids
1178
+ logits = output.logits
1179
+ batch_logits = []
1180
+ B, T_txt = lang_x.shape
1181
+ for i in range(B):
1182
+ sequence_logits = []
1183
+ logits_j = 0
1184
+ for j in range(T_txt):
1185
+ if lang_x[i, j] != self.media_token_id:
1186
+ sequence_logits.append(logits[i, logits_j])
1187
+ logits_j += 1
1188
+ else:
1189
+ # append the logit for the first image token, then skip over the rest
1190
+ # note: the model actually learns to predict <im_patch>, not <image>
1191
+ sequence_logits.append(logits[i, logits_j])
1192
+ logits_j += self.num_tokens_per_vis
1193
+ sequence_logits = torch.stack(sequence_logits, dim=0) # (B, vocab_size)
1194
+ batch_logits.append(sequence_logits)
1195
+
1196
+ batch_logits = torch.stack(batch_logits, dim=0) # (B, T_txt, vocab_size)
1197
+ # The final logits shape should be the same as the original input_ids shape
1198
+ assert batch_logits.shape[:2] == (B, T_txt)
1199
+
1200
+ # assemble the output
1201
+ output = VLMOutputWithPast(
1202
+ loss=output.loss,
1203
+ logits=batch_logits,
1204
+ past_key_values=output.past_key_values,
1205
+ hidden_states=output.hidden_states,
1206
+ attentions=output.attentions,
1207
+ past_media_locations=updated_media_locations,
1208
+ past_vision_tokens=updated_vision_tokens,
1209
+ )
1210
+
1211
+ return output
1212
+
1213
+ def _post_forward_hook(self):
1214
+ pass
1215
+
1216
+
1217
+ @property
1218
+ def num_params_per_module(self):
1219
+ """Print the number of parameters per module in the model"""
1220
+ return "\n".join(
1221
+ [
1222
+ f"Vision encoder: {num_params(self.vision_encoder):,} parameters",
1223
+ f"Vision tokenizer: {num_params(self.vision_tokenizer):,} parameters",
1224
+ f"Language model: {num_params(self.lang_model):,} parameters",
1225
+ ]
1226
+ )
1227
+
1228
+ @property
1229
+ def num_trainable_params_per_module(self):
1230
+ """Print the number of trainable parameters per module in the model"""
1231
+ return "\n".join(
1232
+ [
1233
+ f"Vision encoder: {num_params(self.vision_encoder, filter_to_trainable=True):,} trainable parameters",
1234
+ f"Vision tokenizer: {num_params(self.vision_tokenizer, filter_to_trainable=True):,} trainable parameters",
1235
+ f"Language model: {num_params(self.lang_model, filter_to_trainable=True):,} trainable parameters",
1236
+ ]
1237
+ )
1238
+
1239
+
1240
+ class KosmosInstruct(VLMWithLanguageStream):
1241
+ def __init__(
1242
+ self,
1243
+ vision_encoder: nn.Module,
1244
+ vision_tokenizer: nn.Module,
1245
+ lang_model: nn.Module,
1246
+ initial_tokenizer_len: int,
1247
+ pad_token_id: int,
1248
+ decoder_layers_attr_name: str = None,
1249
+ gradient_checkpointing: bool = False,
1250
+ image_aspect_ratio: str = 'pad',
1251
+ anyres_patch_sampling: bool = False
1252
+ ):
1253
+ """
1254
+ Args:
1255
+ vision_encoder (nn.Module): HF CLIPModel
1256
+ lang_encoder (nn.Module): HF causal language model
1257
+ vis_feature_dim (int): final dimension of the visual features outputted by the vision_encoder
1258
+ initial_tokenizer_len (int): size of the tokenizer vocab
1259
+ padding_token_id (int): id of the padding token. None if no padding token; then a padding token
1260
+ will be inserted into self.special_tokens, which factory.py fills after creating new tokens
1261
+ decoder_layers_attr_name (str, optional): name of the decoder layers attribute. Defaults to None.
1262
+ gradient_checkpointing (bool, optional): whether to use gradient checkpointing. Defaults to False.
1263
+ """
1264
+ self._special_tokens = {
1265
+ "media_token": "<image>",
1266
+ "image_placeholder_token": "<image placeholder>",
1267
+ "end_of_trunk_token": "<|endofchunk|>",
1268
+ }
1269
+ lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
1270
+ super().__init__(
1271
+ vision_encoder=vision_encoder,
1272
+ vision_tokenizer=vision_tokenizer,
1273
+ lang_model=lang_model,
1274
+ initial_tokenizer_len=initial_tokenizer_len,
1275
+ gradient_checkpointing=gradient_checkpointing,
1276
+ decoder_layers_attr_name=decoder_layers_attr_name,
1277
+ pad_token_id=pad_token_id,
1278
+ )
1279
+ self.image_aspect_ratio = image_aspect_ratio
1280
+ self.anyres_patch_sampling = anyres_patch_sampling
1281
+
1282
+ def set_trainable(self):
1283
+ """
1284
+ Unfreeze everything except the vision_encoder
1285
+ """
1286
+ self.requires_grad_(True)
1287
+ self.vision_encoder.requires_grad_(False)
1288
+
1289
+ def _should_apply_weight_decay(self, parameter_name):
1290
+ """
1291
+ Kosmos applies 0.01 weight deacy to everything
1292
+ """
1293
+ return True
1294
+
1295
+ def forward(
1296
+ self,
1297
+ vision_x: Optional[torch.Tensor],
1298
+ lang_x: torch.Tensor,
1299
+ attention_mask: Optional[torch.Tensor] = None,
1300
+ labels: Optional[torch.Tensor] = None,
1301
+ image_size: Optional[Tuple] = None,
1302
+ past_key_values: Optional[
1303
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1304
+ ] = None,
1305
+ past_media_locations: Optional[torch.Tensor] = None,
1306
+ past_vision_tokens: Optional[torch.Tensor] = None,
1307
+ use_cache: Optional[bool] = False,
1308
+ **kwargs,
1309
+ ):
1310
+ """
1311
+ Args:
1312
+ vision_x: Vision input
1313
+ shape (B, T_img, F, C, H, W) with F=1
1314
+ only F = 1 is supported (single-frame videos)
1315
+ if T_img > the number of media tokens in the corresponding input_ids (lang_x),
1316
+ only the first number of media tokens in lang_x are used
1317
+ lang_x: Language input ids, with media tokens denoting where
1318
+ visual media should be inserted.
1319
+ shape (B, T_txt)
1320
+ attention_mask: Attention mask. Defaults to None.
1321
+ labels: Labels. Defaults to None.
1322
+ shape (B, T_txt)
1323
+ past_key_values (Tuple[torch.Tensor]], optional): Past key value pairs for each of the T_txt previous tokens in the language model. Defaults to None.
1324
+ list of length = number of decoder layers in the LM
1325
+ exact implementation depends on LM, see Hugging Face docs
1326
+ past_media_locations (torch.Tensor, optional): boolean mask denoting which of the previous T_txt tokens were media tokens. Defaults to None.
1327
+ shape (B, T_txt)
1328
+ past_vision_tokens (torch.Tensor, optional): Previous vision tokens. Defaults to None.
1329
+ use_cache (Optional[bool], optional): Whether to use cache. Defaults to False.
1330
+ If True, includes key_values, media_locations, and vision_tokens in the output.
1331
+ """
1332
+ assert not (past_vision_tokens is None) ^ (
1333
+ past_media_locations is None
1334
+ ), "past_vision_tokens and past_media_locations must both be None or both be not None"
1335
+
1336
+ # convert pixels to vision tokens
1337
+ vision_attention_mask = None
1338
+ if vision_x is not None:
1339
+ if self.image_aspect_ratio == 'anyres':
1340
+ input_dict = dict(image=vision_x, image_size=image_size)
1341
+ vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
1342
+ else:
1343
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1344
+ vision_attn_masks = None
1345
+ if self.anyres_patch_sampling:
1346
+ split_sizes = [feature.shape[0] for feature in vision_features]
1347
+ vision_features = torch.cat(vision_features, dim=0)
1348
+ vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
1349
+ vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
1350
+ # Prepare text embeds for instruction-aware image query sampling.
1351
+ # FIXME: for debugging purposed, truncating text input to vision tokenizer to be 256 at max.
1352
+ lang_x_truncated = lang_x[:, :256]
1353
+ text_embeds = self.lang_model.get_input_embeddings()(lang_x_truncated)
1354
+ # TODO: repeat text_embeds to match the number of patches for each image patch group.
1355
+ if self.anyres_patch_sampling:
1356
+ repeated_text_embeds = []
1357
+ for i, np in enumerate(split_sizes):
1358
+ repeated_text_embeds.append(text_embeds[i].repeat(np, 1, 1))
1359
+ text_embeds = torch.cat(repeated_text_embeds, dim=0)
1360
+ vision_tokens = self.vision_tokenizer(vision_features, text_embeds)
1361
+
1362
+ # Post-processing: Split the batches into groups of patches and concatenate them together.
1363
+ if self.anyres_patch_sampling:
1364
+ vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
1365
+ max_n_vis_token = max([vis.shape[0]*vis.shape[-2] for vis in vision_token_groups])
1366
+ # Padding.
1367
+ padded_vision_tokens = []
1368
+ padded_attn_masks = []
1369
+ for patch_vis_tokens in vision_token_groups:
1370
+ patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1371
+ n_vis_token = patch_vis_tokens.shape[0]
1372
+ patch_attn = torch.ones((max_n_vis_token), dtype=torch.long, device=patch_vis_tokens.device)
1373
+ if n_vis_token < max_n_vis_token:
1374
+ padded_vis_token = torch.zeros((max_n_vis_token, patch_vis_tokens.shape[-1]),
1375
+ dtype=patch_vis_tokens.dtype, device=patch_vis_tokens.device)
1376
+ padded_vis_token[:n_vis_token, :] = patch_vis_tokens
1377
+ patch_attn[n_vis_token:] = 0
1378
+ else:
1379
+ padded_vis_token = patch_vis_tokens
1380
+ padded_vision_tokens.append(padded_vis_token)
1381
+ padded_attn_masks.append(patch_attn)
1382
+ vision_tokens = torch.stack(padded_vision_tokens, dim=0)
1383
+ vision_attention_mask = torch.stack(padded_attn_masks, dim=0)
1384
+ vision_tokens = vision_tokens[:, None, :, :]
1385
+ else:
1386
+ vision_tokens = None
1387
+
1388
+ # fuse the vision and language tokens
1389
+ new_inputs = self._prepare_inputs_for_forward(
1390
+ vision_tokens=vision_tokens,
1391
+ lang_x=lang_x,
1392
+ attention_mask=attention_mask,
1393
+ vision_attention_mask=vision_attention_mask,
1394
+ labels=labels,
1395
+ past_key_values=past_key_values,
1396
+ past_media_locations=past_media_locations,
1397
+ padding_side="right",
1398
+ past_vision_tokens=past_vision_tokens,
1399
+ )
1400
+ output = self.lang_model(
1401
+ **new_inputs,
1402
+ use_cache=use_cache,
1403
+ past_key_values=past_key_values,
1404
+ **kwargs,
1405
+ )
1406
+
1407
+ # postprocessing may be needed, e.g. to remove extra tokens from logits that were inserted into the language stream
1408
+ # or to add the past_vision_tokens and past_media_locations to the output
1409
+ output = self._postprocess_outputs_from_forward(
1410
+ output=output,
1411
+ lang_x=lang_x,
1412
+ vision_tokens=vision_tokens,
1413
+ use_cache=use_cache,
1414
+ past_vision_tokens=past_vision_tokens,
1415
+ past_media_locations=past_media_locations,
1416
+ )
1417
+
1418
+ # postforward hooks
1419
+ self._post_forward_hook()
1420
+ return output
1421
+
1422
+ def generate(
1423
+ self,
1424
+ vision_x: torch.Tensor,
1425
+ lang_x: torch.Tensor,
1426
+ image_size: Optional[Tuple] = None,
1427
+ attention_mask: torch.Tensor = None,
1428
+ past_key_values: Optional[
1429
+ List[Union[torch.Tensor, Tuple[torch.Tensor]]]
1430
+ ] = None,
1431
+ past_media_locations: Optional[torch.Tensor] = None,
1432
+ past_vision_tokens: Optional[torch.Tensor] = None,
1433
+ **kwargs,
1434
+ ):
1435
+ """
1436
+ Generate text conditioned on vision and language inputs.
1437
+ Args:
1438
+ vision_x (torch.Tensor): Vision input
1439
+ shape (B, T_img, F, C, H, W)
1440
+ see documentation for forward
1441
+ lang_x (torch.Tensor): Language input
1442
+ shape (B, T_txt)
1443
+ attention_mask (torch.Tensor, optional): Attention mask. Defaults to None.
1444
+ **kwargs: see generate documentation in Hugging Face CausalLM models.
1445
+ Returns:
1446
+ torch.Tensor: lang_x with generated tokens appended to it
1447
+ """
1448
+ num_beams = kwargs.pop("num_beams", 1)
1449
+
1450
+ # convert pixels to vision tokens
1451
+ vision_attention_mask = None
1452
+ if vision_x is not None:
1453
+ if self.image_aspect_ratio == 'anyres':
1454
+ input_dict = dict(image=vision_x, image_size=image_size)
1455
+ vision_features, vision_attn_masks = self._encode_vision_x_anyres(input_dict, lang_x.device)
1456
+ else:
1457
+ vision_features = self._encode_vision_x(vision_x=vision_x)
1458
+ vision_attn_masks = None
1459
+ if self.anyres_patch_sampling:
1460
+ split_sizes = [feature.shape[0] for feature in vision_features]
1461
+ vision_features = torch.cat(vision_features, dim=0)
1462
+ vision_features = vision_features[:, None, None, :, :] # Expand dimensions.
1463
+ vision_attn_masks = torch.cat(vision_attn_masks, dim=0)
1464
+ # Prepare text embeds for instruction-aware image query sampling.
1465
+ lang_x_truncated = lang_x[:, :256]
1466
+ text_embeds = self.lang_model.get_input_embeddings()(lang_x_truncated) # FIXME: check function calling.
1467
+ # Repeat text_embeds to match the number of patches for each image patch group.
1468
+ if self.anyres_patch_sampling:
1469
+ repeated_text_embeds = []
1470
+ for i, np in enumerate(split_sizes):
1471
+ repeated_text_embeds.append(text_embeds[i].repeat(np, 1, 1))
1472
+ text_embeds = torch.cat(repeated_text_embeds, dim=0)
1473
+ vision_tokens = self.vision_tokenizer(vision_features, text_embeds)
1474
+
1475
+ # Post-processing: Split the batches into groups of patches and concatenate them together.
1476
+ if self.anyres_patch_sampling:
1477
+ vision_token_groups = torch.split(vision_tokens, split_sizes, dim=0)
1478
+ max_n_vis_token = max([vis.shape[0]*vis.shape[-2] for vis in vision_token_groups])
1479
+ # Padding.
1480
+ padded_vision_tokens = []
1481
+ padded_attn_masks = []
1482
+ for patch_vis_tokens in vision_token_groups:
1483
+ patch_vis_tokens = patch_vis_tokens.flatten(0, 2) # [Np, 1, v, d] -> [Np*v, d]
1484
+ n_vis_token = patch_vis_tokens.shape[0]
1485
+ patch_attn = torch.ones((max_n_vis_token), dtype=torch.long, device=patch_vis_tokens.device)
1486
+ if n_vis_token < max_n_vis_token:
1487
+ padded_vis_token = torch.zeros((max_n_vis_token, patch_vis_tokens.shape[-1]),
1488
+ dtype=patch_vis_tokens.dtype, device=patch_vis_tokens.device)
1489
+ padded_vis_token[:n_vis_token, :] = patch_vis_tokens
1490
+ patch_attn[n_vis_token:] = 0
1491
+ else:
1492
+ padded_vis_token = patch_vis_tokens
1493
+ padded_vision_tokens.append(padded_vis_token)
1494
+ padded_attn_masks.append(patch_attn)
1495
+ vision_tokens = torch.stack(padded_vision_tokens, dim=0)
1496
+ vision_attention_mask = torch.stack(padded_attn_masks, dim=0)
1497
+ vision_tokens = vision_tokens[:, None, :, :]
1498
+ else:
1499
+ vision_tokens = None
1500
+
1501
+ # fuse the vision and language tokens
1502
+ # for xattn, vision_x and media_location are repeat_interleaved s.t.
1503
+ # the total batch size is B * num_beams
1504
+ new_inputs = self._prepare_inputs_for_forward(
1505
+ vision_tokens=vision_tokens,
1506
+ lang_x=lang_x,
1507
+ attention_mask=attention_mask,
1508
+ vision_attention_mask=vision_attention_mask,
1509
+ past_key_values=past_key_values,
1510
+ past_media_locations=past_media_locations,
1511
+ past_vision_tokens=past_vision_tokens,
1512
+ padding_side="left",
1513
+ num_beams=num_beams,
1514
+ )
1515
+ if transformers.__version__ == '4.41.0.dev0':
1516
+ output = self.lang_model.generate(
1517
+ **new_inputs,
1518
+ num_beams=num_beams,
1519
+ use_cache=True,
1520
+ **kwargs,
1521
+ )
1522
+ else:
1523
+ output = self.lang_model.generate(
1524
+ **new_inputs,
1525
+ past_key_values=past_key_values,
1526
+ num_beams=num_beams,
1527
+ use_cache=True,
1528
+ **kwargs,
1529
+ )
1530
+ self._post_forward_hook()
1531
+ return output