Spaces:
Runtime error
Runtime error
File size: 5,290 Bytes
a1a6296 8e1d18e 45cdb18 a1a6296 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import os
os.system('pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu')
os.system('sudo apt-get install tesseract-ocr')
os.system('pip install -q pytesseract')
import torch
from datasets import load_dataset, ClassLabel
from transformers import LayoutLMv3ForTokenClassification, LayoutLMv3Processor,LayoutLMv3FeatureExtractor
import pytesseract
import numpy as np
from PIL import ImageDraw, ImageFont
examples = [['./examples/example1.png'],['./examples/example2.png'],['./examples/example3.png']]
dataset = load_dataset("nielsr/cord-layoutlmv3")['train']
def get_label_list(labels):
unique_labels = set()
for label in labels:
unique_labels = unique_labels | set(label)
label_list = list(unique_labels)
label_list.sort()
return label_list
def convert_l2n_n2l(dataset):
features = dataset.features
label_column_name = "ner_tags"
label_list = features[label_column_name].feature.names
if isinstance(features[label_column_name].feature, ClassLabel):
id2label = {k:v for k,v in enumerate(label_list)}
label2id = {v:k for k,v in enumerate(label_list)}
else:
label_list = get_label_list(dataset[label_column_name])
id2label = {k:v for k,v in enumerate(label_list)}
label2id = {v:k for k,v in enumerate(label_list)}
return label_list, id2label, label2id, len(label_list)
def label_colour(label):
label2color = {'MENU.PRICE':'blue', 'MENU.NM':'green', 'other':'green','MENU.TOTAL_PRICE':'red'}
if label in label2color:
colour = label2color.get(label)
else:
colour = None
return colour
def iob_to_label(label):
label = label[2:]
if not label:
return 'other'
return label
def convert_results(words,tags):
ents = set()
completeword = ""
for word, tag in zip(words, tags):
if tag != "O":
ent_position, ent_type = tag.split("-")
if ent_position == "S":
ents.add((word,ent_type))
else:
if ent_position == "B":
completeword = completeword+ " "+ word
elif ent_position == "I":
completeword= completeword+ " " + word
elif ent_position == "E":
completeword =completeword+" " + word
ents.add((completeword,ent_type))
completeword= ""
return ents
def unnormalize_box(bbox, width, height):
return [
width * (bbox[0] / 1000),
height * (bbox[1] / 1000),
width * (bbox[2] / 1000),
height * (bbox[3] / 1000),
]
def predict(image):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LayoutLMv3ForTokenClassification.from_pretrained("keldrenloy/layoutlmv3cordfinetuned").to(device) #add your model directory here
processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base")
label_list,id2label,label2id, num_labels = convert_l2n_n2l(dataset)
width, height = image.size
encoding_inputs = processor(image,return_offsets_mapping=True, return_tensors="pt",truncation = True)
offset_mapping = encoding_inputs.pop('offset_mapping')
for k,v in encoding_inputs.items():
encoding_inputs[k] = v.to(device)
with torch.no_grad():
outputs = model(**encoding_inputs)
predictions = outputs.logits.argmax(-1).squeeze().tolist()
token_boxes = encoding_inputs.bbox.squeeze().tolist()
is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
true_predictions = [id2label[pred] for idx, pred in enumerate(predictions) if not is_subword[idx]]
true_boxes = [unnormalize_box(box, width, height) for idx, box in enumerate(token_boxes) if not is_subword[idx]]
return true_boxes, true_predictions
def text_extraction(image):
feature_extractor = LayoutLMv3FeatureExtractor()
encoding = feature_extractor(image, return_tensors="pt")
return encoding['words'][0]
def image_render(image):
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
true_boxes,true_predictions = predict(image)
for prediction, box in zip(true_predictions, true_boxes):
predicted_label = iob_to_label(prediction)
draw.rectangle(box, outline=label_colour(predicted_label))
draw.text((box[0]+10, box[1]-10), text=predicted_label, fill=label_colour(predicted_label), font=font)
words = text_extraction(image)
print(words)
extracted_words = convert_results(words,true_predictions)
return image,extracted_words
css = """.output_image, .input_image {height: 600px !important}"""
demo = gr.Interface(fn = image_render,
inputs = gr.inputs.Image(type="pil"),
outputs = [gr.outputs.Image(type="pil", label="annotated image"),'text'],
css = css,
examples = examples,
allow_flagging=True,
flagging_options=["incorrect", "correct"],
flagging_callback = gr.CSVLogger(),
flagging_dir = "flagged"
)
if __name__ == "__main__":
demo.launch() |