File size: 5,364 Bytes
8e1d18e
45cdb18
 
 
8aa0e27
de5405e
a1a6296
 
 
 
 
 
 
f684df1
a1a6296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f7b9192
a1a6296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da66d52
a1a6296
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8aa0e27
a1a6296
 
 
f7b9192
 
376d657
f7b9192
 
 
 
 
 
acfcb3a
 
 
 
 
a1a6296
 
 
acfcb3a
a1a6296
f7b9192
a1a6296
 
 
 
 
8aa0e27
 
a1a6296
acfcb3a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
os.system('pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu')
os.system('sudo apt-get install tesseract-ocr')
os.system('pip install -q pytesseract')
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import gradio as gr
import torch
from datasets import load_dataset, ClassLabel
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor,LayoutLMv3FeatureExtractor
import pytesseract
import numpy as np
from PIL import ImageDraw, ImageFont

examples = [['example1.png'],['example2.png'],['example3.png']]
dataset = load_dataset("nielsr/cord-layoutlmv3")['train']

def get_label_list(labels):
    unique_labels = set()
    for label in labels:
        unique_labels = unique_labels | set(label)
    label_list = list(unique_labels)
    label_list.sort()
    return label_list

def convert_l2n_n2l(dataset):
    features = dataset.features
    label_column_name = "ner_tags"

    label_list = features[label_column_name].feature.names
    if isinstance(features[label_column_name].feature, ClassLabel):
        id2label = {k:v for k,v in enumerate(label_list)}
        label2id = {v:k for k,v in enumerate(label_list)}
    else:
        label_list = get_label_list(dataset[label_column_name])
        id2label = {k:v for k,v in enumerate(label_list)}
        label2id = {v:k for k,v in enumerate(label_list)}

    return label_list, id2label, label2id, len(label_list)

def label_colour(label):
    label2color ={'MENU.PRICE':'blue','MENU.NM':'red','other':None,'MENU.NUM':'orange','TOTAL.TOTAL_PRICE':'green'}
    if label in label2color:
        colour = label2color.get(label)
    else:
        colour = None
    return colour

def iob_to_label(label):
    label = label[2:]
    if not label:
        return 'other'
    return label

def convert_results(words,tags):
    ents = set()
    completeword = ""
    for word, tag in zip(words, tags):
        if tag != "O":
            ent_position, ent_type = tag.split("-")
            if ent_position == "S":
                ents.add((word,ent_type))
            else:
                if ent_position == "B":
                    completeword = completeword+ " "+ word
                elif ent_position == "I":
                    completeword= completeword+ " " + word
                elif ent_position == "E":
                    completeword =completeword+" " + word
                
                ents.add((completeword,ent_type))
                completeword= ""
    return ents

def unnormalize_box(bbox, width, height):
     return [
         width * (bbox[0] / 1000),
         height * (bbox[1] / 1000),
         width * (bbox[2] / 1000),
         height * (bbox[3] / 1000),
     ]

def predict(image):
    model = LayoutLMv3ForTokenClassification.from_pretrained("keldrenloy/LayoutLMv3FineTunedwithCORDandSGReceipts") #add your model directory here
    processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
    label_list,id2label,label2id, num_labels = convert_l2n_n2l(dataset)
    width, height = image.size
    
    encoding_inputs = processor(image,return_offsets_mapping=True, return_tensors="pt",truncation = True)
    offset_mapping = encoding_inputs.pop('offset_mapping')

    with torch.no_grad():
        outputs = model(**encoding_inputs)
    
    predictions = outputs.logits.argmax(-1).squeeze().tolist()
    token_boxes = encoding_inputs.bbox.squeeze().tolist()
    
    is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
    true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
    true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
    
    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
        
    for prediction, box in zip(true_predictions, true_boxes):
        predicted_label = iob_to_label(prediction)
        draw.rectangle(box, outline=label_colour(predicted_label))
        draw.text((box[0]+10, box[1]-10), text=predicted_label, fill=label_colour(predicted_label), font=font)
        
    words = text_extraction(image)
    extracted_words = convert_results(words,true_predictions)
    
    menu_list = []
    price_list = []
    for idx,item in enumerate(extracted_words):
        if item[1] == 'MENU.NM':
            menu_list.append(f"item {idx}.{item[0]}")
        if item[1] == 'MENU.PRICE':
            price_list.append(f"item {idx}. ${item[0]}")
    
    return image,menu_list,price_list
    
def text_extraction(image):
    feature_extractor = LayoutLMv3FeatureExtractor()
    encoding = feature_extractor(image, return_tensors="pt")
    return encoding['words'][0]

css = """.output_image, .input_image {height: 600px !important}"""

demo = gr.Interface(fn = predict,
                    inputs = gr.inputs.Image(type="pil"),
                    outputs = [gr.outputs.Image(type="pil", label="annotated image"),'text','text'],
                    css = css,
                    examples = examples,
                    allow_flagging=True,
                    flagging_options=["incorrect", "correct"],
                    flagging_callback = gr.CSVLogger(),             
                    flagging_dir = "flagged",
                    analytics_enabled = True, enable_queue=True
                   )
demo.launch(debug=False)