|
""" |
|
MIT License |
|
|
|
Copyright (c) 2023 Fixie.ai |
|
2024 Alex Hung |
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a copy |
|
of this software and associated documentation files (the "Software"), to deal |
|
in the Software without restriction, including without limitation the rights |
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
|
copies of the Software, and to permit persons to whom the Software is |
|
furnished to do so, subject to the following conditions: |
|
|
|
The above copyright notice and this permission notice shall be included in all |
|
copies or substantial portions of the Software. |
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
|
SOFTWARE. |
|
""" |
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from transformers import AutoConfig, AutoModel, WhisperConfig |
|
from transformers.generation import GenerationMixin |
|
from transformers.modeling_outputs import (BaseModelOutput, |
|
CausalLMOutputWithPast) |
|
from transformers.modeling_utils import ModuleUtilsMixin |
|
from transformers.models.llama.modeling_llama import LlamaRMSNorm |
|
from transformers.models.mllama.modeling_mllama import ( |
|
MllamaForCausalLM, MllamaPreTrainedModel, MllamaVisionModel, |
|
_prepare_cross_attention_mask) |
|
from transformers.models.whisper.modeling_whisper import WhisperEncoder |
|
from transformers.utils import logging |
|
|
|
from .configuration_ocismllama import MllamaAudioConfig, OcisMllamaConfig |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
class OcisMllamaPreTrainedModel(MllamaPreTrainedModel): |
|
config_class = OcisMllamaConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_no_split_modules = [ |
|
"MllamaVisionEncoderLayer", |
|
"MllamaCrossAttentionDecoderLayer", |
|
"MllamaSelfAttentionDecoderLayer", |
|
"WhisperEncoderLayer", |
|
"WhisperDecoderLayer", |
|
] |
|
_supports_cache_class = True |
|
_supports_static_cache = False |
|
_supports_sdpa = True |
|
_supports_quantized_cache = True |
|
|
|
class OcisMllamaForConditionalGeneration(OcisMllamaPreTrainedModel, GenerationMixin): |
|
_supports_quantized_cache = False |
|
|
|
def __init__(self, config: OcisMllamaConfig): |
|
super().__init__(config) |
|
self.vocab_size = config.text_config.vocab_size |
|
self.hidden_size = config.text_config.hidden_size |
|
self.max_num_tiles = config.vision_config.max_num_tiles |
|
self.vision_output_dim = config.vision_config.vision_output_dim |
|
self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 |
|
|
|
self.vision_model = MllamaVisionModel._from_config(config.vision_config) |
|
self.language_model = MllamaForCausalLM._from_config(config.text_config) |
|
self.multi_modal_projector = nn.Linear( |
|
config.vision_config.vision_output_dim, |
|
config.text_config.hidden_size, |
|
bias=True, |
|
) |
|
whisper_config = WhisperConfig.from_pretrained(config.audio_config.audio_model_id) |
|
self.audio_model = ModifiedWhisperEncoder._from_config(whisper_config) |
|
self.audio_projector = UltravoxProjector(config.audio_config) |
|
self.post_init() |
|
|
|
def get_input_embeddings(self): |
|
return self.language_model.get_input_embeddings() |
|
|
|
def set_input_embeddings(self, value): |
|
self.language_model.set_input_embeddings(value) |
|
|
|
def get_output_embeddings(self): |
|
return self.language_model.get_output_embeddings() |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.language_model.set_output_embeddings(new_embeddings) |
|
|
|
def set_decoder(self, decoder): |
|
self.language_model.set_decoder(decoder) |
|
|
|
def get_decoder(self): |
|
return self.language_model.get_decoder() |
|
|
|
def tie_weights(self): |
|
return self.language_model.tie_weights() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
audio_values: Optional[torch.FloatTensor] = None, |
|
audio_token_start_idx: Optional[torch.Tensor] = None, |
|
audio_len: Optional[torch.Tensor] = None, |
|
audio_token_len: Optional[torch.Tensor] = None, |
|
pixel_values: Optional[torch.FloatTensor] = None, |
|
aspect_ratio_mask: Optional[torch.Tensor] = None, |
|
aspect_ratio_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_states: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
num_logits_to_keep: int = 0, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
r""" |
|
Args: |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
num_logits_to_keep (`int`, *optional*): |
|
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all |
|
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
|
token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
|
|
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, MllamaForConditionalGeneration |
|
|
|
>>> checkpoint = "meta-llama/Llama-3.2-11B-Vision" |
|
>>> model = MllamaForConditionalGeneration.from_pretrained(checkpoint) |
|
>>> processor = AutoProcessor.from_pretrained(checkpoint) |
|
|
|
>>> prompt = "<|image|>If I had to write a haiku for this one" |
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> inputs = processor(text=prompt, images=image, return_tensors="pt") |
|
|
|
>>> # Generate |
|
>>> output = model.generate(**inputs, max_new_tokens=15) |
|
|
|
>>> prompt_len = inputs.input_ids.shape[-1] |
|
>>> generated_ids = output[:, prompt_len:] |
|
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
>>> print(generated_text) |
|
[', it would be:.\\nA stop sign in Chinatown.\\n'] |
|
``` |
|
""" |
|
if cache_position[0] > 0: |
|
audio_values = None |
|
pixel_values = None |
|
cross_attention_mask = None |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
|
if pixel_values is not None and inputs_embeds is not None: |
|
raise ValueError( |
|
"You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" |
|
) |
|
|
|
if pixel_values is not None and cross_attention_states is not None: |
|
raise ValueError("`pixel_values` and `cross_attention_states` cannot be provided simultaneously") |
|
|
|
if pixel_values is not None: |
|
if aspect_ratio_ids is None: |
|
raise ValueError("`aspect_ratio_ids` must be provided if `pixel_values` is provided") |
|
|
|
vision_outputs = self.vision_model( |
|
pixel_values=pixel_values, |
|
aspect_ratio_ids=aspect_ratio_ids, |
|
aspect_ratio_mask=aspect_ratio_mask, |
|
output_hidden_states=output_hidden_states, |
|
output_attentions=output_attentions, |
|
return_dict=return_dict, |
|
) |
|
cross_attention_states = vision_outputs[0] |
|
cross_attention_states = self.multi_modal_projector(cross_attention_states).reshape( |
|
-1, cross_attention_states.shape[-2], self.hidden_size |
|
) |
|
|
|
if cross_attention_mask is not None: |
|
cross_attention_mask, full_text_row_masked_out_mask = _prepare_cross_attention_mask( |
|
cross_attention_mask, |
|
num_vision_tokens=self.vision_model.num_patches, |
|
dtype=self.dtype, |
|
) |
|
else: |
|
full_text_row_masked_out_mask = None |
|
|
|
if cross_attention_mask is not None and cache_position is not None: |
|
cross_attention_mask = cross_attention_mask[:, :, cache_position] |
|
full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position] |
|
|
|
if audio_values is not None: |
|
inputs_embeds = self.get_input_embeddings().forward(input_ids) |
|
assert ( |
|
audio_token_start_idx is not None and audio_token_len is not None |
|
), "audio_token_start_idx and audio_token_len must be provided if audio_values are provided." |
|
assert ( |
|
len(audio_token_start_idx) == len(audio_token_len) == len(audio_values) |
|
), "audio_token_start_idx, audio_token_len, and audio_values must have the same batch size." |
|
|
|
|
|
audio_tower_output = self.audio_model.forward( |
|
audio_values.to(self.audio_model.dtype), |
|
audio_len = audio_len |
|
).last_hidden_state |
|
audio_tower_output = audio_tower_output.to(inputs_embeds.dtype) |
|
|
|
audio_embeds = self.audio_projector.forward(audio_tower_output) |
|
|
|
|
|
for i, (audio, start, length) in enumerate( |
|
zip(audio_embeds, audio_token_start_idx, audio_token_len) |
|
): |
|
assert length <= audio.shape[0] |
|
inputs_embeds[i, start : start + length].copy_(audio[:length]) |
|
input_ids = None |
|
|
|
outputs = self.language_model( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
cross_attention_states=cross_attention_states, |
|
cross_attention_mask=cross_attention_mask, |
|
full_text_row_masked_out_mask=full_text_row_masked_out_mask, |
|
past_key_values=past_key_values, |
|
use_cache=use_cache, |
|
inputs_embeds=inputs_embeds, |
|
labels=labels, |
|
output_hidden_states=output_hidden_states, |
|
output_attentions=output_attentions, |
|
return_dict=return_dict, |
|
cache_position=cache_position, |
|
num_logits_to_keep=num_logits_to_keep, |
|
) |
|
|
|
return outputs |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids=None, |
|
inputs_embeds=None, |
|
attention_mask=None, |
|
position_ids=None, |
|
audio_values: Optional[torch.FloatTensor] = None, |
|
audio_token_start_idx: Optional[torch.Tensor] = None, |
|
audio_token_len: Optional[torch.Tensor] = None, |
|
audio_len: Optional[torch.Tensor] = None, |
|
pixel_values=None, |
|
aspect_ratio_ids=None, |
|
aspect_ratio_mask=None, |
|
cross_attention_mask=None, |
|
past_key_values=None, |
|
use_cache=False, |
|
cache_position=None, |
|
num_logits_to_keep=None, |
|
**kwargs, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
if inputs_embeds is not None: |
|
input_ids = input_ids[:, -cache_position.shape[0] :] |
|
elif input_ids.shape[1] != cache_position.shape[0]: |
|
input_ids = input_ids[:, cache_position] |
|
|
|
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if past_key_values: |
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
|
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
|
|
|
|
|
if inputs_embeds is not None and cache_position[0] == 0: |
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
|
else: |
|
|
|
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} |
|
|
|
if num_logits_to_keep is not None: |
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep |
|
|
|
model_inputs.update( |
|
{ |
|
"position_ids": position_ids, |
|
"cache_position": cache_position, |
|
"past_key_values": past_key_values, |
|
"use_cache": use_cache, |
|
"attention_mask": attention_mask, |
|
"cross_attention_mask": cross_attention_mask, |
|
} |
|
) |
|
|
|
prefill_start_idx = 0 if cache_position is None else cache_position[0] |
|
if ( |
|
audio_values is not None |
|
and audio_token_start_idx is not None |
|
and prefill_start_idx <= torch.max(audio_token_start_idx) |
|
): |
|
model_inputs["audio_values"] = audio_values |
|
model_inputs["audio_token_start_idx"] = ( |
|
audio_token_start_idx - prefill_start_idx |
|
) |
|
model_inputs["audio_token_len"] = audio_token_len |
|
model_inputs["audio_len"] = audio_len |
|
|
|
|
|
|
|
|
|
if cache_position[0] == 0: |
|
model_inputs["pixel_values"] = pixel_values |
|
model_inputs["aspect_ratio_ids"] = aspect_ratio_ids |
|
model_inputs["aspect_ratio_mask"] = aspect_ratio_mask |
|
|
|
return model_inputs |
|
|
|
def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs): |
|
cross_attention_mask_prev = model_kwargs.get("cross_attention_mask", None) |
|
model_kwargs = super()._update_model_kwargs_for_generation( |
|
outputs=outputs, |
|
model_kwargs=model_kwargs, |
|
is_encoder_decoder=is_encoder_decoder, |
|
**kwargs, |
|
) |
|
|
|
|
|
if cross_attention_mask_prev is not None: |
|
model_kwargs["cross_attention_mask"] = torch.cat( |
|
[cross_attention_mask_prev, cross_attention_mask_prev[:, -1:, ...]], dim=1 |
|
) |
|
return model_kwargs |
|
|
|
class StackAudioFrames(nn.Module): |
|
""" |
|
Stack the audio embedding frames to reduce the sequence length by a factor of `stack_factor`. |
|
|
|
The number of output frames will be `ceil(T / stack_factor) + 1` where `T` is the number of input frames. |
|
NOTE: the extra +1 is intentional: in case the number of audio tokens are over-estimated by the processor, |
|
we want to make sure `processor.audio_token_replacement` (i.e. EOS) doesn't get leaked into the middle of embeddings. |
|
In most cases this extra padding will get removed in the model's forward function so it has no effect. |
|
""" |
|
|
|
def __init__(self, stack_factor: int = 8): |
|
super().__init__() |
|
self.stack_factor = stack_factor |
|
|
|
def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: |
|
B, T, C = audio_embeds.shape |
|
T_pad = (T + self.stack_factor - 1) // self.stack_factor * self.stack_factor |
|
audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T + self.stack_factor)) |
|
B, T, C = audio_embeds.shape |
|
audio_embeds = audio_embeds.view( |
|
B, T // self.stack_factor, C * self.stack_factor |
|
) |
|
return audio_embeds |
|
|
|
class RMSNorm(LlamaRMSNorm): |
|
def __init__(self, hidden_size: int, init: float = 1, eps: float = 1e-6): |
|
super().__init__(hidden_size=hidden_size, eps=eps) |
|
self.weight.data.fill_(init) |
|
|
|
class SwiGLU(nn.Module): |
|
def forward(self, x): |
|
x, gate = x.chunk(2, dim=-1) |
|
return F.silu(gate) * x |
|
|
|
class UltravoxProjector(nn.Sequential): |
|
def __init__(self, config: MllamaAudioConfig): |
|
super().__init__() |
|
self.hidden_dim = config.hidden_size |
|
self._pad_and_stack = StackAudioFrames(config.stack_factor) |
|
dim = config.input_hidden_size * config.stack_factor |
|
self.ln_pre = RMSNorm(dim, init=config.norm_init) |
|
self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) |
|
dim = self.hidden_dim |
|
self.act = SwiGLU() |
|
dim = dim // 2 |
|
self.linear_2 = nn.Linear(dim, config.output_hidden_size, bias=False) |
|
self.ln_post = RMSNorm(config.output_hidden_size, init=config.norm_init) |
|
|
|
def forward(self, audio_features: torch.Tensor) -> torch.Tensor: |
|
audio_features = self._pad_and_stack(audio_features) |
|
audio_features = self.ln_pre(audio_features) |
|
hidden_states = self.linear_1(audio_features) |
|
hidden_states = self.act(hidden_states) |
|
hidden_states = self.linear_2(hidden_states) |
|
hidden_states = self.ln_post(hidden_states) |
|
return hidden_states |
|
|
|
|
|
class ModifiedWhisperEncoder(WhisperEncoder, ModuleUtilsMixin): |
|
""" |
|
Encoder portion of OpenAI's Whisper model. |
|
This implementation is a slightly modified version of HF Transformers' Whisper Encoder, with only a few fixes: |
|
1. base_model_prefix updated to allow for doing `.from_pretrained` directly on the encoder |
|
2. allow less than 30 second of audio padding to be passed in: |
|
- relaxed ValueError check for `input_features` length to be less than or equal to `expected_seq_length` instead of strictly equal |
|
- embed_pos is now sliced to match the length of `inputs_embeds` |
|
Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py |
|
""" |
|
|
|
base_model_prefix = "model.encoder" |
|
_no_split_modules = ["WhisperEncoderLayer"] |
|
|
|
def forward( |
|
self, |
|
input_features, |
|
audio_len=None, |
|
head_mask=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
): |
|
expected_seq_length = ( |
|
self.config.max_source_positions |
|
* self.conv1.stride[0] |
|
* self.conv2.stride[0] |
|
) |
|
if input_features.shape[-1] > expected_seq_length: |
|
raise ValueError( |
|
f"Whisper expects the mel input features to be of length {expected_seq_length} or less, but found {input_features.shape[-1]}. Make sure to pad the input mel features to {expected_seq_length}." |
|
) |
|
|
|
output_attentions = ( |
|
output_attentions |
|
if output_attentions is not None |
|
else self.config.output_attentions |
|
) |
|
output_hidden_states = ( |
|
output_hidden_states |
|
if output_hidden_states is not None |
|
else self.config.output_hidden_states |
|
) |
|
return_dict = ( |
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
) |
|
inputs_embeds = nn.functional.gelu(self.conv1(input_features)) |
|
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) |
|
|
|
inputs_embeds = inputs_embeds.permute(0, 2, 1) |
|
embed_pos = self.embed_positions.weight[: inputs_embeds.size(-2)] |
|
|
|
hidden_states = inputs_embeds + embed_pos |
|
hidden_states = nn.functional.dropout( |
|
hidden_states, p=self.dropout, training=self.training |
|
) |
|
|
|
encoder_states = () if output_hidden_states else None |
|
all_attentions = () if output_attentions else None |
|
|
|
attention_mask = None |
|
if audio_len != None: |
|
audio_feature_len = self._get_feat_extract_output_lengths(audio_len) |
|
batch_size = hidden_states.shape[0] |
|
max_seq_len = hidden_states.shape[1] |
|
attention_mask = ( |
|
torch.arange(max_seq_len, device=hidden_states.device)[None, :] |
|
.expand(batch_size, -1) |
|
.lt(audio_feature_len.view(batch_size, 1)) |
|
) |
|
attention_mask = self.get_extended_attention_mask( |
|
attention_mask, |
|
None, |
|
device=hidden_states.device, |
|
dtype=hidden_states.dtype, |
|
) |
|
|
|
|
|
if head_mask is not None: |
|
assert head_mask.size()[0] == ( |
|
len(self.layers) |
|
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}." |
|
|
|
for idx, encoder_layer in enumerate(self.layers): |
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
to_drop = False |
|
if self.training: |
|
dropout_probability = torch.rand([]) |
|
if dropout_probability < self.layerdrop: |
|
to_drop = True |
|
|
|
if to_drop: |
|
layer_outputs = (None, None) |
|
else: |
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
encoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
(head_mask[idx] if head_mask is not None else None), |
|
output_attentions, |
|
) |
|
else: |
|
layer_outputs = encoder_layer( |
|
hidden_states, |
|
attention_mask, |
|
layer_head_mask=( |
|
head_mask[idx] if head_mask is not None else None |
|
), |
|
output_attentions=output_attentions, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_attentions = all_attentions + (layer_outputs[1],) |
|
|
|
hidden_states = self.layer_norm(hidden_states) |
|
if output_hidden_states: |
|
encoder_states = encoder_states + (hidden_states,) |
|
|
|
if not return_dict: |
|
return tuple( |
|
v |
|
for v in [hidden_states, encoder_states, all_attentions] |
|
if v is not None |
|
) |
|
return BaseModelOutput( |
|
last_hidden_state=hidden_states, |
|
hidden_states=encoder_states, |
|
attentions=all_attentions, |
|
) |
|
|
|
OcisMllamaConfig.register_for_auto_class() |
|
OcisMllamaForConditionalGeneration.register_for_auto_class() |
|
AutoConfig.register("ocismllama", OcisMllamaConfig) |
|
AutoModel.register(OcisMllamaConfig, OcisMllamaForConditionalGeneration) |
|
|