llama-3-vision-alpha-hf / modeling_llamavision.py
qtnx's picture
Upload Llamavision
d8cc680 verified
raw
history blame
4.9 kB
import torch
import torch.nn as nn
from transformers import (
PreTrainedModel,
AutoModelForCausalLM,
AutoModel,
SiglipImageProcessor,
)
from .configuration_llamavision import LlamavisionConfig
class ProjectionModule(nn.Module):
def __init__(self, mm_hidden_size=1152, hidden_size=4096):
super(ProjectionModule, self).__init__()
# Directly set up the sequential model
self.model = nn.Sequential(
nn.Linear(mm_hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, x):
return self.model(x)
class Llamavision(PreTrainedModel):
config_class = LlamavisionConfig
def __init__(self, config):
super().__init__(config)
self.text_model = AutoModelForCausalLM.from_config(config.text_config)
self.vision_model = AutoModel.from_config(config.vision_config)
self.processor = SiglipImageProcessor()
self.mm_projector = ProjectionModule()
@property
def device(self):
return self.text_model.device
def tokenizer_image_token(
self, prompt, tokenizer, image_token_index=-200, return_tensors=None
):
prompt_chunks = [
tokenizer(chunk).input_ids for chunk in prompt.split("<image>")
]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if (
len(prompt_chunks) > 0
and len(prompt_chunks[0]) > 0
and prompt_chunks[0][0] == tokenizer.bos_token_id
):
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
return torch.tensor(input_ids, dtype=torch.long)
def process_tensors(self, input_ids, image_features, embedding_layer):
# Find the index of -200 in input_ids
split_index = (input_ids == -200).nonzero(as_tuple=True)[1][0]
# Split the input_ids at the index found, excluding -200
input_ids_1 = input_ids[:, :split_index]
input_ids_2 = input_ids[:, split_index + 1 :]
# Convert input_ids to embeddings
embeddings_1 = embedding_layer(input_ids_1)
embeddings_2 = embedding_layer(input_ids_2)
device = image_features.device
token_embeddings_part1 = embeddings_1.to(device)
token_embeddings_part2 = embeddings_2.to(device)
# Concatenate the token embeddings and image features
concatenated_embeddings = torch.cat(
[token_embeddings_part1, image_features, token_embeddings_part2], dim=1
)
# Create the corrected attention mask
attention_mask = torch.ones(
concatenated_embeddings.shape[:2], dtype=torch.long, device=device
)
return concatenated_embeddings, attention_mask
def answer_question(self, image, question, tokenizer, **kwargs):
question = "<image>" + question
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
input_ids = (
self.tokenizer_image_token(prompt, tokenizer, -200, return_tensors="pt")
.unsqueeze(0)
.to(self.device)
)
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
with torch.inference_mode():
image_inputs = self.processor(
images=[image],
return_tensors="pt",
do_resize=True,
size={"height": 384, "width": 384},
)
image_inputs = image_inputs["pixel_values"].to(
device=self.device, dtype=self.dtype
)
image_forward_outs = self.vision_model(
image_inputs,
output_hidden_states=True,
)
image_features = image_forward_outs.hidden_states[-2]
projected_embeddings = self.mm_projector(image_features).to(self.device)
embedding_layer = self.text_model.get_input_embeddings()
# text_embeddings = embedding_layer(input_ids)
new_embeds, attn_mask = self.process_tensors(
input_ids, projected_embeddings, embedding_layer
)
attn_mask = attn_mask.to(self.device)
new_embeds = new_embeds.to(self.device)
answer = self.text_model.generate(
inputs_embeds=new_embeds,
attention_mask=attn_mask,
eos_token_id=terminators,
temperature=0.2,
do_sample=True,
**kwargs,
)[0]
return answer