ofirab commited on
Commit
3d1911b
·
verified ·
1 Parent(s): f1f18b5

Upload model

Browse files
Files changed (3) hide show
  1. config.json +2 -4
  2. model.safetensors +1 -1
  3. modeling_visfocus.py +18 -25
config.json CHANGED
@@ -1,14 +1,12 @@
1
  {
2
  "architectures": [
3
- "VisFocusModel",
4
- "VisFocusForLocalizedMaskedLanguageModeling",
5
- "VisFocusForImageTextToText"
6
  ],
7
  "auto_map": {
8
  "AutoConfig": "configuration_visfocus.VisFocusConfig",
9
  "AutoModel": "configuration_visfocus.VisFocusPreTrainedModel",
10
  "AutoModelForConditionalGeneration": "configuration_visfocus.VisFocusModelForImageTextToText",
11
- "AutoModelForImageTextToText": "configuration_visfocus.VisFocusModelForImageTextToText"
12
  },
13
  "cache_dir": null,
14
  "do_lower_case": true,
 
1
  {
2
  "architectures": [
3
+ "VisFocusModelForImageTextToText"
 
 
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_visfocus.VisFocusConfig",
7
  "AutoModel": "configuration_visfocus.VisFocusPreTrainedModel",
8
  "AutoModelForConditionalGeneration": "configuration_visfocus.VisFocusModelForImageTextToText",
9
+ "AutoModelForImageTextToText": "modeling_visfocus.VisFocusModelForImageTextToText"
10
  },
11
  "cache_dir": null,
12
  "do_lower_case": true,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f644d82b1150eba66c88fcb62fed2cdd1a871f0ea44bfb136ea0bc182b8c9fae
3
  size 1047109288
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e631e85b53c7ccd3df8c70a10c528c8582394914427f9cb0ba185b81d9b8ed22
3
  size 1047109288
modeling_visfocus.py CHANGED
@@ -4,7 +4,7 @@ from torch import nn
4
  from torch.nn import LayerNorm, CrossEntropyLoss, L1Loss
5
  from torch.nn import functional as F
6
 
7
- from transformers import PreTrainedModel, T5Tokenizer, T5Model, logging
8
  from transformers.models.t5.modeling_t5 import T5Stack
9
  from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
10
  from transformers.file_utils import ModelOutput
@@ -17,8 +17,10 @@ import yaml
17
  import copy
18
  from easydict import EasyDict
19
 
20
- from .configuration_visfocus import VisFocusConfig
21
- from .modeling_vilmaswin import VilmaSwinTransformerV2
 
 
22
 
23
  logger = logging.get_logger(__name__)
24
 
@@ -148,6 +150,7 @@ def load_vision_pretrained(configs, model):
148
 
149
  class T5_Encoder(nn.Module):
150
  def __init__(self, t5_variant='base', freeze=True):
 
151
  super().__init__()
152
  self.tokenizer = T5Tokenizer.from_pretrained(f'{t5_variant}')
153
  model = T5Model.from_pretrained(f'{t5_variant}')
@@ -255,7 +258,7 @@ class MLP(nn.Module):
255
  return x
256
 
257
 
258
- class VisFocusModel(PreTrainedModel):
259
  config_class = VisFocusConfig
260
 
261
  def __init__(self, config):
@@ -413,7 +416,7 @@ class VisFocusModel(PreTrainedModel):
413
 
414
  if self.config.vl_l1_loss:
415
  labels_ = labels.clone()
416
- labels_[labels_ == -100] = self.input_tokenizer.pad_token_id # -> replace the ignore_index with the pad_token id to calculate the text target for the vl loss
417
  with torch.no_grad():
418
  target = self.encoder(input_ids=labels_).last_hidden_state
419
  if target.shape[1] != hidden_states.shape[1]:
@@ -567,15 +570,6 @@ class VisFocusModel(PreTrainedModel):
567
  inputs_embeds=inputs_tensor, **encoder_kwargs)
568
 
569
  return model_kwargs
570
-
571
- def add_task_tokens(self):
572
- self.input_tokenizer.add_tokens('<OCR>', special_tokens=True)
573
- self.task_token_ids = torch.nn.ParameterDict([['ocr', self.register_token('<OCR>')]])
574
-
575
- def register_token(self, token: str):
576
- self.input_tokenizer.add_tokens(token, special_tokens=True)
577
- token_ids = self.input_tokenizer.encode(token)
578
- return torch.nn.Parameter(torch.tensor(token_ids), requires_grad=False)
579
 
580
  def set_task_name(self, task_name):
581
  if task_name:
@@ -585,7 +579,7 @@ class VisFocusModel(PreTrainedModel):
585
  return torch.ones((inp.shape[:2]), dtype=torch.int32).to(self.device)
586
 
587
 
588
- class VisFocusModelForLocalizedMaskedLanguageModeling(VisFocusModel):
589
  def __init__(self, config):
590
  super().__init__(config)
591
  self.set_task_name('mpm')
@@ -604,6 +598,7 @@ class VisFocusModelForLocalizedMaskedLanguageModeling(VisFocusModel):
604
  **kwargs):
605
  if not kwargs.get('encoder_outputs'):
606
  if self.task_name == 'ocr':
 
607
  input_ids = None
608
  if not hasattr(self, 'prompt_embeds'):
609
  prompt = 'what is written in this document?'
@@ -681,11 +676,6 @@ class VisFocusModelForLocalizedMaskedLanguageModeling(VisFocusModel):
681
  inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
682
  return inputs, input_name, model_kwargs
683
 
