ucaslcl commited on
Commit
d610c94
·
verified ·
1 Parent(s): a31b327

Update modeling_GOT.py

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +181 -1
modeling_GOT.py CHANGED
@@ -541,7 +541,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
541
  offset=0,
542
  sep_style=SeparatorStyle.MPT,
543
  sep="<|im_end|>",
544
- )
545
 
546
  conv = conv_mpt.copy()
547
  conv.append_message(conv.roles[0], qs)
@@ -657,3 +657,183 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
657
  # with open(html_path_2, 'w') as web_f_new:
658
  # web_f_new.write(new_web)
659
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  offset=0,
542
  sep_style=SeparatorStyle.MPT,
543
  sep="<|im_end|>",
544
+ )
545
 
546
  conv = conv_mpt.copy()
547
  conv.append_message(conv.roles[0], qs)
 
657
  # with open(html_path_2, 'w') as web_f_new:
658
  # web_f_new.write(new_web)
659
 
660
+ def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
661
+
662
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
663
+ best_ratio_diff = float('inf')
664
+ best_ratio = (1, 1)
665
+ area = width * height
666
+ for ratio in target_ratios:
667
+ target_aspect_ratio = ratio[0] / ratio[1]
668
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
669
+ if ratio_diff < best_ratio_diff:
670
+ best_ratio_diff = ratio_diff
671
+ best_ratio = ratio
672
+ elif ratio_diff == best_ratio_diff:
673
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
674
+ best_ratio = ratio
675
+ # print(f'width: {width}, height: {height}, best_ratio: {best_ratio}')
676
+ return best_ratio
677
+
678
+ orig_width, orig_height = image.size
679
+ aspect_ratio = orig_width / orig_height
680
+
681
+ # calculate the existing image aspect ratio
682
+ target_ratios = set(
683
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
684
+ i * j <= max_num and i * j >= min_num)
685
+ # print(target_ratios)
686
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
687
+
688
+ # find the closest aspect ratio to the target
689
+ target_aspect_ratio = find_closest_aspect_ratio(
690
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
691
+
692
+ # print(target_aspect_ratio)
693
+ # calculate the target width and height
694
+ target_width = image_size * target_aspect_ratio[0]
695
+ target_height = image_size * target_aspect_ratio[1]
696
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
697
+
698
+ # resize the image
699
+ resized_img = image.resize((target_width, target_height))
700
+ processed_images = []
701
+ for i in range(blocks):
702
+ box = (
703
+ (i % (target_width // image_size)) * image_size,
704
+ (i // (target_width // image_size)) * image_size,
705
+ ((i % (target_width // image_size)) + 1) * image_size,
706
+ ((i // (target_width // image_size)) + 1) * image_size
707
+ )
708
+ # split the image
709
+ split_img = resized_img.crop(box)
710
+ processed_images.append(split_img)
711
+ assert len(processed_images) == blocks
712
+ if use_thumbnail and len(processed_images) != 1:
713
+ thumbnail_img = image.resize((image_size, image_size))
714
+ processed_images.append(thumbnail_img)
715
+ return processed_images
716
+
717
+
718
+ def chat_crop(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, multi_page=False):
719
+ # Model
720
+ self.disable_torch_init()
721
+
722
+
723
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
724
+
725
+ use_im_start_end = True
726
+
727
+
728
+ image_token_len = 256
729
+
730
+ image_list = []
731
+
732
+ if multi_page:
733
+ qs = 'OCR with format across multi pages: '
734
+ # only for png files
735
+ import glob
736
+ from natsort import natsorted
737
+ patches = glob.glob(image_file + '/*png')
738
+ patches = natsorted(patches)
739
+ sub_images = []
740
+ for sub_image in patches:
741
+ sub_images.append(self.load_image(sub_image))
742
+
743
+ ll = len(patches)
744
+
745
+ else:
746
+ qs = 'OCR with format upon the patch reference: '
747
+ img = self.load_image(image_file)
748
+ sub_images = self.dynamic_preprocess(img)
749
+ ll = len(sub_images)
750
+
751
+ for image in sub_images:
752
+ image_tensor_1 = image_processor_high(image)
753
+ image_list.append(image_tensor_1)
754
+
755
+
756
+ image_list = torch.stack(image_list)
757
+
758
+ print('====new images batch size======: ',image_list.shape)
759
+
760
+
761
+ if use_im_start_end:
762
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len*ll + DEFAULT_IM_END_TOKEN + '\n' + qs
763
+ else:
764
+ qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
765
+
766
+
767
+ conv_mpt = Conversation(
768
+ system="""<|im_start|>system
769
+ You should follow the instructions carefully and explain your answers in detail.""",
770
+ # system = None,
771
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
772
+ version="mpt",
773
+ messages=(),
774
+ offset=0,
775
+ sep_style=SeparatorStyle.MPT,
776
+ sep="<|im_end|>",
777
+ )
778
+
779
+ conv = conv_mpt.copy()
780
+ conv.append_message(conv.roles[0], qs)
781
+ conv.append_message(conv.roles[1], None)
782
+ prompt = conv.get_prompt()
783
+
784
+
785
+ inputs = tokenizer([prompt])
786
+
787
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
788
+
789
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
790
+ keywords = [stop_str]
791
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
792
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
793
+
794
+
795
+ with torch.autocast("cuda", dtype=torch.bfloat16):
796
+ output_ids = self.generate(
797
+ input_ids,
798
+ images=[(image_list.half().cuda(), image_list.half().cuda())],
799
+ do_sample=False,
800
+ num_beams = 1,
801
+ # no_repeat_ngram_size = 20,
802
+ streamer=streamer,
803
+ max_new_tokens=4096,
804
+ stopping_criteria=[stopping_criteria]
805
+ )
806
+
807
+ # if render:
808
+ # print('==============rendering===============')
809
+ # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
810
+
811
+ # if outputs.endswith(stop_str):
812
+ # outputs = outputs[:-len(stop_str)]
813
+ # outputs = outputs.strip()
814
+
815
+ # html_path = "./render_tools/" + "/content-mmd-to-html.html"
816
+ # html_path_2 = "./results/demo.html"
817
+ # right_num = outputs.count('\\right')
818
+ # left_num = outputs.count('\left')
819
+
820
+ # if right_num != left_num:
821
+ # outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
822
+
823
+
824
+ # outputs = outputs.replace('"', '``').replace('$', '')
825
+
826
+ # outputs_list = outputs.split('\n')
827
+ # gt= ''
828
+ # for out in outputs_list:
829
+ # gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
830
+
831
+ # gt = gt[:-2]
832
+
833
+ # with open(html_path, 'r') as web_f:
834
+ # lines = web_f.read()
835
+ # lines = lines.split("const text =")
836
+ # new_web = lines[0] + 'const text =' + gt + lines[1]
837
+
838
+ # with open(html_path_2, 'w') as web_f_new:
839
+ # web_f_new.write(new_web)