|
import torch |
|
import torch.nn as nn |
|
from transformers import ( |
|
PreTrainedModel, |
|
AutoModelForCausalLM, |
|
AutoModel, |
|
SiglipImageProcessor, |
|
) |
|
from .configuration_llamavision import LlamavisionConfig |
|
|
|
|
|
class ProjectionModule(nn.Module): |
|
def __init__(self, mm_hidden_size=1152, hidden_size=4096): |
|
super(ProjectionModule, self).__init__() |
|
|
|
|
|
self.model = nn.Sequential( |
|
nn.Linear(mm_hidden_size, hidden_size), |
|
nn.GELU(), |
|
nn.Linear(hidden_size, hidden_size), |
|
) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class Llamavision(PreTrainedModel): |
|
config_class = LlamavisionConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.text_model = AutoModelForCausalLM.from_config(config.text_config) |
|
self.vision_model = AutoModel.from_config(config.vision_config) |
|
self.processor = SiglipImageProcessor() |
|
self.mm_projector = ProjectionModule() |
|
|
|
@property |
|
def device(self): |
|
return self.text_model.device |
|
|
|
def tokenizer_image_token( |
|
self, prompt, tokenizer, image_token_index=-200, return_tensors=None |
|
): |
|
prompt_chunks = [ |
|
tokenizer(chunk).input_ids for chunk in prompt.split("<image>") |
|
] |
|
|
|
def insert_separator(X, sep): |
|
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] |
|
|
|
input_ids = [] |
|
offset = 0 |
|
if ( |
|
len(prompt_chunks) > 0 |
|
and len(prompt_chunks[0]) > 0 |
|
and prompt_chunks[0][0] == tokenizer.bos_token_id |
|
): |
|
offset = 1 |
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
|
input_ids.extend(x[offset:]) |
|
|
|
return torch.tensor(input_ids, dtype=torch.long) |
|
|
|
def process_tensors(self, input_ids, image_features, embedding_layer): |
|
|
|
split_index = (input_ids == -200).nonzero(as_tuple=True)[1][0] |
|
|
|
|
|
input_ids_1 = input_ids[:, :split_index] |
|
input_ids_2 = input_ids[:, split_index + 1 :] |
|
|
|
|
|
embeddings_1 = embedding_layer(input_ids_1) |
|
embeddings_2 = embedding_layer(input_ids_2) |
|
|
|
device = image_features.device |
|
token_embeddings_part1 = embeddings_1.to(device) |
|
token_embeddings_part2 = embeddings_2.to(device) |
|
|
|
|
|
concatenated_embeddings = torch.cat( |
|
[token_embeddings_part1, image_features, token_embeddings_part2], dim=1 |
|
) |
|
|
|
|
|
attention_mask = torch.ones( |
|
concatenated_embeddings.shape[:2], dtype=torch.long, device=device |
|
) |
|
return concatenated_embeddings, attention_mask |
|
|
|
def answer_question(self, image, question, tokenizer, **kwargs): |
|
question = "<image>" + question |
|
|
|
prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" |
|
|
|
input_ids = ( |
|
self.tokenizer_image_token(prompt, tokenizer, -200, return_tensors="pt") |
|
.unsqueeze(0) |
|
.to(self.device) |
|
) |
|
terminators = [ |
|
tokenizer.eos_token_id, |
|
tokenizer.convert_tokens_to_ids("<|eot_id|>"), |
|
] |
|
with torch.inference_mode(): |
|
image_inputs = self.processor( |
|
images=[image], |
|
return_tensors="pt", |
|
do_resize=True, |
|
size={"height": 384, "width": 384}, |
|
) |
|
|
|
image_inputs = image_inputs["pixel_values"].to( |
|
device=self.device, dtype=self.dtype |
|
) |
|
|
|
image_forward_outs = self.vision_model( |
|
image_inputs, |
|
output_hidden_states=True, |
|
) |
|
|
|
image_features = image_forward_outs.hidden_states[-2] |
|
|
|
projected_embeddings = self.mm_projector(image_features).to(self.device) |
|
|
|
embedding_layer = self.text_model.get_input_embeddings() |
|
|
|
|
|
new_embeds, attn_mask = self.process_tensors( |
|
input_ids, projected_embeddings, embedding_layer |
|
) |
|
|
|
attn_mask = attn_mask.to(self.device) |
|
new_embeds = new_embeds.to(self.device) |
|
answer = self.text_model.generate( |
|
inputs_embeds=new_embeds, |
|
attention_mask=attn_mask, |
|
eos_token_id=terminators, |
|
temperature=0.2, |
|
do_sample=True, |
|
**kwargs, |
|
)[0] |
|
|
|
return answer |
|
|