Spaces:
Build error
Build error
import logging | |
import warnings | |
import torch | |
import torch.nn as nn | |
from dataclasses import dataclass, field | |
from typing import Optional, Dict, Sequence, Union, List, Tuple, Any | |
from transformers import ( | |
LlamaForCausalLM, | |
Blip2PreTrainedModel, | |
Blip2VisionModel, | |
Blip2Config, | |
Blip2QFormerModel, | |
GenerationConfig, | |
) | |
from transformers.utils import ModelOutput | |
warnings.filterwarnings('ignore') | |
logger = logging.getLogger(__name__) | |
class Blip2ForConditionalGenerationModelOutput(ModelOutput): | |
""" | |
Class defining the outputs of [`Blip2ForConditionalGeneration`]. | |
Args: | |
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): | |
Language modeling loss from the language model. | |
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): | |
Prediction scores of the language modeling head of the language model. | |
vision_outputs (`BaseModelOutputWithPooling`): | |
Outputs of the vision encoder. | |
qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`): | |
Outputs of the Q-Former (Querying Transformer). | |
language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`): | |
Outputs of the language model. | |
""" | |
loss: Optional[Tuple[torch.FloatTensor]] = None | |
logits: Optional[Tuple[torch.FloatTensor]] = None | |
vision_outputs: Optional[torch.FloatTensor] = None | |
qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None | |
language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None | |
def to_tuple(self) -> Tuple[Any]: | |
return tuple( | |
self[k] | |
if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"] | |
else getattr(self, k).to_tuple() | |
for k in self.keys() | |
) | |
class Blip2LlaMAForConditionalGeneration(Blip2PreTrainedModel): | |
config_class = Blip2Config | |
main_input_name = "pixel_values" | |
def __init__(self, config: Blip2Config): | |
super().__init__(config) | |
self.vision_model = Blip2VisionModel(config.vision_config) | |
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size)) | |
self.qformer = Blip2QFormerModel(config.qformer_config) | |
language_model = LlamaForCausalLM(config.text_config) | |
self.language_model = language_model | |
self.language_projection = nn.Linear(config.qformer_config.hidden_size, language_model.config.hidden_size) | |
self.config.hidden_size = config.text_config.hidden_size | |
self.num_queries = config.num_query_tokens | |
self.offset = 5 | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.language_model.get_input_embeddings() | |
def set_input_embeddings(self, value): | |
self.language_model.set_input_embeddings(value) | |
def set_output_embeddings(self, new_embeddings): | |
self.language_model.set_output_embeddings(new_embeddings) | |
def get_output_embeddings(self) -> nn.Module: | |
return self.language_model.get_output_embeddings() | |
def get_encoder(self): | |
return self.language_model.get_encoder() | |
def get_decoder(self): | |
return self.language_model.get_decoder() | |
def extract_feature( | |
self, | |
pixel_values: torch.FloatTensor, | |
): | |
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state | |
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_outputs = self.qformer( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_attention_mask, | |
return_dict=True, | |
) | |
query_output = query_outputs.last_hidden_state | |
language_model_inputs = self.language_projection(query_output) | |
return language_model_inputs | |
def _tie_weights(self): | |
if not self.config.use_decoder_only_language_model: | |
self.language_model.encoder.embed_tokens = self.language_model.shared | |
self.language_model.decoder.embed_tokens = self.language_model.shared | |
def _preprocess_accelerate(self): | |
r""" | |
Some pre-processing hacks to make the model `accelerate` compatible. Check | |
https://github.com/huggingface/transformers/pull/21707 for more details. | |
""" | |
hf_device_map = self.hf_device_map | |
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1: | |
# warn users about unexpected behavior when using multi-GPU + BLIP-2 + `accelerate`. | |
logger.warning( | |
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script" | |
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`." | |
" Please pass a `device_map` that contains `language_model` to remove this warning." | |
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for", | |
" more details on creating a `device_map` for large models.", | |
) | |
if hasattr(self.language_model, "_hf_hook"): | |
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility | |
def forward( | |
self, | |
pixel_values: torch.FloatTensor, | |
input_ids: torch.FloatTensor, | |
attention_mask: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
labels: Optional[torch.LongTensor] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, Blip2ForConditionalGenerationModelOutput]: | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# step 1: forward the images through the vision encoder, | |
# to get image embeddings of shape (batch_size, seq_len, hidden_size) | |
vision_outputs = self.vision_model( | |
pixel_values=pixel_values, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
image_embeds = vision_outputs[0] | |
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention | |
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_outputs = self.qformer( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
query_output = query_outputs[0] | |
# step 3: use the language model, conditioned on the query outputs and the prompt | |
language_model_inputs = self.language_projection(query_output) | |
assert language_model_inputs.shape[1] == self.num_queries | |
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) | |
# Human: <img><IMAGE></img>. Give the describe Assistant: | |
# position of <image>: [offset: offset+num_queries] | |
inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids) | |
outputs = self.language_model( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
logits = outputs.logits if return_dict else outputs[0] | |
loss = None | |
# we compute the loss here since we need to take into account the sequence length of the query embeds | |
if labels is not None: | |
logits = logits[:, -labels.size(1):, :] | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :].contiguous() | |
shift_labels = labels[..., 1:].contiguous().to(logits.device).to(torch.long) | |
# Flatten the tokens | |
loss_fct = nn.CrossEntropyLoss(reduction="mean") | |
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1)) | |
if not return_dict: | |
output = (logits, vision_outputs, query_outputs, outputs) | |
return ((loss,) + output) if loss is not None else output | |
return Blip2ForConditionalGenerationModelOutput( | |
loss=loss, | |
logits=logits, | |
vision_outputs=vision_outputs, | |
qformer_outputs=query_outputs, | |
language_model_outputs=outputs, | |
) | |
def generate( | |
self, | |
pixel_values: torch.FloatTensor, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.LongTensor] = None, | |
language_model_inputs: Optional[torch.FloatTensor] = None, | |
generation_config: Optional[GenerationConfig] = None, | |
**generate_kwargs, | |
) -> torch.LongTensor: | |
""" | |
Overrides `generate` function to be able to use the model as a conditional generator. | |
Args: | |
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)): | |
Input images to be processed. | |
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): | |
The sequence used as a prompt for the generation. | |
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*): | |
Mask to avoid performing attention on padding token indices | |
generation_config (`~generation.GenerationConfig`, *optional*): | |
The generation configuration to be used as base parametrization for the generation call. `**kwargs` | |
passed to generate matching the attributes of `generation_config` will override them. If | |
`generation_config` is not provided, the default will be used, which had the following loading | |
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model | |
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s | |
default values, whose documentation should be checked to parameterize generation. | |
Returns: | |
captions (list): A list of strings of length batch_size * num_captions. | |
""" | |
if hasattr(self, "hf_device_map"): | |
# preprocess for `accelerate` | |
self._preprocess_accelerate() | |
if language_model_inputs is None: | |
batch_size = pixel_values.shape[0] | |
image_embeds = self.vision_model(pixel_values, return_dict=True).last_hidden_state | |
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device) | |
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) | |
query_outputs = self.qformer( | |
query_embeds=query_tokens, | |
encoder_hidden_states=image_embeds, | |
encoder_attention_mask=image_attention_mask, | |
return_dict=True, | |
) | |
query_output = query_outputs.last_hidden_state | |
language_model_inputs = self.language_projection(query_output) | |
assert language_model_inputs.shape[1] == self.num_queries | |
if input_ids is None: | |
input_ids = ( | |
torch.LongTensor([[self.config.text_config.bos_token_id]]) | |
.repeat(batch_size, 1) | |
.to(image_embeds.device) | |
) | |
if attention_mask is None: | |
attention_mask = torch.ones_like(input_ids) | |
inputs_embeds = self.language_model.get_input_embeddings()(input_ids) | |
# position of <image>: [offset: offset+num_queries] | |
inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs | |
outputs = self.language_model.generate( | |
inputs_embeds=inputs_embeds, | |
attention_mask=attention_mask, | |
generation_config=generation_config, | |
**generate_kwargs, | |
) | |
return outputs | |