ucaslcl commited on
Commit
35202c0
·
verified ·
1 Parent(s): 5b7a219

Update modeling_GOT.py

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +50 -26
modeling_GOT.py CHANGED
@@ -484,7 +484,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
- def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False):
488
 
489
  self.disable_torch_init()
490
 
@@ -565,18 +565,30 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
565
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
566
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
 
568
-
569
- with torch.autocast("cuda", dtype=torch.bfloat16):
570
- output_ids = self.generate(
571
- input_ids,
572
- images=[image_tensor_1.unsqueeze(0).half().cuda()],
573
- do_sample=False,
574
- num_beams = 1,
575
- no_repeat_ngram_size = 20,
576
- # streamer=streamer,
577
- max_new_tokens=4096,
578
- stopping_criteria=[stopping_criteria]
579
- )
 
 
 
 
 
 
 
 
 
 
 
 
580
 
581
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
582
 
@@ -716,7 +728,7 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
716
  return processed_images
717
 
718
 
719
- def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False):
720
  # Model
721
  self.disable_torch_init()
722
  multi_page=False
@@ -807,18 +819,30 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
807
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
808
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
809
 
810
-
811
- with torch.autocast("cuda", dtype=torch.bfloat16):
812
- output_ids = self.generate(
813
- input_ids,
814
- images=[image_list.half().cuda()],
815
- do_sample=False,
816
- num_beams = 1,
817
- # no_repeat_ngram_size = 20,
818
- # streamer=streamer,
819
- max_new_tokens=4096,
820
- stopping_criteria=[stopping_criteria]
821
- )
 
 
 
 
 
 
 
 
 
 
 
 
822
 
823
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
824
 
 
484
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
485
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
486
 
487
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
488
 
489
  self.disable_torch_init()
490
 
 
565
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
566
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
567
 
568
+ if stream_flag:
569
+ with torch.autocast("cuda", dtype=torch.bfloat16):
570
+ output_ids = self.generate(
571
+ input_ids,
572
+ images=[image_tensor_1.unsqueeze(0).half().cuda()],
573
+ do_sample=False,
574
+ num_beams = 1,
575
+ no_repeat_ngram_size = 20,
576
+ streamer=streamer,
577
+ max_new_tokens=4096,
578
+ stopping_criteria=[stopping_criteria]
579
+ )
580
+ else:
581
+ with torch.autocast("cuda", dtype=torch.bfloat16):
582
+ output_ids = self.generate(
583
+ input_ids,
584
+ images=[image_tensor_1.unsqueeze(0).half().cuda()],
585
+ do_sample=False,
586
+ num_beams = 1,
587
+ no_repeat_ngram_size = 20,
588
+ # streamer=streamer,
589
+ max_new_tokens=4096,
590
+ stopping_criteria=[stopping_criteria]
591
+ )
592
 
593
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
594
 
 
728
  return processed_images
729
 
730
 
731
+ def chat_crop(self, tokenizer, image_file, ocr_type, render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag = False):
732
  # Model
733
  self.disable_torch_init()
734
  multi_page=False
 
819
  stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
820
  streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
821
 
822
+ if stream_flag:
823
+ with torch.autocast("cuda", dtype=torch.bfloat16):
824
+ output_ids = self.generate(
825
+ input_ids,
826
+ images=[image_list.half().cuda()],
827
+ do_sample=False,
828
+ num_beams = 1,
829
+ # no_repeat_ngram_size = 20,
830
+ streamer=streamer,
831
+ max_new_tokens=4096,
832
+ stopping_criteria=[stopping_criteria]
833
+ )
834
+ else:
835
+ with torch.autocast("cuda", dtype=torch.bfloat16):
836
+ output_ids = self.generate(
837
+ input_ids,
838
+ images=[image_list.half().cuda()],
839
+ do_sample=False,
840
+ num_beams = 1,
841
+ # no_repeat_ngram_size = 20,
842
+ # streamer=streamer,
843
+ max_new_tokens=4096,
844
+ stopping_criteria=[stopping_criteria]
845
+ )
846
 
847
  outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
848