ucaslcl commited on
Commit
35a1672
·
verified ·
1 Parent(s): 5ab5846

Update modeling_GOT.py

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +64 -518
modeling_GOT.py CHANGED
@@ -1,145 +1,16 @@
1
- from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM, StoppingCriteria, TextStreamer
 
 
2
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
3
  from typing import List, Optional, Tuple, Union
4
- from transformers.cache_utils import Cache
5
- import requests
6
- from PIL import Image
7
- from io import BytesIO
8
  import torch
9
  import torch.nn as nn
 
10
  from torch.nn import CrossEntropyLoss
11
- from .got_vision_b import build_GOT_vit_b
12
- from torchvision import transforms
13
- from torchvision.transforms.functional import InterpolationMode
14
- import dataclasses
15
- from megfile import smart_open
16
-
17
- DEFAULT_IMAGE_TOKEN = "<image>"
18
- DEFAULT_IMAGE_PATCH_TOKEN = '<imgpad>'
19
- DEFAULT_IM_START_TOKEN = '<img>'
20
- DEFAULT_IM_END_TOKEN = '</img>'
21
-
22
- from enum import auto, Enum
23
- class SeparatorStyle(Enum):
24
- """Different separator style."""
25
- SINGLE = auto()
26
- TWO = auto()
27
- MPT = auto()
28
-
29
-
30
- @dataclasses.dataclass
31
- class Conversation:
32
- """A class that keeps all conversation history."""
33
- system: str
34
- roles: List[str]
35
- messages: List[List[str]]
36
- offset: int
37
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
38
- sep: str = "<|im_end|>"
39
- sep2: str = None
40
- version: str = "Unknown"
41
-
42
- skip_next: bool = False
43
-
44
- def get_prompt(self):
45
- if self.sep_style == SeparatorStyle.SINGLE:
46
- ret = self.system + self.sep + '\n'
47
- for role, message in self.messages:
48
- if message:
49
- if type(message) is tuple:
50
- message, _, _ = message
51
- ret += role + ": " + message + self.sep
52
- else:
53
- ret += role + ":"
54
- return ret
55
- elif self.sep_style == SeparatorStyle.TWO:
56
- seps = [self.sep, self.sep2]
57
- ret = self.system + seps[0]
58
- for i, (role, message) in enumerate(self.messages):
59
- if message:
60
- if type(message) is tuple:
61
- message, _, _ = message
62
- ret += role + ": " + message + seps[i % 2]
63
- else:
64
- ret += role + ":"
65
- return ret
66
- if self.sep_style == SeparatorStyle.MPT:
67
- if self.system:
68
- ret = self.system + self.sep
69
- else:
70
- ret = ''
71
- for role, message in self.messages:
72
- if message:
73
- if type(message) is tuple:
74
- message, _, _ = message
75
- ret += role + message + self.sep
76
- else:
77
- ret += role
78
- return ret
79
- else:
80
- raise ValueError(f"Invalid style: {self.sep_style}")
81
-
82
-
83
- def append_message(self, role, message):
84
- self.messages.append([role, message])
85
-
86
- def copy(self):
87
- return Conversation(
88
- system=self.system,
89
- roles=self.roles,
90
- messages=[[x, y] for x, y in self.messages],
91
- offset=self.offset,
92
- sep_style=self.sep_style,
93
- sep=self.sep,
94
- sep2=self.sep2)
95
-
96
-
97
-
98
- class KeywordsStoppingCriteria(StoppingCriteria):
99
- def __init__(self, keywords, tokenizer, input_ids):
100
- self.keywords = keywords
101
- self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
102
- self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
103
- self.tokenizer = tokenizer
104
- self.start_len = None
105
- self.input_ids = input_ids
106
-
107
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
108
- if self.start_len is None:
109
- self.start_len = self.input_ids.shape[1]
110
- else:
111
- for keyword_id in self.keyword_ids:
112
- if output_ids[0, -1] == keyword_id:
113
- return True
114
- outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
115
- for keyword in self.keywords:
116
- if keyword in outputs:
117
- return True
118
- return False
119
-
120
-
121
- class GOTImageEvalProcessor:
122
- def __init__(self, image_size=384, mean=None, std=None):
123
- if mean is None:
124
- mean = (0.48145466, 0.4578275, 0.40821073)
125
- if std is None:
126
- std = (0.26862954, 0.26130258, 0.27577711)
127
-
128
- self.normalize = transforms.Normalize(mean, std)
129
-
130
- self.transform = transforms.Compose(
131
- [
132
- transforms.Resize(
133
- (image_size, image_size), interpolation=InterpolationMode.BICUBIC
134
- ),
135
- transforms.ToTensor(),
136
- self.normalize,
137
- ]
138
- )
139
- def __call__(self, item):
140
- return self.transform(item)
141
-
142
-
143
 
