File size: 2,955 Bytes
52c57e0
 
 
 
 
 
 
9b31f19
 
 
52c57e0
1655071
 
 
52c57e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b248a87
 
52c57e0
 
 
 
 
 
 
 
 
b248a87
52c57e0
 
 
 
 
 
 
 
 
 
 
b248a87
52c57e0
a8ea53b
f4b0061
358c7b8
0581571
52c57e0
80ff823
f4b0061
 
52c57e0
2a1d11c
52c57e0
2a1d11c
52c57e0
e3e9c8c
52c57e0
70bfa79
52c57e0
 
 
 
80ff823
52c57e0
ad0f1c1
52c57e0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
import numpy as np
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from PIL import Image, ImageDraw, ImageFont
from utils import OCR, unnormalize_box


tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-base-cord", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-base-cord", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-base-cord")

id2label = model.config.id2label
label2id = model.config.label2id

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


def prediction(image):
    boxes, words = OCR(image)
    encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
    offset_mapping = encoding.pop('offset_mapping')

    for k, v in encoding.items():
        encoding[k] = v.to(device)

    outputs = model(**encoding)

    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    token_boxes = encoding.bbox.squeeze().tolist()
    probabilities = torch.softmax(outputs.logits, dim=-1)
    confidence_scores = probabilities.max(-1).values.squeeze().tolist()

    inp_ids = encoding.input_ids.squeeze().tolist()
    inp_words = [tokenizer.decode(i) for i in inp_ids]

    width, height = image.size
    is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0

    true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
    true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
    true_confidence_scores = [confidence_scores[idx] for idx, conf in enumerate(confidence_scores) if not is_subword[idx]]
    true_words = []

    for id, i in enumerate(inp_words):
        if not is_subword[id]:
            true_words.append(i)
        else:
            true_words[-1] = true_words[-1]+i

    true_predictions = true_predictions[1:-1]
    true_boxes = true_boxes[1:-1]
    true_words = true_words[1:-1]
    true_confidence_scores = true_confidence_scores[1:-1]

    for i, conf in enumerate(true_confidence_scores):
        if conf < 0.5 :
            true_predictions[i] = "O"          
    
    d = {}
    for id, i in enumerate(true_predictions):
        if i != "O":
            i = i[2:]
        if i not in d.keys():
            d[i] = true_words[id]
        else:
            d[i] = d[i] + ", " + true_words[id]
    d = {k: v.strip() for (k, v) in d.items()}
    if "O" in d: d.pop("O")

    # TODO:process the json

    draw = ImageDraw.Draw(image, "RGBA")
    font = ImageFont.load_default()

    for prediction, box, confidence in zip(true_predictions, true_boxes, true_confidence_scores):
        draw.rectangle(box)
        draw.text((box[0]+10, box[1]-10), text=prediction+ ", "+ str(confidence), font=font, fill="black", font_size="15")

    return d, image