# Copyright 2024 Microsoft. All rights reserved. # Licensed under the MSRLA License. See LICENSE in the repo root for license information. from typing import Optional, List, Tuple, Union import torch from torch.nn import Linear, Module, Sequential from transformers import AutoBackbone, AutoModelForCausalLM, LlavaForConditionalGeneration, LlavaPreTrainedModel from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast from transformers.activations import ACT2FN from transformers.utils import check_min_version from .configuration_maira2 import Maira2Config class Maira2MultiModalProjector(Module): """ This class implements the multimodal projector for MAIRA-2 model. It projects the image features to the text hidden size via a series of linear layers (4 layers in MAIRA-2). """ def __init__(self, config: Maira2Config): super().__init__() n_layers = config.projector_n_layers if n_layers < 1: raise ValueError(f"Number of layers should be at least 1, got {n_layers=}") text_hidden_size = config.text_config.hidden_size vision_hidden_size = config.vision_config.hidden_size _layers = [Linear(vision_hidden_size, text_hidden_size, bias=True)] for _ in range(n_layers - 1): _layers.append(ACT2FN[config.projector_hidden_act]) _layers.append(Linear(text_hidden_size, text_hidden_size, bias=True)) self.layers = Sequential(*_layers) def forward(self, image_features: torch.Tensor) -> torch.FloatTensor: hidden_states = self.layers(image_features) return hidden_states # type: ignore[no-any-return] class Maira2ForConditionalGeneration(LlavaForConditionalGeneration): """ This model implements the multimodal model MAIRA-2. It consists of a vision backbone, a multimodal projector, and a language model. The model can be used for grounded and ungrounded report generation tasks as well as phrase grounding. This class inherits from `LlavaForConditionalGeneration`, defining a custom multimodal projector and changing image feature selection. """ config_class = Maira2Config def __init__(self, config: Maira2Config) -> None: # Check transformers version is at least 4.46.0.dev0 otherwise the model fails # silently since get_image_features is not called in the forward pass check_min_version("4.46.0.dev0") super(LlavaPreTrainedModel, self).__init__(config) self.vision_tower = AutoBackbone.from_config(config.vision_config) self.multi_modal_projector = Maira2MultiModalProjector(config) self.vocab_size = config.text_config.vocab_size self.language_model = AutoModelForCausalLM.from_config( config.text_config, attn_implementation=config._attn_implementation, ) self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.post_init() def get_image_features( self, pixel_values: torch.FloatTensor, vision_feature_layer: int, vision_feature_select_strategy: str ) -> torch.Tensor: """ This method extracts the image features from the vision backbone using the specified feature layer and selection strategy. This is custom to MAIRA-2 model since we want to use the `feature_maps` from the Dinov2Backbone class instead of the `hidden_states` which are used in the default implementation of `get_image_features` in LlavaForConditionalGeneration. The feature_maps returned by Dinov2Backbone are the hideen_states with a layernorm applied to them. """ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) selected_image_feature = image_outputs.feature_maps[vision_feature_layer] if vision_feature_select_strategy == "default": selected_image_feature = selected_image_feature[:, 1:] elif vision_feature_select_strategy == "full": selected_image_feature = selected_image_feature else: raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}") image_features = self.multi_modal_projector(selected_image_feature) return image_features # type: ignore[no-any-return] # modification from original, added forward from transformers 4.46 to prevent new preprocessing def forward( self, input_ids: torch.LongTensor = None, pixel_values: torch.FloatTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[int] = None, vision_feature_select_strategy: Optional[str] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, LlavaCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. num_logits_to_keep (`int`, *optional*): Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, LlavaForConditionalGeneration >>> model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") >>> processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") >>> prompt = "USER: \nWhat's the content of the image? ASSISTANT:" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(images=image, text=prompt, return_tensors="pt") >>> # Generate >>> generate_ids = model.generate(**inputs, max_new_tokens=15) >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] "USER: \nWhat's the content of the image? ASSISTANT: The image features a busy city street with a stop sign prominently displayed" ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict vision_feature_layer = ( vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer ) vision_feature_select_strategy = ( vision_feature_select_strategy if vision_feature_select_strategy is not None else self.config.vision_feature_select_strategy ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if pixel_values is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) legacy_processing = False if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing # not very reliable, but we don't expect one to actually pass 500+ images for one prompt # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True legacy_processing = ( (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length ) or (input_ids.shape[-1] == 1 and pixel_values is not None) if pixel_values is not None: image_features = self.get_image_features( pixel_values=pixel_values, vision_feature_layer=vision_feature_layer, vision_feature_select_strategy=vision_feature_select_strategy, ) print(image_features.shape) if legacy_processing: # prefill stage vs decoding stage (legacy behavior copied) if input_ids.shape[1] != 1: inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features( image_features, inputs_embeds, input_ids, attention_mask, labels ) cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) else: # Retrieve the first layer to inspect the logits and mask out the hidden states # that are set to 0 first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] # Sum all dimensions of head_dim (-2) to avoid random errors such as: https://github.com/huggingface/transformers/pull/28032#issuecomment-1863691941 batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) # Get the target length target_length = input_ids.shape[1] past_length = first_layer_past_key_value.shape[-1] extended_attention_mask = torch.ones( (attention_mask.shape[0], past_length), dtype=attention_mask.dtype, device=attention_mask.device, ) # Filter out only the tokens that can be un-attended, this can happen # if one uses Llava + Fused modules where the cache on the # first iteration is already big enough, or if one passes custom cache valid_indices = non_attended_tokens < extended_attention_mask.size(-1) new_batch_index = batch_index[valid_indices] new_non_attended_tokens = non_attended_tokens[valid_indices] # Zero-out the places where we don't need to attend extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ -target_length: ] # TODO: @raushan retain only the new behavior after v4.47 else: special_image_mask = ( (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) outputs = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, num_logits_to_keep=num_logits_to_keep, ) logits = outputs[0] loss = None if labels is not None: # Shift so that tokens < n predict n if attention_mask is not None: shift_attention_mask = attention_mask[..., 1:] shift_logits = logits[..., :-1, :][shift_attention_mask.to(logits.device) != 0].contiguous() shift_labels = labels[..., 1:][shift_attention_mask.to(labels.device) != 0].contiguous() else: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = torch.nn.CrossEntropyLoss() loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1).to(shift_logits.device) ) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return LlavaCausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, image_hidden_states=image_features if pixel_values is not None else None, ) def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels): num_images, num_image_patches, embed_dim = image_features.shape batch_size, sequence_length = input_ids.shape left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id)) # 1. Create a mask to know where special image tokens are special_image_token_mask = input_ids == self.config.image_token_index num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) # Compute the maximum embed dimension max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index) # 2. Compute the positions where text should be written # Calculate new positions for text tokens in merged image-text sequence. # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens. # `torch.cumsum` computes how each image token shifts subsequent text token positions. # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1 nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1] if left_padding: new_token_positions += nb_image_pad[:, None] # offset for left padding text_to_overwrite = new_token_positions[batch_indices, non_image_indices] # 3. Create the full embedding, already padded to the maximum position final_embedding = torch.zeros( batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device ) final_attention_mask = torch.zeros( batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device ) if labels is not None: final_labels = torch.full( (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device ) # In case the Vision model or the Language model has been offloaded to CPU, we need to manually # set the corresponding tensors into their correct target device. target_device = inputs_embeds.device batch_indices, non_image_indices, text_to_overwrite = ( batch_indices.to(target_device), non_image_indices.to(target_device), text_to_overwrite.to(target_device), ) attention_mask = attention_mask.to(target_device) # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] if labels is not None: final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) image_to_overwrite = torch.full( (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device ) image_to_overwrite[batch_indices, text_to_overwrite] = False image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device) if image_to_overwrite.sum() != image_features.shape[:-1].numel(): raise ValueError( f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while" f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation." ) final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) final_attention_mask |= image_to_overwrite position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens. batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id) indices_to_mask = new_token_positions[batch_indices, pad_indices] final_embedding[batch_indices, indices_to_mask] = 0 if labels is None: final_labels = None return final_embedding, final_attention_mask, final_labels, position_ids