144
  class GOTConfig(Qwen2Config):
145
  model_type = "GOT"
@@ -151,7 +22,7 @@ class GOTQwenModel(Qwen2Model):
151
  def __init__(self, config: Qwen2Config):
152
  super(GOTQwenModel, self).__init__(config)
153
 
154
- self.vision_tower_high = build_GOT_vit_b()
155
 
156
  self.mm_projector_vary = nn.Linear(1024, 1024)
157
 
@@ -167,8 +38,13 @@ class GOTQwenModel(Qwen2Model):
167
  device="cuda"
168
  ):
169
 
 
 
 
 
 
 
170
 
171
- image_processor_high = GOTImageEvalProcessor(image_size=1024)
172
 
173
  self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
174
 
@@ -179,17 +55,20 @@ class GOTQwenModel(Qwen2Model):
179
 
180
  self.config.vision_tower = vision_tower
181
  self.config.image_token_len = image_token_len
182
-
183
  self.config.use_im_start_end = True
184
 
185
  self.config.vision_select_layer = vision_select_layer
186
  self.config.freeze_vision_tower = freeze_vision_tower
187
 
188
  return dict(
 
189
  image_processor_high=image_processor_high,
190
  image_token_len=image_token_len,
191
  )
192
 
 
 
193
 
194
  def forward(
195
  self,
@@ -219,6 +98,9 @@ class GOTQwenModel(Qwen2Model):
219
 
220
 
221
  if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
 
 
 
222
  use_im_start_end = getattr(self.config, "use_im_start_end", -1)
223
 
224
  vision_select_layer = getattr(self.config, "vision_select_layer", -1)
@@ -233,20 +115,31 @@ class GOTQwenModel(Qwen2Model):
233
 
234
  im_end_token = 151858
235
 
 
 
236
  image_features = []
237
 
238
-
239
  for image in images:
240
- P, C, H, W = image.shape
 
 
 
 
 
 
241
  if P == 1:
242
  with torch.set_grad_enabled(False):
243
- cnn_feature = vision_tower_high(image)
 
244
  cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
 
 
245
  image_feature = self.mm_projector_vary(cnn_feature)
246
  image_features.append(image_feature)
247
 
248
  else:
249
- image_patches = torch.unbind(image)
250
  image_patches_features = []
251
  for image_patch in image_patches:
252
  image_p = torch.stack([image_patch])
@@ -256,15 +149,21 @@ class GOTQwenModel(Qwen2Model):
256
  image_feature_p = self.mm_projector_vary(cnn_feature_p)
257
  image_patches_features.append(image_feature_p)
258
  image_feature = torch.cat(image_patches_features, dim=1)
 
 
 
259
  image_features.append(image_feature)
260
 
261
 
 
262
  dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
 
263
  dummy_image_features = dummy_image_features_2
264
  use_im_start_end = True
265
  new_input_embeds = []
266
  for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
267
  if (cur_input_ids == im_patch_token).sum() == 0:
 
268
  cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
269
  new_input_embeds.append(cur_input_embeds)
270
  continue
@@ -323,6 +222,11 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
323
  def get_model(self):
324
  return self.model
325
 
 
 
 
 
 
326
  def forward(
327
  self,
328
  input_ids: torch.LongTensor = None,
@@ -344,6 +248,12 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
344
  )
345
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
346
 
 
 
 
 
 
 
347
  outputs = self.model(
348
  input_ids=input_ids,
349
  past_key_values=past_key_values,
@@ -358,6 +268,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
358
 
359
  )
360
 
 
361
  hidden_states = outputs[0]
362
  logits = self.lm_head(hidden_states)
363
  logits = logits.float()
@@ -457,389 +368,24 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
457
  ):
458
  config = self.get_model().config
459
 
460
-
 
461
  self.resize_token_embeddings(len(tokenizer))
 
462
 
463
  config.im_patch_token = 151859
464
 
465
  config.use_im_start_end = True
466
 
 
467
  if config.use_im_start_end:
 
468
  self.resize_token_embeddings(len(tokenizer))
