aho-tai's picture
Upload VisionPixtralEncoderDecoder
fba310b verified
import torch
import torch.nn as nn
from typing import Optional, Union, Tuple
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import (
shift_tokens_right,
VisionEncoderDecoderModel
)
from transformers.modeling_outputs import Seq2SeqLMOutput
from transformers import PreTrainedModel
from transformers.models.pixtral.modeling_pixtral import apply_rotary_pos_emb, PixtralAttention, PixtralVisionModel
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_outputs import BaseModelOutput
from pixtral_encoder_decoder.configuration import PixtralVisionModelBatchConfig, VisionPixtralEncoderDecoderConfig
def position_ids_in_meshgrid_batch(patch_embeds, max_width):
"""get the position ids of the batch. """
# unlike flattened patch_embeds, we use the padded ones, which mean each entry has the same w/h and thus the same ids
height, width = patch_embeds.shape[-2:]
mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1)
ids = h_grid * max_width + v_grid
# expand ids to batch size
ids = ids.reshape(1, -1).repeat(patch_embeds.shape[0], 1)
return ids
def create_attention_mask_batch(w, h, image_sizes, patch_size):
def foo(i, j):
return ((torch.arange(h).unsqueeze(1) < i) & (torch.arange(w).unsqueeze(0) < j)).float()
mask = [foo(size[0] // patch_size, size[1] // patch_size) for size in image_sizes]
return torch.stack(mask, dim=0)
def pixtral_attention_fix_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, patches, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, patches, -1)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
# monkey patch a fix for unsqueeze dim for position embedds (since our input is batched and the old one is not)
PixtralAttention.forward = pixtral_attention_fix_forward
class PixtralVisionModelBatch(PixtralVisionModel):
config_class = PixtralVisionModelBatchConfig
def __init__(self, config):
super().__init__(config)
def forward(
self,
pixel_values: torch.Tensor,
image_sizes: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
*args,
**kwargs,
) -> Union[Tuple, BaseModelOutput]:
"""
Returns:
pixel_values: tensor of token features for
all tokens of all images of shape (N_toks, D)
"""
if attention_mask is None and image_sizes is None:
raise ValueError("Either `attention_mask` or `image_sizes` must be defined")
# pass images through initial convolution independently
patch_embeds = self.patch_conv(pixel_values)
# build attention mask based on image_sizes if not provided
if attention_mask is None:
h, w = patch_embeds.shape[-2:]
attention_mask = create_attention_mask_batch(w, h, image_sizes, self.patch_size).to(patch_embeds.device)
attention_mask = attention_mask.flatten(start_dim=-2)
# positional embeddings
position_ids = position_ids_in_meshgrid_batch(
patch_embeds, max_width=self.config.image_size // self.config.patch_size
)
position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids)
# flatten patch_embeds
# seq_len = (h*w); hidden x seq_len -> seq_len x hidden.
patch_embeds = patch_embeds.flatten(start_dim=-2).transpose(-1, -2)
attention_mask = _prepare_4d_attention_mask(attention_mask, torch.float)
patch_embeds = self.ln_pre(patch_embeds)
out = self.transformer(
patch_embeds,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
return out
class VisionPixtralEncoderDecoder(VisionEncoderDecoderModel):
config_class = VisionPixtralEncoderDecoderConfig
def __init__(self, config,
encoder: Optional[PixtralVisionModelBatch] = None,
decoder: Optional[PreTrainedModel] = None):
super().__init__(config, encoder, decoder)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
image_sizes: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.BoolTensor] = None,
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# num_items_in_batch is only needed for loss computation
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
kwargs_decoder = {
argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
if encoder_outputs is None:
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
if encoder_attention_mask is None and image_sizes is None:
raise ValueError("Either `encoder_attention_mask` or `image_sizes` must be defined")
if encoder_attention_mask is None:
h, w = pixel_values.shape[-2:]
h = h // self.encoder.patch_size # simulate convolution to get num_patches
w = w // self.encoder.patch_size # simulate convolution to get num_patches
encoder_attention_mask = create_attention_mask_batch(w, h, image_sizes, self.encoder.patch_size)
encoder_attention_mask = encoder_attention_mask.to(pixel_values.device)
encoder_attention_mask = encoder_attention_mask.flatten(start_dim=-2)
encoder_outputs = self.encoder(
pixel_values=pixel_values,
image_sizes=image_sizes,
attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs_encoder,
)
elif isinstance(encoder_outputs, tuple):
encoder_outputs = BaseModelOutput(*encoder_outputs)
encoder_hidden_states = encoder_outputs[0]
# optionally project encoder_hidden_states
if (
self.encoder.config.hidden_size != self.decoder.config.hidden_size
and self.decoder.config.cross_attention_hidden_size is None
):
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
# else:
# encoder_attention_mask = None
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
**kwargs_decoder,
)
# Compute loss independent from decoder (as some shift the logits inside them)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.decoder.config.vocab_size,
num_items_in_batch=num_items_in_batch,
)
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)