ucaslcl commited on
Commit
5fb0b31
·
verified ·
1 Parent(s): 4c5895b

Update modeling_GOT.py

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +16 -15
modeling_GOT.py CHANGED
@@ -249,7 +249,7 @@ class GOTQwenModel(Qwen2Model):
249
  image_patches_features = []
250
  for image_patch in image_patches:
251
  image_p = torch.stack([image_patch])
252
- print(image_p.shape)
253
  with torch.set_grad_enabled(False):
254
  cnn_feature_p = vision_tower_high(image_p)
255
  cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
@@ -257,7 +257,6 @@ class GOTQwenModel(Qwen2Model):
257
  image_patches_features.append(image_feature_p)
258
  image_feature = torch.cat(image_patches_features, dim=1)
259
  image_features.append(image_feature)
260
- exit()
261
 
262
 
263
  dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
@@ -485,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
485
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
486
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
487
 
488
- def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None):
489
 
490
  self.disable_torch_init()
491
 
@@ -549,7 +548,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
549
  conv.append_message(conv.roles[1], None)
550
  prompt = conv.get_prompt()
551
 
552
- print(prompt)
 
553
 
554
  inputs = tokenizer([prompt])
555
 
@@ -570,7 +570,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
570
  do_sample=False,
571
  num_beams = 1,
572
  no_repeat_ngram_size = 20,
573
- streamer=streamer,
574
  max_new_tokens=4096,
575
  stopping_criteria=[stopping_criteria]
576
  )
@@ -715,7 +715,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
715
  return processed_images
716
 
717
 
718
- def chat_plus(self, tokenizer, image_file_list, render=False, save_render_file=None):
719
  # Model
720
  self.disable_torch_init()
721
  multi_page=False
@@ -730,8 +730,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
730
 
731
  image_list = []
732
 
733
- if len(image_file_list)>1:
734
- multi_page = True
735
 
736
  if multi_page:
737
  qs = 'OCR with format across multi pages: '
@@ -739,19 +739,19 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
739
  import glob
740
  # from natsort import natsorted
741
  # patches = glob.glob(image_file + '/*png')
742
- patches = image_file_list
743
  # patches = natsorted(patches)
744
  sub_images = []
745
  for sub_image in patches:
746
  sub_images.append(self.load_image(sub_image))
747
 
748
  ll = len(patches)
749
- print(patches)
750
- print("len ll: ", ll)
751
 
752
  else:
753
  qs = 'OCR with format upon the patch reference: '
754
- img = self.load_image(image_file_list[0])
755
  sub_images = self.dynamic_preprocess(img)
756
  ll = len(sub_images)
757
 
@@ -762,7 +762,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
762
 
763
  image_list = torch.stack(image_list)
764
 
765
- print('====new images batch size======: ',image_list.shape)
766
 
767
 
768
  if use_im_start_end:
@@ -788,7 +788,8 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
788
  conv.append_message(conv.roles[1], None)
789
  prompt = conv.get_prompt()
790
 
791
- print(prompt)
 
792
 
793
  inputs = tokenizer([prompt])
794
 
@@ -807,7 +808,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
807
  do_sample=False,
808
  num_beams = 1,
809
  # no_repeat_ngram_size = 20,
810
- streamer=streamer,
811
  max_new_tokens=4096,
812
  stopping_criteria=[stopping_criteria]
813
  )
 
249
  image_patches_features = []
250
  for image_patch in image_patches:
251
  image_p = torch.stack([image_patch])
252
+
253
  with torch.set_grad_enabled(False):
254
  cnn_feature_p = vision_tower_high(image_p)
255
  cnn_feature_p = cnn_feature_p.flatten(2).permute(0, 2, 1)
 
257
  image_patches_features.append(image_feature_p)
258
  image_feature = torch.cat(image_patches_features, dim=1)
259
  image_features.append(image_feature)
 
260
 
261
 
262
  dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
 
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False):
488
 
489
  self.disable_torch_init()
490
 
 
548
  conv.append_message(conv.roles[1], None)
549
  prompt = conv.get_prompt()
550
 
551
+ if print_prompt:
552
+ print(prompt)
553
 
554
  inputs = tokenizer([prompt])
555
 
 
570
  do_sample=False,
571
  num_beams = 1,
572
  no_repeat_ngram_size = 20,
573
+ # streamer=streamer,
574
  max_new_tokens=4096,
575
  stopping_criteria=[stopping_criteria]
576
  )
 
715
  return processed_images
716
 
717
 
718
+ def chat_plus(self, tokenizer, image_file, render=False, save_render_file=None, print_prompt=False):
719
  # Model
720
  self.disable_torch_init()
721
  multi_page=False
 
730
 
731
  image_list = []
732
 
733
+ # if len(image_file_list)>1:
734
+ # multi_page = True
735
 
736
  if multi_page:
737
  qs = 'OCR with format across multi pages: '
 
739
  import glob
740
  # from natsort import natsorted
741
  # patches = glob.glob(image_file + '/*png')
742
+ patches = image_file
743
  # patches = natsorted(patches)
744
  sub_images = []
745
  for sub_image in patches:
746
  sub_images.append(self.load_image(sub_image))
747
 
748
  ll = len(patches)
749
+ # print(patches)
750
+ # print("len ll: ", ll)
751
 
752
  else:
753
  qs = 'OCR with format upon the patch reference: '
754
+ img = self.load_image(image_file)
755
  sub_images = self.dynamic_preprocess(img)
756
  ll = len(sub_images)
757
 
 
762
 
763
  image_list = torch.stack(image_list)
764
 
765
+ print('====new images batch size======: \n',image_list.shape)
766
 
767
 
768
  if use_im_start_end:
 
788
  conv.append_message(conv.roles[1], None)
789
  prompt = conv.get_prompt()
790
 
791
+ if print_prompt:
792
+ print(prompt)
793
 
794
  inputs = tokenizer([prompt])
795
 
 
808
  do_sample=False,
809
  num_beams = 1,
810
  # no_repeat_ngram_size = 20,
811
+ # streamer=streamer,
812
  max_new_tokens=4096,
813
  stopping_criteria=[stopping_criteria]
814
  )