469
- config.im_start_token, config.im_end_token = 151857, 151858
470
-
471
- def load_image(self, image_file):
472
- if image_file.startswith('http') or image_file.startswith('https'):
473
- response = requests.get(image_file)
474
- image = Image.open(BytesIO(response.content)).convert('RGB')
475
- else:
476
- image = Image.open(image_file).convert('RGB')
477
- return image
478
-
479
- def disable_torch_init(self):
480
- """
481
- Disable the redundant torch default initialization to accelerate model creation.
482
- """
483
- import torch
484
- setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
- setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
-
487
- def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None):
488
-
489
- self.disable_torch_init()
490
-
491
-
492
- image_processor_high = GOTImageEvalProcessor(image_size=1024)
493
-
494
- use_im_start_end = True
495
-
496
- image_token_len = 256
497
-
498
- image = self.load_image(image_file)
499
-
500
- w, h = image.size
501
-
502
- if ocr_type == 'format':
503
- qs = 'OCR with format: '
504
- else:
505
- qs = 'OCR: '
506
-
507
- if ocr_box:
508
- bbox = eval(ocr_box)
509
- if len(bbox) == 2:
510
- bbox[0] = int(bbox[0]/w*1000)
511
- bbox[1] = int(bbox[1]/h*1000)
512
- if len(bbox) == 4:
513
- bbox[0] = int(bbox[0]/w*1000)
514
- bbox[1] = int(bbox[1]/h*1000)
515
- bbox[2] = int(bbox[2]/w*1000)
516
- bbox[3] = int(bbox[3]/h*1000)
517
- if ocr_type == 'format':
518
- qs = str(bbox) + ' ' + 'OCR with format: '
519
- else:
520
- qs = str(bbox) + ' ' + 'OCR: '
521
-
522
- if ocr_color:
523
- if ocr_type == 'format':
524
- qs = '[' + ocr_color + ']' + ' ' + 'OCR with format: '
525
- else:
526
- qs = '[' + ocr_color + ']' + ' ' + 'OCR: '
527
-
528
- if use_im_start_end:
529
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
530
- else:
531
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
532
-
533
-
534
- conv_mpt = Conversation(
535
- system="""<|im_start|>system
536
- You should follow the instructions carefully and explain your answers in detail.""",
537
- # system = None,
538
- roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
539
- version="mpt",
540
- messages=(),
541
- offset=0,
542
- sep_style=SeparatorStyle.MPT,
543
- sep="<|im_end|>",
544
- )
545
-
546
- conv = conv_mpt.copy()
547
- conv.append_message(conv.roles[0], qs)
548
- conv.append_message(conv.roles[1], None)
549
- prompt = conv.get_prompt()
550
-
551
- print(prompt)
552
 
