mp-02 commited on
Commit
52c57e0
·
verified ·
1 Parent(s): fc77b97

Update cord_inference.py

Browse files
Files changed (1) hide show
  1. cord_inference.py +80 -80
cord_inference.py CHANGED
@@ -1,80 +1,80 @@
1
- import torch
2
- import numpy as np
3
- from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
4
- from PIL import Image, ImageDraw, ImageFont
5
- from utils import OCR, unnormalize_box
6
-
7
-
8
- labels = ["O", "B-MENU.NM", "B-MENU.NUM", "B-MENU.UNITPRICE", "B-MENU.CNT", "B-MENU.DISCOUNTPRICE", "B-MENU.PRICE", "B-MENU.ITEMSUBTOTAL", "B-MENU.VATYN", "B-MENU.ETC", "B-MENU.SUB.NM", "B-MENU.SUB.UNITPRICE", "B-MENU.SUB.CNT", "B-MENU.SUB.PRICE", "B-MENU.SUB.ETC", "B-VOID_MENU.NM", "B-VOID_MENU.PRICE", "B-SUB_TOTAL.SUBTOTAL_PRICE", "B-SUB_TOTAL.DISCOUNT_PRICE", "B-SUB_TOTAL.SERVICE_PRICE", "B-SUB_TOTAL.OTHERSVC_PRICE", "B-SUB_TOTAL.TAX_PRICE", "B-SUB_TOTAL.ETC", "B-TOTAL.TOTAL_PRICE", "B-TOTAL.TOTAL_ETC", "B-TOTAL.CASHPRICE", "B-TOTAL.CHANGEPRICE", "B-TOTAL.CREDITCARDPRICE", "B-TOTAL.EMONEYPRICE", "B-TOTAL.MENUTYPE_CNT", "B-TOTAL.MENUQTY_CNT", "I-MENU.NM", "I-MENU.NUM", "I-MENU.UNITPRICE", "I-MENU.CNT", "I-MENU.DISCOUNTPRICE", "I-MENU.PRICE", "I-MENU.ITEMSUBTOTAL", "I-MENU.VATYN", "I-MENU.ETC", "I-MENU.SUB.NM", "I-MENU.SUB.UNITPRICE", "I-MENU.SUB.CNT", "I-MENU.SUB.PRICE", "I-MENU.SUB.ETC", "I-VOID_MENU.NM", "I-VOID_MENU.PRICE", "I-SUB_TOTAL.SUBTOTAL_PRICE", "I-SUB_TOTAL.DISCOUNT_PRICE", "I-SUB_TOTAL.SERVICE_PRICE", "I-SUB_TOTAL.OTHERSVC_PRICE", "I-SUB_TOTAL.TAX_PRICE", "I-SUB_TOTAL.ETC", "I-TOTAL.TOTAL_PRICE", "I-TOTAL.TOTAL_ETC", "I-TOTAL.CASHPRICE", "I-TOTAL.CHANGEPRICE", "I-TOTAL.CREDITCARDPRICE", "I-TOTAL.EMONEYPRICE", "I-TOTAL.MENUTYPE_CNT", "I-TOTAL.MENUQTY_CNT"]
9
- id2label = {v: k for v, k in enumerate(labels)}
10
- label2id = {k: v for v, k in enumerate(labels)}
11
-
12
- tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
13
- processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
14
- model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord")
15
-
16
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
- model.to(device)
18
-
19
-
20
- def prediction(image):
21
- boxes, words = OCR(image)
22
- encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
23
- offset_mapping = encoding.pop('offset_mapping')
24
-
25
- for k, v in encoding.items():
26
- encoding[k] = v.to(device)
27
-
28
- outputs = model(**encoding)
29
-
30
- predictions = outputs.logits.argmax(-1).squeeze().tolist()
31
- token_boxes = encoding.bbox.squeeze().tolist()
32
-
33
- inp_ids = encoding.input_ids.squeeze().tolist()
34
- inp_words = [tokenizer.decode(i) for i in inp_ids]
35
-
36
- width, height = image.size
37
- is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
38
-
39
- true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
40
- true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
41
- true_words = []
42
-
43
- for id, i in enumerate(inp_words):
44
- if not is_subword[id]:
45
- true_words.append(i)
46
- else:
47
- true_words[-1] = true_words[-1]+i
48
-
49
- true_predictions = true_predictions[1:-1]
50
- true_boxes = true_boxes[1:-1]
51
- true_words = true_words[1:-1]
52
-
53
- preds = []
54
- l_words = []
55
- bboxes = []
56
-
57
- for i, j in enumerate(true_predictions):
58
- if j != 'others':
59
- preds.append(true_predictions[i])
60
- l_words.append(true_words[i])
61
- bboxes.append(true_boxes[i])
62
-
63
- d = {}
64
- for id, i in enumerate(preds):
65
- if i not in d.keys():
66
- d[i] = l_words[id]
67
- else:
68
- d[i] = d[i] + ", " + l_words[id]
69
- d = {k: v.strip() for (k, v) in d.items()}
70
-
71
- # TODO: process the json
72
-
73
- draw = ImageDraw.Draw(image, "RGBA")
74
- font = ImageFont.load_default()
75
-
76
- for prediction, box in zip(preds, bboxes):
77
- draw.rectangle(box)
78
- draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black")
79
-
80
- return d, image
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from utils import OCR, unnormalize_box
6
+
7
+
8
+ labels = ["O", "B-MENU.NM", "B-MENU.NUM", "B-MENU.UNITPRICE", "B-MENU.CNT", "B-MENU.DISCOUNTPRICE", "B-MENU.PRICE", "B-MENU.ITEMSUBTOTAL", "B-MENU.VATYN", "B-MENU.ETC", "B-MENU.SUB.NM", "B-MENU.SUB.UNITPRICE", "B-MENU.SUB.CNT", "B-MENU.SUB.PRICE", "B-MENU.SUB.ETC", "B-VOID_MENU.NM", "B-VOID_MENU.PRICE", "B-SUB_TOTAL.SUBTOTAL_PRICE", "B-SUB_TOTAL.DISCOUNT_PRICE", "B-SUB_TOTAL.SERVICE_PRICE", "B-SUB_TOTAL.OTHERSVC_PRICE", "B-SUB_TOTAL.TAX_PRICE", "B-SUB_TOTAL.ETC", "B-TOTAL.TOTAL_PRICE", "B-TOTAL.TOTAL_ETC", "B-TOTAL.CASHPRICE", "B-TOTAL.CHANGEPRICE", "B-TOTAL.CREDITCARDPRICE", "B-TOTAL.EMONEYPRICE", "B-TOTAL.MENUTYPE_CNT", "B-TOTAL.MENUQTY_CNT", "I-MENU.NM", "I-MENU.NUM", "I-MENU.UNITPRICE", "I-MENU.CNT", "I-MENU.DISCOUNTPRICE", "I-MENU.PRICE", "I-MENU.ITEMSUBTOTAL", "I-MENU.VATYN", "I-MENU.ETC", "I-MENU.SUB.NM", "I-MENU.SUB.UNITPRICE", "I-MENU.SUB.CNT", "I-MENU.SUB.PRICE", "I-MENU.SUB.ETC", "I-VOID_MENU.NM", "I-VOID_MENU.PRICE", "I-SUB_TOTAL.SUBTOTAL_PRICE", "I-SUB_TOTAL.DISCOUNT_PRICE", "I-SUB_TOTAL.SERVICE_PRICE", "I-SUB_TOTAL.OTHERSVC_PRICE", "I-SUB_TOTAL.TAX_PRICE", "I-SUB_TOTAL.ETC", "I-TOTAL.TOTAL_PRICE", "I-TOTAL.TOTAL_ETC", "I-TOTAL.CASHPRICE", "I-TOTAL.CHANGEPRICE", "I-TOTAL.CREDITCARDPRICE", "I-TOTAL.EMONEYPRICE", "I-TOTAL.MENUTYPE_CNT", "I-TOTAL.MENUQTY_CNT"]
9
+ id2label = {v: k for v, k in enumerate(labels)}
10
+ label2id = {k: v for v, k in enumerate(labels)}
11
+
12
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord2", apply_ocr=False)
13
+ processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord2", apply_ocr=False)
14
+ model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord2")
15
+
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ model.to(device)
18
+
19
+
20
+ def prediction(image):
21
+ boxes, words = OCR(image)
22
+ encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
23
+ offset_mapping = encoding.pop('offset_mapping')
24
+
25
+ for k, v in encoding.items():
26
+ encoding[k] = v.to(device)
27
+
28
+ outputs = model(**encoding)
29
+
30
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
31
+ token_boxes = encoding.bbox.squeeze().tolist()
32
+
33
+ inp_ids = encoding.input_ids.squeeze().tolist()
34
+ inp_words = [tokenizer.decode(i) for i in inp_ids]
35
+
36
+ width, height = image.size
37
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
38
+
39
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
40
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
41
+ true_words = []
42
+
43
+ for id, i in enumerate(inp_words):
44
+ if not is_subword[id]:
45
+ true_words.append(i)
46
+ else:
47
+ true_words[-1] = true_words[-1]+i
48
+
49
+ true_predictions = true_predictions[1:-1]
50
+ true_boxes = true_boxes[1:-1]
51
+ true_words = true_words[1:-1]
52
+
53
+ preds = []
54
+ l_words = []
55
+ bboxes = []
56
+
57
+ for i, j in enumerate(true_predictions):
58
+ if j != 'others':
59
+ preds.append(true_predictions[i])
60
+ l_words.append(true_words[i])
61
+ bboxes.append(true_boxes[i])
62
+
63
+ d = {}
64
+ for id, i in enumerate(preds):
65
+ if i not in d.keys():
66
+ d[i] = l_words[id]
67
+ else:
68
+ d[i] = d[i] + ", " + l_words[id]
69
+ d = {k: v.strip() for (k, v) in d.items()}
70
+
71
+ # TODO: process the json
72
+
73
+ draw = ImageDraw.Draw(image, "RGBA")
74
+ font = ImageFont.load_default()
75
+
76
+ for prediction, box in zip(preds, bboxes):
77
+ draw.rectangle(box)
78
+ draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black")
79
+
80
+ return d, image