import torch from torch import nn from torch.nn import LayerNorm, CrossEntropyLoss, L1Loss from torch.nn import functional as F from transformers import PreTrainedModel, AutoTokenizer, GenerationMixin, logging from transformers.models.t5.modeling_t5 import T5Stack from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput from transformers.file_utils import ModelOutput from timm.models.layers import trunc_normal_ from typing import Any, Dict, Optional, Tuple import warnings import random import yaml import copy from easydict import EasyDict from configuration_visfocus import VisFocusConfig from modeling_vilmaswin import VilmaSwinTransformerV2 from image_processing_visfocus import VisFocusImageProcessor from processing_visfocus import VisFocusProcessor logger = logging.get_logger(__name__) def get_vision_model(config): vision_model = VilmaSwinTransformerV2( img_size=config.image_size, patch_size=config.patch_size, in_chans=config.in_chans, embed_dim=config.embed_dim, depths=config.depths, num_heads=config.num_heads, window_size=config.window_size, mlp_ratio=config.mlp_ratio, qkv_bias=config.qkv_bias, drop_rate=config.drop_rate, drop_path_rate=config.drop_path_rate, ape=config.ape, patch_norm=config.patch_norm, use_checkpoint=config.use_checkpoint, pretrained_window_sizes=config.pretrained_window_sizes, do_shift=config.do_shift, vl_cross_attn_layers=config.vl_cross_attn_layers, vl_alpha=config.vl_alpha, lm_d_model=config.lm_d_model, input_type=config.input_type, vl_learned_ape=config.vl_learned_ape) return vision_model def load_vision_pretrained(configs, model): logger.info("Loading vision model from %s", configs.model.vision_resume_from) if configs.model.vision_resume_from.startswith("https"): checkpoint = torch.hub.load_state_dict_from_url( configs.model.vision_resume_from, map_location="cpu", check_hash=True ) else: checkpoint = torch.load(configs.model.vision_resume_from, map_location="cpu") state_dict = checkpoint["model"] if "swin" in configs.model.type: # delete relative_position_index since we always re-init it relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k] for k in relative_position_index_keys: del state_dict[k] # delete relative_coords_table since we always re-init it relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k] for k in relative_position_index_keys: del state_dict[k] # delete attn_mask since we always re-init it attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k] for k in attn_mask_keys: del state_dict[k] # bicubic interpolate relative_position_bias_table if not match relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] for k in relative_position_bias_table_keys: relative_position_bias_table_pretrained = state_dict[k] relative_position_bias_table_current = model.vision_model.state_dict()[k] L1, nH1 = relative_position_bias_table_pretrained.size() L2, nH2 = relative_position_bias_table_current.size() if nH1 != nH2: logger.warning(f"Error in loading {k}, passing......") else: if L1 != L2: # bicubic interpolate relative_position_bias_table if not match S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate( relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2), mode='bicubic') state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0) # bicubic interpolate absolute_pos_embed if not match absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k] for k in absolute_pos_embed_keys: # dpe absolute_pos_embed_pretrained = state_dict[k] absolute_pos_embed_current = model.vision_model.state_dict()[k] _, L1, C1 = absolute_pos_embed_pretrained.size() _, L2, C2 = absolute_pos_embed_current.size() if C1 != C1: logger.warning(f"Error in loading {k}, passing......") else: if L1 != L2: S1 = int(L1 ** 0.5) S2 = int(L2 ** 0.5) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1) absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2) absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate( absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic') absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1) absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2) state_dict[k] = absolute_pos_embed_pretrained_resized if model.vision_model.patch_embed.proj.weight.shape != state_dict['patch_embed.proj.weight'].shape: model.vision_model.input_type == 'flattened_patches' logger.warning(f"PatchEmbed (patch_embed) was not loaded, because input_type is falttened_patches.") del state_dict['patch_embed.proj.weight'] # import pdb;pdb.set_trace() msg = model.vision_model.load_state_dict(state_dict, strict=False) # do not print unnecessary (vl attn is not loaded now) filtered_missing_keys = {k for k in msg.missing_keys if 'vl_cross_attn_layers' not in k or 'relative_position' not in k} filtered_missing_keys.union({'relative_position' for k in msg.missing_keys if 'relative_position' not in k}) # if len({k for k in msg.missing_keys if 'relative_' in k}) > 0: # logger.warning(f'Relative position were not loaded') # filtered_missing_keys.union() logger.warning(f'Missing keys: {set(msg.missing_keys) - filtered_missing_keys}') logger.warning(f'Unexpected keys: {msg.unexpected_keys}') # logger.warning(msg) logger.info("Loaded model successfully from %s", configs.model.vision_resume_from) del checkpoint torch.cuda.empty_cache() class T5_Encoder(nn.Module): def __init__(self, t5_variant='base', freeze=True): from transformers import T5Tokenizer, T5Model super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(f'{t5_variant}') model = T5Model.from_pretrained(f'{t5_variant}') del model.decoder self.encoder = model.encoder if freeze: for p in self.encoder.parameters(): p.requires_grad = False def forward(self, input_ids): encoder_outputs = self.encoder( input_ids=input_ids, return_dict=True, ) return encoder_outputs[0] class SpatialEmbeddings(nn.Module): def __init__(self, config): super().__init__() self.x_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.y_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.h_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.w_position_embeddings = nn.Embedding( config.max_2d_position_embeddings, config.hidden_size ) self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config def forward( self, bbox, ): seq_length = bbox.size(1) left_position_embeddings = self.x_position_embeddings(bbox[:, :, 0]) upper_position_embeddings = self.y_position_embeddings(bbox[:, :, 1]) right_position_embeddings = self.x_position_embeddings(bbox[:, :, 2]) lower_position_embeddings = self.y_position_embeddings(bbox[:, :, 3]) h_position_embeddings = self.h_position_embeddings( bbox[:, :, 3] - bbox[:, :, 1] ) w_position_embeddings = self.w_position_embeddings( bbox[:, :, 2] - bbox[:, :, 0] ) embeddings = ( left_position_embeddings + upper_position_embeddings + right_position_embeddings + lower_position_embeddings + h_position_embeddings + w_position_embeddings ) embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class EmbedMatcher(nn.Module): def __init__(self, input_dim, inner_dim, output_dim, dropout_rate=0.1): super().__init__() self.embedd_matcher = nn.Sequential( nn.Linear(input_dim, inner_dim, bias=True), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(inner_dim, output_dim, bias=False), nn.Dropout(dropout_rate) ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.embedd_matcher(x) return x class MLP(nn.Module): """ Very simple multi-layer perceptron (also called FFN)""" def __init__(self, input_dim, hidden_dim, output_dim, num_layers): super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) def forward(self, x): for i, layer in enumerate(self.layers): x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) return x class VisFocusPreTrainedModel(PreTrainedModel, GenerationMixin): config_class = VisFocusConfig def __init__(self, config): super().__init__(config.lm_config) self.set_task_name('ocr') self.model_arch = 'visfocus' self.config = config self.lm_config = config.lm_config self.vision_config = config.vision_config self.vision_model = get_vision_model(self.vision_config) input_dim = self.vision_model.num_features matcher = MATCHER_MAP[self.config.matcher_type] # load T5 encoder and decoder encoder_config = copy.deepcopy(self.lm_config) encoder_config.is_decoder = False encoder_config.use_cache = False encoder_config.is_encoder_decoder = False self.encoder = T5Stack(encoder_config) decoder_config = copy.deepcopy(self.lm_config) decoder_config.is_decoder = True decoder_config.is_encoder_decoder = False decoder_config.num_layers = self.lm_config.num_decoder_layers self.decoder = T5Stack(decoder_config) self.lm_head = nn.Linear(self.lm_config.d_model, self.lm_config.vocab_size, bias=False) if hasattr(self.vision_model, 'last_ds'): input_dim = self.vision_model.last_ds.norm.normalized_shape[0] self.vision_embed_matcher = matcher( input_dim, config.lm_config.hidden_size, config.lm_config.hidden_size, config.hidden_dropout_prob ) # losses self.loss_fct = CrossEntropyLoss(ignore_index=-100) self.init_weights() if self.config.lora is not None: self.apply_lora() if self.config.vl_l1_loss: self.vl_l1_loss_fct = L1Loss() def encoder_decoder_forward( self, input_ids=None, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, encoder_outputs=None, past_key_values=None, inputs_embeds=None, decoder_inputs_embeds=None, labels=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): r""" https://huggingface.co/transformers/v4.5.1/_modules/transformers/modeling_t5.html#T5ForConditionalGeneration.forward or https://huggingface.co/transformers/_modules/transformers/modeling_t5.html#T5ForConditionalGeneration.forward """ if "lm_labels" in kwargs: warnings.warn( "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", FutureWarning, ) labels = kwargs.pop("lm_labels") if "decoder_past_key_value_states" in kwargs: warnings.warn( "The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) past_key_values = kwargs.pop("decoder_past_key_value_states") if "decoder_past_key_values" in kwargs: warnings.warn( "The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.", FutureWarning, ) past_key_values = kwargs.pop("decoder_past_key_values") assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode if needed (training, first prediction pass) if encoder_outputs is None: # Convert encoder inputs in embeddings if needed encoder_outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, inputs_embeds=inputs_embeds, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): encoder_outputs = BaseModelOutput( last_hidden_state=encoder_outputs[0], hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) hidden_states = encoder_outputs[0] if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: # get decoder inputs from shifting lm labels to the right decoder_input_ids = self._shift_right(labels) # If decoding with past key value states, only the last tokens # should be given as an input if past_key_values is not None: assert labels is None, "Decoder should not use cached key value states when training." if decoder_input_ids is not None: decoder_input_ids = decoder_input_ids[:, -1:] if decoder_inputs_embeds is not None: decoder_inputs_embeds = decoder_inputs_embeds[:, -1:] # Decode decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, inputs_embeds=decoder_inputs_embeds, past_key_values=past_key_values, encoder_hidden_states=hidden_states, encoder_attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) sequence_output = decoder_outputs[0] # Rescale output before projecting on vocab # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim ** -0.5) lm_logits = self.lm_head(sequence_output) loss = None if labels is not None: loss = self.loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) if self.config.vl_l1_loss: labels_ = labels.clone() labels_[labels_ == -100] = 0 # -> replace the ignore_index with the pad_token id to calculate the text target for the vl loss with torch.no_grad(): target = self.encoder(input_ids=labels_).last_hidden_state if target.shape[1] != hidden_states.shape[1]: v_encoder_intrp = F.interpolate(hidden_states.permute(0,2,1), size=target.shape[1], mode='linear').permute(0,2,1) vl_loss = (50 * self.vl_l1_loss_fct(v_encoder_intrp, target)) loss += vl_loss if not return_dict: output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs if loss is not None: output = ((loss,) + output) return output seq2seq_output = Seq2SeqLMOutput( loss=loss, logits=lm_logits, past_key_values=decoder_outputs.past_key_values, decoder_hidden_states=decoder_outputs.hidden_states, decoder_attentions=decoder_outputs.attentions, cross_attentions=decoder_outputs.cross_attentions, encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, ) return seq2seq_output def forward(self, input_ids=None, bbox=None, image=None, attention_mask=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, **kwargs): # see https://huggingface.co/transformers/v2.10.0/_modules/transformers/modeling_t5.html#T5Model.forward if not kwargs.get('encoder_outputs'): _, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=None, image=image) else: # for generation mode assert kwargs.get('decoder_input_ids') is not None _ = vision_embeds = attention_mask = None return self.encoder_decoder_forward(input_ids=None, attention_mask=attention_mask, encoder_outputs=kwargs.get('encoder_outputs'), decoder_input_ids=kwargs.get('decoder_input_ids'), decoder_attention_mask=None, head_mask=head_mask, decoder_head_mask=None, past_key_values=kwargs.get('past_key_values'), inputs_embeds=vision_embeds, decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'), labels=labels, use_cache=True, output_attentions=kwargs.get('output_attentions'), output_hidden_states=kwargs.get('output_hidden_states'), return_dict=kwargs.get('return_dict') ) def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]: if kwargs.get('encoder_outputs') is not None: return {'attention_mask': kwargs.get('attention_mask'), 'encoder_outputs': kwargs.get('encoder_outputs'), 'decoder_input_ids': input_ids, 'past_key_values': kwargs.get('past'), } else: raise ValueError( "Make sure that encoder_outputs is already computed when preapring inputs for generation. --y.x.") def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None): # text embedding batch_size = image.shape[0] if input_ids is not None: text_embeds = self.shared(input_ids) text_seq_length = text_embeds.shape[1] else: text_embeds = None text_seq_length = 0 assert self.config.vision is not None # vision embedding vision_embeds = self.vision_model(image) vision_embeds = self.vision_embed_matcher(vision_embeds) vision_seq_length = vision_embeds.shape[1] # add task token (e.g for ocr) vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length) attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device) return text_embeds, vision_embeds, attention_mask def concat_task_token(self, embeds, text_seq_length=0): # add task token (e.g for ocr) if self.task_name in self.task_token_ids.keys(): B = embeds.shape[0] task_embeds = self.shared(self.task_token_ids[self.task_name]) text_seq_length += task_embeds.shape[0] return torch.cat((embeds, task_embeds.repeat((B, 1, 1))), dim=1), text_seq_length else: # no such task token exists return embeds, text_seq_length def _prepare_model_inputs( self, inputs: Optional[torch.Tensor] = None, bos_token_id: Optional[int] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: """ This function extracts the model-specific `inputs` for generation. """ input_name = 'inputs_embeds' _, vision_embeds, attention_mask = self._prepare_encoder_inputs(image=model_kwargs['image']) model_kwargs['attention_mask'] = attention_mask inputs = vision_embeds # 4. if `inputs` is still None, try to create `input_ids` from BOS token inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) return inputs, input_name, model_kwargs def _prepare_encoder_decoder_kwargs_for_generation( self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None ) -> Dict[str, Any]: assert "encoder_outputs" not in model_kwargs # 1. get encoder encoder = self.get_encoder() # 2. prepare encoder args and encoder kwargs from model kwargs irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"] irrelevent_fields = ['input_ids', 'attention_mask', 'inputs_embeds', 'image', 'bbox', 'line_coordinates', 'adj', 'lm_labels', 'banned_token_ids', 'questions', 'answers', 'labels', 'task_name'] encoder_kwargs = { argument: value for argument, value in model_kwargs.items() if not any(argument.startswith(p) for p in irrelevant_prefix) and argument not in irrelevent_fields } # 3. make sure that encoder returns `ModelOutput` encoder_kwargs["return_dict"] = True model_kwargs["encoder_outputs"]: ModelOutput = encoder( input_ids=None, attention_mask=model_kwargs['attention_mask'], inputs_embeds=inputs_tensor, **encoder_kwargs) return model_kwargs def set_task_name(self, task_name): if task_name: self.task_name = task_name def get_trivial_mask(self, inp): return torch.ones((inp.shape[:2]), dtype=torch.int32).to(self.device) class VisFocusModelForLocalizedMaskedLanguageModeling(VisFocusPreTrainedModel): def __init__(self, config): super().__init__(config) self.set_task_name('mpm') self.text_embedder = T5_Encoder(self.vision_config.text_embedder, freeze=True) def forward(self, input_ids=None, bbox=None, image=None, attention_mask=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, **kwargs): if not kwargs.get('encoder_outputs'): if self.task_name == 'ocr': # NOTE: not supported yet input_ids = None if not hasattr(self, 'prompt_embeds'): prompt = 'what is written in this document?' prompt_ids = self.input_tokenizer.encode(prompt) B = image.shape[0] prompt_ids = torch.tensor(prompt_ids).expand(B, len(prompt_ids)).to(self.device) setattr(self, 'prompt_embeds', self.text_embedder(prompt_ids).detach()) _, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=input_ids, image=image) else: # for generation mode assert kwargs.get('decoder_input_ids') is not None _ = vision_embeds = attention_mask = None return self.encoder_decoder_forward(input_ids=None, attention_mask=attention_mask, encoder_outputs=kwargs.get('encoder_outputs'), decoder_input_ids=kwargs.get('decoder_input_ids'), decoder_attention_mask=None, head_mask=head_mask, decoder_head_mask=None, past_key_values=kwargs.get('past_key_values'), inputs_embeds=vision_embeds, decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'), labels=labels, use_cache=True, output_attentions=kwargs.get('output_attentions'), output_hidden_states=kwargs.get('output_hidden_states'), return_dict=kwargs.get('return_dict') ) def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None): batch_size = image.shape[0] # if prompt is contant if self.task_name == 'ocr': assert input_ids is None text_embeds = self.prompt_embeds else: assert input_ids is not None if self.text_embedder == self.encoder: with torch.no_grad(): text_embeds = self.encoder(input_ids).last_hidden_state else: text_embeds = self.text_embedder(input_ids) text_embeds = text_embeds.detach() text_seq_length = text_embeds.shape[1] if self.task_name == 'pm_vqa_concat' else 0 assert self.config.vision is not None # vision embedding vision_embeds = self.vision_model(image, context_prompts=text_embeds) if self.vision_model.model_name in ["swin_v2"]: vision_embeds = self.vision_embed_matcher(vision_embeds) vision_seq_length = vision_embeds.shape[1] # add task token (e.g for ocr) vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length=text_seq_length) attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device) return text_embeds, vision_embeds, attention_mask def _prepare_model_inputs( self, inputs: Optional[torch.Tensor] = None, bos_token_id: Optional[int] = None, model_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: """ This function extracts the model-specific `inputs` for generation. """ input_name = 'inputs_embeds' _, vision_embeds, attention_mask = self._prepare_encoder_inputs(image=model_kwargs['image'], input_ids=model_kwargs['input_ids']) model_kwargs['attention_mask'] = attention_mask inputs = vision_embeds # 4. if `inputs` is still None, try to create `input_ids` from BOS token inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) return inputs, input_name, model_kwargs class VisFocusModelForImageTextToText(VisFocusModelForLocalizedMaskedLanguageModeling): def __init__(self, config): super().__init__(config) self.set_task_name('pm_vqa_concat') def forward(self, questions=None, answers=None, image=None, labels=None, **kwargs): if kwargs.get('encoder_outputs') is None: text_embeds, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=questions['input_ids'], image=image) inputs_embeds = torch.concat((text_embeds, vision_embeds), dim=1) attention_mask = self.get_trivial_mask(inputs_embeds) # -> when different tokenizer is used for ViLMA/concat, need to re-calculate attn. mask else: # for generation mode (image encoding happens before) assert kwargs.get('decoder_input_ids') is not None assert kwargs.get('encoder_outputs') is not None inputs_embeds = kwargs.get('encoder_outputs') text_embeds = vision_embeds = attention_mask = None return self.encoder_decoder_forward(input_ids=None, attention_mask=attention_mask, encoder_outputs=kwargs.get('encoder_outputs'), decoder_input_ids=kwargs.get('decoder_input_ids'), decoder_attention_mask=None, head_mask=None, decoder_head_mask=None, past_key_values=kwargs.get('past_key_values'), inputs_embeds=inputs_embeds, decoder_inputs_embeds=kwargs.get('decoder_inputs_embeds'), labels=labels, use_cache=True, output_attentions=kwargs.get('output_attentions'), output_hidden_states=kwargs.get('output_hidden_states'), return_dict=kwargs.get('return_dict') ) def _prepare_model_inputs(self, inputs=None, bos_token_id=None, model_kwargs=None ) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]: """ This function extracts the model-specific `inputs` for generation. """ input_name = 'inputs_embeds' text_embeds, vision_embeds, attention_mask = self._prepare_encoder_inputs(input_ids=model_kwargs['questions']['input_ids'], image=model_kwargs['image']) model_kwargs['attention_mask'] = attention_mask inputs_embeds = torch.concat((text_embeds, vision_embeds), dim=1) inputs = inputs_embeds # 4. if `inputs` is still None, try to create `input_ids` from BOS token inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) model_kwargs['attention_mask'] = self.get_trivial_mask(inputs) return inputs, input_name, model_kwargs def _prepare_encoder_inputs(self, image, input_ids=None, bbox=None, attention_mask=None): batch_size = image.shape[0] assert input_ids is not None if self.text_embedder == self.encoder: with torch.no_grad(): text_embeds = self.encoder(input_ids).last_hidden_state else: text_embeds = self.text_embedder(input_ids) text_embeds = text_embeds.detach() text_seq_length = text_embeds.shape[1] if self.task_name == 'pm_vqa_concat' else 0 assert self.config.vision is not None # vision embedding vision_embeds = self.vision_model(image, context_prompts=text_embeds) if self.vision_model.model_name in ["swin_v2"]: vision_embeds = self.vision_embed_matcher(vision_embeds) vision_seq_length = vision_embeds.shape[1] # add task token (e.g for ocr) vision_embeds, text_seq_length = self.concat_task_token(vision_embeds, text_seq_length=text_seq_length) attention_mask = torch.ones((batch_size, vision_seq_length + text_seq_length), dtype=torch.int32).to(self.device) text_embeds = self.shared(input_ids) # for concat, use direct the T5 nn.embeddings return text_embeds, vision_embeds, attention_mask def _to_cuda(sample, device=torch.device('cuda')): if isinstance(sample, torch.Tensor): return sample.to(device) elif isinstance(sample, list): return sample else: for k in sample.keys(): sample[k] = _to_cuda(sample[k], device) return sample def fetch_sample(ds, ds_for_vis): idx = random.randint(50, 100) for i in range(idx): inputs = next(ds) inputs_to_vis = next(ds_for_vis) return inputs, inputs_to_vis MATCHER_MAP = { 'default': EmbedMatcher, } # vqa if __name__ == '__main__': # load yaml with open('configs/test_expts/vf_base_finetune_docvqa__v2_accum4_f32_V5__mpm_altConcat__vilma_concat_V1/vqa_model_args.yaml', 'r') as f: model_args = EasyDict(yaml.safe_load(f)) DEVICE = 'cpu' # 'cpu' ## load pretrained if needed last_ckpt = None # get_last_checkpoint(dirname(model_args.model_config_path)) ## # model = get_model_class(model_args, last_ckpt=last_ckpt) cfg = VisFocusConfig.from_pretrained('configs/config.json') cfg.push_to_hub('ofirab/visfocus-base-docvqa') model = VisFocusModelForImageTextToText(cfg) VisFocusConfig.register_for_auto_class() VisFocusPreTrainedModel.register_for_auto_class("AutoModel") VisFocusModelForImageTextToText.register_for_auto_class("AutoModelForImageTextToText") model.push_to_hub('ofirab/visfocus-base-docvqa') pr = VisFocusImageProcessor(is_train=False) tokenizer = AutoTokenizer.from_pretrained('ofirab/visfocus-base-docvqa') prr = VisFocusProcessor(pr, tokenizer) model.to(DEVICE)