553
- inputs = tokenizer([prompt])
554
-
555
- image_tensor_1 = image_processor_high(image)
556
-
557
- input_ids = torch.as_tensor(inputs.input_ids).cuda()
558
-
559
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
560
- keywords = [stop_str]
561
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
562
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
563
-
564
-
565
- with torch.autocast("cuda", dtype=torch.bfloat16):
566
- output_ids = self.generate(
567
- input_ids,
568
- images=[image_tensor_1.unsqueeze(0).half().cuda()],
569
- do_sample=False,
570
- num_beams = 1,
571
- no_repeat_ngram_size = 20,
572
- streamer=streamer,
573
- max_new_tokens=4096,
574
- stopping_criteria=[stopping_criteria]
575
- )
576
-
577
-
578
- if render:
579
- print('==============rendering===============')
580
- from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
581
-
582
- outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
583
-
584
- if outputs.endswith(stop_str):
585
- outputs = outputs[:-len(stop_str)]
586
- outputs = outputs.strip()
587
-
588
- if '**kern' in outputs:
589
- import verovio
590
- from cairosvg import svg2png
591
- import cv2
592
- import numpy as np
593
- tk = verovio.toolkit()
594
- tk.loadData(outputs)
595
- tk.setOptions({"pageWidth": 2100, "footer": 'none',
596
- 'barLineWidth': 0.5, 'beamMaxSlope': 15,
597
- 'staffLineWidth': 0.2, 'spacingStaff': 6})
598
- tk.getPageCount()
599
- svg = tk.renderToSVG()
600
- svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
601
-
602
- svg_to_html(svg, save_render_file)
603
-
604
- if ocr_type == 'format' and '**kern' not in outputs:
605
-
606
-
607
- if '\\begin{tikzpicture}' not in outputs:
608
- html_path_2 = save_render_file
609
- right_num = outputs.count('\\right')
610
- left_num = outputs.count('\left')
611
-
612
- if right_num != left_num:
613
- outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
614
-
615
-
616
- outputs = outputs.replace('"', '``').replace('$', '')
617
-
618
- outputs_list = outputs.split('\n')
619
- gt= ''
620
- for out in outputs_list:
621
- gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
622
-
623
- gt = gt[:-2]
624
-
625
-
626
- lines = content_mmd_to_html
627
- lines = lines.split("const text =")
628
- new_web = lines[0] + 'const text =' + gt + lines[1]
629
-
630
- else:
631
- html_path_2 = save_render_file
632
- outputs = outputs.translate(translation_table)
633
- outputs_list = outputs.split('\n')
634
- gt= ''
635
- for out in outputs_list:
636
- if out:
637
- if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
638
- while out[-1] == ' ':
639
- out = out[:-1]
640
- if out is None:
641
- break
642
-
643
- if out:
644
- if out[-1] != ';':
645
- gt += out[:-1] + ';\n'
646
- else:
647
- gt += out + '\n'
648
- else:
649
- gt += out + '\n'
650
-
651
-
652
- lines = tik_html
653
- lines = lines.split("const text =")
654
- new_web = lines[0] + gt + lines[1]
655
-
656
- with smart_open(html_path_2, 'w') as web_f_new:
657
- web_f_new.write(new_web)
658
-
659
- def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
660
-
661
- def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
662
- best_ratio_diff = float('inf')
663
- best_ratio = (1, 1)
664
- area = width * height
665
- for ratio in target_ratios:
666
- target_aspect_ratio = ratio[0] / ratio[1]
667
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
668
- if ratio_diff < best_ratio_diff:
669
- best_ratio_diff = ratio_diff
670
- best_ratio = ratio
671
- elif ratio_diff == best_ratio_diff:
672
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
673
- best_ratio = ratio
674
- # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
675
- return best_ratio
676
-
677
- orig_width, orig_height = image.size
678
- aspect_ratio = orig_width / orig_height
679
-
680
- # calculate the existing image aspect ratio
681
- target_ratios = set(
682
- (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
683
- i * j <= max_num and i * j >= min_num)
684
- # print(target_ratios)
685
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
686
-
687
- # find the closest aspect ratio to the target
688
- target_aspect_ratio = find_closest_aspect_ratio(
689
- aspect_ratio, target_ratios, orig_width, orig_height, image_size)
690
-
691
- # print(target_aspect_ratio)
692
- # calculate the target width and height
693
- target_width = image_size * target_aspect_ratio[0]
694
- target_height = image_size * target_aspect_ratio[1]
695
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
696
-
697
- # resize the image
698
- resized_img = image.resize((target_width, target_height))
699
- processed_images = []
700
- for i in range(blocks):
701
- box = (
702
- (i % (target_width // image_size)) * image_size,
703
- (i // (target_width // image_size)) * image_size,
704
- ((i % (target_width // image_size)) + 1) * image_size,
705
- ((i // (target_width // image_size)) + 1) * image_size
706
- )
707
- # split the image
708
- split_img = resized_img.crop(box)
709
- processed_images.append(split_img)
710
- assert len(processed_images) == blocks
711
- if use_thumbnail and len(processed_images) != 1:
712
- thumbnail_img = image.resize((image_size, image_size))
713
- processed_images.append(thumbnail_img)
714
- return processed_images
715
-
716
-
717
- def chat_plus(self, tokenizer, image_file_list, render=False, save_render_file=None):
718
- # Model
719
- self.disable_torch_init()
720
- multi_page=False
721
-
722
-
723
- image_processor_high = GOTImageEvalProcessor(image_size=1024)
724
-
725
- use_im_start_end = True
726
-
727
-
728
- image_token_len = 256
729
-
730
- image_list = []
731
-
732
- if len(image_file_list)>1:
733
- multi_page = True
734
-
735
- if multi_page:
736
- qs = 'OCR with format across multi pages: '
737
- # only for png files
738
- import glob
739
- from natsort import natsorted
740
- # patches = glob.glob(image_file + '/*png')
741
- patches = image_file_list
742
- patches = natsorted(patches)
743
- sub_images = []
744
- for sub_image in patches:
745
- sub_images.append(self.load_image(sub_image))
746
-
747
- ll = len(patches)
748
- print(patches)
749
- print("len ll: ", ll)
750
-
751
- else:
752
- qs = 'OCR with format upon the patch reference: '
753
- img = self.load_image(image_file_list[0])
754
- sub_images = self.dynamic_preprocess(img)
755
- ll = len(sub_images)
756
-
757
- for image in sub_images:
758
- image_tensor_1 = image_processor_high(image)
759
- image_list.append(image_tensor_1)
760
-
761
-
762
- image_list = torch.stack(image_list)
763
-
764
- print('====new images batch size======: ',image_list.shape)
765
-
766
-
767
- if use_im_start_end:
768
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
769
- else:
770
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
771
-
772
-
773
- conv_mpt = Conversation(
774
- system="""<|im_start|>system
775
- You should follow the instructions carefully and explain your answers in detail.""",
776
- # system = None,
777
- roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
778
- version="mpt",
779
- messages=(),
780
- offset=0,
781
- sep_style=SeparatorStyle.MPT,
782
- sep="<|im_end|>",
783
- )
784
-
785
- conv = conv_mpt.copy()
786
- conv.append_message(conv.roles[0], qs)
787
- conv.append_message(conv.roles[1], None)
788
- prompt = conv.get_prompt()
789
-
790
- print(prompt)
791
-
792
- inputs = tokenizer([prompt])
793
-
794
- input_ids = torch.as_tensor(inputs.input_ids).cuda()
795
-
796
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
797
- keywords = [stop_str]
798
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
799
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
800
-
801
-
802
- with torch.autocast("cuda", dtype=torch.bfloat16):
803
- output_ids = self.generate(
804
- input_ids,
805
- images=[image_list.half().cuda()],
806
- do_sample=False,
807
- num_beams = 1,
808
- # no_repeat_ngram_size = 20,
809
- streamer=streamer,
810
- max_new_tokens=4096,
811
- stopping_criteria=[stopping_criteria]
812
- )
813
-
814
- if render:
815
- print('==============rendering===============')
816
- from .render_tools import content_mmd_to_html
817
- outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
818
-
819
- if outputs.endswith(stop_str):
820
- outputs = outputs[:-len(stop_str)]
821
- outputs = outputs.strip()
822
-
823
- html_path_2 = save_render_file
824
- right_num = outputs.count('\\right')
825
- left_num = outputs.count('\left')
826
-
827
- if right_num != left_num:
828
- outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
829
 
830
 
831
- outputs = outputs.replace('"', '``').replace('$', '')
 
832
 
833
- outputs_list = outputs.split('\n')
834
- gt= ''
835
- for out in outputs_list:
836
- gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
837
-
838
- gt = gt[:-2]
839
-
840
- lines = content_mmd_to_html
841
- lines = lines.split("const text =")
842
- new_web = lines[0] + 'const text =' + gt + lines[1]
843
-
844
- with smart_open(html_path_2, 'w') as web_f_new:
845
- web_f_new.write(new_web)
 
1
+ from transformers import AutoConfig, AutoModelForCausalLM, \
2
+ Qwen2Config, Qwen2Model, Qwen2ForCausalLM, \
3
+ CLIPVisionModel, CLIPImageProcessor
4
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
5
  from typing import List, Optional, Tuple, Union
6
+ from transformers.cache_utils import Cache, DynamicCache
 
 
 
7
  import torch
8
  import torch.nn as nn
9
+ import torch.nn.functional as F
10
  from torch.nn import CrossEntropyLoss
11
+ from GOT.utils.constants import *
12
+ from GOT.model.vision_encoder.vary_b import build_vary_vit_b
13
+ from GOT.model.plug.blip_process import BlipImageEvalProcessor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class GOTConfig(Qwen2Config):
16
  model_type = "GOT"
 
22
  def __init__(self, config: Qwen2Config):
23
  super(GOTQwenModel, self).__init__(config)
24
 
25
+ self.vision_tower_high = build_vary_vit_b()
26
 
27
  self.mm_projector_vary = nn.Linear(1024, 1024)
28
 
 
38
  device="cuda"
39
  ):
40
 
41
+ # Vary old codes, not use in GOT
42
+ image_processor = BlipImageEvalProcessor(image_size=1024)
43
+ # 1024*1024
44
+
45
+ image_processor_high = BlipImageEvalProcessor(image_size=1024)
46
+
47
 
 
48
 
49
  self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
50
 
 
55
 
56
  self.config.vision_tower = vision_tower
57
  self.config.image_token_len = image_token_len
58
+ # self.config.use_im_start_end = use_im_start_end
59
  self.config.use_im_start_end = True
60
 
61
  self.config.vision_select_layer = vision_select_layer
62
  self.config.freeze_vision_tower = freeze_vision_tower
63
 
64
  return dict(
65
+ image_processor=image_processor,
66
  image_processor_high=image_processor_high,
67
  image_token_len=image_token_len,
68
  )
69
 
70
+ # def get_input_embeddings(self, x):
71
+ # return self.wte(x)
72
 
73
  def forward(
74
  self,
 
98
 
99
 
100
  if vision_tower_high is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
101
+ # if True:
102
+ # assert type(images) is list, ValueError("To fit both interleave and conversation, images must be list of batches of images")
103
+ # print(im)
104
  use_im_start_end = getattr(self.config, "use_im_start_end", -1)
105
 
106
  vision_select_layer = getattr(self.config, "vision_select_layer", -1)
 
115
 
116
  im_end_token = 151858
117
 
118
+
119
+
120
  image_features = []
121
 
122
+ print(images.shape)
123
  for image in images:
124
+ P, C, H, W = image[1].shape
125
+ # with torch.set_grad_enabled(True):
126
+ # # print(image[1].shape)
127
+ # cnn_feature = vision_tower_high(image[1])
128
+ # cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256 1024
129
+ # # image_features.append(cnn_feature)
130
+ # image_features_2.append(cnn_feature)
131
  if P == 1:
132
  with torch.set_grad_enabled(False):
133
+ # print(image[1].shape)
134
+ cnn_feature = vision_tower_high(image[1])
135
  cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
136
+ # image_features.append(cnn_feature)
137
+ # image_features_2.append(cnn_feature)
138
  image_feature = self.mm_projector_vary(cnn_feature)
139
  image_features.append(image_feature)
140
 
141
  else:
142
+ image_patches = torch.unbind(image[1])
143
  image_patches_features = []
144
  for image_patch in image_patches:
145
  image_p = torch.stack([image_patch])
 
149
  image_feature_p = self.mm_projector_vary(cnn_feature_p)
150
  image_patches_features.append(image_feature_p)
151
  image_feature = torch.cat(image_patches_features, dim=1)
152
+ # print(P)
153
+ # print(image_feature.shape)
154
+ # exit()
155
  image_features.append(image_feature)
156
 
157
 
158
+
159
  dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
160
+ # dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)
161
  dummy_image_features = dummy_image_features_2
162
  use_im_start_end = True
163
  new_input_embeds = []
164
  for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
165
  if (cur_input_ids == im_patch_token).sum() == 0:
166
+ # multimodal LLM, but the current sample is not multimodal
167
  cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
168
  new_input_embeds.append(cur_input_embeds)
169
  continue
 
222
  def get_model(self):
223
  return self.model
224
 
225
+ # def _set_gradient_checkpointing(self, module, value=False):
226
+ # if isinstance(module, GOTQwenModel):
227
+ # module.gradient_checkpointing = value
228
+ # @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
229
+ # @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
230
  def forward(
231
  self,
232
  input_ids: torch.LongTensor = None,
 
248
  )
249
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
250
 
251
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
252
+ # print(input_ids)
253
+ # print(len(images))
254
+
255
+ # print(inputs_embeds)
256
+
257
  outputs = self.model(
258
  input_ids=input_ids,
259
  past_key_values=past_key_values,
 
268
 
269
  )
270
 
271
+
272
  hidden_states = outputs[0]
273
  logits = self.lm_head(hidden_states)
274
  logits = logits.float()
 
368
  ):
369
  config = self.get_model().config
370
 
371
+ # add image patch token <image>
372
+ # tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
373
  self.resize_token_embeddings(len(tokenizer))
374
+ # config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
375
 
376
  config.im_patch_token = 151859
377
 
378
  config.use_im_start_end = True
379
 
380
+ # add image start token <im_start> and end token <im_end>
381
  if config.use_im_start_end:
382
+ # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
383
  self.resize_token_embeddings(len(tokenizer))
384
+ # config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
385
 
386
+ config.im_start_token, config.im_end_token = 151857, 151858
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
 
389
+ AutoConfig.register("GOT", GOTConfig)
390
+ AutoModelForCausalLM.register(GOTConfig, GOTQwenForCausalLM)
391