|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import LayerNorm, CrossEntropyLoss, L1Loss |
|
from torch.nn import functional as F |
|
|
|
from transformers import PreTrainedModel, AutoTokenizer, GenerationMixin, logging |
|
from transformers.models.t5.modeling_t5 import T5Stack |
|
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput |
|
from transformers.file_utils import ModelOutput |
|
|
|
from timm.models.layers import trunc_normal_ |
|
from typing import Any, Dict, Optional, Tuple |
|
import warnings |
|
import random |
|
import yaml |
|
import copy |
|
from easydict import EasyDict |
|
|
|
from configuration_visfocus import VisFocusConfig |
|
from modeling_vilmaswin import VilmaSwinTransformerV2 |
|
from image_processing_visfocus import VisFocusImageProcessor |
|
from processing_visfocus import VisFocusProcessor |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def get_vision_model(config): |
|
vision_model = VilmaSwinTransformerV2( |
|
img_size=config.image_size, |
|
patch_size=config.patch_size, |
|
in_chans=config.in_chans, |
|
embed_dim=config.embed_dim, |
|
depths=config.depths, |
|
num_heads=config.num_heads, |
|
window_size=config.window_size, |
|
mlp_ratio=config.mlp_ratio, |
|
qkv_bias=config.qkv_bias, |
|
drop_rate=config.drop_rate, |
|
drop_path_rate=config.drop_path_rate, |
|
ape=config.ape, |
|
patch_norm=config.patch_norm, |
|
use_checkpoint=config.use_checkpoint, |
|
pretrained_window_sizes=config.pretrained_window_sizes, |
|
do_shift=config.do_shift, |
|
vl_cross_attn_layers=config.vl_cross_attn_layers, |
|
vl_alpha=config.vl_alpha, |
|
lm_d_model=config.lm_d_model, |
|
input_type=config.input_type, |
|
vl_learned_ape=config.vl_learned_ape) |
|
return vision_model |
|
|
|
|
|
def load_vision_pretrained(configs, model): |
|
logger.info("Loading vision model from %s", configs.model.vision_resume_from) |
|
if configs.model.vision_resume_from.startswith("https"): |
|
checkpoint = torch.hub.load_state_dict_from_url( |
|
configs.model.vision_resume_from, map_location="cpu", check_hash=True |
|
) |
|
else: |
|
checkpoint = torch.load(configs.model.vision_resume_from, map_location="cpu") |
|
|
|
state_dict = checkpoint["model"] |
|
|
|
if "swin" in configs.model.type: |
|
|
|
relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] |
|
for k in relative_position_index_keys: |
|
del state_dict[k] |
|
|
|
|
|
relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] |
|
for k in relative_position_index_keys: |
|
del state_dict[k] |
|
|
|
|
|
attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] |
|
for k in attn_mask_keys: |
|
del state_dict[k] |
|
|
|
|
|
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] |
|
for k in relative_position_bias_table_keys: |
|
relative_position_bias_table_pretrained = state_dict[k] |
|
relative_position_bias_table_current = model.vision_model.state_dict()[k] |
|
L1, nH1 = relative_position_bias_table_pretrained.size() |
|
L2, nH2 = relative_position_bias_table_current.size() |
|
if nH1 != nH2: |
|
logger.warning(f"Error in loading {k}, passing......") |
|
else: |
|
if L1 != L2: |
|
|
|
S1 = int(L1 ** 0.5) |
|
S2 = int(L2 ** 0.5) |
|
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( |
|
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), |
|
mode='bicubic') |
|
state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) |
|
|
|
|
|
absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] |
|
for k in absolute_pos_embed_keys: |
|
|
|
absolute_pos_embed_pretrained = state_dict[k] |
|
absolute_pos_embed_current = model.vision_model.state_dict()[k] |
|
_, L1, C1 = absolute_pos_embed_pretrained.size() |
|
_, L2, C2 = absolute_pos_embed_current.size() |
|
if C1 != C1: |
|
logger.warning(f"Error in loading {k}, passing......") |
|
else: |
|
if L1 != L2: |
|
S1 = int(L1 ** 0.5) |
|
S2 = int(L2 ** 0.5) |
|
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) |
|
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) |
|
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( |
|
absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') |
|
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) |
|
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) |
|
state_dict[k] = absolute_pos_embed_pretrained_resized |
|
|
|
if model.vision_model.patch_embed.proj.weight.shape != state_dict['patch_embed.proj.weight'].shape: |
|
model.vision_model.input_type == 'flattened_patches' |
|
logger.warning(f"PatchEmbed (patch_embed) was not loaded, because input_type is falttened_patches.") |
|
del state_dict['patch_embed.proj.weight'] |
|
|
|
|
|
|
|
msg = model.vision_model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
filtered_missing_keys = {k for k in msg.missing_keys |
|
if 'vl_cross_attn_layers' not in k |
|
or 'relative_position' not in k} |
|
filtered_missing_keys.union({'relative_position' for k in msg.missing_keys |
|
if 'relative_position' not in k}) |
|
|
|
|
|
|
|
logger.warning(f'Missing keys: {set(msg.missing_keys) - filtered_missing_keys}') |
|
logger.warning(f'Unexpected keys: {msg.unexpected_keys}') |
|
|
|
|
|
|
|
logger.info("Loaded model successfully from %s", configs.model.vision_resume_from) |
|
|
|
del checkpoint |
|
torch.cuda.empty_cache() |
|
|
|
|
|
class T5_Encoder(nn.Module): |
|
def __init__(self, t5_variant='base', freeze=True): |
|
from transformers import T5Tokenizer, T5Model |
|
super().__init__() |
|
self.tokenizer = T5Tokenizer.from_pretrained(f'{t5_variant}') |
|
model = T5Model.from_pretrained(f'{t5_variant}') |
|
del model.decoder |
|
self.encoder = model.encoder |
|
if freeze: |
|
for p in self.encoder.parameters(): |
|
p.requires_grad = False |
|
|
|
def forward(self, input_ids): |
|
encoder_outputs = self.encoder( |
|
input_ids=input_ids, |
|
return_dict=True, |
|
) |
|
return encoder_outputs[0] |
|
|
|
|
|
class SpatialEmbeddings(nn.Module): |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
self.x_position_embeddings = nn.Embedding( |
|
config.max_2d_position_embeddings, config.hidden_size |
|
) |
|
self.y_position_embeddings = nn.Embedding( |
|
config.max_2d_position_embeddings, config.hidden_size |
|
) |
|
self.h_position_embeddings = nn.Embedding( |
|
config.max_2d_position_embeddings, config.hidden_size |
|
) |
|
self.w_position_embeddings = nn.Embedding( |
|
config.max_2d_position_embeddings, config.hidden_size |
|
) |
|
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
self.config = config |
|
|
|
def forward( |
|
self, |
|
bbox, |
|
): |
|
seq_length = bbox.size(1) |
|
|
|
left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) |
|
upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) |
|
right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) |
|
lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) |
|
h_position_embeddings = self.h_position_embeddings( |
|
bbox[:, :, 3] - bbox[:, :, 1] |
|
) |
|
w_position_embeddings = self.w_position_embeddings( |
|
bbox[:, :, 2] - bbox[:, :, 0] |
|
) |
|
embeddings = ( |
|
left_position_embeddings |
|
+ upper_position_embeddings |
|
+ right_position_embeddings |
|
+ lower_position_embeddings |
|
+ h_position_embeddings |
|
+ w_position_embeddings |
|
) |
|
|
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class EmbedMatcher(nn.Module): |
|
def __init__(self, input_dim, inner_dim, output_dim, dropout_rate=0.1): |
|
super().__init__() |
|
self.embedd_matcher = nn.Sequential( |
|
nn.Linear(input_dim, inner_dim, bias=True), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(inner_dim, output_dim, bias=False), |
|
nn.Dropout(dropout_rate) |
|
) |
|
|
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
trunc_normal_(m.weight, std=.02) |
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x): |
|
x = self.embedd_matcher(x) |
|
return x |
|
|
|
|
|
class MLP(nn.Module): |
|
""" Very simple multi-layer perceptron (also called FFN)""" |
|
|
|
def __init__(self, input_dim, hidden_dim, output_dim, num_layers): |
|
super().__init__() |
|
self.num_layers = num_layers |
|
h = [hidden_dim] * (num_layers - 1) |
|
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) |
|
|
|
def forward(self, x): |
|
for i, layer in enumerate(self.layers): |
|
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) |
|
return x |
|
|
|
|
|
class VisFocusPreTrainedModel(PreTrainedModel, GenerationMixin): |
|
config_class = VisFocusConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config.lm_config) |
|
self.set_task_name('ocr') |
|
self.model_arch = 'visfocus' |
|
self.config = config |
|
self.lm_config = config.lm_config |
|
self.vision_config = config.vision_config |
|
|
|
self.vision_model = get_vision_model(self.vision_config) |
|
|
|
input_dim = self.vision_model.num_features |
|
matcher = MATCHER_MAP[self.config.matcher_type] |
|
|
|
|
|
encoder_config = copy.deepcopy(self.lm_config) |
|
encoder_config.is_decoder = False |
|
encoder_config.use_cache = False |
|
encoder_config.is_encoder_decoder = False |
|
self.encoder = T5Stack(encoder_config) |
|
|
|
decoder_config = copy.deepcopy(self.lm_config) |
|
decoder_config.is_decoder = True |
|
decoder_config.is_encoder_decoder = False |
|
decoder_config.num_layers = self.lm_config.num_decoder_layers |
|
self.decoder = T5Stack(decoder_config) |
|
self.lm_head = nn.Linear(self.lm_config.d_model, self.lm_config.vocab_size, bias=False) |
|
|
|
if hasattr(self.vision_model, 'last_ds'): |
|
input_dim = self.vision_model.last_ds.norm.normalized_shape[0] |
|
|
|
self.vision_embed_matcher = matcher( |
|
input_dim, |
|
config.lm_config.hidden_size, |
|
config.lm_config.hidden_size, |
|
config.hidden_dropout_prob |
|
) |
|
|
|
|
|
self.loss_fct = CrossEntropyLoss(ignore_index=-100) |
|
|
|
self.init_weights() |
|
|
|
if self.config.lora is not None: |
|
self.apply_lora() |
|
|
|
if self.config.vl_l1_loss: |
|
self.vl_l1_loss_fct = L1Loss() |
|
|
|
def encoder_decoder_forward( |
|
self, |
|
input_ids=None, |
|
attention_mask=None, |
|
decoder_input_ids=None, |
|
decoder_attention_mask=None, |
|
head_mask=None, |
|
decoder_head_mask=None, |
|
encoder_outputs=None, |
|
past_key_values=None, |
|
inputs_embeds=None, |
|
decoder_inputs_embeds=None, |
|
labels=None, |
|
use_cache=None, |
|
output_attentions=None, |
|
output_hidden_states=None, |
|
return_dict=None, |
|
**kwargs, |
|
): |
|
r""" |
|
https://huggingface.co/transformers/v4.5.1/_modules/transformers/modeling_t5.html#T5ForConditionalGeneration.forward |
|
or https://huggingface.co/transformers/_modules/transformers/modeling_t5.html#T5ForConditionalGeneration.forward |
|
""" |
|
|
|
if "lm_labels" in kwargs: |
|
warnings.warn( |
|
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", |
|
FutureWarning, |
|
) |
|
labels = kwargs.pop("lm_labels") |
|
if "decoder_past_key_value_states" in kwargs: |
|
warnings.warn( |
|
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", |
|
FutureWarning, |
|
) |
|
past_key_values = kwargs.pop("decoder_past_key_value_states") |
|
if "decoder_past_key_values" in kwargs: |
|
warnings.warn( |
|
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", |
|
FutureWarning, |
|
) |
|
past_key_values = kwargs.pop("decoder_past_key_values") |
|
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." |
|
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if encoder_outputs is None: |
|
|
|
encoder_outputs = self.encoder( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
inputs_embeds=inputs_embeds, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): |
|
encoder_outputs = BaseModelOutput( |
|
last_hidden_state=encoder_outputs[0], |
|
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, |
|
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, |
|
) |
|
|
|
hidden_states = encoder_outputs[0] |
|
|
|
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: |
|
|
|
decoder_input_ids = self._shift_right(labels) |
|
|
|
|
|
|
|
if past_key_values is not None: |
|
assert labels is None, "Decoder should not use cached key value states when training." |
|
if decoder_input_ids is not None: |
|
decoder_input_ids = decoder_input_ids[:, -1:] |
|
if decoder_inputs_embeds is not None: |
|
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] |
|
|
|
|
|
decoder_outputs = self.decoder( |
|
input_ids=decoder_input_ids, |
|
attention_mask=decoder_attention_mask, |
|
inputs_embeds=decoder_inputs_embeds, |
|
past_key_values=past_key_values, |
|
encoder_hidden_states=hidden_states, |
|
encoder_attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
sequence_output = decoder_outputs[0] |
|
|
|
|
|
sequence_output = sequence_output * (self.model_dim ** -0.5) |
|
lm_logits = self.lm_head(sequence_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) |
|
|
|
if self.config.vl_l1_loss: |
|
labels_ = labels.clone() |
|
labels_[labels_ == -100] = 0 |
|
with torch.no_grad(): |
|
target = self.encoder(input_ids=labels_).last_hidden_state |
|
if target.shape[1] != hidden_states.shape[1]: |
|
v_encoder_intrp = F.interpolate(hidden_states.permute(0,2,1), size=target.shape[1], mode='linear').permute(0,2,1) |
|
vl_loss = (50 * self.vl_l1_loss_fct(v_encoder_intrp, target)) |
|
loss += vl_loss |
|
|
|
if not return_dict: |
|
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs |
|
if loss is not None: |
|
output = ((loss,) + output) |
|
|
|
return output |
|
|
|
seq2seq_output = Seq2SeqLMOutput( |
|
loss=loss, |
|
logits=lm_logits, |
|
past_key_values=decoder_outputs.past_key_values, |
|
decoder_hidden_states=decoder_outputs.hidden_states, |
|
decoder_attentions=decoder_outputs.attentions, |
|
cross_attentions=decoder_outputs.cross_attentions, |
|
encoder_last_hidden_state=encoder_outputs.last_hidden_state, |
|
encoder_hidden_states=encoder_outputs.hidden_states, |
|
encoder_attentions=encoder_outputs.attentions, |
|
) |
|
|
|
return seq2seq_output |
|
|
|
def forward(self, |
|
input_ids=None, |
|
bbox=None, |
|
image=None, |
|
attention_mask=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
labels=None, |
|
**kwargs): |
|
|
|
|
|
if not kwargs.get('encoder_outputs'): |
|
_, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=None, image=image) |
|
else: |
|
|
|
assert kwargs.get('decoder_input_ids') is not None |
|
_ = vision_embeds = attention_mask = None |
|
|
|
return self.encoder_decoder_forward(input_ids=None, |
|
attention_mask=attention_mask, |
|
encoder_outputs=kwargs.get('encoder_outputs'), |
|
decoder_input_ids=kwargs.get('decoder_input_ids'), |
|
decoder_attention_mask=None, |
|
head_mask=head_mask, |
|
decoder_head_mask=None, |
|
past_key_values=kwargs.get('past_key_values'), |
|
inputs_embeds=vision_embeds, |
|
decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'), |
|
labels=labels, |
|
use_cache=True, |
|
output_attentions=kwargs.get('output_attentions'), |
|
output_hidden_states=kwargs.get('output_hidden_states'), |
|
return_dict=kwargs.get('return_dict') |
|
) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: |
|
if kwargs.get('encoder_outputs') is not None: |
|
return {'attention_mask': kwargs.get('attention_mask'), |
|
'encoder_outputs': kwargs.get('encoder_outputs'), |
|
'decoder_input_ids': input_ids, |
|
'past_key_values': kwargs.get('past'), |
|
} |
|
else: |
|
raise ValueError( |
|
"Make sure that encoder_outputs is already computed when preapring inputs for generation. --y.x.") |
|
|
|
def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None): |
|
|
|
batch_size = image.shape[0] |
|
|
|
if input_ids is not None: |
|
text_embeds = self.shared(input_ids) |
|
text_seq_length = text_embeds.shape[1] |
|
else: |
|
text_embeds = None |
|
text_seq_length = 0 |
|
|
|
assert self.config.vision is not None |
|
|
|
vision_embeds = self.vision_model(image) |
|
vision_embeds = self.vision_embed_matcher(vision_embeds) |
|
vision_seq_length = vision_embeds.shape[1] |
|
|
|
vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length) |
|
attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device) |
|
return text_embeds, vision_embeds, attention_mask |
|
|
|
def concat_task_token(self, embeds, text_seq_length=0): |
|
|
|
if self.task_name in self.task_token_ids.keys(): |
|
B = embeds.shape[0] |
|
task_embeds = self.shared(self.task_token_ids[self.task_name]) |
|
text_seq_length += task_embeds.shape[0] |
|
return torch.cat((embeds, task_embeds.repeat((B, 1, 1))), dim=1), text_seq_length |
|
else: |
|
|
|
return embeds, text_seq_length |
|
|
|
def _prepare_model_inputs( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
bos_token_id: Optional[int] = None, |
|
model_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: |
|
""" |
|
This function extracts the model-specific `inputs` for generation. |
|
""" |
|
input_name = 'inputs_embeds' |
|
_, vision_embeds, attention_mask = self._prepare_encoder_inputs(image=model_kwargs['image']) |
|
model_kwargs['attention_mask'] = attention_mask |
|
|
|
inputs = vision_embeds |
|
|
|
|
|
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) |
|
return inputs, input_name, model_kwargs |
|
|
|
def _prepare_encoder_decoder_kwargs_for_generation( |
|
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None |
|
) -> Dict[str, Any]: |
|
assert "encoder_outputs" not in model_kwargs |
|
|
|
|
|
encoder = self.get_encoder() |
|
|
|
|
|
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] |
|
irrelevent_fields = ['input_ids', 'attention_mask', 'inputs_embeds', 'image', 'bbox', 'line_coordinates', |
|
'adj', 'lm_labels', 'banned_token_ids', 'questions', 'answers', 'labels', 'task_name'] |
|
encoder_kwargs = { |
|
argument: value |
|
for argument, value in model_kwargs.items() |
|
if not any(argument.startswith(p) for p in irrelevant_prefix) and argument not in irrelevent_fields |
|
} |
|
|
|
|
|
encoder_kwargs["return_dict"] = True |
|
model_kwargs["encoder_outputs"]: ModelOutput = encoder( |
|
input_ids=None, attention_mask=model_kwargs['attention_mask'], |
|
inputs_embeds=inputs_tensor, **encoder_kwargs) |
|
|
|
return model_kwargs |
|
|
|
def set_task_name(self, task_name): |
|
if task_name: |
|
self.task_name = task_name |
|
|
|
def get_trivial_mask(self, inp): |
|
return torch.ones((inp.shape[:2]), dtype=torch.int32).to(self.device) |
|
|
|
|
|
class VisFocusModelForLocalizedMaskedLanguageModeling(VisFocusPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.set_task_name('mpm') |
|
self.text_embedder = T5_Encoder(self.vision_config.text_embedder, freeze=True) |
|
|
|
def forward(self, |
|
input_ids=None, |
|
bbox=None, |
|
image=None, |
|
attention_mask=None, |
|
head_mask=None, |
|
inputs_embeds=None, |
|
encoder_hidden_states=None, |
|
encoder_attention_mask=None, |
|
labels=None, |
|
**kwargs): |
|
if not kwargs.get('encoder_outputs'): |
|
if self.task_name == 'ocr': |
|
|
|
input_ids = None |
|
if not hasattr(self, 'prompt_embeds'): |
|
prompt = 'what is written in this document?' |
|
prompt_ids = self.input_tokenizer.encode(prompt) |
|
B = image.shape[0] |
|
prompt_ids = torch.tensor(prompt_ids).expand(B, len(prompt_ids)).to(self.device) |
|
setattr(self, 'prompt_embeds', self.text_embedder(prompt_ids).detach()) |
|
_, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=input_ids, image=image) |
|
else: |
|
|
|
assert kwargs.get('decoder_input_ids') is not None |
|
_ = vision_embeds = attention_mask = None |
|
|
|
return self.encoder_decoder_forward(input_ids=None, |
|
attention_mask=attention_mask, |
|
encoder_outputs=kwargs.get('encoder_outputs'), |
|
decoder_input_ids=kwargs.get('decoder_input_ids'), |
|
decoder_attention_mask=None, |
|
head_mask=head_mask, |
|
decoder_head_mask=None, |
|
past_key_values=kwargs.get('past_key_values'), |
|
inputs_embeds=vision_embeds, |
|
decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'), |
|
labels=labels, |
|
use_cache=True, |
|
output_attentions=kwargs.get('output_attentions'), |
|
output_hidden_states=kwargs.get('output_hidden_states'), |
|
return_dict=kwargs.get('return_dict') |
|
) |
|
|
|
def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None): |
|
batch_size = image.shape[0] |
|
|
|
|
|
if self.task_name == 'ocr': |
|
assert input_ids is None |
|
text_embeds = self.prompt_embeds |
|
else: |
|
assert input_ids is not None |
|
if self.text_embedder == self.encoder: |
|
with torch.no_grad(): |
|
text_embeds = self.encoder(input_ids).last_hidden_state |
|
else: |
|
text_embeds = self.text_embedder(input_ids) |
|
|
|
text_embeds = text_embeds.detach() |
|
|
|
text_seq_length = text_embeds.shape[1] if self.task_name == 'pm_vqa_concat' else 0 |
|
assert self.config.vision is not None |
|
|
|
vision_embeds = self.vision_model(image, context_prompts=text_embeds) |
|
if self.vision_model.model_name in ["swin_v2"]: |
|
vision_embeds = self.vision_embed_matcher(vision_embeds) |
|
vision_seq_length = vision_embeds.shape[1] |
|
|
|
vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length=text_seq_length) |
|
attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device) |
|
return text_embeds, vision_embeds, attention_mask |
|
|
|
def _prepare_model_inputs( |
|
self, |
|
inputs: Optional[torch.Tensor] = None, |
|
bos_token_id: Optional[int] = None, |
|
model_kwargs: Optional[Dict[str, torch.Tensor]] = None, |
|
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: |
|
""" |
|
This function extracts the model-specific `inputs` for generation. |
|
""" |
|
|
|
input_name = 'inputs_embeds' |
|
_, vision_embeds, attention_mask = self._prepare_encoder_inputs(image=model_kwargs['image'], input_ids=model_kwargs['input_ids']) |
|
model_kwargs['attention_mask'] = attention_mask |
|
inputs = vision_embeds |
|
|
|
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) |
|
return inputs, input_name, model_kwargs |
|
|
|
|
|
class VisFocusModelForImageTextToText(VisFocusModelForLocalizedMaskedLanguageModeling): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.set_task_name('pm_vqa_concat') |
|
|
|
def forward(self, questions=None, answers=None, image=None, labels=None, **kwargs): |
|
if kwargs.get('encoder_outputs') is None: |
|
text_embeds, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=questions['input_ids'], image=image) |
|
inputs_embeds = torch.concat((text_embeds, vision_embeds), dim=1) |
|
attention_mask = self.get_trivial_mask(inputs_embeds) |
|
else: |
|
|
|
assert kwargs.get('decoder_input_ids') is not None |
|
assert kwargs.get('encoder_outputs') is not None |
|
inputs_embeds = kwargs.get('encoder_outputs') |
|
text_embeds = vision_embeds = attention_mask = None |
|
|
|
return self.encoder_decoder_forward(input_ids=None, |
|
attention_mask=attention_mask, |
|
encoder_outputs=kwargs.get('encoder_outputs'), |
|
decoder_input_ids=kwargs.get('decoder_input_ids'), |
|
decoder_attention_mask=None, |
|
head_mask=None, |
|
decoder_head_mask=None, |
|
past_key_values=kwargs.get('past_key_values'), |
|
inputs_embeds=inputs_embeds, |
|
decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'), |
|
labels=labels, |
|
use_cache=True, |
|
output_attentions=kwargs.get('output_attentions'), |
|
output_hidden_states=kwargs.get('output_hidden_states'), |
|
return_dict=kwargs.get('return_dict') |
|
) |
|
|
|
def _prepare_model_inputs(self, inputs=None, bos_token_id=None, model_kwargs=None ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: |
|
""" |
|
This function extracts the model-specific `inputs` for generation. |
|
""" |
|
input_name = 'inputs_embeds' |
|
text_embeds, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=model_kwargs['questions']['input_ids'], image=model_kwargs['image']) |
|
model_kwargs['attention_mask'] = attention_mask |
|
inputs_embeds = torch.concat((text_embeds, vision_embeds), dim=1) |
|
inputs = inputs_embeds |
|
|
|
inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) |
|
model_kwargs['attention_mask'] = self.get_trivial_mask(inputs) |
|
return inputs, input_name, model_kwargs |
|
|
|
def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None): |
|
batch_size = image.shape[0] |
|
assert input_ids is not None |
|
if self.text_embedder == self.encoder: |
|
with torch.no_grad(): |
|
text_embeds = self.encoder(input_ids).last_hidden_state |
|
else: |
|
text_embeds = self.text_embedder(input_ids) |
|
|
|
text_embeds = text_embeds.detach() |
|
|
|
text_seq_length = text_embeds.shape[1] if self.task_name == 'pm_vqa_concat' else 0 |
|
assert self.config.vision is not None |
|
|
|
vision_embeds = self.vision_model(image, context_prompts=text_embeds) |
|
if self.vision_model.model_name in ["swin_v2"]: |
|
vision_embeds = self.vision_embed_matcher(vision_embeds) |
|
vision_seq_length = vision_embeds.shape[1] |
|
|
|
vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length=text_seq_length) |
|
attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device) |
|
text_embeds = self.shared(input_ids) |
|
return text_embeds, vision_embeds, attention_mask |
|
|
|
|
|
def _to_cuda(sample, device=torch.device('cuda')): |
|
if isinstance(sample, torch.Tensor): |
|
return sample.to(device) |
|
elif isinstance(sample, list): |
|
return sample |
|
else: |
|
for k in sample.keys(): |
|
sample[k] = _to_cuda(sample[k], device) |
|
return sample |
|
|
|
|
|
def fetch_sample(ds, ds_for_vis): |
|
idx = random.randint(50, 100) |
|
for i in range(idx): |
|
inputs = next(ds) |
|
inputs_to_vis = next(ds_for_vis) |
|
return inputs, inputs_to_vis |
|
|
|
|
|
MATCHER_MAP = { |
|
'default': EmbedMatcher, |
|
} |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
with open('configs/test_expts/vf_base_finetune_docvqa__v2_accum4_f32_V5__mpm_altConcat__vilma_concat_V1/vqa_model_args.yaml', 'r') as f: |
|
model_args = EasyDict(yaml.safe_load(f)) |
|
|
|
DEVICE = 'cpu' |
|
|
|
|
|
last_ckpt = None |
|
|
|
|
|
|
|
|
|
cfg = VisFocusConfig.from_pretrained('configs/config.json') |
|
cfg.push_to_hub('ofirab/visfocus-base-docvqa') |
|
model = VisFocusModelForImageTextToText(cfg) |
|
|
|
VisFocusConfig.register_for_auto_class() |
|
VisFocusPreTrainedModel.register_for_auto_class("AutoModel") |
|
VisFocusModelForImageTextToText.register_for_auto_class("AutoModelForImageTextToText") |
|
|
|
model.push_to_hub('ofirab/visfocus-base-docvqa') |
|
pr = VisFocusImageProcessor(is_train=False) |
|
tokenizer = AutoTokenizer.from_pretrained('ofirab/visfocus-base-docvqa') |
|
prr = VisFocusProcessor(pr, tokenizer) |
|
model.to(DEVICE) |
|
|