jw2yang commited on
Commit
18e9ab4
·
1 Parent(s): 25537d2

add modeling_magma

Browse files
Files changed (1) hide show
  1. modeling_magma.py +1460 -0
modeling_magma.py ADDED
@@ -0,0 +1,1460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Magma model."""
16
+
17
+ import math
18
+ import re
19
+ import os
20
+ from dataclasses import dataclass
21
+ from typing import List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torch.utils.checkpoint
26
+ from torch import nn
27
+ import wandb
28
+ import torch.distributed as dist
29
+ from transformers.modeling_utils import PreTrainedModel
30
+ from transformers.activations import ACT2FN
31
+ from transformers.cache_utils import Cache, DynamicCache
32
+ from transformers.utils import ModelOutput
33
+ from transformers.utils import (
34
+ add_code_sample_docstrings,
35
+ add_start_docstrings,
36
+ add_start_docstrings_to_model_forward,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from transformers import AutoConfig, AutoModelForCausalLM
41
+ from .configuration_magma import MagmaConfig
42
+ from .image_tower_magma import MagmaImageTower
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+ _CONFIG_FOR_DOC = "MagmaConfig"
47
+
48
+ @dataclass
49
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Magma
50
+ class MagmaCausalLMOutputWithPast(ModelOutput):
51
+ """
52
+ Base class for Magma causal language model (or autoregressive) outputs.
53
+
54
+ Args:
55
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
56
+ Language modeling loss (for next-token prediction).
57
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
58
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
59
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
60
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
61
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
62
+
63
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
64
+ `past_key_values` input) to speed up sequential decoding.
65
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
66
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
67
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
68
+
69
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
70
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
71
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
72
+ sequence_length)`.
73
+
74
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
75
+ heads.
76
+ image_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
77
+ Tuple of `torch.FloatTensor` (one for the output of the image embeddings, `(batch_size, num_images,
78
+ sequence_length, hidden_size)`.
79
+
80
+ image_hidden_states of the model produced by the vision encoder, and optionally by the perceiver
81
+ """
82
+
83
+ loss: Optional[torch.FloatTensor] = None
84
+ logits: torch.FloatTensor = None
85
+ past_key_values: Optional[List[torch.FloatTensor]] = None
86
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
87
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
88
+ image_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
89
+
90
+
91
+ class MagmaMultiModalProjector(nn.Module):
92
+ def __init__(self, config):
93
+ super().__init__()
94
+ self.config = config
95
+
96
+ dim_vision = {'base': 640, 'large': 768, 'xxlarge': 1024}
97
+ vision_backbone = config.get('vision_backbone', 'convnextxxlarge')
98
+ vision_backbone_size = vision_backbone.replace('convnext', '')
99
+ projector_type = config.get('mm_projector_type', 'linear')
100
+ mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu_segtokv(\d+)$', projector_type)
101
+ if mlp_gelu_match:
102
+ mlp_depth = int(mlp_gelu_match.group(1))
103
+ modules = [nn.Linear(config['mm_hidden_size'], config['hidden_size'])]
104
+ for _ in range(1, mlp_depth):
105
+ modules.append(nn.GELU())
106
+ modules.append(nn.Linear(config['hidden_size'], config['hidden_size']))
107
+ self.proj = nn.Sequential(*modules)
108
+
109
+ # define a row seperator
110
+ self.row_seperator = nn.Parameter(torch.zeros(1, 1, config['hidden_size']))
111
+ if config.get('mm_use_im_start_end', False):
112
+ self.img_start_seperator = nn.Parameter(torch.zeros(1, config['hidden_size']))
113
+ self.img_end_seperator = nn.Parameter(torch.zeros(1, config['hidden_size']))
114
+
115
+ def forward(self, x):
116
+ return self.proj(x)
117
+
118
+
119
+ MAGMA_START_DOCSTRING = r"""
120
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
121
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
122
+ etc.)
123
+
124
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
125
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
126
+ and behavior.
127
+
128
+ Parameters:
129
+ config ([`MagmaConfig`] or [`MagmaVisionConfig`]):
130
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
131
+ load the weights associated with the model, only the configuration. Check out the
132
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
133
+ """
134
+
135
+
136
+ @add_start_docstrings(
137
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
138
+ MAGMA_START_DOCSTRING,
139
+ )
140
+
141
+ class MagmaPreTrainedModel(PreTrainedModel):
142
+ config_class = MagmaConfig
143
+ base_model_prefix = "model"
144
+ supports_gradient_checkpointing = True
145
+ _no_split_modules = ["MagmaVisionAttention"]
146
+ _skip_keys_device_placement = "past_key_values"
147
+ _supports_flash_attn_2 = True
148
+
149
+ def _init_weights(self, module):
150
+ std = (
151
+ self.config.initializer_range
152
+ if hasattr(self.config, "initializer_range")
153
+ else self.config.text_config.initializer_range
154
+ )
155
+
156
+ if hasattr(module, "class_embedding"):
157
+ module.class_embedding.data.normal_(mean=0.0, std=std)
158
+
159
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
160
+ module.weight.data.normal_(mean=0.0, std=std)
161
+ if module.bias is not None:
162
+ module.bias.data.zero_()
163
+ elif isinstance(module, nn.Embedding):
164
+ module.weight.data.normal_(mean=0.0, std=std)
165
+ if module.padding_idx is not None:
166
+ module.weight.data[module.padding_idx].zero_()
167
+
168
+ @property
169
+ def _supports_sdpa(self):
170
+ """
171
+ Retrieve language_model's attribute to check whether the model supports
172
+ SDPA or not.
173
+ """
174
+ return self.language_model._supports_sdpa
175
+
176
+
177
+ MAGMA_INPUTS_DOCSTRING = r"""
178
+ Args:
179
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
180
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
181
+ it.
182
+
183
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
184
+ [`PreTrainedTokenizer.__call__`] for details.
185
+
186
+ [What are input IDs?](../glossary#input-ids)
187
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)):
188
+ The tensors corresponding to the input images. Pixel values can be obtained using
189
+ [`AutoImageProcessor`]. See [`MagmaImageProcessor.__call__`] for details. [`MagmaProcessor`] uses
190
+ [`MagmaImageProcessor`] for processing images.
191
+ image_sizes (`torch.LongTensor` of shape `(batch_size, 2)`, *optional*):
192
+ The sizes of the images in the batch, being (height, width) for each image.
193
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
194
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
195
+
196
+ - 1 for tokens that are **not masked**,
197
+ - 0 for tokens that are **masked**.
198
+
199
+ [What are attention masks?](../glossary#attention-mask)
200
+
201
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
202
+ [`PreTrainedTokenizer.__call__`] for details.
203
+
204
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
205
+ `past_key_values`).
206
+
207
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
208
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
209
+ information on the default strategy.
210
+
211
+ - 1 indicates the head is **not masked**,
212
+ - 0 indicates the head is **masked**.
213
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
214
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
215
+ config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids)
216
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
217
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
218
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
219
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
220
+
221
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
222
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
223
+
224
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
225
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
226
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
227
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
228
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
229
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
230
+ model's internal embedding lookup matrix.
231
+ vision_feature_layer (`int`, *optional*, defaults to -2):
232
+ The index of the layer to select the vision feature.
233
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"default"`):
234
+ The feature selection strategy used to select the vision feature from the vision backbone.
235
+ Can be one of `"default"` or `"full"`. If `"default"`, the CLS token is removed from the vision features.
236
+ If `"full"`, the full vision features are used.
237
+ use_cache (`bool`, *optional*):
238
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
239
+ `past_key_values`).
240
+ output_attentions (`bool`, *optional*):
241
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
242
+ tensors for more detail.
243
+ output_hidden_states (`bool`, *optional*):
244
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
245
+ more detail.
246
+ return_dict (`bool`, *optional*):
247
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
248
+ """
249
+
250
+ @add_start_docstrings(
251
+ """The Magma model which consists of a vision backbone and a language model.""",
252
+ MAGMA_START_DOCSTRING,
253
+ )
254
+ class MagmaForForCausalLM(MagmaPreTrainedModel):
255
+ def __init__(self, config: MagmaConfig):
256
+ super().__init__(config)
257
+
258
+ self.vision_tower = MagmaImageTower(config.vision_config, require_pretrained=False)
259
+ config.vision_config['mm_hidden_size'] = config.vision_config['mm_hidden_size'] \
260
+ if 'mm_hidden_size' in config.vision_config else self.vision_tower.hidden_size
261
+ config.vision_config['hidden_size'] = config.vision_config['hidden_size'] \
262
+ if 'hidden_size' in config.vision_config else self.config.text_config.hidden_size
263
+ self.multi_modal_projector = MagmaMultiModalProjector(config.vision_config)
264
+
265
+ self.vocab_size = config.text_config.vocab_size
266
+ if hasattr(config.text_config, 'auto_map'):
267
+ del config.text_config.auto_map
268
+
269
+ try:
270
+ self.language_model = AutoModelForCausalLM.from_config(
271
+ config.text_config,
272
+ # attn_implementation=config._attn_implementation,
273
+ trust_remote_code=True
274
+ )
275
+ except:
276
+ self.language_model = AutoModelForCausalLM.from_pretrained(
277
+ config.text_config._name_or_path,
278
+ # attn_implementation=config._attn_implementation,
279
+ trust_remote_code=True
280
+ )
281
+
282
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
283
+ self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
284
+
285
+ try:
286
+ if dist.get_rank() == 0:
287
+ wandb.init(project=os.environ['WANDB_PROJECT'])
288
+ except:
289
+ pass
290
+
291
+ self.post_init()
292
+
293
+ # def from_pretrained(self, pretrained_model_name_or_path, *model_args, **kwargs):
294
+ # import pdb; pdb.set_trace()
295
+ # kwargs["_from_auto"] = True
296
+ # return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
297
+
298
+ @property
299
+ def padding_side(self):
300
+ return self._padding_side
301
+
302
+ @padding_side.setter
303
+ def padding_side(self, padding_side: str):
304
+ if padding_side not in ["left", "right"]:
305
+ raise ValueError(f"{padding_side} is not `left` or `right`.")
306
+ self._padding_side = padding_side
307
+
308
+ def get_input_embeddings(self):
309
+ return self.language_model.get_input_embeddings()
310
+
311
+ def set_input_embeddings(self, value):
312
+ self.language_model.set_input_embeddings(value)
313
+
314
+ def get_output_embeddings(self):
315
+ return self.language_model.get_output_embeddings()
316
+
317
+ def set_output_embeddings(self, new_embeddings):
318
+ self.language_model.set_output_embeddings(new_embeddings)
319
+
320
+ def set_decoder(self, decoder):
321
+ self.language_model.set_decoder(decoder)
322
+
323
+ def get_decoder(self):
324
+ return self.language_model.get_decoder()
325
+
326
+ def tie_weights(self):
327
+ return self.language_model.tie_weights()
328
+
329
+ def load_special_module_from_ckpt(self, ckpt_path, torch_dtype=None):
330
+ from deepspeed.runtime.zero import Init
331
+ from deepspeed import zero
332
+ # Defer initialization for ZeRO-3 compatibility
333
+ # with Init(data_parallel_group=None):
334
+ # # Initialize the special module
335
+ # self.vision_tower = MagmaImageTower(self.config.vision_config, require_pretrained=False)
336
+
337
+ # Load checkpoint weights into the special module
338
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
339
+ state_dict = {k.replace('visual.', ''): v for k, v in checkpoint.items() if 'visual.' in k}
340
+
341
+ # Convert checkpoint weights to match model's parameter dtype
342
+ if torch_dtype is None:
343
+ model_dtype = next(self.vision_tower.clip_vision_model.parameters()).dtype
344
+ for k, v in state_dict.items():
345
+ state_dict[k] = v.to(model_dtype)
346
+ else:
347
+ for k, v in state_dict.items():
348
+ state_dict[k] = v.to(torch_dtype)
349
+
350
+ # Temporarily gather parameters for loading (if ZeRO-3 is active)
351
+ with zero.GatheredParameters(list(self.vision_tower.parameters()), modifier_rank=0):
352
+ # Load the state dictionary
353
+ self.vision_tower.clip_vision_model.load_state_dict(state_dict, strict=False)
354
+ # After loading, ensure the module is on the correct device
355
+ for param in self.vision_tower.parameters():
356
+ param.data = param.data.to(self.device).to(torch_dtype)
357
+
358
+ # import pdb; pdb.set_trace()
359
+ # If using a DeepSpeed engine, attach the updated module
360
+ if hasattr(self, "deepspeed_engine"):
361
+ self.deepspeed_engine.module = self
362
+
363
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
364
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
365
+ # update vocab size
366
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
367
+ self.vocab_size = model_embeds.num_embeddings
368
+ return model_embeds
369
+
370
+ def _merge_input_ids_with_image_features(
371
+ self,
372
+ image_features,
373
+ feature_lens,
374
+ inputs_embeds,
375
+ input_ids,
376
+ attention_mask,
377
+ position_ids=None,
378
+ labels=None,
379
+ image_token_index=None,
380
+ ignore_index=-100,
381
+ ):
382
+ """
383
+ Merge input_ids with with image features into final embeddings
384
+
385
+ Args:
386
+ image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
387
+ All vision vectors of all images in the batch
388
+ feature_lens (`torch.LongTensor` of shape `(num_images)`):
389
+ The length of visual embeddings of each image as stacked in `image_features`
390
+ inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
391
+ Token embeddings before merging with visual embeddings
392
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
393
+ Input_ids of tokens, possibly filled with image token
394
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
395
+ Mask to avoid performing attention on padding token indices.
396
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
397
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
398
+ config.n_positions - 1]`.
399
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
400
+ :abels need to be recalculated to support training (if provided)
401
+ image_token_index (`int`, *optional*)
402
+ Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
403
+ ignore_index (`int`, *optional*)
404
+ Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
405
+ Returns:
406
+ final_embedding, final_attention_mask, position_ids, final_labels
407
+
408
+ Explanation:
409
+ each image has variable length embeddings, with length specified by feature_lens
410
+ image_features is concatenation of all visual embed vectors
411
+ task: fill each <image> with the correct number of visual embeddings
412
+ Example:
413
+ X (5 patches), Y (3 patches), Z (8)
414
+ X, Y are in the same sequence (in-context learning)
415
+ if right padding
416
+ input_ids: [
417
+ a b c d e f X g h i j k Y l m
418
+ o p q r Z s t u v _ _ _ _ _ _
419
+ ]
420
+ input_ids should be: [
421
+ a b c d e f X X X X X g h i j k Y Y Y l m
422
+ o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
423
+ ]
424
+ labels should be: [
425
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
426
+ o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
427
+ ]
428
+ elif left padding
429
+ input_ids: [
430
+ a b c d e f X g h i j k Y l m
431
+ _ _ _ _ _ _ o p q r Z s t u v
432
+ ]
433
+ input_ids should be: [
434
+ a b c d e f X X X X X g h i j k Y Y Y l m
435
+ _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
436
+ ]
437
+ labels should be: [
438
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
439
+ _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
440
+ ]
441
+ Edge cases:
442
+ * If tokens are same but image token sizes are different, then cannot infer left or right padding
443
+
444
+ input_ids: [
445
+ a b c d X g h
446
+ i j Y k l m n
447
+ ]
448
+ where X is 3 tokens while Y is 5, this mean after merge
449
+ if left-padding (batched generation)
450
+ input_ids should be: [
451
+ _ _ a b c d X X X g h
452
+ i j Y Y Y Y Y k l m n
453
+ ]
454
+ elif (right padding) (training)
455
+ input_ids should be: [
456
+ a b c d X X X g h _ _
457
+ i j Y Y Y Y Y k l m n
458
+ ]
459
+ """
460
+ image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
461
+ ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
462
+
463
+ with torch.no_grad():
464
+ num_images = feature_lens.size(0)
465
+ num_image_features, embed_dim = image_features.shape
466
+ if feature_lens.sum() != num_image_features:
467
+ raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
468
+ batch_size = input_ids.shape[0]
469
+ _left_padding = torch.any(attention_mask[:, 0] == 0)
470
+ _right_padding = torch.any(attention_mask[:, -1] == 0)
471
+
472
+ left_padding = True
473
+ if batch_size > 1:
474
+ if _left_padding and not _right_padding:
475
+ left_padding = True
476
+ elif not _left_padding and _right_padding:
477
+ left_padding = False
478
+ elif not _left_padding and not _right_padding:
479
+ # both side is 1, so cannot tell
480
+ left_padding = self.padding_side == "left"
481
+ else:
482
+ # invalid attention_mask
483
+ raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
484
+
485
+ # Whether to turn off right padding
486
+ # 1. Create a mask to know where special image tokens are
487
+ special_image_token_mask = input_ids == image_token_index
488
+ # special_image_token_mask: [bsz, seqlen]
489
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
490
+ # num_special_image_tokens: [bsz]
491
+ # Reserve for padding of num_images
492
+ total_num_special_image_tokens = torch.sum(special_image_token_mask)
493
+ if total_num_special_image_tokens != num_images:
494
+ raise ValueError(
495
+ f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
496
+ )
497
+ # Compute the maximum embed dimension
498
+ # max_image_feature_lens is max_feature_lens per batch
499
+ feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
500
+ feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=feature_lens.device)
501
+ embed_sequence_lengths = (
502
+ (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
503
+ )
504
+ max_embed_dim = embed_sequence_lengths.max()
505
+
506
+ batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
507
+ # 2. Compute the positions where text should be written
508
+ # Calculate new positions for text tokens in merged image-text sequence.
509
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
510
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
511
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
512
+ # ! instead of special_image_token_mask * (num_image_patches - 1)
513
+ # special_image_token_mask * (num_feature_len - 1)
514
+ special_image_token_mask = special_image_token_mask.long()
515
+ special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
516
+ new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
517
+ if left_padding:
518
+ # shift right token positions so that they are ending at the same number
519
+ # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
520
+ new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
521
+
522
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
523
+
524
+ # 3. Create the full embedding, already padded to the maximum position
525
+ final_embedding = torch.zeros(
526
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
527
+ )
528
+ final_attention_mask = torch.zeros(
529
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
530
+ )
531
+ final_labels = None
532
+ if labels is not None:
533
+ # NOTE: this is a bug in the original code!!!
534
+ final_labels = torch.full_like(final_attention_mask.long(), ignore_index).to(torch.long)
535
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
536
+ # set the corresponding tensors into their correct target device.
537
+ target_device = inputs_embeds.device
538
+ batch_indices, non_image_indices, text_to_overwrite = (
539
+ batch_indices.to(target_device),
540
+ non_image_indices.to(target_device),
541
+ text_to_overwrite.to(target_device),
542
+ )
543
+ attention_mask = attention_mask.to(target_device)
544
+
545
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
546
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
547
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
548
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
549
+ if labels is not None:
550
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
551
+
552
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
553
+ with torch.no_grad():
554
+ image_to_overwrite = torch.full(
555
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
556
+ )
557
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
558
+ embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
559
+ embed_indices = embed_indices.expand(batch_size, max_embed_dim)
560
+ embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
561
+
562
+ if left_padding:
563
+ # exclude padding on the left
564
+ val = (max_embed_dim - embed_indices) <= embed_seq_lens
565
+ else:
566
+ # exclude padding on the right
567
+ val = embed_indices < embed_seq_lens
568
+ image_to_overwrite &= val
569
+
570
+ if image_to_overwrite.sum() != num_image_features:
571
+ raise ValueError(
572
+ f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
573
+ f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
574
+ f" the number of image given to the model is {num_images}. "
575
+ f"This prevents correct indexing and breaks batch generation."
576
+ )
577
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
578
+ final_attention_mask |= image_to_overwrite
579
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
580
+
581
+ return final_embedding, final_attention_mask, position_ids, final_labels
582
+
583
+ @add_start_docstrings_to_model_forward(MAGMA_INPUTS_DOCSTRING)
584
+ @replace_return_docstrings(output_type=MagmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
585
+ def forward(
586
+ self,
587
+ input_ids: torch.LongTensor = None,
588
+ pixel_values: Union[torch.FloatTensor, List[torch.FloatTensor], List[List[torch.FloatTensor]]] = None,
589
+ image_sizes: Union[torch.LongTensor, List[torch.LongTensor], List[List[torch.LongTensor]]] = None,
590
+ attention_mask: Optional[torch.Tensor] = None,
591
+ position_ids: Optional[torch.LongTensor] = None,
592
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
593
+ inputs_embeds: Optional[torch.FloatTensor] = None,
594
+ vision_feature_layer: Optional[int] = None,
595
+ vision_feature_select_strategy: Optional[str] = None,
596
+ labels: Optional[torch.LongTensor] = None,
597
+ use_cache: Optional[bool] = None,
598
+ output_attentions: Optional[bool] = None,
599
+ output_hidden_states: Optional[bool] = None,
600
+ return_dict: Optional[bool] = None,
601
+ ) -> Union[Tuple, MagmaCausalLMOutputWithPast]:
602
+ r"""
603
+ Args:
604
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
605
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
606
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
607
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
608
+
609
+ Returns:
610
+
611
+ Example:
612
+
613
+ ```python
614
+ >>> from PIL import Image
615
+ >>> import requests
616
+ >>> from transformers import AutoProcessor, MagmaForConditionalGeneration
617
+
618
+ >>> model = MagmaForConditionalGeneration.from_pretrained("microsoft/magma-8b-hf")
619
+ >>> processor = AutoProcessor.from_pretrained("microsoft/magma-8b-hf")
620
+
621
+ >>> prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
622
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
623
+ >>> image = Image.open(requests.get(url, stream=True).raw)
624
+
625
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
626
+
627
+ >>> # Generate
628
+ >>> generate_ids = model.generate(**inputs, max_length=30)
629
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
630
+ "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
631
+ ```"""
632
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
633
+ output_hidden_states = (
634
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
635
+ )
636
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
637
+ vision_feature_layer = (
638
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_config['vision_feature_layer']
639
+ )
640
+
641
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
642
+
643
+ if inputs_embeds is None:
644
+ # 1. Extract the input embeddings
645
+ # In case image_token_index is not in the embeddings (extra token but embedding don't have it)
646
+ for_inputs_embeds_ids = input_ids.clone()
647
+ for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
648
+ inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
649
+
650
+ # 2. Merge text and images
651
+ if pixel_values is not None and input_ids.shape[1] != 1 and len(pixel_values) > 0:
652
+ # ! infer image_num_patches from image_sizes
653
+ if type(pixel_values) == list:
654
+ # nested list of pixel_values, each element is a list of pixel_values for each training instance, it could be multiple for video or interleaved setting
655
+ # e.g., pixel_values = [[img1, img2], [img1, img2, img3]]
656
+ n_imgs_per_sample = [len(pv) for pv in pixel_values]
657
+ pixels_values_list = sum(pixel_values, [])
658
+ image_sizes_list = sum(image_sizes, [])
659
+ else:
660
+ image_num_patches = [(imsize[imsize.sum(1) > 0,0] * imsize[imsize.sum(1) > 0,1]).tolist() for imsize in image_sizes]
661
+ # image_num_patches = [(imsize[:,0]*imsize[:,1]).tolist() for imsize in image_sizes]
662
+ # figure out if pixel_values is concatenated or stacked
663
+ if pixel_values.dim() == 5:
664
+ # stacking when input is (batch_size, num_patches, num_channels, height, width)
665
+ _pixel_values_list = [
666
+ pix_val[:sum(num_patch)].split(num_patch, dim=0) for pix_val, num_patch in zip(pixel_values, image_num_patches)
667
+ ]
668
+ _image_sizes_list = [image_size[image_size.sum(-1) > 0].tolist() for image_size in image_sizes]
669
+ elif pixel_values.dim() != 4:
670
+ # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
671
+ raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
672
+
673
+ if self.config.vision_config['img_anyres_strategy'] == "global":
674
+ selected_image_features = []
675
+ # NOTE: both _image_sizes_list and _pixel_values_list are lists of lists, each item represents an training instance with one or multiple images
676
+ for idx, (image_size_for_instance, pixel_values_for_instance) in enumerate(zip(_image_sizes_list, _pixel_values_list)):
677
+ assert len(image_size_for_instance) == len(pixel_values_for_instance), f"{len(image_size_for_instance)} != {len(pixel_values_for_instance)}"
678
+ for image_size, pixel_values_for_image in zip(image_size_for_instance, pixel_values_for_instance):
679
+ pixel_values_for_image = pixel_values_for_image.view(image_size[0], image_size[1], *pixel_values_for_image.shape[1:])
680
+ pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
681
+ image_features = self.vision_tower(pixel_values_for_image)
682
+ selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
683
+ selected_image_feature = self.multi_modal_projector((selected_image_feature, None))
684
+ selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
685
+ selected_image_features.append(selected_image_feature.flatten(0, 1))
686
+ elif self.config.vision_config['img_anyres_strategy'] == "crop":
687
+ # calculate number of crops for each instance in the batch given _image_sizes_list
688
+ _image_sizes_list_temp = sum(_image_sizes_list, [])
689
+ # concate nate all images in _pixel_values_list
690
+ _pixel_values_list_temp = sum(_pixel_values_list, ())
691
+ _pixel_values_list_temp = torch.cat(_pixel_values_list_temp, dim=0)
692
+ image_features = self.vision_tower(_pixel_values_list_temp)[vision_feature_layer].permute(0, 2, 3, 1)
693
+ image_features = self.multi_modal_projector((image_features, None))
694
+
695
+ num_crops_list = [_image_size[0]*_image_size[1] for _image_size in _image_sizes_list_temp]
696
+ image_features_split = torch.split(image_features, num_crops_list, dim=0)
697
+ selected_image_features = []
698
+ for image_feature, image_size in zip(image_features_split, _image_sizes_list_temp):
699
+ image_feature = image_feature.view(image_size[0], image_size[1], *image_feature.shape[1:])
700
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).flatten(2, 3).flatten(0, 1)
701
+ image_feature = torch.cat((image_feature, self.multi_modal_projector.row_seperator.repeat(image_feature.shape[0],1,1)), dim=1)
702
+ selected_image_features.append(image_feature.flatten(0, 1))
703
+
704
+ # raise NotImplementedError("crop strategy is not implemented yet")
705
+ # image_features = self.vision_tower(pixel_values)
706
+ # selected_image_feature = image_features[vision_feature_layer]
707
+ # image_features = torch.split(image_features, image_num_patches, dim=0)
708
+
709
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
710
+ feature_lens = [elem.shape[0] for elem in selected_image_features]
711
+ image_features = torch.cat(selected_image_features, 0)
712
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
713
+
714
+ # inputs_embeds = inputs_embeds.to(image_features.dtype)
715
+ inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features(
716
+ image_features,
717
+ feature_lens,
718
+ inputs_embeds,
719
+ input_ids,
720
+ attention_mask,
721
+ position_ids,
722
+ labels=labels,
723
+ )
724
+
725
+ # pixel_values is not None but is empty ---> text only cases
726
+ elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
727
+ # there are no images
728
+ pass
729
+
730
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
731
+ # generation with cache
732
+ elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
733
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
734
+ # that are set to 0
735
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
736
+
737
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
738
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
739
+
740
+ # Get the target length
741
+ target_length = input_ids.shape[1]
742
+ past_length = first_layer_past_key_value.shape[-1]
743
+
744
+ extended_attention_mask = torch.ones(
745
+ (attention_mask.shape[0], past_length),
746
+ dtype=attention_mask.dtype,
747
+ device=attention_mask.device,
748
+ )
749
+
750
+ # Filter out only the tokens that can be un-attended, this can happen
751
+ # if one uses Llava + Fused modules where the cache on the
752
+ # first iteration is already big enough, or if one passes custom cache
753
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
754
+ new_batch_index = batch_index[valid_indices]
755
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
756
+
757
+ # Zero-out the places where we don't need to attend
758
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
759
+
760
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
761
+
762
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
763
+
764
+ # outputs = self.language_model(
765
+ # attention_mask=attention_mask,
766
+ # position_ids=position_ids,
767
+ # past_key_values=past_key_values,
768
+ # inputs_embeds=inputs_embeds,
769
+ # use_cache=use_cache,
770
+ # output_attentions=output_attentions,
771
+ # output_hidden_states=output_hidden_states,
772
+ # return_dict=return_dict,
773
+ # )
774
+
775
+ # logits = outputs[0]
776
+ # loss = None
777
+ # if labels is not None:
778
+ # # Shift so that tokens < n predict n
779
+ # if attention_mask is not None:
780
+ # shift_attention_mask = attention_mask[..., 1:]
781
+ # shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
782
+ # shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
783
+ # else:
784
+ # shift_logits = logits[..., :-1, :].contiguous()
785
+ # shift_labels = labels[..., 1:].contiguous()
786
+ # # Flatten the tokens
787
+ # loss_fct = nn.CrossEntropyLoss()
788
+ # loss = loss_fct(
789
+ # shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
790
+ # )
791
+
792
+ outputs = self.language_model.model(
793
+ attention_mask=attention_mask,
794
+ position_ids=position_ids,
795
+ past_key_values=past_key_values,
796
+ inputs_embeds=inputs_embeds,
797
+ use_cache=use_cache,
798
+ output_attentions=output_attentions,
799
+ output_hidden_states=output_hidden_states,
800
+ return_dict=return_dict
801
+ )
802
+
803
+ hidden_states = outputs[0]
804
+
805
+ loss = None
806
+
807
+ if labels is not None and self.training:
808
+ valid_mask = labels[..., 1:] != -100
809
+ shift_logits = self.language_model.lm_head(hidden_states[:,:-1][valid_mask]).contiguous()
810
+ shift_logits = shift_logits.view(-1, self.language_model.config.vocab_size)
811
+ logits = shift_logits # dummy logits
812
+ shift_labels = labels[..., 1:][valid_mask].contiguous()
813
+ shift_labels = shift_labels.to(shift_logits.device)
814
+ loss_fct = nn.CrossEntropyLoss()
815
+ loss = loss_fct(shift_logits, shift_labels)
816
+
817
+ # localize the positions for shift_labels where the id is in betweek [config.tokenizer_vocab_size-256, config.tokenizer_vocab_size]
818
+ valid_indices = (shift_labels<self.config.tokenizer_vocab_size) & (shift_labels>=self.config.tokenizer_vocab_size-256)
819
+ if valid_indices.sum() > 0:
820
+ action_labels = shift_labels[valid_indices]
821
+ action_logits = shift_logits[valid_indices]
822
+ # calcualte the accuracy
823
+ action_accuracy = (action_logits.argmax(-1) == action_labels).float().mean()
824
+ # log the action accuracy
825
+ else:
826
+ action_accuracy = torch.tensor(0.0).to(shift_logits.device)
827
+ # torch distributed gather the action accuracy across all devices
828
+ action_accuracy = action_accuracy.unsqueeze(0)
829
+ # gather the action accuracy across all devices
830
+ action_accuracy_gather = [torch.zeros_like(action_accuracy) for _ in range(dist.get_world_size())]
831
+ dist.all_gather(action_accuracy_gather, action_accuracy)
832
+ # concatenate the action accuracy across all devices
833
+ action_accuracy = torch.cat(action_accuracy_gather)
834
+
835
+ if dist.get_rank() == 0:
836
+ # remove zero values
837
+ if action_accuracy.mean() == 0:
838
+ wandb.log({"action_accuracy": action_accuracy.mean().item()})
839
+ else:
840
+ action_accuracy = action_accuracy[action_accuracy != 0]
841
+ wandb.log({"action_accuracy": action_accuracy.mean().item()})
842
+ else:
843
+ logits = self.language_model.lm_head(hidden_states)
844
+ logits = logits.float()
845
+
846
+ if not return_dict:
847
+ output = (logits,) + outputs[1:]
848
+ return (loss,) + output if loss is not None else output
849
+
850
+ return MagmaCausalLMOutputWithPast(
851
+ loss=loss,
852
+ logits=logits,
853
+ past_key_values=outputs.past_key_values,
854
+ hidden_states=outputs.hidden_states,
855
+ attentions=outputs.attentions,
856
+ )
857
+
858
+ def prepare_inputs_for_generation(
859
+ self,
860
+ input_ids,
861
+ past_key_values=None,
862
+ inputs_embeds=None,
863
+ pixel_values=None,
864
+ image_sizes=None,
865
+ attention_mask=None,
866
+ **kwargs,
867
+ ):
868
+ if past_key_values is not None:
869
+ if isinstance(past_key_values, Cache):
870
+ cache_length = past_key_values.get_seq_length()
871
+ past_length = past_key_values.seen_tokens
872
+ else:
873
+ cache_length = past_length = past_key_values[0][0].shape[2]
874
+
875
+ # Keep only the unprocessed tokens:
876
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
877
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
878
+ # input)
879
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
880
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
881
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
882
+ # input_ids based on the past_length.
883
+ elif past_length < input_ids.shape[1]:
884
+ input_ids = input_ids[:, past_length:]
885
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
886
+ elif self.config.image_token_index in input_ids:
887
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
888
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
889
+ # older attention values, as their corresponding values are not part of the input.
890
+ if cache_length < past_length and attention_mask is not None:
891
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
892
+
893
+ position_ids = kwargs.get("position_ids", None)
894
+ if attention_mask is not None and position_ids is None:
895
+ # create position_ids on the fly for batch generation
896
+ position_ids = attention_mask.long().cumsum(-1) - 1
897
+ position_ids.masked_fill_(attention_mask == 0, 1)
898
+ if past_key_values:
899
+ position_ids = position_ids[:, -input_ids.shape[1] :]
900
+
901
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
902
+ if inputs_embeds is not None and past_key_values is None:
903
+ model_inputs = {"inputs_embeds": inputs_embeds}
904
+ else:
905
+ model_inputs = {"input_ids": input_ids}
906
+
907
+ model_inputs.update(
908
+ {
909
+ "position_ids": position_ids,
910
+ "past_key_values": past_key_values,
911
+ "use_cache": kwargs.get("use_cache"),
912
+ "attention_mask": attention_mask,
913
+ "pixel_values": pixel_values,
914
+ "image_sizes": image_sizes,
915
+ }
916
+ )
917
+ return model_inputs
918
+
919
+ def _reorder_cache(self, *args, **kwargs):
920
+ return self.language_model._reorder_cache(*args, **kwargs)
921
+
922
+ @add_start_docstrings(
923
+ """The Magma model which consists of a vision backbone and a language model.""",
924
+ MAGMA_START_DOCSTRING,
925
+ )
926
+ class MagmaForConditionalGeneration(MagmaPreTrainedModel):
927
+ def __init__(self, config: MagmaConfig):
928
+ super().__init__(config)
929
+
930
+ self.vision_tower = MagmaImageTower(config.vision_config, require_pretrained=('magma' not in config.name_or_path))
931
+ self.multi_modal_projector = MagmaMultiModalProjector(config.vision_config)
932
+
933
+ self.vocab_size = config.text_config.vocab_size
934
+ self.language_model = AutoModelForCausalLM.from_config(
935
+ config.text_config,
936
+ # attn_implementation=config._attn_implementation,
937
+ trust_remote_code=True
938
+ )
939
+
940
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
941
+ self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
942
+
943
+ self.post_init()
944
+
945
+ @property
946
+ def padding_side(self):
947
+ return self._padding_side
948
+
949
+ @padding_side.setter
950
+ def padding_side(self, padding_side: str):
951
+ if padding_side not in ["left", "right"]:
952
+ raise ValueError(f"{padding_side} is not `left` or `right`.")
953
+ self._padding_side = padding_side
954
+
955
+ def get_input_embeddings(self):
956
+ return self.language_model.get_input_embeddings()
957
+
958
+ def set_input_embeddings(self, value):
959
+ self.language_model.set_input_embeddings(value)
960
+
961
+ def get_output_embeddings(self):
962
+ return self.language_model.get_output_embeddings()
963
+
964
+ def set_output_embeddings(self, new_embeddings):
965
+ self.language_model.set_output_embeddings(new_embeddings)
966
+
967
+ def set_decoder(self, decoder):
968
+ self.language_model.set_decoder(decoder)
969
+
970
+ def get_decoder(self):
971
+ return self.language_model.get_decoder()
972
+
973
+ def tie_weights(self):
974
+ return self.language_model.tie_weights()
975
+
976
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
977
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
978
+ # update vocab size
979
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
980
+ self.vocab_size = model_embeds.num_embeddings
981
+ return model_embeds
982
+
983
+ def _merge_input_ids_with_image_features(
984
+ self,
985
+ image_features,
986
+ feature_lens,
987
+ inputs_embeds,
988
+ input_ids,
989
+ attention_mask,
990
+ position_ids=None,
991
+ labels=None,
992
+ image_token_index=None,
993
+ ignore_index=-100,
994
+ ):
995
+ """
996
+ Merge input_ids with with image features into final embeddings
997
+
998
+ Args:
999
+ image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`):
1000
+ All vision vectors of all images in the batch
1001
+ feature_lens (`torch.LongTensor` of shape `(num_images)`):
1002
+ The length of visual embeddings of each image as stacked in `image_features`
1003
+ inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`):
1004
+ Token embeddings before merging with visual embeddings
1005
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1006
+ Input_ids of tokens, possibly filled with image token
1007
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1008
+ Mask to avoid performing attention on padding token indices.
1009
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1010
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
1011
+ config.n_positions - 1]`.
1012
+ labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*)
1013
+ :abels need to be recalculated to support training (if provided)
1014
+ image_token_index (`int`, *optional*)
1015
+ Token id used to indicate the special "image" token. Defaults to `config.image_token_index`
1016
+ ignore_index (`int`, *optional*)
1017
+ Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100.
1018
+ Returns:
1019
+ final_embedding, final_attention_mask, position_ids, final_labels
1020
+
1021
+ Explanation:
1022
+ each image has variable length embeddings, with length specified by feature_lens
1023
+ image_features is concatenation of all visual embed vectors
1024
+ task: fill each <image> with the correct number of visual embeddings
1025
+ Example:
1026
+ X (5 patches), Y (3 patches), Z (8)
1027
+ X, Y are in the same sequence (in-context learning)
1028
+ if right padding
1029
+ input_ids: [
1030
+ a b c d e f X g h i j k Y l m
1031
+ o p q r Z s t u v _ _ _ _ _ _
1032
+ ]
1033
+ input_ids should be: [
1034
+ a b c d e f X X X X X g h i j k Y Y Y l m
1035
+ o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _
1036
+ ]
1037
+ labels should be: [
1038
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
1039
+ o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _
1040
+ ]
1041
+ elif left padding
1042
+ input_ids: [
1043
+ a b c d e f X g h i j k Y l m
1044
+ _ _ _ _ _ _ o p q r Z s t u v
1045
+ ]
1046
+ input_ids should be: [
1047
+ a b c d e f X X X X X g h i j k Y Y Y l m
1048
+ _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v
1049
+ ]
1050
+ labels should be: [
1051
+ a b c d e f _ _ _ _ _ g h i j k _ _ _ l m
1052
+ _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v
1053
+ ]
1054
+ Edge cases:
1055
+ * If tokens are same but image token sizes are different, then cannot infer left or right padding
1056
+
1057
+ input_ids: [
1058
+ a b c d X g h
1059
+ i j Y k l m n
1060
+ ]
1061
+ where X is 3 tokens while Y is 5, this mean after merge
1062
+ if left-padding (batched generation)
1063
+ input_ids should be: [
1064
+ _ _ a b c d X X X g h
1065
+ i j Y Y Y Y Y k l m n
1066
+ ]
1067
+ elif (right padding) (training)
1068
+ input_ids should be: [
1069
+ a b c d X X X g h _ _
1070
+ i j Y Y Y Y Y k l m n
1071
+ ]
1072
+ """
1073
+ image_token_index = image_token_index if image_token_index is not None else self.config.image_token_index
1074
+ ignore_index = ignore_index if ignore_index is not None else self.config.ignore_index
1075
+
1076
+ with torch.no_grad():
1077
+ num_images = feature_lens.size(0)
1078
+ num_image_features, embed_dim = image_features.shape
1079
+ if feature_lens.sum() != num_image_features:
1080
+ raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}")
1081
+ batch_size = input_ids.shape[0]
1082
+ _left_padding = torch.any(attention_mask[:, 0] == 0)
1083
+ _right_padding = torch.any(attention_mask[:, -1] == 0)
1084
+
1085
+ left_padding = True
1086
+ if batch_size > 1:
1087
+ if _left_padding and not _right_padding:
1088
+ left_padding = True
1089
+ elif not _left_padding and _right_padding:
1090
+ left_padding = False
1091
+ elif not _left_padding and not _right_padding:
1092
+ # both side is 1, so cannot tell
1093
+ left_padding = self.padding_side == "left"
1094
+ else:
1095
+ # invalid attention_mask
1096
+ raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}")
1097
+
1098
+ # Whether to turn off right padding
1099
+ # 1. Create a mask to know where special image tokens are
1100
+ special_image_token_mask = input_ids == image_token_index
1101
+ # special_image_token_mask: [bsz, seqlen]
1102
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
1103
+ # num_special_image_tokens: [bsz]
1104
+ # Reserve for padding of num_images
1105
+ total_num_special_image_tokens = torch.sum(special_image_token_mask)
1106
+ if total_num_special_image_tokens != num_images:
1107
+ raise ValueError(
1108
+ f"Number of image tokens in input_ids ({total_num_special_image_tokens}) different from num_images ({num_images})."
1109
+ )
1110
+ # Compute the maximum embed dimension
1111
+ # max_image_feature_lens is max_feature_lens per batch
1112
+ feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0)
1113
+ feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=feature_lens.device)
1114
+ embed_sequence_lengths = (
1115
+ (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum
1116
+ )
1117
+ max_embed_dim = embed_sequence_lengths.max()
1118
+
1119
+ batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1))
1120
+ # 2. Compute the positions where text should be written
1121
+ # Calculate new positions for text tokens in merged image-text sequence.
1122
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images` text tokens.
1123
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
1124
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
1125
+ # ! instead of special_image_token_mask * (num_image_patches - 1)
1126
+ # special_image_token_mask * (num_feature_len - 1)
1127
+ special_image_token_mask = special_image_token_mask.long()
1128
+ special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1
1129
+ new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1
1130
+ if left_padding:
1131
+ # shift right token positions so that they are ending at the same number
1132
+ # the below here was incorrect? new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:]
1133
+ new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:]
1134
+
1135
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
1136
+
1137
+ # 3. Create the full embedding, already padded to the maximum position
1138
+ final_embedding = torch.zeros(
1139
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
1140
+ )
1141
+ final_attention_mask = torch.zeros(
1142
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
1143
+ )
1144
+ final_labels = None
1145
+ if labels is not None:
1146
+ final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long)
1147
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
1148
+ # set the corresponding tensors into their correct target device.
1149
+ target_device = inputs_embeds.device
1150
+ batch_indices, non_image_indices, text_to_overwrite = (
1151
+ batch_indices.to(target_device),
1152
+ non_image_indices.to(target_device),
1153
+ text_to_overwrite.to(target_device),
1154
+ )
1155
+ attention_mask = attention_mask.to(target_device)
1156
+
1157
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
1158
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
1159
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
1160
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
1161
+ if labels is not None:
1162
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
1163
+
1164
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
1165
+ with torch.no_grad():
1166
+ image_to_overwrite = torch.full(
1167
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
1168
+ )
1169
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
1170
+ embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device)
1171
+ embed_indices = embed_indices.expand(batch_size, max_embed_dim)
1172
+ embed_seq_lens = embed_sequence_lengths[:, None].to(target_device)
1173
+
1174
+ if left_padding:
1175
+ # exclude padding on the left
1176
+ val = (max_embed_dim - embed_indices) <= embed_seq_lens
1177
+ else:
1178
+ # exclude padding on the right
1179
+ val = embed_indices < embed_seq_lens
1180
+ image_to_overwrite &= val
1181
+
1182
+ if image_to_overwrite.sum() != num_image_features:
1183
+ raise ValueError(
1184
+ f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. "
1185
+ f"The number of image tokens is {torch.sum(special_image_token_mask)} while"
1186
+ f" the number of image given to the model is {num_images}. "
1187
+ f"This prevents correct indexing and breaks batch generation."
1188
+ )
1189
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
1190
+ final_attention_mask |= image_to_overwrite
1191
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
1192
+
1193
+ return final_embedding, final_attention_mask, position_ids, final_labels
1194
+
1195
+ @add_start_docstrings_to_model_forward(MAGMA_INPUTS_DOCSTRING)
1196
+ @replace_return_docstrings(output_type=MagmaCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
1197
+ def forward(
1198
+ self,
1199
+ input_ids: torch.LongTensor = None,
1200
+ pixel_values: torch.FloatTensor = None,
1201
+ image_sizes: Optional[torch.LongTensor] = None,
1202
+ attention_mask: Optional[torch.Tensor] = None,
1203
+ position_ids: Optional[torch.LongTensor] = None,
1204
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1205
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1206
+ vision_feature_layer: Optional[int] = None,
1207
+ vision_feature_select_strategy: Optional[str] = None,
1208
+ labels: Optional[torch.LongTensor] = None,
1209
+ use_cache: Optional[bool] = None,
1210
+ output_attentions: Optional[bool] = None,
1211
+ output_hidden_states: Optional[bool] = None,
1212
+ return_dict: Optional[bool] = None,
1213
+ ) -> Union[Tuple, MagmaCausalLMOutputWithPast]:
1214
+ r"""
1215
+ Args:
1216
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1217
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1218
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1219
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1220
+
1221
+ Returns:
1222
+
1223
+ Example:
1224
+
1225
+ ```python
1226
+ >>> from PIL import Image
1227
+ >>> import requests
1228
+ >>> from transformers import AutoProcessor, MagmaForConditionalGeneration
1229
+
1230
+ >>> model = MagmaForConditionalGeneration.from_pretrained("microsoft/magma-8b-hf")
1231
+ >>> processor = AutoProcessor.from_pretrained("microsoft/magma-8b-hf")
1232
+
1233
+ >>> prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
1234
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1235
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1236
+
1237
+ >>> inputs = processor(text=prompt, images=image, return_tensors="pt")
1238
+
1239
+ >>> # Generate
1240
+ >>> generate_ids = model.generate(**inputs, max_length=30)
1241
+ >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1242
+ "[INST] \nWhat is shown in this image? [/INST] The image appears to be a radar chart, which is a type of multi-dimensional plot (...)"
1243
+ ```"""
1244
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1245
+ output_hidden_states = (
1246
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1247
+ )
1248
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1249
+ vision_feature_layer = (
1250
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_config['vision_feature_layer']
1251
+ )
1252
+
1253
+ if inputs_embeds is None:
1254
+ # 1. Extract the input embeddings
1255
+ # In case image_token_index is not in the embeddings (extra token but embedding don't have it)
1256
+ for_inputs_embeds_ids = input_ids.clone()
1257
+ for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
1258
+ inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
1259
+
1260
+ # 2. Merge text and images
1261
+ if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
1262
+ # ! infer image_num_patches from image_sizes
1263
+ # figure out if pixel_values is concatenated or stacked
1264
+ if pixel_values.dim() == 5:
1265
+ image_num_patches = [(imsize[:,0]*imsize[:,1]).tolist() for imsize in image_sizes]
1266
+ # stacking when input is (batch_size, num_patches, num_channels, height, width)
1267
+ _pixel_values_list = [
1268
+ pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
1269
+ ]
1270
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
1271
+ elif pixel_values.dim() != 4:
1272
+ # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
1273
+ raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
1274
+
1275
+ if self.config.vision_config['img_anyres_strategy'] == "global":
1276
+ num_patches_for_images = [(imsize[0]*imsize[1]).item() for imsize in image_sizes]
1277
+ pixel_values_for_images = pixel_values.split(num_patches_for_images, dim=0)
1278
+ selected_image_features = []
1279
+ for idx, (image_size, pixel_values_for_image) in enumerate(zip(image_sizes, pixel_values_for_images)):
1280
+ pixel_values_for_image = pixel_values_for_image.view(image_size[0], image_size[1], *pixel_values_for_image.shape[1:])
1281
+ pixel_values_for_image = pixel_values_for_image.permute(2, 0, 3, 1, 4).flatten(3, 4).flatten(1, 2).unsqueeze(0)
1282
+ image_features = self.vision_tower(pixel_values_for_image)
1283
+ selected_image_feature = image_features[vision_feature_layer][0].permute(1, 2, 0)
1284
+ selected_image_feature = self.multi_modal_projector((selected_image_feature, None))
1285
+ selected_image_feature = torch.cat((selected_image_feature, self.multi_modal_projector.row_seperator.repeat(selected_image_feature.shape[0],1,1)), dim=1)
1286
+ selected_image_features.append(selected_image_feature)
1287
+ elif self.config.vision_config['img_anyres_strategy'] == "crop":
1288
+ image_features = self.vision_tower(pixel_values)[vision_feature_layer].permute(0, 2, 3, 1)
1289
+ image_features = self.multi_modal_projector((image_features, None))
1290
+ num_patches_for_images = [(imsize[0]*imsize[1]).item() for imsize in image_sizes]
1291
+ image_features_split = torch.split(image_features, num_patches_for_images, dim=0)
1292
+ selected_image_features = []
1293
+ for image_feature, image_size in zip(image_features_split, image_sizes):
1294
+ image_feature = image_feature.view(image_size[0], image_size[1], *image_feature.shape[1:])
1295
+ image_feature = image_feature.permute(0, 2, 1, 3, 4).flatten(2, 3).flatten(0, 1)
1296
+ image_feature = torch.cat((image_feature, self.multi_modal_projector.row_seperator.repeat(image_feature.shape[0],1,1)), dim=1)
1297
+ selected_image_features.append(image_feature)
1298
+
1299
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
1300
+ feature_lens = [elem.shape[0]*elem.shape[1] for elem in selected_image_features]
1301
+ image_features = torch.cat([elem.flatten(0, 1) for elem in selected_image_features], 0)
1302
+ feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device)
1303
+
1304
+ # inputs_embeds = inputs_embeds.to(image_features.dtype)
1305
+ inputs_embeds, attention_mask, position_ids, labels = self._merge_input_ids_with_image_features(
1306
+ image_features,
1307
+ feature_lens,
1308
+ inputs_embeds,
1309
+ input_ids,
1310
+ attention_mask,
1311
+ position_ids,
1312
+ labels=labels,
1313
+ )
1314
+
1315
+ # pixel_values is not None but is empty ---> text only cases
1316
+ elif pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0:
1317
+ # there are no images
1318
+ pass
1319
+
1320
+ # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
1321
+ # generation with cache
1322
+ elif past_key_values is not None and pixel_values is not None and input_ids.shape[1] == 1:
1323
+ # Retrieve the first layer to inspect the logits and mask out the hidden states
1324
+ # that are set to 0
1325
+ first_layer_past_key_value = past_key_values[0][0][:, :, :, 0]
1326
+
1327
+ # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941
1328
+ batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0)
1329
+
1330
+ # Get the target length
1331
+ target_length = input_ids.shape[1]
1332
+ past_length = first_layer_past_key_value.shape[-1]
1333
+
1334
+ extended_attention_mask = torch.ones(
1335
+ (attention_mask.shape[0], past_length),
1336
+ dtype=attention_mask.dtype,
1337
+ device=attention_mask.device,
1338
+ )
1339
+
1340
+ # Filter out only the tokens that can be un-attended, this can happen
1341
+ # if one uses Llava + Fused modules where the cache on the
1342
+ # first iteration is already big enough, or if one passes custom cache
1343
+ valid_indices = non_attended_tokens < extended_attention_mask.size(-1)
1344
+ new_batch_index = batch_index[valid_indices]
1345
+ new_non_attended_tokens = non_attended_tokens[valid_indices]
1346
+
1347
+ # Zero-out the places where we don't need to attend
1348
+ extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0
1349
+
1350
+ attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1)
1351
+
1352
+ position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
1353
+
1354
+ outputs = self.language_model(
1355
+ attention_mask=attention_mask,
1356
+ position_ids=position_ids,
1357
+ past_key_values=past_key_values,
1358
+ inputs_embeds=inputs_embeds,
1359
+ use_cache=use_cache,
1360
+ output_attentions=output_attentions,
1361
+ output_hidden_states=output_hidden_states,
1362
+ return_dict=return_dict,
1363
+ )
1364
+
1365
+ logits = outputs[0]
1366
+
1367
+ loss = None
1368
+ if labels is not None:
1369
+ # Shift so that tokens < n predict n
1370
+ if attention_mask is not None:
1371
+ shift_attention_mask = attention_mask[..., 1:]
1372
+ shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous()
1373
+ shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous()
1374
+ else:
1375
+ shift_logits = logits[..., :-1, :].contiguous()
1376
+ shift_labels = labels[..., 1:].contiguous()
1377
+ # Flatten the tokens
1378
+ loss_fct = nn.CrossEntropyLoss()
1379
+ loss = loss_fct(
1380
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device)
1381
+ )
1382
+
1383
+ if not return_dict:
1384
+ output = (logits,) + outputs[1:]
1385
+ return (loss,) + output if loss is not None else output
1386
+
1387
+ return MagmaCausalLMOutputWithPast(
1388
+ loss=loss,
1389
+ logits=logits,
1390
+ past_key_values=outputs.past_key_values,
1391
+ hidden_states=outputs.hidden_states,
1392
+ attentions=outputs.attentions,
1393
+ )
1394
+
1395
+ def prepare_inputs_for_generation(
1396
+ self,
1397
+ input_ids,
1398
+ past_key_values=None,
1399
+ inputs_embeds=None,
1400
+ pixel_values=None,
1401
+ image_sizes=None,
1402
+ attention_mask=None,
1403
+ **kwargs,
1404
+ ):
1405
+ if past_key_values is not None:
1406
+ if isinstance(past_key_values, Cache):
1407
+ cache_length = past_key_values.get_seq_length()
1408
+ past_length = past_key_values.seen_tokens
1409
+ else:
1410
+ cache_length = past_length = past_key_values[0][0].shape[2]
1411
+
1412
+ # Keep only the unprocessed tokens:
1413
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1414
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1415
+ # input)
1416
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1417
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1418
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1419
+ # input_ids based on the past_length.
1420
+ elif past_length < input_ids.shape[1]:
1421
+ input_ids = input_ids[:, past_length:]
1422
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
1423
+ elif self.config.image_token_index in input_ids:
1424
+ input_ids = input_ids[:, input_ids.shape[1] - 1 :]
1425
+ # If the cache has seen more tokens than it can hold, then the cache has a size limit. Let's discard the
1426
+ # older attention values, as their corresponding values are not part of the input.
1427
+ if cache_length < past_length and attention_mask is not None:
1428
+ attention_mask = attention_mask[:, -(cache_length + input_ids.shape[1]) :]
1429
+
1430
+ position_ids = kwargs.get("position_ids", None)
1431
+ if attention_mask is not None and position_ids is None:
1432
+ # create position_ids on the fly for batch generation
1433
+ position_ids = attention_mask.long().cumsum(-1) - 1
1434
+ position_ids.masked_fill_(attention_mask == 0, 1)
1435
+ if past_key_values:
1436
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1437
+
1438
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1439
+ if inputs_embeds is not None and past_key_values is None:
1440
+ model_inputs = {"inputs_embeds": inputs_embeds}
1441
+ else:
1442
+ model_inputs = {"input_ids": input_ids}
1443
+
1444
+ model_inputs.update(
1445
+ {
1446
+ "position_ids": position_ids,
1447
+ "past_key_values": past_key_values,
1448
+ "use_cache": kwargs.get("use_cache"),
1449
+ "attention_mask": attention_mask,
1450
+ "pixel_values": pixel_values,
1451
+ "image_sizes": image_sizes,
1452
+ }
1453
+ )
1454
+ return model_inputs
1455
+
1456
+ def _reorder_cache(self, *args, **kwargs):
1457
+ return self.language_model._reorder_cache(*args, **kwargs)
1458
+
1459
+ AutoConfig.register("magma", MagmaConfig)
1460
+ AutoModelForCausalLM.register(MagmaConfig, MagmaForConditionalGeneration)