Spaces:
Running
Running
Update sroie_inference.py
Browse files- sroie_inference.py +10 -10
sroie_inference.py
CHANGED
@@ -6,9 +6,9 @@ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv
|
|
6 |
from utils import OCR, unnormalize_box
|
7 |
|
8 |
|
9 |
-
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
|
10 |
-
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-sroie", apply_ocr=False)
|
11 |
-
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-sroie")
|
12 |
|
13 |
id2label = model.config.id2label
|
14 |
label2id = model.config.label2id
|
@@ -85,8 +85,8 @@ def prediction(image):
|
|
85 |
d[i] = d[i] + ", " + true_words[id]
|
86 |
d = {k: v.strip() for (k, v) in d.items()}
|
87 |
|
88 |
-
if "O" in d: d.pop("O")
|
89 |
-
if "TOTAL" in d: d.pop("TOTAL")
|
90 |
|
91 |
blur_boxes = []
|
92 |
for prediction, box in zip(true_predictions, true_boxes):
|
@@ -95,12 +95,12 @@ def prediction(image):
|
|
95 |
|
96 |
image = (blur(image, blur_boxes))
|
97 |
|
98 |
-
|
99 |
-
|
100 |
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
|
105 |
return d, image
|
106 |
|
|
|
6 |
from utils import OCR, unnormalize_box
|
7 |
|
8 |
|
9 |
+
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie", apply_ocr=False)
|
10 |
+
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie", apply_ocr=False)
|
11 |
+
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord-sroie")
|
12 |
|
13 |
id2label = model.config.id2label
|
14 |
label2id = model.config.label2id
|
|
|
85 |
d[i] = d[i] + ", " + true_words[id]
|
86 |
d = {k: v.strip() for (k, v) in d.items()}
|
87 |
|
88 |
+
#if "O" in d: d.pop("O")
|
89 |
+
#if "TOTAL" in d: d.pop("TOTAL")
|
90 |
|
91 |
blur_boxes = []
|
92 |
for prediction, box in zip(true_predictions, true_boxes):
|
|
|
95 |
|
96 |
image = (blur(image, blur_boxes))
|
97 |
|
98 |
+
draw = ImageDraw.Draw(image, "RGBA")
|
99 |
+
font = ImageFont.load_default()
|
100 |
|
101 |
+
for prediction, box in zip(true_predictions, true_boxes):
|
102 |
+
draw.rectangle(box)
|
103 |
+
draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
|
104 |
|
105 |
return d, image
|
106 |
|