ucaslcl commited on
Commit
ccdbcff
·
verified ·
1 Parent(s): 35202c0

Update modeling_GOT.py

Browse files
Files changed (1) hide show
  1. modeling_GOT.py +70 -70
modeling_GOT.py CHANGED
@@ -590,84 +590,84 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
590
  stopping_criteria=[stopping_criteria]
591
  )
592
 
593
- outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
594
-
595
- if outputs.endswith(stop_str):
596
- outputs = outputs[:-len(stop_str)]
597
- outputs = outputs.strip()
598
- response_str = outputs
599
-
600
- if render:
601
- print('==============rendering===============')
602
- from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
603
-
604
- if '**kern' in outputs:
605
- import verovio
606
- tk = verovio.toolkit()
607
- tk.loadData(outputs)
608
- tk.setOptions({"pageWidth": 2100, "footer": 'none',
609
- 'barLineWidth': 0.5, 'beamMaxSlope': 15,
610
- 'staffLineWidth': 0.2, 'spacingStaff': 6})
611
- tk.getPageCount()
612
- svg = tk.renderToSVG()
613
- svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
614
-
615
- svg_to_html(svg, save_render_file)
616
-
617
- if ocr_type == 'format' and '**kern' not in outputs:
618
 
619
-
620
- if '\\begin{tikzpicture}' not in outputs:
621
- html_path_2 = save_render_file
622
- right_num = outputs.count('\\right')
623
- left_num = outputs.count('\left')
624
 
625
- if right_num != left_num:
626
- outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
 
 
 
 
 
 
 
 
627
 
 
628
 
629
- outputs = outputs.replace('"', '``').replace('$', '')
630
 
631
- outputs_list = outputs.split('\n')
632
- gt= ''
633
- for out in outputs_list:
634
- gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
635
-
636
- gt = gt[:-2]
637
-
638
-
639
- lines = content_mmd_to_html
640
- lines = lines.split("const text =")
641
- new_web = lines[0] + 'const text =' + gt + lines[1]
642
-
643
- else:
644
- html_path_2 = save_render_file
645
- outputs = outputs.translate(translation_table)
646
- outputs_list = outputs.split('\n')
647
- gt= ''
648
- for out in outputs_list:
649
- if out:
650
- if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
651
- while out[-1] == ' ':
652
- out = out[:-1]
653
- if out is None:
654
- break
655
-
656
- if out:
657
- if out[-1] != ';':
658
- gt += out[:-1] + ';\n'
659
- else:
660
- gt += out + '\n'
661
- else:
662
- gt += out + '\n'
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
 
665
- lines = tik_html
666
- lines = lines.split("const text =")
667
- new_web = lines[0] + gt + lines[1]
668
 
669
- with smart_open(html_path_2, 'w') as web_f_new:
670
- web_f_new.write(new_web)
671
  return response_str
672
 
673
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
 
590
  stopping_criteria=[stopping_criteria]
591
  )
592
 
593
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
594
+
595
+ if outputs.endswith(stop_str):
596
+ outputs = outputs[:-len(stop_str)]
597
+ outputs = outputs.strip()
598
+ response_str = outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
+ if render:
601
+ print('==============rendering===============')
602
+ from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
 
 
603
 
604
+ if '**kern' in outputs:
605
+ import verovio
606
+ tk = verovio.toolkit()
607
+ tk.loadData(outputs)
608
+ tk.setOptions({"pageWidth": 2100, "footer": 'none',
609
+ 'barLineWidth': 0.5, 'beamMaxSlope': 15,
610
+ 'staffLineWidth': 0.2, 'spacingStaff': 6})
611
+ tk.getPageCount()
612
+ svg = tk.renderToSVG()
613
+ svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
614
 
615
+ svg_to_html(svg, save_render_file)
616
 
617
+ if ocr_type == 'format' and '**kern' not in outputs:
618
 
619
+
620
+ if '\\begin{tikzpicture}' not in outputs:
621
+ html_path_2 = save_render_file
622
+ right_num = outputs.count('\\right')
623
+ left_num = outputs.count('\left')
624
+
625
+ if right_num != left_num:
626
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
627
+
628
+
629
+ outputs = outputs.replace('"', '``').replace('$', '')
630
+
631
+ outputs_list = outputs.split('\n')
632
+ gt= ''
633
+ for out in outputs_list:
634
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
635
+
636
+ gt = gt[:-2]
637
+
638
+
639
+ lines = content_mmd_to_html
640
+ lines = lines.split("const text =")
641
+ new_web = lines[0] + 'const text =' + gt + lines[1]
642
+
643
+ else:
644
+ html_path_2 = save_render_file
645
+ outputs = outputs.translate(translation_table)
646
+ outputs_list = outputs.split('\n')
647
+ gt= ''
648
+ for out in outputs_list:
649
+ if out:
650
+ if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
651
+ while out[-1] == ' ':
652
+ out = out[:-1]
653
+ if out is None:
654
+ break
655
+
656
+ if out:
657
+ if out[-1] != ';':
658
+ gt += out[:-1] + ';\n'
659
+ else:
660
+ gt += out + '\n'
661
+ else:
662
+ gt += out + '\n'
663
 
664
 
665
+ lines = tik_html
666
+ lines = lines.split("const text =")
667
+ new_web = lines[0] + gt + lines[1]
668
 
669
+ with smart_open(html_path_2, 'w') as web_f_new:
670
+ web_f_new.write(new_web)
671
  return response_str
672
 
673
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):