mp-02 commited on
Commit
6d1caf6
·
verified ·
1 Parent(s): e2111cd

Upload folder using huggingface_hub

Browse files
Files changed (6) hide show
  1. README.md +2 -8
  2. app.py +52 -0
  3. cord_inference.py +81 -0
  4. requirements.txt +6 -0
  5. sroie_inference.py +114 -0
  6. utils.py +40 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: LayoutLMv3 For Recepits
3
- emoji: 👁
4
- colorFrom: indigo
5
- colorTo: green
6
  sdk: gradio
7
  sdk_version: 4.40.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: LayoutLMv3_for_recepits
3
+ app_file: app.py
 
 
4
  sdk: gradio
5
  sdk_version: 4.40.0
 
 
6
  ---
 
 
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cord_inference import prediction as cord_prediction
2
+ from sroie_inference import prediction as sroie_prediction
3
+ import gradio as gr
4
+ import json
5
+
6
+ def prediction(image_path: str):
7
+
8
+ #we first use mp-02/layoutlmv3-finetuned-cord on the image, which gives us a JSON with some info and a blurred image
9
+ d, image = sroie_prediction(image_path)
10
+
11
+ #we save the blurred image in order to pass it to the other model
12
+ image_path_blurred = image_path.split('.')[0] + '_blurred.' + image_path.split('.')[1]
13
+ image.save(image_path_blurred)
14
+
15
+ #then we use the model fine-tuned on sroie (for now it is Theivaprakasham/layoutlmv3-finetuned-sroie)
16
+ d1, image1 = cord_prediction(image_path_blurred)
17
+
18
+ #we then link the two json files
19
+ if len(d) == 0:
20
+ k = d1
21
+ else:
22
+ k = json.dumps(d).split('}')[0] + ', ' + json.dumps(d1).split('{')[1]
23
+
24
+ return d, image, d1, image1, k
25
+
26
+ # p,i,j = prediction("11990982-img.png")
27
+ # print(p)
28
+
29
+
30
+ title = "Interactive demo: LayoutLMv3 for receipts"
31
+ description = "Demo for Microsoft's LayoutLMv3, a Transformer for state-of-the-art document image understanding tasks. This particular model is fine-tuned on CORD and SROIE, which are datasets of receipts.\n It firsts uses the fine-tune on SROIE to extract date, company and address, then the fine-tune on CORD for the other info.\n To use it, simply upload an image or use the example image below. Results will show up in a few seconds."
32
+ examples = [['image.jpg']]
33
+
34
+ css = """.output_image, .input_image {height: 600px !important}"""
35
+
36
+
37
+ # we use a gradio interface that takes in input an image and return a JSON file that contains its info
38
+ # we show also the intermediate steps (first we take some info with the model fine-tuned on SROIE and we blur the relative boxes
39
+ # then we pass the image to the model fine-tuned on CORD
40
+ iface = gr.Interface(fn=prediction,
41
+ inputs=gr.Image(type="filepath"),
42
+ outputs=[gr.JSON(label="json parsing"),
43
+ gr.Image(type="pil", label="blurred image"),
44
+ gr.JSON(label="json parsing"),
45
+ gr.Image(type="pil", label="annotated image"),
46
+ gr.JSON(label="json parsing")],
47
+ title=title,
48
+ description=description,
49
+ examples=examples,
50
+ css=css)
51
+
52
+ iface.launch()
cord_inference.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from utils import OCR, unnormalize_box
6
+
7
+
8
+ labels = ["O", "B-MENU.NM", "B-MENU.NUM", "B-MENU.UNITPRICE", "B-MENU.CNT", "B-MENU.DISCOUNTPRICE", "B-MENU.PRICE", "B-MENU.ITEMSUBTOTAL", "B-MENU.VATYN", "B-MENU.ETC", "B-MENU.SUB.NM", "B-MENU.SUB.UNITPRICE", "B-MENU.SUB.CNT", "B-MENU.SUB.PRICE", "B-MENU.SUB.ETC", "B-VOID_MENU.NM", "B-VOID_MENU.PRICE", "B-SUB_TOTAL.SUBTOTAL_PRICE", "B-SUB_TOTAL.DISCOUNT_PRICE", "B-SUB_TOTAL.SERVICE_PRICE", "B-SUB_TOTAL.OTHERSVC_PRICE", "B-SUB_TOTAL.TAX_PRICE", "B-SUB_TOTAL.ETC", "B-TOTAL.TOTAL_PRICE", "B-TOTAL.TOTAL_ETC", "B-TOTAL.CASHPRICE", "B-TOTAL.CHANGEPRICE", "B-TOTAL.CREDITCARDPRICE", "B-TOTAL.EMONEYPRICE", "B-TOTAL.MENUTYPE_CNT", "B-TOTAL.MENUQTY_CNT", "I-MENU.NM", "I-MENU.NUM", "I-MENU.UNITPRICE", "I-MENU.CNT", "I-MENU.DISCOUNTPRICE", "I-MENU.PRICE", "I-MENU.ITEMSUBTOTAL", "I-MENU.VATYN", "I-MENU.ETC", "I-MENU.SUB.NM", "I-MENU.SUB.UNITPRICE", "I-MENU.SUB.CNT", "I-MENU.SUB.PRICE", "I-MENU.SUB.ETC", "I-VOID_MENU.NM", "I-VOID_MENU.PRICE", "I-SUB_TOTAL.SUBTOTAL_PRICE", "I-SUB_TOTAL.DISCOUNT_PRICE", "I-SUB_TOTAL.SERVICE_PRICE", "I-SUB_TOTAL.OTHERSVC_PRICE", "I-SUB_TOTAL.TAX_PRICE", "I-SUB_TOTAL.ETC", "I-TOTAL.TOTAL_PRICE", "I-TOTAL.TOTAL_ETC", "I-TOTAL.CASHPRICE", "I-TOTAL.CHANGEPRICE", "I-TOTAL.CREDITCARDPRICE", "I-TOTAL.EMONEYPRICE", "I-TOTAL.MENUTYPE_CNT", "I-TOTAL.MENUQTY_CNT"]
9
+ id2label = {v: k for v, k in enumerate(labels)}
10
+ label2id = {k: v for v, k in enumerate(labels)}
11
+
12
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
13
+ processor = LayoutLMv3Processor.from_pretrained("mp-02/layoutlmv3-finetuned-cord", apply_ocr=False)
14
+ model = LayoutLMv3ForTokenClassification.from_pretrained("mp-02/layoutlmv3-finetuned-cord")
15
+
16
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17
+ model.to(device)
18
+
19
+
20
+ def prediction(image_path: str):
21
+ image = Image.open(image_path).convert('RGB')
22
+ boxes, words = OCR(image_path)
23
+ encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
24
+ offset_mapping = encoding.pop('offset_mapping')
25
+
26
+ for k, v in encoding.items():
27
+ encoding[k] = v.to(device)
28
+
29
+ outputs = model(**encoding)
30
+
31
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
32
+ token_boxes = encoding.bbox.squeeze().tolist()
33
+
34
+ inp_ids = encoding.input_ids.squeeze().tolist()
35
+ inp_words = [tokenizer.decode(i) for i in inp_ids]
36
+
37
+ width, height = image.size
38
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
39
+
40
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
41
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
42
+ true_words = []
43
+
44
+ for id, i in enumerate(inp_words):
45
+ if not is_subword[id]:
46
+ true_words.append(i)
47
+ else:
48
+ true_words[-1] = true_words[-1]+i
49
+
50
+ true_predictions = true_predictions[1:-1]
51
+ true_boxes = true_boxes[1:-1]
52
+ true_words = true_words[1:-1]
53
+
54
+ preds = []
55
+ l_words = []
56
+ bboxes = []
57
+
58
+ for i, j in enumerate(true_predictions):
59
+ if j != 'others':
60
+ preds.append(true_predictions[i])
61
+ l_words.append(true_words[i])
62
+ bboxes.append(true_boxes[i])
63
+
64
+ d = {}
65
+ for id, i in enumerate(preds):
66
+ if i not in d.keys():
67
+ d[i] = l_words[id]
68
+ else:
69
+ d[i] = d[i] + ", " + l_words[id]
70
+ d = {k: v.strip() for (k, v) in d.items()}
71
+
72
+ # TODO: process the json
73
+
74
+ draw = ImageDraw.Draw(image, "RGBA")
75
+ font = ImageFont.load_default()
76
+
77
+ for prediction, box in zip(preds, bboxes):
78
+ draw.rectangle(box)
79
+ draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black")
80
+
81
+ return d, image
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ json
2
+ torch
3
+ cv2
4
+ PIL
5
+ transformers
6
+ paddleocr
sroie_inference.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from transformers import LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
6
+ from utils import OCR, unnormalize_box
7
+
8
+
9
+ labels = ["O", "B-COMPANY", "I-COMPANY", "B-DATE", "I-DATE", "B-ADDRESS", "I-ADDRESS", "B-TOTAL", "I-TOTAL"]
10
+ id2label = {v: k for v, k in enumerate(labels)}
11
+ label2id = {k: v for v, k in enumerate(labels)}
12
+
13
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
14
+ processor = LayoutLMv3Processor.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie", apply_ocr=False)
15
+ model = LayoutLMv3ForTokenClassification.from_pretrained("Theivaprakasham/layoutlmv3-finetuned-sroie")
16
+
17
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
18
+ model.to(device)
19
+
20
+
21
+ def blur(image, boxes):
22
+ img = cv2.imread(image)
23
+ for box in boxes:
24
+ blur_x = int(box[0])
25
+ blur_y = int(box[1])
26
+ blur_width = int(box[2]-box[0])
27
+ blur_height = int(box[3]-box[1])
28
+
29
+ roi = img[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width]
30
+ blur_image = cv2.GaussianBlur(roi, (201, 201), 0)
31
+ img[blur_y:blur_y + blur_height, blur_x:blur_x + blur_width] = blur_image
32
+
33
+ cv2.imwrite("images/example_with_blur.jpg", img)
34
+ return "example_with_blur.jpg"
35
+
36
+
37
+ def prediction(image_path: str):
38
+ boxes, words = OCR(image_path)
39
+ image = Image.open(image_path).convert('RGB')
40
+ encoding = processor(image, words, boxes=boxes, return_offsets_mapping=True, return_tensors="pt", truncation=True)
41
+ offset_mapping = encoding.pop('offset_mapping')
42
+
43
+ for k, v in encoding.items():
44
+ encoding[k] = v.to(device)
45
+
46
+ outputs = model(**encoding)
47
+
48
+ predictions = outputs.logits.argmax(-1).squeeze().tolist()
49
+ token_boxes = encoding.bbox.squeeze().tolist()
50
+
51
+ inp_ids = encoding.input_ids.squeeze().tolist()
52
+ inp_words = [tokenizer.decode(i) for i in inp_ids]
53
+
54
+ width, height = image.size
55
+ is_subword = np.array(offset_mapping.squeeze().tolist())[:, 0] != 0
56
+
57
+ true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
58
+ true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
59
+ true_words = []
60
+
61
+ for id, i in enumerate(inp_words):
62
+ if not is_subword[id]:
63
+ true_words.append(i)
64
+ else:
65
+ true_words[-1] = true_words[-1]+i
66
+
67
+ true_predictions = true_predictions[1:-1]
68
+ true_boxes = true_boxes[1:-1]
69
+ true_words = true_words[1:-1]
70
+
71
+ preds = []
72
+ l_words = []
73
+ bboxes = []
74
+
75
+ for i, j in enumerate(true_predictions):
76
+ if j != 'others':
77
+ preds.append(true_predictions[i])
78
+ l_words.append(true_words[i])
79
+ bboxes.append(true_boxes[i])
80
+
81
+ d = {}
82
+ for id, i in enumerate(preds):
83
+ if i not in d.keys():
84
+ d[i] = l_words[id]
85
+ else:
86
+ d[i] = d[i] + ", " + l_words[id]
87
+
88
+ d = {k: v.strip() for (k, v) in d.items()}
89
+
90
+ keys_to_pop = []
91
+ for k, v in d.items():
92
+ if k[:2] == "I-":
93
+ d["B-" + k[2:]] = d["B-" + k[2:]] + ", " + v
94
+ keys_to_pop.append(k)
95
+
96
+ if "O" in d: d.pop("O")
97
+ if "B-TOTAL" in d: d.pop("B-TOTAL")
98
+ for k in keys_to_pop: d.pop(k)
99
+
100
+ blur_boxes = []
101
+ for prediction, box in zip(preds, bboxes):
102
+ if prediction != 'O' and prediction[2:] != 'TOTAL':
103
+ blur_boxes.append(box)
104
+
105
+ image = Image.open(blur(image_path, blur_boxes))
106
+
107
+ draw = ImageDraw.Draw(image, "RGBA")
108
+ font = ImageFont.load_default()
109
+ for prediction, box in zip(preds, bboxes):
110
+ draw.rectangle(box)
111
+ draw.text((box[0]+10, box[1]-10), text=prediction, font=font, fill="black", font_size="8")
112
+
113
+ return d, image
114
+
utils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from paddleocr import PaddleOCR
2
+ from PIL import Image
3
+
4
+ def normalize_bbox(bbox, width, height):
5
+
6
+ return [
7
+ int(1000 * (bbox[0] / width)),
8
+ int(1000 * (bbox[1] / height)),
9
+ int(1000 * (bbox[2] / width)),
10
+ int(1000 * (bbox[3] / height)),
11
+ ]
12
+
13
+ def unnormalize_box(bbox, width, height):
14
+
15
+ return [
16
+ width * (bbox[0] / 1000),
17
+ height * (bbox[1] / 1000),
18
+ width * (bbox[2] / 1000),
19
+ height * (bbox[3] / 1000),
20
+ ]
21
+
22
+
23
+ def OCR(image_path: str):
24
+ ocr = PaddleOCR(use_angle_cls=True)
25
+ image = Image.open(image_path)
26
+ result = ocr.ocr(image_path, cls=True)
27
+ bboxes = []
28
+ words = []
29
+
30
+ for idx in range(len(result)):
31
+ res = result[idx]
32
+
33
+ for line in res:
34
+ # print(line)
35
+ # print(line[0][0] + line[0][2])
36
+ bboxes.append(normalize_bbox(line[0][0]+line[0][2], image.width, image.height))
37
+ # print(line[1][0])
38
+ words.append(line[1][0])
39
+
40
+ return bboxes, words