684
- def add_task_tokens(self):
685
- super().add_task_tokens()
686
- self.input_tokenizer.add_tokens('<MPM>', special_tokens=True)
687
- self.task_token_ids.update({'mpm': self.register_token('<MPM>')})
688
-
689
 
690
  class VisFocusModelForImageTextToText(VisFocusModelForLocalizedMaskedLanguageModeling):
691
  def __init__(self, config):
@@ -759,11 +749,6 @@ class VisFocusModelForImageTextToText(VisFocusModelForLocalizedMaskedLanguageMod
759
  text_embeds = self.shared(input_ids) # for concat, use direct the T5 nn.embeddings
760
  return text_embeds, vision_embeds, attention_mask
761
 
762
- def add_task_tokens(self):
763
- super().add_task_tokens()
764
- self.input_tokenizer.add_tokens('<LMPM_VQA_CONCAT>', special_tokens=True)
765
- self.task_token_ids.update({'pm_vqa_concat': self.register_token('<LMPM_VQA_CONCAT>')})
766
-
767
 
768
  def _to_cuda(sample, device=torch.device('cuda')):
769
  if isinstance(sample, torch.Tensor):
@@ -806,5 +791,13 @@ if __name__ == '__main__':
806
  cfg = VisFocusConfig.from_pretrained('configs/config.json')
807
  cfg.push_to_hub('ofirab/visfocus-base-docvqa')
808
  model = VisFocusModelForImageTextToText(cfg)
 
 
 
 
 
809
  model.push_to_hub('ofirab/visfocus-base-docvqa')
 
 
 
810
  model.to(DEVICE)
 
4
  from torch.nn import LayerNorm, CrossEntropyLoss, L1Loss
5
  from torch.nn import functional as F
6
 
7
+ from transformers import PreTrainedModel, AutoTokenizer, GenerationMixin, logging
8
  from transformers.models.t5.modeling_t5 import T5Stack
9
  from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
10
  from transformers.file_utils import ModelOutput
 
17
  import copy
18
  from easydict import EasyDict
19
 
20
+ from configuration_visfocus import VisFocusConfig
21
+ from modeling_vilmaswin import VilmaSwinTransformerV2
22
+ from image_processing_visfocus import VisFocusImageProcessor
23
+ from processing_visfocus import VisFocusProcessor
24
 
25
  logger = logging.get_logger(__name__)
26
 
 
150
 
151
  class T5_Encoder(nn.Module):
152
  def __init__(self, t5_variant='base', freeze=True):
153
+ from transformers import T5Tokenizer, T5Model
154
  super().__init__()
155
  self.tokenizer = T5Tokenizer.from_pretrained(f'{t5_variant}')
156
  model = T5Model.from_pretrained(f'{t5_variant}')
 
258
  return x
259
 
260
 
261
+ class VisFocusPreTrainedModel(PreTrainedModel, GenerationMixin):
262
  config_class = VisFocusConfig
263
 
264
  def __init__(self, config):
 
416
 
417
  if self.config.vl_l1_loss:
418
  labels_ = labels.clone()
419
+ labels_[labels_ == -100] = 0 # -> replace the ignore_index with the pad_token id to calculate the text target for the vl loss
420
  with torch.no_grad():
421
  target = self.encoder(input_ids=labels_).last_hidden_state
422
  if target.shape[1] != hidden_states.shape[1]:
 
570
  inputs_embeds=inputs_tensor, **encoder_kwargs)
571
 
572
  return model_kwargs
 
 
 
 
 
 
 
 
 
573
 
574
  def set_task_name(self, task_name):
575
  if task_name:
 
579
  return torch.ones((inp.shape[:2]), dtype=torch.int32).to(self.device)
580
 
581
 
582
+ class VisFocusModelForLocalizedMaskedLanguageModeling(VisFocusPreTrainedModel):
583
  def __init__(self, config):
584
  super().__init__(config)
585
  self.set_task_name('mpm')
 
598
  **kwargs):
599
  if not kwargs.get('encoder_outputs'):
600
  if self.task_name == 'ocr':
601
+ # NOTE: not supported yet
602
  input_ids = None
603
  if not hasattr(self, 'prompt_embeds'):
604
  prompt = 'what is written in this document?'
 
676
  inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
677
  return inputs, input_name, model_kwargs
678
 
 
 
 
 
 
679
 
680
  class VisFocusModelForImageTextToText(VisFocusModelForLocalizedMaskedLanguageModeling):
681
  def __init__(self, config):
 
749
  text_embeds = self.shared(input_ids) # for concat, use direct the T5 nn.embeddings
750
  return text_embeds, vision_embeds, attention_mask
751
 
 
 
 
 
 
752
 
753
  def _to_cuda(sample, device=torch.device('cuda')):
754
  if isinstance(sample, torch.Tensor):
 
791
  cfg = VisFocusConfig.from_pretrained('configs/config.json')
792
  cfg.push_to_hub('ofirab/visfocus-base-docvqa')
793
  model = VisFocusModelForImageTextToText(cfg)
794
+
795
+ VisFocusConfig.register_for_auto_class()
796
+ VisFocusPreTrainedModel.register_for_auto_class("AutoModel")
797
+ VisFocusModelForImageTextToText.register_for_auto_class("AutoModelForImageTextToText")
798
+
799
  model.push_to_hub('ofirab/visfocus-base-docvqa')
800
+ pr = VisFocusImageProcessor(is_train=False)
801
+ tokenizer = AutoTokenizer.from_pretrained('ofirab/visfocus-base-docvqa')
802
+ prr = VisFocusProcessor(pr, tokenizer)
803
  model.to(DEVICE)