File size: 3,726 Bytes
e41ca05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from utils import OCR, unnormalize_box


labels = ["O", "B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL"]
id2label = {v: k for v, k in enumerate(labels)}
label2id = {k: v for v, k in enumerate(labels)}

tokenizer = LayoutLMv3TokenizerFast.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
processor = LayoutLMv3Processor.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
model = LayoutLMv3ForTokenClassification.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


def blur(image, boxes):
    image = np.array(image)
    for box in boxes:

        blur_x = int(box[0])
        blur_y = int(box[1])
        blur_width = int(box[2]-box[0])
        blur_height = int(box[3]-box[1])

        roi = image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
        blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
        image[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image

    return Image.fromarray(image, 'RGB')


def prediction(image):
    boxes, words = OCR(image)
    encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
    offset_mapping = encoding.pop('offset_mapping')

    for k, v in encoding.items():
        encoding[k] = v.to(device)

    outputs = model(**encoding)

    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    token_boxes = encoding.bbox.squeeze().tolist()

    inp_ids = encoding.input_ids.squeeze().tolist()
    inp_words = [tokenizer.decode(i) for i in inp_ids]

    width, height = image.size
    is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0

    true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
    true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
    true_words = []

    for id, i in enumerate(inp_words):
        if not is_subword[id]:
            true_words.append(i)
        else:
            true_words[-1] = true_words[-1]+i

    true_predictions = true_predictions[1:-1]
    true_boxes = true_boxes[1:-1]
    true_words = true_words[1:-1]

    preds = []
    l_words = []
    bboxes = []

    for i, j in enumerate(true_predictions):
        if j != 'others':
            preds.append(true_predictions[i])
            l_words.append(true_words[i])
            bboxes.append(true_boxes[i])

    d = {}
    for id, i in enumerate(preds):
        if i not in d.keys():
            d[i] = l_words[id]
        else:
            d[i] = d[i] + ", " + l_words[id]

    d = {k: v.strip() for (k, v) in d.items()}

    keys_to_pop = []
    for k, v in d.items():
        if k[:2] == "I-":
            d["B-" + k[2:]] = d["B-" + k[2:]] + ", " + v
            keys_to_pop.append(k)

    if "O" in d: d.pop("O")
    if "B-TOTAL" in d: d.pop("B-TOTAL")
    for k in keys_to_pop: d.pop(k)

    blur_boxes = []
    for prediction, box in zip(preds, bboxes):
        if prediction != 'O' and prediction[2:] != 'TOTAL':
            blur_boxes.append(box)

    image = (blur(image, blur_boxes))

    #draw = ImageDraw.Draw(image, "RGBA")
    #font = ImageFont.load_default()

    #for prediction, box in zip(preds, bboxes):
    #    draw.rectangle(box)
    #    draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")

    return d, image