""" MIT License Copyright (c) 2023 Fixie.ai 2024 Alex Hung Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ from typing import List, Optional, Tuple, Union import torch import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from transformers import AutoConfig, AutoModel, WhisperConfig from transformers.generation import GenerationMixin from transformers.modeling_outputs import (BaseModelOutput, CausalLMOutputWithPast) from transformers.modeling_utils import ModuleUtilsMixin from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.models.mllama.modeling_mllama import ( MllamaForCausalLM, MllamaPreTrainedModel, MllamaVisionModel, _prepare_cross_attention_mask) from transformers.models.whisper.modeling_whisper import WhisperEncoder from transformers.utils import logging from .configuration_ocismllama import MllamaAudioConfig, OcisMllamaConfig logger = logging.get_logger(__name__) class OcisMllamaPreTrainedModel(MllamaPreTrainedModel): config_class = OcisMllamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = [ "MllamaVisionEncoderLayer", "MllamaCrossAttentionDecoderLayer", "MllamaSelfAttentionDecoderLayer", "WhisperEncoderLayer", "WhisperDecoderLayer", ] _supports_cache_class = True _supports_static_cache = False # static cache cannot have different shapes for each layer _supports_sdpa = True _supports_quantized_cache = True class OcisMllamaForConditionalGeneration(OcisMllamaPreTrainedModel, GenerationMixin): _supports_quantized_cache = False # quant cache not supported in encoder-decoder setting def __init__(self, config: OcisMllamaConfig): super().__init__(config) self.vocab_size = config.text_config.vocab_size self.hidden_size = config.text_config.hidden_size self.max_num_tiles = config.vision_config.max_num_tiles self.vision_output_dim = config.vision_config.vision_output_dim self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 self.vision_model = MllamaVisionModel._from_config(config.vision_config) self.language_model = MllamaForCausalLM._from_config(config.text_config) self.multi_modal_projector = nn.Linear( config.vision_config.vision_output_dim, config.text_config.hidden_size, bias=True, ) whisper_config = WhisperConfig.from_pretrained(config.audio_config.audio_model_id) self.audio_model = ModifiedWhisperEncoder._from_config(whisper_config) self.audio_projector = UltravoxProjector(config.audio_config) self.post_init() def get_input_embeddings(self): return self.language_model.get_input_embeddings() def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) def get_output_embeddings(self): return self.language_model.get_output_embeddings() def set_output_embeddings(self, new_embeddings): self.language_model.set_output_embeddings(new_embeddings) def set_decoder(self, decoder): self.language_model.set_decoder(decoder) def get_decoder(self): return self.language_model.get_decoder() def tie_weights(self): return self.language_model.tie_weights() def forward( self, input_ids: Optional[torch.LongTensor] = None, audio_values: Optional[torch.FloatTensor] = None, audio_token_start_idx: Optional[torch.Tensor] = None, audio_len: Optional[torch.Tensor] = None, audio_token_len: Optional[torch.Tensor] = None, pixel_values: Optional[torch.FloatTensor] = None, aspect_ratio_mask: Optional[torch.Tensor] = None, aspect_ratio_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, num_logits_to_keep: int = 0, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. num_logits_to_keep (`int`, *optional*): Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that token can save memory, which becomes pretty significant for long sequences or large vocabulary size. Returns: Example: ```python >>> from PIL import Image >>> import requests >>> from transformers import AutoProcessor, MllamaForConditionalGeneration >>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" >>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint) >>> processor = AutoProcessor.from_pretrained(checkpoint) >>> prompt = "<|image|>If I had to write a haiku for this one" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> image = Image.open(requests.get(url, stream=True).raw) >>> inputs = processor(text=prompt, images=image, return_tensors="pt") >>> # Generate >>> output = model.generate(**inputs, max_new_tokens=15) >>> prompt_len = inputs.input_ids.shape[-1] >>> generated_ids = output[:, prompt_len:] >>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) >>> print(generated_text) [', it would be:.\\nA stop sign in Chinatown.\\n'] ``` """ if cache_position[0] > 0: audio_values = None pixel_values = None cross_attention_mask = None output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if pixel_values is not None and inputs_embeds is not None: raise ValueError( "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" ) if pixel_values is not None and cross_attention_states is not None: raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") if pixel_values is not None: if aspect_ratio_ids is None: raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") # get vision tokens from vision model vision_outputs = self.vision_model( pixel_values=pixel_values, aspect_ratio_ids=aspect_ratio_ids, aspect_ratio_mask=aspect_ratio_mask, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, ) cross_attention_states = vision_outputs[0] cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( -1, cross_attention_states.shape[-2], self.hidden_size ) if cross_attention_mask is not None: cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( cross_attention_mask, num_vision_tokens=self.vision_model.num_patches, dtype=self.dtype, ) else: full_text_row_masked_out_mask = None if cross_attention_mask is not None and cache_position is not None: cross_attention_mask = cross_attention_mask[:, :, cache_position] full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] if audio_values is not None: inputs_embeds = self.get_input_embeddings().forward(input_ids) assert ( audio_token_start_idx is not None and audio_token_len is not None ), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided." assert ( len(audio_token_start_idx) == len(audio_token_len) == len(audio_values) ), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size." # B x A/3200 x D audio_tower_output = self.audio_model.forward( audio_values.to(self.audio_model.dtype), audio_len = audio_len ).last_hidden_state audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) audio_embeds = self.audio_projector.forward(audio_tower_output) # combine audio and text embeddings for i, (audio, start, length) in enumerate( zip(audio_embeds, audio_token_start_idx, audio_token_len) ): assert length <= audio.shape[0] inputs_embeds[i, start : start + length].copy_(audio[:length]) input_ids = None outputs = self.language_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, use_cache=use_cache, inputs_embeds=inputs_embeds, labels=labels, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_dict=return_dict, cache_position=cache_position, num_logits_to_keep=num_logits_to_keep, ) return outputs def prepare_inputs_for_generation( self, input_ids=None, inputs_embeds=None, attention_mask=None, position_ids=None, audio_values: Optional[torch.FloatTensor] = None, audio_token_start_idx: Optional[torch.Tensor] = None, audio_token_len: Optional[torch.Tensor] = None, audio_len: Optional[torch.Tensor] = None, pixel_values=None, aspect_ratio_ids=None, aspect_ratio_mask=None, cross_attention_mask=None, past_key_values=None, use_cache=False, cache_position=None, num_logits_to_keep=None, **kwargs, ): # Overwritten -- in specific circumstances we don't want to forward image inputs to the model # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens # Exception 1: when passing input_embeds, input_ids may be missing entries # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here if past_key_values is not None: if inputs_embeds is not None: # Exception 1 input_ids = input_ids[:, -cache_position.shape[0] :] elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) input_ids = input_ids[:, cache_position] # TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. position_ids = position_ids.clone(memory_format=torch.contiguous_format) # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and cache_position[0] == 0: model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} else: # The clone here is for the same reason as for `position_ids`. model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} if num_logits_to_keep is not None: model_inputs["num_logits_to_keep"] = num_logits_to_keep model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, "use_cache": use_cache, "attention_mask": attention_mask, "cross_attention_mask": cross_attention_mask, } ) prefill_start_idx = 0 if cache_position is None else cache_position[0] if ( audio_values is not None and audio_token_start_idx is not None and prefill_start_idx <= torch.max(audio_token_start_idx) ): model_inputs["audio_values"] = audio_values model_inputs["audio_token_start_idx"] = ( audio_token_start_idx - prefill_start_idx ) model_inputs["audio_token_len"] = audio_token_len model_inputs["audio_len"] = audio_len # If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios # to compute image hidden states, otherwise they are cached within each cross attn layer if cache_position[0] == 0: model_inputs["pixel_values"] = pixel_values model_inputs["aspect_ratio_ids"] = aspect_ratio_ids model_inputs["aspect_ratio_mask"] = aspect_ratio_mask return model_inputs def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) model_kwargs = super()._update_model_kwargs_for_generation( outputs=outputs, model_kwargs=model_kwargs, is_encoder_decoder=is_encoder_decoder, **kwargs, ) # add cross-attn mask for new token if cross_attention_mask_prev is not None: model_kwargs["cross_attention_mask"] = torch.cat( [cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 ) return model_kwargs class StackAudioFrames(nn.Module): """ Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`. The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames. NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor, we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings. In most cases this extra padding will get removed in the model's forward function so it has no effect. """ def __init__(self, stack_factor: int = 8): super().__init__() self.stack_factor = stack_factor def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: B, T, C = audio_embeds.shape T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor)) B, T, C = audio_embeds.shape audio_embeds = audio_embeds.view( B, T // self.stack_factor, C * self.stack_factor ) return audio_embeds class RMSNorm(LlamaRMSNorm): def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6): super().__init__(hidden_size=hidden_size, eps=eps) self.weight.data.fill_(init) class SwiGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x class UltravoxProjector(nn.Sequential): def __init__(self, config: MllamaAudioConfig): super().__init__() self.hidden_dim = config.hidden_size self._pad_and_stack = StackAudioFrames(config.stack_factor) dim = config.input_hidden_size * config.stack_factor self.ln_pre = RMSNorm(dim, init=config.norm_init) self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) dim = self.hidden_dim self.act = SwiGLU() dim = dim // 2 self.linear_2 = nn.Linear(dim, config.output_hidden_size, bias=False) self.ln_post = RMSNorm(config.output_hidden_size, init=config.norm_init) def forward(self, audio_features: torch.Tensor) -> torch.Tensor: audio_features = self._pad_and_stack(audio_features) audio_features = self.ln_pre(audio_features) hidden_states = self.linear_1(audio_features) hidden_states = self.act(hidden_states) hidden_states = self.linear_2(hidden_states) hidden_states = self.ln_post(hidden_states) return hidden_states class ModifiedWhisperEncoder(WhisperEncoder, ModuleUtilsMixin): """ Encoder portion of OpenAI's Whisper model. This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes: 1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder 2. allow less than 30 second of audio padding to be passed in: - relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal - embed_pos is now sliced to match the length of `inputs_embeds` Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py """ base_model_prefix = "model.encoder" _no_split_modules = ["WhisperEncoderLayer"] def forward( self, input_features, audio_len=None, head_mask=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): expected_seq_length = ( self.config.max_source_positions * self.conv1.stride[0] * self.conv2.stride[0] ) if input_features.shape[-1] > expected_seq_length: raise ValueError( f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." ) output_attentions = ( output_attentions if output_attentions is not None else self.config.output_attentions ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) inputs_embeds = nn.functional.gelu(self.conv1(input_features)) inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) inputs_embeds = inputs_embeds.permute(0, 2, 1) embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)] hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout( hidden_states, p=self.dropout, training=self.training ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None attention_mask = None if audio_len != None: audio_feature_len = self._get_feat_extract_output_lengths(audio_len) batch_size = hidden_states.shape[0] max_seq_len = hidden_states.shape[1] attention_mask = ( torch.arange(max_seq_len, device=hidden_states.device)[None, :] .expand(batch_size, -1) .lt(audio_feature_len.view(batch_size, 1)) ) attention_mask = self.get_extended_attention_mask( attention_mask, None, device=hidden_states.device, dtype=hidden_states.dtype, ) # check if head_mask has a correct number of layers specified if desired if head_mask is not None: assert head_mask.size()[0] == ( len(self.layers) ), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." for idx, encoder_layer in enumerate(self.layers): if output_hidden_states: encoder_states = encoder_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) to_drop = False if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: # skip the layer to_drop = True if to_drop: layer_outputs = (None, None) else: if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( encoder_layer.__call__, hidden_states, attention_mask, (head_mask[idx] if head_mask is not None else None), output_attentions, ) else: layer_outputs = encoder_layer( hidden_states, attention_mask, layer_head_mask=( head_mask[idx] if head_mask is not None else None ), output_attentions=output_attentions, ) hidden_states = layer_outputs[0] if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) hidden_states = self.layer_norm(hidden_states) if output_hidden_states: encoder_states = encoder_states + (hidden_states,) if not return_dict: return tuple( v for v in [hidden_states, encoder_states, all_attentions] if v is not None ) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions, ) OcisMllamaConfig.register_for_auto_class() OcisMllamaForConditionalGeneration.register_for_auto_class() AutoConfig.register("ocismllama", OcisMllamaConfig) AutoModel.register(OcisMllamaConfig, OcisMllamaForConditionalGeneration)