File size: 5,225 Bytes
52c57e0
 
 
 
 
 
 
05e8dab
0a411ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05e8dab
0a411ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52c57e0
6f35df8
e69a81f
 
 
52c57e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b248a87
 
52c57e0
 
 
 
 
 
 
 
 
b248a87
52c57e0
 
 
 
 
 
 
 
 
 
 
b248a87
52c57e0
6a47a94
a8ea53b
0581571
8e43b0e
6a47a94
0581571
52c57e0
80ff823
52c57e0
2a1d11c
52c57e0
2a1d11c
52c57e0
 
70bfa79
52c57e0
 
 
 
80ff823
52c57e0
ad0f1c1
52c57e0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
import numpy as np
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image, ImageDraw, ImageFont
from utils import OCR, unnormalize_box


id2label = {
    "0": "O",
    "1": "B-MENU.CNT",
    "2": "B-MENU.DISCOUNTPRICE",
    "3": "B-MENU.NM",
    "4": "B-MENU.NUM",
    "5": "B-MENU.PRICE",
    "6": "B-MENU.SUB.CNT",
    "7": "B-MENU.SUB.NM",
    "8": "B-MENU.SUB.PRICE",
    "9": "B-MENU.UNITPRICE",
    "10": "B-SUB_TOTAL.DISCOUNT_PRICE",
    "11": "B-SUB_TOTAL.ETC",
    "12": "B-SUB_TOTAL.SERVICE_PRICE",
    "13": "B-SUB_TOTAL.SUBTOTAL_PRICE",
    "14": "B-SUB_TOTAL.TAX_PRICE",
    "15": "B-TOTAL.CASHPRICE",
    "16": "B-TOTAL.CHANGEPRICE",
    "17": "B-TOTAL.CREDITCARDPRICE",
    "18": "B-TOTAL.MENUQTY_CNT",
    "19": "B-TOTAL.TOTAL_PRICE",
    "20": "I-MENU.CNT",
    "21": "I-MENU.DISCOUNTPRICE",
    "22": "I-MENU.NM",
    "23": "I-MENU.NUM",
    "24": "I-MENU.PRICE",
    "25": "I-MENU.SUB.CNT",
    "26": "I-MENU.SUB.NM",
    "27": "I-MENU.SUB.PRICE",
    "28": "I-MENU.UNITPRICE",
    "29": "I-SUB_TOTAL.DISCOUNT_PRICE",
    "30": "I-SUB_TOTAL.ETC",
    "31": "I-SUB_TOTAL.SERVICE_PRICE",
    "32": "I-SUB_TOTAL.SUBTOTAL_PRICE",
    "33": "I-SUB_TOTAL.TAX_PRICE",
    "34": "I-TOTAL.CASHPRICE",
    "35": "I-TOTAL.CHANGEPRICE",
    "36": "I-TOTAL.CREDITCARDPRICE",
    "37": "I-TOTAL.MENUQTY_CNT",
    "38": "I-TOTAL.TOTAL_PRICE"
  }

label2id = {
    "B-MENU.CNT": 1,
    "B-MENU.DISCOUNTPRICE": 2,
    "B-MENU.NM": 3,
    "B-MENU.NUM": 4,
    "B-MENU.PRICE": 5,
    "B-MENU.SUB.CNT": 6,
    "B-MENU.SUB.NM": 7,
    "B-MENU.SUB.PRICE": 8,
    "B-MENU.UNITPRICE": 9,
    "B-SUB_TOTAL.DISCOUNT_PRICE": 10,
    "B-SUB_TOTAL.ETC": 11,
    "B-SUB_TOTAL.SERVICE_PRICE": 12,
    "B-SUB_TOTAL.SUBTOTAL_PRICE": 13,
    "B-SUB_TOTAL.TAX_PRICE": 14,
    "B-TOTAL.CASHPRICE": 15,
    "B-TOTAL.CHANGEPRICE": 16,
    "B-TOTAL.CREDITCARDPRICE": 17,
    "B-TOTAL.MENUQTY_CNT": 18,
    "B-TOTAL.TOTAL_PRICE": 19,
    "I-MENU.CNT": 20,
    "I-MENU.DISCOUNTPRICE": 21,
    "I-MENU.NM": 22,
    "I-MENU.NUM": 23,
    "I-MENU.PRICE": 24,
    "I-MENU.SUB.CNT": 25,
    "I-MENU.SUB.NM": 26,
    "I-MENU.SUB.PRICE": 27,
    "I-MENU.UNITPRICE": 28,
    "I-SUB_TOTAL.DISCOUNT_PRICE": 29,
    "I-SUB_TOTAL.ETC": 30,
    "I-SUB_TOTAL.SERVICE_PRICE": 31,
    "I-SUB_TOTAL.SUBTOTAL_PRICE": 32,
    "I-SUB_TOTAL.TAX_PRICE": 33,
    "I-TOTAL.CASHPRICE": 34,
    "I-TOTAL.CHANGEPRICE": 35,
    "I-TOTAL.CREDITCARDPRICE": 36,
    "I-TOTAL.MENUQTY_CNT": 37,
    "I-TOTAL.TOTAL_PRICE": 38,
    "O": 0
  }

# nielsr/layoutlmv3-finetuned-cord
tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


def prediction(image):
    boxes, words = OCR(image)
    encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
    offset_mapping = encoding.pop('offset_mapping')

    for k, v in encoding.items():
        encoding[k] = v.to(device)

    outputs = model(**encoding)

    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    token_boxes = encoding.bbox.squeeze().tolist()
    probabilities = torch.softmax(outputs.logits, dim=-1)
    confidence_scores = probabilities.max(-1).values.squeeze().tolist()

    inp_ids = encoding.input_ids.squeeze().tolist()
    inp_words = [tokenizer.decode(i) for i in inp_ids]

    width, height = image.size
    is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0

    true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
    true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
    true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
    true_words = []

    for id, i in enumerate(inp_words):
        if not is_subword[id]:
            true_words.append(i)
        else:
            true_words[-1] = true_words[-1]+i

    true_predictions = true_predictions[1:-1]
    true_boxes = true_boxes[1:-1]
    true_words = true_words[1:-1]
    true_confidence_scores = true_confidence_scores[1:-1]

    
    for i, conf in enumerate(true_confidence_scores):
        if conf < 0.5 :
            true_predictions[i] = "O"
            
    
    d = {}
    for id, i in enumerate(true_predictions):
        if i not in d.keys():
            d[i] = true_words[id]
        else:
            d[i] = d[i] + ", " + true_words[id]
    d = {k: v.strip() for (k, v) in d.items()}

    # TODO:process the json

    draw = ImageDraw.Draw(image, "RGBA")
    font = ImageFont.load_default()

    for prediction, box, confidence in zip(true_predictions, true_boxes, true_confidence_scores):
        draw.rectangle(box)
        draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")

    return d, image