Spaces:
Runtime error
Runtime error
File size: 5,364 Bytes
8e1d18e 45cdb18 8aa0e27 de5405e a1a6296 f684df1 a1a6296 f7b9192 a1a6296 da66d52 a1a6296 8aa0e27 a1a6296 f7b9192 376d657 f7b9192 acfcb3a a1a6296 acfcb3a a1a6296 f7b9192 a1a6296 8aa0e27 a1a6296 acfcb3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
os.system('pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu')
os.system('sudo apt-get install tesseract-ocr')
os.system('pip install -q pytesseract')
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import gradio as gr
import torch
from datasets import load_dataset, ClassLabel
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor,LayoutLMv3FeatureExtractor
import pytesseract
import numpy as np
from PIL import ImageDraw, ImageFont
examples = [['example1.png'],['example2.png'],['example3.png']]
dataset = load_dataset("nielsr/cord-layoutlmv3")['train']
def get_label_list(labels):
unique_labels = set()
for label in labels:
unique_labels = unique_labels | set(label)
label_list = list(unique_labels)
label_list.sort()
return label_list
def convert_l2n_n2l(dataset):
features = dataset.features
label_column_name = "ner_tags"
label_list = features[label_column_name].feature.names
if isinstance(features[label_column_name].feature, ClassLabel):
id2label = {k:v for k,v in enumerate(label_list)}
label2id = {v:k for k,v in enumerate(label_list)}
else:
label_list = get_label_list(dataset[label_column_name])
id2label = {k:v for k,v in enumerate(label_list)}
label2id = {v:k for k,v in enumerate(label_list)}
return label_list, id2label, label2id, len(label_list)
def label_colour(label):
label2color ={'MENU.PRICE':'blue','MENU.NM':'red','other':None,'MENU.NUM':'orange','TOTAL.TOTAL_PRICE':'green'}
if label in label2color:
colour = label2color.get(label)
else:
colour = None
return colour
def iob_to_label(label):
label = label[2:]
if not label:
return 'other'
return label
def convert_results(words,tags):
ents = set()
completeword = ""
for word, tag in zip(words, tags):
if tag != "O":
ent_position, ent_type = tag.split("-")
if ent_position == "S":
ents.add((word,ent_type))
else:
if ent_position == "B":
completeword = completeword+ " "+ word
elif ent_position == "I":
completeword= completeword+ " " + word
elif ent_position == "E":
completeword =completeword+" " + word
ents.add((completeword,ent_type))
completeword= ""
return ents
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
def predict(image):
model = LayoutLMv3ForTokenClassification.from_pretrained("keldrenloy/LayoutLMv3FineTunedwithCORDandSGReceipts") #add your model directory here
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
label_list,id2label,label2id, num_labels = convert_l2n_n2l(dataset)
width, height = image.size
encoding_inputs = processor(image,return_offsets_mapping=True, return_tensors="pt",truncation = True)
offset_mapping = encoding_inputs.pop('offset_mapping')
with torch.no_grad():
outputs = model(**encoding_inputs)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding_inputs.bbox.squeeze().tolist()
is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
for prediction, box in zip(true_predictions, true_boxes):
predicted_label = iob_to_label(prediction)
draw.rectangle(box, outline=label_colour(predicted_label))
draw.text((box[0]+10, box[1]-10), text=predicted_label, fill=label_colour(predicted_label), font=font)
words = text_extraction(image)
extracted_words = convert_results(words,true_predictions)
menu_list = []
price_list = []
for idx,item in enumerate(extracted_words):
if item[1] == 'MENU.NM':
menu_list.append(f"item {idx}.{item[0]}")
if item[1] == 'MENU.PRICE':
price_list.append(f"item {idx}. ${item[0]}")
return image,menu_list,price_list
def text_extraction(image):
feature_extractor = LayoutLMv3FeatureExtractor()
encoding = feature_extractor(image, return_tensors="pt")
return encoding['words'][0]
css = """.output_image, .input_image {height: 600px !important}"""
demo = gr.Interface(fn = predict,
inputs = gr.inputs.Image(type="pil"),
outputs = [gr.outputs.Image(type="pil", label="annotated image"),'text','text'],
css = css,
examples = examples,
allow_flagging=True,
flagging_options=["incorrect", "correct"],
flagging_callback = gr.CSVLogger(),
flagging_dir = "flagged",
analytics_enabled = True, enable_queue=True
)
demo.launch(debug=False) |