|
import torch |
|
import torch.nn as nn |
|
|
|
from typing import Optional, Union, Tuple |
|
|
|
from transformers.models.vision_encoder_decoder.modeling_vision_encoder_decoder import ( |
|
shift_tokens_right, |
|
VisionEncoderDecoderModel |
|
) |
|
from transformers.modeling_outputs import Seq2SeqLMOutput |
|
from transformers import PreTrainedModel |
|
from transformers.models.pixtral.modeling_pixtral import apply_rotary_pos_emb, PixtralAttention, PixtralVisionModel |
|
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask |
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
from pixtral_encoder_decoder.configuration import PixtralVisionModelBatchConfig, VisionPixtralEncoderDecoderConfig |
|
|
|
|
|
def position_ids_in_meshgrid_batch(patch_embeds, max_width): |
|
"""get the position ids of the batch. """ |
|
|
|
height, width = patch_embeds.shape[-2:] |
|
mesh = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") |
|
h_grid, v_grid = torch.stack(mesh, dim=-1).reshape(-1, 2).chunk(2, -1) |
|
ids = h_grid * max_width + v_grid |
|
|
|
ids = ids.reshape(1, -1).repeat(patch_embeds.shape[0], 1) |
|
return ids |
|
|
|
|
|
def create_attention_mask_batch(w, h, image_sizes, patch_size): |
|
def foo(i, j): |
|
return ((torch.arange(h).unsqueeze(1) < i) & (torch.arange(w).unsqueeze(0) < j)).float() |
|
|
|
mask = [foo(size[0] // patch_size, size[1] // patch_size) for size in image_sizes] |
|
return torch.stack(mask, dim=0) |
|
|
|
|
|
def pixtral_attention_fix_forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
output_attentions: Optional[bool] = False, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
|
"""Input shape: Batch x Time x Channel""" |
|
|
|
batch_size, patches, _ = hidden_states.size() |
|
|
|
query_states = self.q_proj(hidden_states) |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) |
|
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
cos, sin = position_embeddings |
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=1) |
|
|
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale |
|
|
|
if attention_mask is not None: |
|
attn_weights = attn_weights + attention_mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(batch_size, patches, -1) |
|
|
|
attn_output = self.o_proj(attn_output) |
|
|
|
return attn_output, attn_weights |
|
|
|
|
|
|
|
PixtralAttention.forward = pixtral_attention_fix_forward |
|
|
|
|
|
class PixtralVisionModelBatch(PixtralVisionModel): |
|
config_class = PixtralVisionModelBatchConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
def forward( |
|
self, |
|
pixel_values: torch.Tensor, |
|
image_sizes: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
*args, |
|
**kwargs, |
|
) -> Union[Tuple, BaseModelOutput]: |
|
""" |
|
Returns: |
|
pixel_values: tensor of token features for |
|
all tokens of all images of shape (N_toks, D) |
|
""" |
|
if attention_mask is None and image_sizes is None: |
|
raise ValueError("Either `attention_mask` or `image_sizes` must be defined") |
|
|
|
patch_embeds = self.patch_conv(pixel_values) |
|
|
|
if attention_mask is None: |
|
h, w = patch_embeds.shape[-2:] |
|
attention_mask = create_attention_mask_batch(w, h, image_sizes, self.patch_size).to(patch_embeds.device) |
|
attention_mask = attention_mask.flatten(start_dim=-2) |
|
|
|
|
|
position_ids = position_ids_in_meshgrid_batch( |
|
patch_embeds, max_width=self.config.image_size // self.config.patch_size |
|
) |
|
position_embeddings = self.patch_positional_embedding(patch_embeds, position_ids) |
|
|
|
|
|
|
|
patch_embeds = patch_embeds.flatten(start_dim=-2).transpose(-1, -2) |
|
|
|
attention_mask = _prepare_4d_attention_mask(attention_mask, torch.float) |
|
|
|
patch_embeds = self.ln_pre(patch_embeds) |
|
|
|
out = self.transformer( |
|
patch_embeds, |
|
attention_mask=attention_mask, |
|
position_embeddings=position_embeddings, |
|
output_hidden_states=output_hidden_states, |
|
output_attentions=output_attentions, |
|
return_dict=return_dict, |
|
) |
|
return out |
|
|
|
|
|
class VisionPixtralEncoderDecoder(VisionEncoderDecoderModel): |
|
config_class = VisionPixtralEncoderDecoderConfig |
|
|
|
def __init__(self, config, |
|
encoder: Optional[PixtralVisionModelBatch] = None, |
|
decoder: Optional[PreTrainedModel] = None): |
|
super().__init__(config, encoder, decoder) |
|
|
|
def forward( |
|
self, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
image_sizes: Optional[torch.Tensor] = None, |
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
decoder_input_ids: Optional[torch.LongTensor] = None, |
|
decoder_attention_mask: Optional[torch.BoolTensor] = None, |
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None, |
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, |
|
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
**kwargs, |
|
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
num_items_in_batch = kwargs.pop("num_items_in_batch", None) |
|
|
|
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")} |
|
|
|
kwargs_decoder = { |
|
argument[len("decoder_"):]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
|
} |
|
|
|
if encoder_outputs is None: |
|
if pixel_values is None: |
|
raise ValueError("You have to specify pixel_values") |
|
if encoder_attention_mask is None and image_sizes is None: |
|
raise ValueError("Either `encoder_attention_mask` or `image_sizes` must be defined") |
|
if encoder_attention_mask is None: |
|
h, w = pixel_values.shape[-2:] |
|
h = h // self.encoder.patch_size |
|
w = w // self.encoder.patch_size |
|
encoder_attention_mask = create_attention_mask_batch(w, h, image_sizes, self.encoder.patch_size) |
|
encoder_attention_mask = encoder_attention_mask.to(pixel_values.device) |
|
encoder_attention_mask = encoder_attention_mask.flatten(start_dim=-2) |
|
|
|
encoder_outputs = self.encoder( |
|
pixel_values=pixel_values, |
|
image_sizes=image_sizes, |
|
attention_mask=encoder_attention_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
**kwargs_encoder, |
|
) |
|
elif isinstance(encoder_outputs, tuple): |
|
encoder_outputs = BaseModelOutput(*encoder_outputs) |
|
|
|
encoder_hidden_states = encoder_outputs[0] |
|
|
|
|
|
if ( |
|
self.encoder.config.hidden_size != self.decoder.config.hidden_size |
|
and self.decoder.config.cross_attention_hidden_size is None |
|
): |
|
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) |
|
|
|
|
|
|
|
|
|
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): |
|
decoder_input_ids = shift_tokens_right( |
|
labels, self.config.pad_token_id, self.config.decoder_start_token_id |
|
) |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
encoder_hidden_states=encoder_hidden_states, |
|
encoder_attention_mask=encoder_attention_mask, |
|
inputs_embeds=decoder_inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
use_cache=use_cache, |
|
past_key_values=past_key_values, |
|
return_dict=return_dict, |
|
**kwargs_decoder, |
|
) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
logits = decoder_outputs.logits if return_dict else decoder_outputs[0] |
|
|
|
loss = self.loss_function( |
|
logits=logits, |
|
labels=labels, |
|
vocab_size=self.decoder.config.vocab_size, |
|
num_items_in_batch=num_items_in_batch, |
|
) |
|
|
|
if not return_dict: |
|
if loss is not None: |
|
return (loss,) + decoder_outputs + encoder_outputs |
|
else: |
|
return decoder_outputs + encoder_outputs |
|
|
|
return Seq2SeqLMOutput( |
|
loss=loss, |
|
logits=decoder_outputs.logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|