import torch import torch.nn as nn from typing import Optional, Union, Tuple from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import ( shift_tokens_right, VisionEncoderDecoderModel ) from transformers.modeling_outputs import Seq2SeqLMOutput from transformers import PreTrainedModel from transformers.models.pixtral.modeling_pixtral import apply_rotary_pos_emb, PixtralAttention, PixtralVisionModel from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from transformers.modeling_outputs import BaseModelOutput from pixtral_encoder_decoder.configuration import PixtralVisionModelBatchConfig, VisionPixtralEncoderDecoderConfig def position_ids_in_meshgrid_batch(patch_embeds, max_width): """get the position ids of the batch. """ # unlike flattened patch_embeds, we use the padded ones, which mean each entry has the same w/h and thus the same ids height, width = patch_embeds.shape[-2:] mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) ids = h_grid * max_width + v_grid # expand ids to batch size ids = ids.reshape(1, -1).repeat(patch_embeds.shape[0], 1) return ids def create_attention_mask_batch(w, h, image_sizes, patch_size): def foo(i, j): return ((torch.arange(h).unsqueeze(1) < i) & (torch.arange(w).unsqueeze(0) < j)).float() mask = [foo(size[0] // patch_size, size[1] // patch_size) for size in image_sizes] return torch.stack(mask, dim=0) def pixtral_attention_fix_forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Input shape: Batch x Time x Channel""" batch_size, patches, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale if attention_mask is not None: attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, patches, -1) attn_output = self.o_proj(attn_output) return attn_output, attn_weights # monkey patch a fix for unsqueeze dim for position embedds (since our input is batched and the old one is not) PixtralAttention.forward = pixtral_attention_fix_forward class PixtralVisionModelBatch(PixtralVisionModel): config_class = PixtralVisionModelBatchConfig def __init__(self, config): super().__init__(config) def forward( self, pixel_values: torch.Tensor, image_sizes: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: Optional[bool] = None, output_attentions: Optional[bool] = None, return_dict: Optional[bool] = None, *args, **kwargs, ) -> Union[Tuple, BaseModelOutput]: """ Returns: pixel_values: tensor of token features for all tokens of all images of shape (N_toks, D) """ if attention_mask is None and image_sizes is None: raise ValueError("Either `attention_mask` or `image_sizes` must be defined") # pass images through initial convolution independently patch_embeds = self.patch_conv(pixel_values) # build attention mask based on image_sizes if not provided if attention_mask is None: h, w = patch_embeds.shape[-2:] attention_mask = create_attention_mask_batch(w, h, image_sizes, self.patch_size).to(patch_embeds.device) attention_mask = attention_mask.flatten(start_dim=-2) # positional embeddings position_ids = position_ids_in_meshgrid_batch( patch_embeds, max_width=self.config.image_size // self.config.patch_size ) position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) # flatten patch_embeds # seq_len = (h*w); hidden x seq_len -> seq_len x hidden. patch_embeds = patch_embeds.flatten(start_dim=-2).transpose(-1, -2) attention_mask = _prepare_4d_attention_mask(attention_mask, torch.float) patch_embeds = self.ln_pre(patch_embeds) out = self.transformer( patch_embeds, attention_mask=attention_mask, position_embeddings=position_embeddings, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, ) return out class VisionPixtralEncoderDecoder(VisionEncoderDecoderModel): config_class = VisionPixtralEncoderDecoderConfig def __init__(self, config, encoder: Optional[PixtralVisionModelBatch] = None, decoder: Optional[PreTrainedModel] = None): super().__init__(config, encoder, decoder) def forward( self, pixel_values: Optional[torch.Tensor] = None, image_sizes: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, decoder_input_ids: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None, encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict # num_items_in_batch is only needed for loss computation num_items_in_batch = kwargs.pop("num_items_in_batch", None) kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} kwargs_decoder = { argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } if encoder_outputs is None: if pixel_values is None: raise ValueError("You have to specify pixel_values") if encoder_attention_mask is None and image_sizes is None: raise ValueError("Either `encoder_attention_mask` or `image_sizes` must be defined") if encoder_attention_mask is None: h, w = pixel_values.shape[-2:] h = h // self.encoder.patch_size # simulate convolution to get num_patches w = w // self.encoder.patch_size # simulate convolution to get num_patches encoder_attention_mask = create_attention_mask_batch(w, h, image_sizes, self.encoder.patch_size) encoder_attention_mask = encoder_attention_mask.to(pixel_values.device) encoder_attention_mask = encoder_attention_mask.flatten(start_dim=-2) encoder_outputs = self.encoder( pixel_values=pixel_values, image_sizes=image_sizes, attention_mask=encoder_attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, **kwargs_encoder, ) elif isinstance(encoder_outputs, tuple): encoder_outputs = BaseModelOutput(*encoder_outputs) encoder_hidden_states = encoder_outputs[0] # optionally project encoder_hidden_states if ( self.encoder.config.hidden_size != self.decoder.config.hidden_size and self.decoder.config.cross_attention_hidden_size is None ): encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) # else: # encoder_attention_mask = None if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, inputs_embeds=decoder_inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache, past_key_values=past_key_values, return_dict=return_dict, **kwargs_decoder, ) # Compute loss independent from decoder (as some shift the logits inside them) loss = None if labels is not None: logits = decoder_outputs.logits if return_dict else decoder_outputs[0] loss = self.loss_function( logits=logits, labels=labels, vocab_size=self.decoder.config.vocab_size, num_items_in_batch=num_items_in_batch, ) if not return_dict: if loss is not None: return (loss,) + decoder_outputs + encoder_outputs else: return decoder_outputs + encoder_outputs return Seq2SeqLMOutput( loss=loss, logits=decoder_outputs.logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, )