Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/transformers
/models
/instructblipvideo
/diff_instructblipvideo.py
# coding=utf-8 | |
# Copyright 2024 HuggingFace Inc. team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import dataclass | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.utils.checkpoint | |
from torch.nn import CrossEntropyLoss | |
from transformers.models.instructblip.configuration_instructblip import ( | |
InstructBlipConfig, | |
InstructBlipQFormerConfig, | |
InstructBlipVisionConfig, | |
) | |
from transformers.models.instructblip.modeling_instructblip import ( | |
InstructBlipAttention, | |
InstructBlipEncoder, | |
InstructBlipEncoderLayer, | |
InstructBlipForConditionalGeneration, | |
InstructBlipForConditionalGenerationModelOutput, | |
InstructBlipMLP, | |
InstructBlipPreTrainedModel, | |
InstructBlipQFormerAttention, | |
InstructBlipQFormerEmbeddings, | |
InstructBlipQFormerEncoder, | |
InstructBlipQFormerIntermediate, | |
InstructBlipQFormerLayer, | |
InstructBlipQFormerModel, | |
InstructBlipQFormerOutput, | |
InstructBlipQFormerSelfOutput, | |
InstructBlipVisionEmbeddings, | |
InstructBlipVisionModel, | |
) | |
from ...utils import logging | |
logger = logging.get_logger(__name__) | |
class InstructBlipVideoVisionConfig(InstructBlipVisionConfig): | |
pass | |
class InstructBlipVideoQFormerConfig(InstructBlipQFormerConfig): | |
pass | |
class InstructBlipVideoConfig(InstructBlipConfig): | |
pass | |
class InstructBlipVideoForConditionalGenerationModelOutput(InstructBlipForConditionalGenerationModelOutput): | |
pass | |
class InstructBlipVideoVisionEmbeddings(InstructBlipVisionEmbeddings): | |
pass | |
class InstructBlipVideoAttention(InstructBlipAttention): | |
pass | |
class InstructBlipVideoMLP(InstructBlipMLP): | |
pass | |
class InstructBlipVideoEncoderLayer(InstructBlipEncoderLayer): | |
pass | |
class InstructBlipVideoPreTrainedModel(InstructBlipPreTrainedModel): | |
pass | |
class InstructBlipVideoEncoder(InstructBlipEncoder): | |
pass | |
class InstructBlipVideoVisionModel(InstructBlipVisionModel): | |
pass | |
class InstructBlipVideoQFormerSelfOutput(InstructBlipQFormerSelfOutput): | |
pass | |
class InstructBlipVideoQFormerAttention(InstructBlipQFormerAttention): | |
pass | |
class InstructBlipVideoQFormerIntermediate(InstructBlipQFormerIntermediate): | |
pass | |
class InstructBlipVideoQFormerOutput(InstructBlipQFormerOutput): | |
pass | |
class InstructBlipVideoQFormerLayer(InstructBlipQFormerLayer): | |
pass | |
class InstructBlipVideoQFormerEncoder(InstructBlipQFormerEncoder): | |
pass | |
class InstructBlipVideoQFormerEmbeddings(InstructBlipQFormerEmbeddings): | |
pass | |
class InstructBlipVideoQFormerModel(InstructBlipQFormerModel): | |
pass | |
class InstructBlipVideoForConditionalGeneration(InstructBlipForConditionalGeneration): | |
def forward( | |
self, | |
pixel_values: torch.FloatTensor, | |
qformer_input_ids: torch.FloatTensor, | |
qformer_attention_mask: Optional[torch.LongTensor] = None, | |
input_ids: Optional[torch.FloatTensor] = None, | |
attention_mask: Optional[torch.LongTensor] = None, | |
decoder_input_ids: Optional[torch.LongTensor] = None, | |
decoder_attention_mask: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
labels: Optional[torch.LongTensor] = None, | |
return_dict: Optional[bool] = None, | |
interpolate_pos_encoding: bool = False, | |
) -> Union[Tuple, InstructBlipVideoForConditionalGenerationModelOutput]: | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): | |
Labels for computing the language modeling loss. Indices should be in `[-100, 0, ..., config.vocab_size - | |
1]`. All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ..., | |
config.vocab_size]` | |
Returns: | |
Examples: | |
```python | |
>>> from transformers import InstructBlipVideoProcessor, InstructBlipVideoForConditionalGeneration | |
>>> import torch | |
>>> from huggingface_hub import hf_hub_download | |
>>> import av | |
>>> import numpy as np | |
>>> def read_video_pyav(container, indices): | |
... ''' | |
... Decode the video with PyAV decoder. | |
... Args: | |
... container (`av.container.input.InputContainer`): PyAV container. | |
... indices (`List[int]`): List of frame indices to decode. | |
... Returns: | |
... result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3). | |
... ''' | |
... frames = [] | |
... container.seek(0) | |
... start_index = indices[0] | |
... end_index = indices[-1] | |
... for i, frame in enumerate(container.decode(video=0)): | |
... if i > end_index: | |
... break | |
... if i >= start_index and i in indices: | |
... frames.append(frame) | |
... return np.stack([x.to_ndarray(format="rgb24") for x in frames]) | |
>>> model = InstructBlipVideoForConditionalGeneration.from_pretrained("Salesforce/instructblip-vicuna-7b", device_map="auto") | |
>>> processor = InstructBlipVideoProcessor.from_pretrained("Salesforce/instructblip-vicuna-7b") | |
>>> file_path = hf_hub_download( | |
... repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset" | |
... ) | |
>>> container = av.open(file_path) | |
>>> # sample uniformly 4 frames from the videWhy is this video funny?o | |
>>> total_frames = container.streams.video[0].frames | |
>>> indices = np.arange(0, total_frames, total_frames / 4).astype(int) | |
>>> clip = read_video_pyav(container, indices) | |
>>> prompt = "What is happening in the video?" | |
>>> inputs = processor(text=prompt, images=clip, return_tensors="pt").to(model.device) | |
>>> outputs = model.generate( | |
... **inputs, | |
... do_sample=False, | |
... num_beams=5, | |
... max_length=256, | |
... repetition_penalty=1.5, | |
... length_penalty=1.0, | |
... ) | |
>>> generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip() | |
>>> print(generated_text) | |
"A person is eating a bowl of pasta, and they are using a fork to eat it. The person is sitting at a table, and the plate of pasta is on the table in front" | |
```""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# step 1: forward the images through the vision encoder, | |
# we process in a batched way, later unbatch it back (video has frames=4 always) | |
batch_size, frames, channel, height, width = pixel_values.shape | |
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width) | |
vision_outputs = self.vision_model( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
interpolate_pos_encoding=interpolate_pos_encoding, | |
) | |
image_embeds = vision_outputs[0] | |
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention | |
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) | |
# difference with BLIP-2 here: we also feed the instruction prompt to the Q-Former | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) | |
if qformer_attention_mask is None: | |
qformer_attention_mask = torch.ones_like(qformer_input_ids) | |
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0) | |
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0) | |
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) | |
query_outputs = self.qformer( | |
input_ids=qformer_input_ids, | |
attention_mask=qformer_attention_mask, | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
query_output = query_outputs[0][:, : query_tokens.size(1), :] | |
# step 3: use the language model, conditioned on the query outputs and the prompt | |
language_model_inputs = self.language_projection(query_output) | |
# unbatch inputs back, each video-frame gets `num_query_tokens` seq length | |
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) | |
language_model_attention_mask = torch.ones( | |
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device | |
) | |
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) | |
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids) | |
attention_mask = torch.cat([language_model_attention_mask.to(attention_mask.device), attention_mask], dim=1) | |
if self.config.use_decoder_only_language_model: | |
outputs = self.language_model( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
logits = outputs.logits if return_dict else outputs[0] | |
loss = None | |
# we compute the loss here since we need to take into account the sequence length of the query embeds | |
if labels is not None: | |
labels = labels.to(logits.device) | |
logits = logits[:, -labels.size(1) :, :] | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous().to(logits.device) | |
# Flatten the tokens | |
loss_fct = CrossEntropyLoss(reduction="mean") | |
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) | |
else: | |
outputs = self.language_model( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
decoder_input_ids=decoder_input_ids, | |
decoder_attention_mask=decoder_attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
labels=labels, | |
) | |
loss = outputs.loss if return_dict else outputs[0] | |
logits = outputs.logits if return_dict else outputs[1] | |
if not return_dict: | |
output = (logits, vision_outputs, query_outputs, outputs) | |
return ((loss,) + output) if loss is not None else output | |
return InstructBlipVideoForConditionalGenerationModelOutput( | |
loss=loss, | |
logits=logits, | |
vision_outputs=vision_outputs, | |
qformer_outputs=query_outputs, | |
language_model_outputs=outputs, | |
) | |
def generate( | |
self, | |
pixel_values: torch.FloatTensor, | |
qformer_input_ids: Optional[torch.LongTensor] = None, | |
qformer_attention_mask: Optional[torch.LongTensor] = None, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.LongTensor] = None, | |
interpolate_pos_encoding: bool = False, | |
**generate_kwargs, | |
) -> torch.LongTensor: | |
""" | |
Overrides `generate` function to be able to use the model as a conditional generator. | |
Args: | |
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width) or | |
(batch_size, num_frames, num_channels, height, width)): Input images or videos to be processed. | |
qformer_input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): | |
The sequence used as a prompt to be fed to the Q-Former module. | |
qformer_attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): | |
Mask to avoid performing attention on padding token indices. | |
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): | |
The sequence used as a prompt for the generation. | |
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): | |
Mask to avoid performing attention on padding token indices. | |
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`): | |
Whether to interpolate the positional encoding of the image embeddings. | |
Returns: | |
captions (list): A list of strings of length batch_size * num_captions. | |
""" | |
if hasattr(self, "hf_device_map"): | |
# preprocess for `accelerate` | |
self._preprocess_accelerate() | |
# we process in a batched way, later unbatch it back (video has frames=4) | |
batch_size, frames, channel, height, width = pixel_values.shape | |
pixel_values = pixel_values.reshape(batch_size * frames, channel, height, width) | |
image_embeds = self.vision_model( | |
pixel_values, | |
return_dict=True, | |
interpolate_pos_encoding=interpolate_pos_encoding, | |
).last_hidden_state | |
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_attention_mask = torch.ones(query_tokens.size()[:-1], dtype=torch.long, device=image_embeds.device) | |
if qformer_attention_mask is None: | |
qformer_attention_mask = torch.ones_like(qformer_input_ids) | |
qformer_input_ids = qformer_input_ids.repeat_interleave(frames, dim=0) | |
qformer_attention_mask = qformer_attention_mask.repeat_interleave(frames, dim=0) | |
qformer_attention_mask = torch.cat([query_attention_mask, qformer_attention_mask], dim=1) | |
query_outputs = self.qformer( | |
input_ids=qformer_input_ids, | |
attention_mask=qformer_attention_mask, | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_attention_mask, | |
return_dict=True, | |
) | |
query_output = query_outputs.last_hidden_state[:, : query_tokens.size(1), :] | |
language_model_inputs = self.language_projection(query_output) | |
# unbatch the embeddings back by moving frames to seq-len | |
language_model_inputs = language_model_inputs.reshape(batch_size, self.config.num_query_tokens * frames, -1) | |
language_attention_mask = torch.ones( | |
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device | |
) | |
if input_ids is None: | |
input_ids = ( | |
torch.LongTensor([[self.config.text_config.bos_token_id]]) | |
.repeat(batch_size, 1) | |
.to(image_embeds.device) | |
) | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids) | |
attention_mask = torch.cat([language_attention_mask, attention_mask.to(language_attention_mask.device)], dim=1) | |
# concatenate query embeddings with prompt embeddings | |
inputs_embeds = self.get_input_embeddings()(input_ids) | |
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1) | |
# add image_embeds length to max_length, so that the final max_length in counted only on token embeds | |
# -1 is to account for the prepended BOS after `generate.` | |
if not self.language_model.config.is_encoder_decoder: | |
generate_kwargs["max_length"] = generate_kwargs.get("max_length", 20) + language_model_inputs.shape[1] - 1 | |
generate_kwargs["min_length"] = generate_kwargs.get("min_length", 0) + language_model_inputs.shape[1] | |
outputs = self.language_model.generate( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
**generate_kwargs, | |
) | |
# this is a temporary workaround to be consistent with other generation models and | |
# have BOS as the first token, even though under the hood we are calling LM with embeds | |
if not self.language_model.config.is_encoder_decoder: | |
# the InstructBLIP authors used inconsistent tokenizer/model files during training, | |
# with the tokenizer's bos token being set to </s> which has ID=2, | |
# whereas the model's text config has bos token id = 0 | |
bos_token_id = ( | |
2 | |
if self.config.text_config.architectures[0] == "LLaMAForCausalLM" | |
else self.config.text_config.bos_token_id | |
) | |
bos_tokens = torch.LongTensor([[bos_token_id]]).repeat(batch_size, 1).to(image_embeds.device) | |
if not isinstance(outputs, torch.Tensor): | |
outputs.sequences = torch.cat([bos_tokens, outputs.sequences], dim=-1) | |
else: | |
outputs = torch.cat([bos_tokens, outputs], dim=-1) | |
return outputs | |