mp-02 commited on
Commit
336da81
1 Parent(s): de1560a

Update sroie_inference.py

Browse files
Files changed (1) hide show
  1. sroie_inference.py +10 -10
sroie_inference.py CHANGED
@@ -6,9 +6,9 @@ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv
6
  from utils import OCR, unnormalize_box
7
 
8
 
9
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
10
- processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
11
- model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-sroie")
12
 
13
  id2label = model.config.id2label
14
  label2id = model.config.label2id
@@ -85,8 +85,8 @@ def prediction(image):
85
  d[i] = d[i] + ", " + true_words[id]
86
  d = {k: v.strip() for (k, v) in d.items()}
87
 
88
- if "O" in d: d.pop("O")
89
- if "TOTAL" in d: d.pop("TOTAL")
90
 
91
  blur_boxes = []
92
  for prediction, box in zip(true_predictions, true_boxes):
@@ -95,12 +95,12 @@ def prediction(image):
95
 
96
  image = (blur(image, blur_boxes))
97
 
98
- #draw = ImageDraw.Draw(image, "RGBA")
99
- #font = ImageFont.load_default()
100
 
101
- #for prediction, box in zip(true_predictions, true_boxes):
102
- # draw.rectangle(box)
103
- # draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
104
 
105
  return d, image
106
 
 
6
  from utils import OCR, unnormalize_box
7
 
8
 
9
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie", apply_ocr=False)
10
+ processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie", apply_ocr=False)
11
+ model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie")
12
 
13
  id2label = model.config.id2label
14
  label2id = model.config.label2id
 
85
  d[i] = d[i] + ", " + true_words[id]
86
  d = {k: v.strip() for (k, v) in d.items()}
87
 
88
+ #if "O" in d: d.pop("O")
89
+ #if "TOTAL" in d: d.pop("TOTAL")
90
 
91
  blur_boxes = []
92
  for prediction, box in zip(true_predictions, true_boxes):
 
95
 
96
  image = (blur(image, blur_boxes))
97
 
98
+ draw = ImageDraw.Draw(image, "RGBA")
99
+ font = ImageFont.load_default()
100
 
101
+ for prediction, box in zip(true_predictions, true_boxes):
102
+ draw.rectangle(box)
103
+ draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
104
 
105
  return d, image
106