multi-modal-llama-tp1 / modeling_ocismllama.py
AlexHung29629's picture
Update modeling_ocismllama.py
aa9a8fc verified
"""
MIT License
Copyright (c) 2023 Fixie.ai
2024 Alex Hung
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers import AutoConfig, AutoModel, WhisperConfig
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import (BaseModelOutput,
CausalLMOutputWithPast)
from transformers.modeling_utils import ModuleUtilsMixin
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.mllama.modeling_mllama import (
MllamaForCausalLM, MllamaPreTrainedModel, MllamaVisionModel,
_prepare_cross_attention_mask)
from transformers.models.whisper.modeling_whisper import WhisperEncoder
from transformers.utils import logging
from .configuration_ocismllama import MllamaAudioConfig, OcisMllamaConfig
logger = logging.get_logger(__name__)
class OcisMllamaPreTrainedModel(MllamaPreTrainedModel):
config_class = OcisMllamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = [
"MllamaVisionEncoderLayer",
"MllamaCrossAttentionDecoderLayer",
"MllamaSelfAttentionDecoderLayer",
"WhisperEncoderLayer",
"WhisperDecoderLayer",
]
_supports_cache_class = True
_supports_static_cache = False # static cache cannot have different shapes for each layer
_supports_sdpa = True
_supports_quantized_cache = True
class OcisMllamaForConditionalGeneration(OcisMllamaPreTrainedModel, GenerationMixin):
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
def __init__(self, config: OcisMllamaConfig):
super().__init__(config)
self.vocab_size = config.text_config.vocab_size
self.hidden_size = config.text_config.hidden_size
self.max_num_tiles = config.vision_config.max_num_tiles
self.vision_output_dim = config.vision_config.vision_output_dim
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
self.vision_model = MllamaVisionModel._from_config(config.vision_config)
self.language_model = MllamaForCausalLM._from_config(config.text_config)
self.multi_modal_projector = nn.Linear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
)
whisper_config = WhisperConfig.from_pretrained(config.audio_config.audio_model_id)
self.audio_model = ModifiedWhisperEncoder._from_config(whisper_config)
self.audio_projector = UltravoxProjector(config.audio_config)
self.post_init()
def get_input_embeddings(self):
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.language_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)
def get_decoder(self):
return self.language_model.get_decoder()
def tie_weights(self):
return self.language_model.tie_weights()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
audio_values: Optional[torch.FloatTensor] = None,
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_len: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
aspect_ratio_mask: Optional[torch.Tensor] = None,
aspect_ratio_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cross_attention_mask: Optional[torch.Tensor] = None,
cross_attention_states: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision"
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint)
>>> processor = AutoProcessor.from_pretrained(checkpoint)
>>> prompt = "<|image|>If I had to write a haiku for this one"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(text=prompt, images=image, return_tensors="pt")
>>> # Generate
>>> output = model.generate(**inputs, max_new_tokens=15)
>>> prompt_len = inputs.input_ids.shape[-1]
>>> generated_ids = output[:, prompt_len:]
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)
>>> print(generated_text)
[', it would be:.\\nA stop sign in Chinatown.\\n']
```
"""
if cache_position[0] > 0:
audio_values = None
pixel_values = None
cross_attention_mask = None
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if pixel_values is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
)
if pixel_values is not None and cross_attention_states is not None:
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously")
if pixel_values is not None:
if aspect_ratio_ids is None:
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided")
# get vision tokens from vision model
vision_outputs = self.vision_model(
pixel_values=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
)
cross_attention_states = vision_outputs[0]
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape(
-1, cross_attention_states.shape[-2], self.hidden_size
)
if cross_attention_mask is not None:
cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask(
cross_attention_mask,
num_vision_tokens=self.vision_model.num_patches,
dtype=self.dtype,
)
else:
full_text_row_masked_out_mask = None
if cross_attention_mask is not None and cache_position is not None:
cross_attention_mask = cross_attention_mask[:, :, cache_position]
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
if audio_values is not None:
inputs_embeds = self.get_input_embeddings().forward(input_ids)
assert (
audio_token_start_idx is not None and audio_token_len is not None
), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided."
assert (
len(audio_token_start_idx) == len(audio_token_len) == len(audio_values)
), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size."
# B x A/3200 x D
audio_tower_output = self.audio_model.forward(
audio_values.to(self.audio_model.dtype),
audio_len = audio_len
).last_hidden_state
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype)
audio_embeds = self.audio_projector.forward(audio_tower_output)
# combine audio and text embeddings
for i, (audio, start, length) in enumerate(
zip(audio_embeds, audio_token_start_idx, audio_token_len)
):
assert length <= audio.shape[0]
inputs_embeds[i, start : start + length].copy_(audio[:length])
input_ids = None
outputs = self.language_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
past_key_values=past_key_values,
use_cache=use_cache,
inputs_embeds=inputs_embeds,
labels=labels,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
)
return outputs
def prepare_inputs_for_generation(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
position_ids=None,
audio_values: Optional[torch.FloatTensor] = None,
audio_token_start_idx: Optional[torch.Tensor] = None,
audio_token_len: Optional[torch.Tensor] = None,
audio_len: Optional[torch.Tensor] = None,
pixel_values=None,
aspect_ratio_ids=None,
aspect_ratio_mask=None,
cross_attention_mask=None,
past_key_values=None,
use_cache=False,
cache_position=None,
num_logits_to_keep=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None:
if inputs_embeds is not None: # Exception 1
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
# TODO: we have no attention_mask so this won't work, check if we really won't need attention mask and find another way
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
# This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
position_ids = position_ids.clone(memory_format=torch.contiguous_format)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
else:
# The clone here is for the same reason as for `position_ids`.
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep
model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"cross_attention_mask": cross_attention_mask,
}
)
prefill_start_idx = 0 if cache_position is None else cache_position[0]
if (
audio_values is not None
and audio_token_start_idx is not None
and prefill_start_idx <= torch.max(audio_token_start_idx)
):
model_inputs["audio_values"] = audio_values
model_inputs["audio_token_start_idx"] = (
audio_token_start_idx - prefill_start_idx
)
model_inputs["audio_token_len"] = audio_token_len
model_inputs["audio_len"] = audio_len
# If we're in pre-fill or cacheless decoding step, then we need pixel_values and aspect ratios
# to compute image hidden states, otherwise they are cached within each cross attn layer
if cache_position[0] == 0:
model_inputs["pixel_values"] = pixel_values
model_inputs["aspect_ratio_ids"] = aspect_ratio_ids
model_inputs["aspect_ratio_mask"] = aspect_ratio_mask
return model_inputs
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None)
model_kwargs = super()._update_model_kwargs_for_generation(
outputs=outputs,
model_kwargs=model_kwargs,
is_encoder_decoder=is_encoder_decoder,
**kwargs,
)
# add cross-attn mask for new token
if cross_attention_mask_prev is not None:
model_kwargs["cross_attention_mask"] = torch.cat(
[cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1
)
return model_kwargs
class StackAudioFrames(nn.Module):
"""
Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`.
The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames.
NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor,
we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings.
In most cases this extra padding will get removed in the model's forward function so it has no effect.
"""
def __init__(self, stack_factor: int = 8):
super().__init__()
self.stack_factor = stack_factor
def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor:
B, T, C = audio_embeds.shape
T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor
audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor))
B, T, C = audio_embeds.shape
audio_embeds = audio_embeds.view(
B, T // self.stack_factor, C * self.stack_factor
)
return audio_embeds
class RMSNorm(LlamaRMSNorm):
def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6):
super().__init__(hidden_size=hidden_size, eps=eps)
self.weight.data.fill_(init)
class SwiGLU(nn.Module):
def forward(self, x):
x, gate = x.chunk(2, dim=-1)
return F.silu(gate) * x
class UltravoxProjector(nn.Sequential):
def __init__(self, config: MllamaAudioConfig):
super().__init__()
self.hidden_dim = config.hidden_size
self._pad_and_stack = StackAudioFrames(config.stack_factor)
dim = config.input_hidden_size * config.stack_factor
self.ln_pre = RMSNorm(dim, init=config.norm_init)
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False)
dim = self.hidden_dim
self.act = SwiGLU()
dim = dim // 2
self.linear_2 = nn.Linear(dim, config.output_hidden_size, bias=False)
self.ln_post = RMSNorm(config.output_hidden_size, init=config.norm_init)
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
audio_features = self._pad_and_stack(audio_features)
audio_features = self.ln_pre(audio_features)
hidden_states = self.linear_1(audio_features)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
hidden_states = self.ln_post(hidden_states)
return hidden_states
class ModifiedWhisperEncoder(WhisperEncoder, ModuleUtilsMixin):
"""
Encoder portion of OpenAI's Whisper model.
This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes:
1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder
2. allow less than 30 second of audio padding to be passed in:
- relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal
- embed_pos is now sliced to match the length of `inputs_embeds`
Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py
"""
base_model_prefix = "model.encoder"
_no_split_modules = ["WhisperEncoderLayer"]
def forward(
self,
input_features,
audio_len=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
expected_seq_length = (
self.config.max_source_positions
* self.conv1.stride[0]
* self.conv2.stride[0]
)
if input_features.shape[-1] > expected_seq_length:
raise ValueError(
f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}."
)
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)]
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(
hidden_states, p=self.dropout, training=self.training
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
attention_mask = None
if audio_len != None:
audio_feature_len = self._get_feat_extract_output_lengths(audio_len)
batch_size = hidden_states.shape[0]
max_seq_len = hidden_states.shape[1]
attention_mask = (
torch.arange(max_seq_len, device=hidden_states.device)[None, :]
.expand(batch_size, -1)
.lt(audio_feature_len.view(batch_size, 1))
)
attention_mask = self.get_extended_attention_mask(
attention_mask,
None,
device=hidden_states.device,
dtype=hidden_states.dtype,
)
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
to_drop = False
if self.training:
dropout_probability = torch.rand([])
if dropout_probability < self.layerdrop: # skip the layer
to_drop = True
if to_drop:
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask=(
head_mask[idx] if head_mask is not None else None
),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [hidden_states, encoder_states, all_attentions]
if v is not None
)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
OcisMllamaConfig.register_for_auto_class()
OcisMllamaForConditionalGeneration.register_for_auto_class()
AutoConfig.register("ocismllama", OcisMllamaConfig)
AutoModel.register(OcisMllamaConfig, OcisMllamaForConditionalGeneration)