visfocus-base-docvqa / modeling_visfocus.py
ofirab's picture
Upload model
3d1911b verified
import torch
from torch import nn
from torch.nn import LayerNorm, CrossEntropyLoss, L1Loss
from torch.nn import functional as F
from transformers import PreTrainedModel, AutoTokenizer, GenerationMixin, logging
from transformers.models.t5.modeling_t5 import T5Stack
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from transformers.file_utils import ModelOutput
from timm.models.layers import trunc_normal_
from typing import Any, Dict, Optional, Tuple
import warnings
import random
import yaml
import copy
from easydict import EasyDict
from configuration_visfocus import VisFocusConfig
from modeling_vilmaswin import VilmaSwinTransformerV2
from image_processing_visfocus import VisFocusImageProcessor
from processing_visfocus import VisFocusProcessor
logger = logging.get_logger(__name__)
def get_vision_model(config):
vision_model = VilmaSwinTransformerV2(
img_size=config.image_size,
patch_size=config.patch_size,
in_chans=config.in_chans,
embed_dim=config.embed_dim,
depths=config.depths,
num_heads=config.num_heads,
window_size=config.window_size,
mlp_ratio=config.mlp_ratio,
qkv_bias=config.qkv_bias,
drop_rate=config.drop_rate,
drop_path_rate=config.drop_path_rate,
ape=config.ape,
patch_norm=config.patch_norm,
use_checkpoint=config.use_checkpoint,
pretrained_window_sizes=config.pretrained_window_sizes,
do_shift=config.do_shift,
vl_cross_attn_layers=config.vl_cross_attn_layers,
vl_alpha=config.vl_alpha,
lm_d_model=config.lm_d_model,
input_type=config.input_type,
vl_learned_ape=config.vl_learned_ape)
return vision_model
def load_vision_pretrained(configs, model):
logger.info("Loading vision model from %s", configs.model.vision_resume_from)
if configs.model.vision_resume_from.startswith("https"):
checkpoint = torch.hub.load_state_dict_from_url(
configs.model.vision_resume_from, map_location="cpu", check_hash=True
)
else:
checkpoint = torch.load(configs.model.vision_resume_from, map_location="cpu")
state_dict = checkpoint["model"]
if "swin" in configs.model.type:
# delete relative_position_index since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
for k in relative_position_index_keys:
del state_dict[k]
# delete relative_coords_table since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
for k in relative_position_index_keys:
del state_dict[k]
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
for k in attn_mask_keys:
del state_dict[k]
# bicubic interpolate relative_position_bias_table if not match
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
for k in relative_position_bias_table_keys:
relative_position_bias_table_pretrained = state_dict[k]
relative_position_bias_table_current = model.vision_model.state_dict()[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if nH1 != nH2:
logger.warning(f"Error in loading {k}, passing......")
else:
if L1 != L2:
# bicubic interpolate relative_position_bias_table if not match
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
mode='bicubic')
state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
# bicubic interpolate absolute_pos_embed if not match
absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
for k in absolute_pos_embed_keys:
# dpe
absolute_pos_embed_pretrained = state_dict[k]
absolute_pos_embed_current = model.vision_model.state_dict()[k]
_, L1, C1 = absolute_pos_embed_pretrained.size()
_, L2, C2 = absolute_pos_embed_current.size()
if C1 != C1:
logger.warning(f"Error in loading {k}, passing......")
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
state_dict[k] = absolute_pos_embed_pretrained_resized
if model.vision_model.patch_embed.proj.weight.shape != state_dict['patch_embed.proj.weight'].shape:
model.vision_model.input_type == 'flattened_patches'
logger.warning(f"PatchEmbed (patch_embed) was not loaded, because input_type is falttened_patches.")
del state_dict['patch_embed.proj.weight']
# import pdb;pdb.set_trace()
msg = model.vision_model.load_state_dict(state_dict, strict=False)
# do not print unnecessary (vl attn is not loaded now)
filtered_missing_keys = {k for k in msg.missing_keys
if 'vl_cross_attn_layers' not in k
or 'relative_position' not in k}
filtered_missing_keys.union({'relative_position' for k in msg.missing_keys
if 'relative_position' not in k})
# if len({k for k in msg.missing_keys if 'relative_' in k}) > 0:
# logger.warning(f'Relative position were not loaded')
# filtered_missing_keys.union()
logger.warning(f'Missing keys: {set(msg.missing_keys) - filtered_missing_keys}')
logger.warning(f'Unexpected keys: {msg.unexpected_keys}')
# logger.warning(msg)
logger.info("Loaded model successfully from %s", configs.model.vision_resume_from)
del checkpoint
torch.cuda.empty_cache()
class T5_Encoder(nn.Module):
def __init__(self, t5_variant='base', freeze=True):
from transformers import T5Tokenizer, T5Model
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(f'{t5_variant}')
model = T5Model.from_pretrained(f'{t5_variant}')
del model.decoder
self.encoder = model.encoder
if freeze:
for p in self.encoder.parameters():
p.requires_grad = False
def forward(self, input_ids):
encoder_outputs = self.encoder(
input_ids=input_ids,
return_dict=True,
)
return encoder_outputs[0]
class SpatialEmbeddings(nn.Module):
def __init__(self, config):
super().__init__()
self.x_position_embeddings = nn.Embedding(
config.max_2d_position_embeddings, config.hidden_size
)
self.y_position_embeddings = nn.Embedding(
config.max_2d_position_embeddings, config.hidden_size
)
self.h_position_embeddings = nn.Embedding(
config.max_2d_position_embeddings, config.hidden_size
)
self.w_position_embeddings = nn.Embedding(
config.max_2d_position_embeddings, config.hidden_size
)
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config
def forward(
self,
bbox,
):
seq_length = bbox.size(1)
left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0])
upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1])
right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2])
lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3])
h_position_embeddings = self.h_position_embeddings(
bbox[:, :, 3] - bbox[:, :, 1]
)
w_position_embeddings = self.w_position_embeddings(
bbox[:, :, 2] - bbox[:, :, 0]
)
embeddings = (
left_position_embeddings
+ upper_position_embeddings
+ right_position_embeddings
+ lower_position_embeddings
+ h_position_embeddings
+ w_position_embeddings
)
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class EmbedMatcher(nn.Module):
def __init__(self, input_dim, inner_dim, output_dim, dropout_rate=0.1):
super().__init__()
self.embedd_matcher = nn.Sequential(
nn.Linear(input_dim, inner_dim, bias=True),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(inner_dim, output_dim, bias=False),
nn.Dropout(dropout_rate)
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.embedd_matcher(x)
return x
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class VisFocusPreTrainedModel(PreTrainedModel, GenerationMixin):
config_class = VisFocusConfig
def __init__(self, config):
super().__init__(config.lm_config)
self.set_task_name('ocr')
self.model_arch = 'visfocus'
self.config = config
self.lm_config = config.lm_config
self.vision_config = config.vision_config
self.vision_model = get_vision_model(self.vision_config)
input_dim = self.vision_model.num_features
matcher = MATCHER_MAP[self.config.matcher_type]
# load T5 encoder and decoder
encoder_config = copy.deepcopy(self.lm_config)
encoder_config.is_decoder = False
encoder_config.use_cache = False
encoder_config.is_encoder_decoder = False
self.encoder = T5Stack(encoder_config)
decoder_config = copy.deepcopy(self.lm_config)
decoder_config.is_decoder = True
decoder_config.is_encoder_decoder = False
decoder_config.num_layers = self.lm_config.num_decoder_layers
self.decoder = T5Stack(decoder_config)
self.lm_head = nn.Linear(self.lm_config.d_model, self.lm_config.vocab_size, bias=False)
if hasattr(self.vision_model, 'last_ds'):
input_dim = self.vision_model.last_ds.norm.normalized_shape[0]
self.vision_embed_matcher = matcher(
input_dim,
config.lm_config.hidden_size,
config.lm_config.hidden_size,
config.hidden_dropout_prob
)
# losses
self.loss_fct = CrossEntropyLoss(ignore_index=-100)
self.init_weights()
if self.config.lora is not None:
self.apply_lora()
if self.config.vl_l1_loss:
self.vl_l1_loss_fct = L1Loss()
def encoder_decoder_forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
encoder_outputs=None,
past_key_values=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
r"""
https://huggingface.co/transformers/v4.5.1/_modules/transformers/modeling_t5.html#T5ForConditionalGeneration.forward
or https://huggingface.co/transformers/_modules/transformers/modeling_t5.html#T5ForConditionalGeneration.forward
"""
if "lm_labels" in kwargs:
warnings.warn(
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("lm_labels")
if "decoder_past_key_value_states" in kwargs:
warnings.warn(
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("decoder_past_key_value_states")
if "decoder_past_key_values" in kwargs:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("decoder_past_key_values")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
hidden_states = encoder_outputs[0]
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
# get decoder inputs from shifting lm labels to the right
decoder_input_ids = self._shift_right(labels)
# If decoding with past key value states, only the last tokens
# should be given as an input
if past_key_values is not None:
assert labels is None, "Decoder should not use cached key value states when training."
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = decoder_outputs[0]
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim ** -0.5)
lm_logits = self.lm_head(sequence_output)
loss = None
if labels is not None:
loss = self.loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
if self.config.vl_l1_loss:
labels_ = labels.clone()
labels_[labels_ == -100] = 0 # -> replace the ignore_index with the pad_token id to calculate the text target for the vl loss
with torch.no_grad():
target = self.encoder(input_ids=labels_).last_hidden_state
if target.shape[1] != hidden_states.shape[1]:
v_encoder_intrp = F.interpolate(hidden_states.permute(0,2,1), size=target.shape[1], mode='linear').permute(0,2,1)
vl_loss = (50 * self.vl_l1_loss_fct(v_encoder_intrp, target))
loss += vl_loss
if not return_dict:
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
if loss is not None:
output = ((loss,) + output)
return output
seq2seq_output = Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
return seq2seq_output
def forward(self,
input_ids=None,
bbox=None,
image=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
**kwargs):
# see https://huggingface.co/transformers/v2.10.0/_modules/transformers/modeling_t5.html#T5Model.forward
if not kwargs.get('encoder_outputs'):
_, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=None, image=image)
else:
# for generation mode
assert kwargs.get('decoder_input_ids') is not None
_ = vision_embeds = attention_mask = None
return self.encoder_decoder_forward(input_ids=None,
attention_mask=attention_mask,
encoder_outputs=kwargs.get('encoder_outputs'),
decoder_input_ids=kwargs.get('decoder_input_ids'),
decoder_attention_mask=None,
head_mask=head_mask,
decoder_head_mask=None,
past_key_values=kwargs.get('past_key_values'),
inputs_embeds=vision_embeds,
decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'),
labels=labels,
use_cache=True,
output_attentions=kwargs.get('output_attentions'),
output_hidden_states=kwargs.get('output_hidden_states'),
return_dict=kwargs.get('return_dict')
)
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
if kwargs.get('encoder_outputs') is not None:
return {'attention_mask': kwargs.get('attention_mask'),
'encoder_outputs': kwargs.get('encoder_outputs'),
'decoder_input_ids': input_ids,
'past_key_values': kwargs.get('past'),
}
else:
raise ValueError(
"Make sure that encoder_outputs is already computed when preapring inputs for generation. --y.x.")
def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None):
# text embedding
batch_size = image.shape[0]
if input_ids is not None:
text_embeds = self.shared(input_ids)
text_seq_length = text_embeds.shape[1]
else:
text_embeds = None
text_seq_length = 0
assert self.config.vision is not None
# vision embedding
vision_embeds = self.vision_model(image)
vision_embeds = self.vision_embed_matcher(vision_embeds)
vision_seq_length = vision_embeds.shape[1]
# add task token (e.g <OCR> for ocr)
vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length)
attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device)
return text_embeds, vision_embeds, attention_mask
def concat_task_token(self, embeds, text_seq_length=0):
# add task token (e.g <OCR> for ocr)
if self.task_name in self.task_token_ids.keys():
B = embeds.shape[0]
task_embeds = self.shared(self.task_token_ids[self.task_name])
text_seq_length += task_embeds.shape[0]
return torch.cat((embeds, task_embeds.repeat((B, 1, 1))), dim=1), text_seq_length
else:
# no such task token exists
return embeds, text_seq_length
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
input_name = 'inputs_embeds'
_, vision_embeds, attention_mask = self._prepare_encoder_inputs(image=model_kwargs['image'])
model_kwargs['attention_mask'] = attention_mask
inputs = vision_embeds
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs
def _prepare_encoder_decoder_kwargs_for_generation(
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
) -> Dict[str, Any]:
assert "encoder_outputs" not in model_kwargs
# 1. get encoder
encoder = self.get_encoder()
# 2. prepare encoder args and encoder kwargs from model kwargs
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
irrelevent_fields = ['input_ids', 'attention_mask', 'inputs_embeds', 'image', 'bbox', 'line_coordinates',
'adj', 'lm_labels', 'banned_token_ids', 'questions', 'answers', 'labels', 'task_name']
encoder_kwargs = {
argument: value
for argument, value in model_kwargs.items()
if not any(argument.startswith(p) for p in irrelevant_prefix) and argument not in irrelevent_fields
}
# 3. make sure that encoder returns `ModelOutput`
encoder_kwargs["return_dict"] = True
model_kwargs["encoder_outputs"]: ModelOutput = encoder(
input_ids=None, attention_mask=model_kwargs['attention_mask'],
inputs_embeds=inputs_tensor, **encoder_kwargs)
return model_kwargs
def set_task_name(self, task_name):
if task_name:
self.task_name = task_name
def get_trivial_mask(self, inp):
return torch.ones((inp.shape[:2]), dtype=torch.int32).to(self.device)
class VisFocusModelForLocalizedMaskedLanguageModeling(VisFocusPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.set_task_name('mpm')
self.text_embedder = T5_Encoder(self.vision_config.text_embedder, freeze=True)
def forward(self,
input_ids=None,
bbox=None,
image=None,
attention_mask=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
**kwargs):
if not kwargs.get('encoder_outputs'):
if self.task_name == 'ocr':
# NOTE: not supported yet
input_ids = None
if not hasattr(self, 'prompt_embeds'):
prompt = 'what is written in this document?'
prompt_ids = self.input_tokenizer.encode(prompt)
B = image.shape[0]
prompt_ids = torch.tensor(prompt_ids).expand(B, len(prompt_ids)).to(self.device)
setattr(self, 'prompt_embeds', self.text_embedder(prompt_ids).detach())
_, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=input_ids, image=image)
else:
# for generation mode
assert kwargs.get('decoder_input_ids') is not None
_ = vision_embeds = attention_mask = None
return self.encoder_decoder_forward(input_ids=None,
attention_mask=attention_mask,
encoder_outputs=kwargs.get('encoder_outputs'),
decoder_input_ids=kwargs.get('decoder_input_ids'),
decoder_attention_mask=None,
head_mask=head_mask,
decoder_head_mask=None,
past_key_values=kwargs.get('past_key_values'),
inputs_embeds=vision_embeds,
decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'),
labels=labels,
use_cache=True,
output_attentions=kwargs.get('output_attentions'),
output_hidden_states=kwargs.get('output_hidden_states'),
return_dict=kwargs.get('return_dict')
)
def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None):
batch_size = image.shape[0]
# if prompt is contant
if self.task_name == 'ocr':
assert input_ids is None
text_embeds = self.prompt_embeds
else:
assert input_ids is not None
if self.text_embedder == self.encoder:
with torch.no_grad():
text_embeds = self.encoder(input_ids).last_hidden_state
else:
text_embeds = self.text_embedder(input_ids)
text_embeds = text_embeds.detach()
text_seq_length = text_embeds.shape[1] if self.task_name == 'pm_vqa_concat' else 0
assert self.config.vision is not None
# vision embedding
vision_embeds = self.vision_model(image, context_prompts=text_embeds)
if self.vision_model.model_name in ["swin_v2"]:
vision_embeds = self.vision_embed_matcher(vision_embeds)
vision_seq_length = vision_embeds.shape[1]
# add task token (e.g <OCR> for ocr)
vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length=text_seq_length)
attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device)
return text_embeds, vision_embeds, attention_mask
def _prepare_model_inputs(
self,
inputs: Optional[torch.Tensor] = None,
bos_token_id: Optional[int] = None,
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
input_name = 'inputs_embeds'
_, vision_embeds, attention_mask = self._prepare_encoder_inputs(image=model_kwargs['image'], input_ids=model_kwargs['input_ids'])
model_kwargs['attention_mask'] = attention_mask
inputs = vision_embeds
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
return inputs, input_name, model_kwargs
class VisFocusModelForImageTextToText(VisFocusModelForLocalizedMaskedLanguageModeling):
def __init__(self, config):
super().__init__(config)
self.set_task_name('pm_vqa_concat')
def forward(self, questions=None, answers=None, image=None, labels=None, **kwargs):
if kwargs.get('encoder_outputs') is None:
text_embeds, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=questions['input_ids'], image=image)
inputs_embeds = torch.concat((text_embeds, vision_embeds), dim=1)
attention_mask = self.get_trivial_mask(inputs_embeds) # -> when different tokenizer is used for ViLMA/concat, need to re-calculate attn. mask
else:
# for generation mode (image encoding happens before)
assert kwargs.get('decoder_input_ids') is not None
assert kwargs.get('encoder_outputs') is not None
inputs_embeds = kwargs.get('encoder_outputs')
text_embeds = vision_embeds = attention_mask = None
return self.encoder_decoder_forward(input_ids=None,
attention_mask=attention_mask,
encoder_outputs=kwargs.get('encoder_outputs'),
decoder_input_ids=kwargs.get('decoder_input_ids'),
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
past_key_values=kwargs.get('past_key_values'),
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'),
labels=labels,
use_cache=True,
output_attentions=kwargs.get('output_attentions'),
output_hidden_states=kwargs.get('output_hidden_states'),
return_dict=kwargs.get('return_dict')
)
def _prepare_model_inputs(self, inputs=None, bos_token_id=None, model_kwargs=None ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
"""
This function extracts the model-specific `inputs` for generation.
"""
input_name = 'inputs_embeds'
text_embeds, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=model_kwargs['questions']['input_ids'], image=model_kwargs['image'])
model_kwargs['attention_mask'] = attention_mask
inputs_embeds = torch.concat((text_embeds, vision_embeds), dim=1)
inputs = inputs_embeds
# 4. if `inputs` is still None, try to create `input_ids` from BOS token
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
model_kwargs['attention_mask'] = self.get_trivial_mask(inputs)
return inputs, input_name, model_kwargs
def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None):
batch_size = image.shape[0]
assert input_ids is not None
if self.text_embedder == self.encoder:
with torch.no_grad():
text_embeds = self.encoder(input_ids).last_hidden_state
else:
text_embeds = self.text_embedder(input_ids)
text_embeds = text_embeds.detach()
text_seq_length = text_embeds.shape[1] if self.task_name == 'pm_vqa_concat' else 0
assert self.config.vision is not None
# vision embedding
vision_embeds = self.vision_model(image, context_prompts=text_embeds)
if self.vision_model.model_name in ["swin_v2"]:
vision_embeds = self.vision_embed_matcher(vision_embeds)
vision_seq_length = vision_embeds.shape[1]
# add task token (e.g <OCR> for ocr)
vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length=text_seq_length)
attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device)
text_embeds = self.shared(input_ids) # for concat, use direct the T5 nn.embeddings
return text_embeds, vision_embeds, attention_mask
def _to_cuda(sample, device=torch.device('cuda')):
if isinstance(sample, torch.Tensor):
return sample.to(device)
elif isinstance(sample, list):
return sample
else:
for k in sample.keys():
sample[k] = _to_cuda(sample[k], device)
return sample
def fetch_sample(ds, ds_for_vis):
idx = random.randint(50, 100)
for i in range(idx):
inputs = next(ds)
inputs_to_vis = next(ds_for_vis)
return inputs, inputs_to_vis
MATCHER_MAP = {
'default': EmbedMatcher,
}
# vqa
if __name__ == '__main__':
# load yaml
with open('configs/test_expts/vf_base_finetune_docvqa__v2_accum4_f32_V5__mpm_altConcat__vilma_concat_V1/vqa_model_args.yaml', 'r') as f:
model_args = EasyDict(yaml.safe_load(f))
DEVICE = 'cpu' # 'cpu'
## load pretrained if needed
last_ckpt = None # get_last_checkpoint(dirname(model_args.model_config_path))
##
# model = get_model_class(model_args, last_ckpt=last_ckpt)
cfg = VisFocusConfig.from_pretrained('configs/config.json')
cfg.push_to_hub('ofirab/visfocus-base-docvqa')
model = VisFocusModelForImageTextToText(cfg)
VisFocusConfig.register_for_auto_class()
VisFocusPreTrainedModel.register_for_auto_class("AutoModel")
VisFocusModelForImageTextToText.register_for_auto_class("AutoModelForImageTextToText")
model.push_to_hub('ofirab/visfocus-base-docvqa')
pr = VisFocusImageProcessor(is_train=False)
tokenizer = AutoTokenizer.from_pretrained('ofirab/visfocus-base-docvqa')
prr = VisFocusProcessor(pr, tokenizer)
model.to(DEVICE)