File size: 3,931 Bytes
52c57e0
 
 
 
 
 
33024f8
ec3d00d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52c57e0
ec3d00d
 
52c57e0
7fe76ee
 
 
52c57e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b248a87
 
52c57e0
 
 
 
 
 
 
 
 
b248a87
52c57e0
 
 
 
 
 
 
 
 
 
 
b248a87
52c57e0
4072a5e
a8ea53b
0581571
8e43b0e
4072a5e
0581571
52c57e0
80ff823
52c57e0
2a1d11c
52c57e0
2a1d11c
52c57e0
 
70bfa79
52c57e0
 
 
 
80ff823
52c57e0
ad0f1c1
52c57e0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import numpy as np
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image, ImageDraw, ImageFont
from utils import OCR, unnormalize_box

label_list = [
    "O",
    "B-MENU.CNT",
    "B-MENU.DISCOUNTPRICE",
    "B-MENU.NM",
    "B-MENU.NUM",
    "B-MENU.PRICE",
    "B-MENU.SUB.CNT",
    "B-MENU.SUB.NM",
    "B-MENU.SUB.PRICE",
    "B-MENU.UNITPRICE",
    "B-SUB_TOTAL.DISCOUNT_PRICE",
    "B-SUB_TOTAL.ETC",
    "B-SUB_TOTAL.SERVICE_PRICE",
    "B-SUB_TOTAL.SUBTOTAL_PRICE",
    "B-SUB_TOTAL.TAX_PRICE",
    "B-TOTAL.CASHPRICE",
    "B-TOTAL.CHANGEPRICE",
    "B-TOTAL.CREDITCARDPRICE",
    "B-TOTAL.MENUQTY_CNT",
    "B-TOTAL.TOTAL_PRICE",
    "I-MENU.CNT",
    "I-MENU.DISCOUNTPRICE",
    "I-MENU.NM",
    "I-MENU.NUM",
    "I-MENU.PRICE",
    "I-MENU.SUB.CNT",
    "I-MENU.SUB.NM",
    "I-MENU.SUB.PRICE",
    "I-MENU.UNITPRICE",
    "I-SUB_TOTAL.DISCOUNT_PRICE",
    "I-SUB_TOTAL.ETC",
    "I-SUB_TOTAL.SERVICE_PRICE",
    "I-SUB_TOTAL.SUBTOTAL_PRICE",
    "I-SUB_TOTAL.TAX_PRICE",
    "I-TOTAL.CASHPRICE",
    "I-TOTAL.CHANGEPRICE",
    "I-TOTAL.CREDITCARDPRICE",
    "I-TOTAL.MENUQTY_CNT",
    "I-TOTAL.TOTAL_PRICE"
]

id2label = dict(enumerate(label_list))
label2id = {v: k for k, v in enumerate(label_list)}

tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


def prediction(image):
    boxes, words = OCR(image)
    encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
    offset_mapping = encoding.pop('offset_mapping')

    for k, v in encoding.items():
        encoding[k] = v.to(device)

    outputs = model(**encoding)

    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    token_boxes = encoding.bbox.squeeze().tolist()
    probabilities = torch.softmax(outputs.logits, dim=-1)
    confidence_scores = probabilities.max(-1).values.squeeze().tolist()

    inp_ids = encoding.input_ids.squeeze().tolist()
    inp_words = [tokenizer.decode(i) for i in inp_ids]

    width, height = image.size
    is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0

    true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
    true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
    true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
    true_words = []

    for id, i in enumerate(inp_words):
        if not is_subword[id]:
            true_words.append(i)
        else:
            true_words[-1] = true_words[-1]+i

    true_predictions = true_predictions[1:-1]
    true_boxes = true_boxes[1:-1]
    true_words = true_words[1:-1]
    true_confidence_scores = true_confidence_scores[1:-1]

    """
    for i, conf in enumerate(true_confidence_scores):
        if conf < 0.5 :
            true_predictions[i] = "O"
    """            
    
    d = {}
    for id, i in enumerate(true_predictions):
        if i not in d.keys():
            d[i] = true_words[id]
        else:
            d[i] = d[i] + ", " + true_words[id]
    d = {k: v.strip() for (k, v) in d.items()}

    # TODO:process the json

    draw = ImageDraw.Draw(image, "RGBA")
    font = ImageFont.load_default()

    for prediction, box, confidence in zip(true_predictions, true_boxes, true_confidence_scores):
        draw.rectangle(box)
        draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")

    return d, image