File size: 10,283 Bytes
cb4c41f f50ef7b dea123e 1592bb3 93fe32c cb4c41f a773f5d cb4c41f be4a233 66dee59 35135c8 429bc8a 3cc134f 429bc8a 2497ea4 de95cee 281fc5d cb4c41f 46b1dae b7221a3 cb4c41f 82aea69 fceae27 82aea69 93fee22 4fd7d4e e45a37b de09c5c f7b4868 e60f9c6 3cc134f 7384818 da509d5 12d54a0 7384818 66dee59 f4ccbdc facb0ae 93fee22 de09c5c b7221a3 8732834 b7221a3 bc05740 c1b495f c8c471f 429bc8a 762ba24 12d54a0 bc05740 b7221a3 d3ffd5d 429bc8a 3cc134f cb4c41f be4a233 35135c8 be4a233 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import re
import gradio as gr
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests
from io import BytesIO
import json
import os
processor = DonutProcessor.from_pretrained("to-be/donut-base-finetuned-invoices")
model = VisionEncoderDecoderModel.from_pretrained("to-be/donut-base-finetuned-invoices")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def update_status(state):
if state == "start_or_clear":
state = 'processing' #current state becomes
return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True))
elif state == "processing":
state = 'finished_processing' #current state becomes
return (gr.update(value="",visible=False),gr.update(value="",visible=False))
elif state == "finished_processing":
state = 'processing' #current state becomes
return (gr.update(value="snowangel.gif",visible=True),gr.update(value="snowangel.gif",visible=True))
def process_document(image,sendimg):
if sendimg == True:
im1 = Image.fromarray(image)
elif sendimg == False:
im1 = Image.open('./no_image.jpg')
#keep track of demo count
resp = requests.get('https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fto-be%2Finvoice_document_headers_extraction_with_donut%2Fdemo&label=demos%20served&labelColor=%23edd239&countColor=%23d9e3f0')
#send notification through telegram
TOKEN = os.getenv('TELEGRAM_BOT_TOKEN')
CHAT_ID = os.getenv('TELEGRAM_CHANNEL_ID')
url = f'https://api.telegram.org/bot{TOKEN}/sendPhoto?chat_id={CHAT_ID}'
bio = BytesIO()
bio.name = 'image.jpeg'
im1.save(bio, 'JPEG')
bio.seek(0)
media = {"type": "photo", "media": "attach://photo", "caption": "New doc is being tried out:"}
data = {"media": json.dumps(media)}
try:
response = requests.post(url, files={'photo': bio}, data=data)
except:
print("telegram api error")
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
img2.update(visible=False)
return processor.token2json(sequence), image
title = '<table align="center" border="0" cellpadding="1" cellspacing="1" ><tbody><tr><td style="text-align:center"><img alt="" src="https://huggingface.co/spaces/to-be/invoice_document_headers_extraction_with_donut/resolve/main/circling_small.gif" style="float:right; height:50px; width:50px" /></td><td style="text-align:center"><h1>Demo: invoice header extraction with Donut</h1></td><td style="text-align:center"><img alt="" src="https://huggingface.co/spaces/to-be/invoice_document_headers_extraction_with_donut/resolve/main/circling2_small.gif" style="float:left; height:50px; width:50px" /></td></tr></tbody></table>'
paragraph0 = '<p><strong>(update 29/03/2023: for more info, you can read <a href="https://toon-beerten.medium.com/hands-on-document-data-extraction-with-transformer-7130df3b6132">my article on medium</a>)<br />(update 28/04/2023: want to finetune with your own data? Read <a href="https://towardsdatascience.com/ocr-free-document-data-extraction-with-transformers-1-2-b5a826bc2ac3">this article</a>)</strong></p>'
paragraph1 = '<p>Basic idea of the base 🍩 model is to give it an image as input and extract indexes as text. No bounding boxes or confidences are generated.<br /> I finetuned it on invoices. For more info, see the <a href="https://arxiv.org/abs/2111.15664">original paper</a> and the 🤗 <a href="https://huggingface.co/naver-clova-ix/donut-base">model</a>.</p>'
paragraph2 = '<p><strong>Training</strong>:<br />The model was trained with a few thousand of annotated invoices and non-invoices (for those the doctype will be 'Other'). They span across different countries and languages. They are always one page only. The dataset is proprietary unfortunately. Model is set to input resolution of 1280x1920 pixels. So any sample you want to try with higher dpi than 150 has no added value.<br />It was trained for about 4 hours on a NVIDIA RTX A4000 for 20k steps with a val_metric of 0.03413819904382196 at the end.<br />The <u>following indexes</u> were included in the train set:</p><ul><li><span style="font-family:Calibri"><span style="color:black">DocType</span></span></li><li><span style="font-family:Calibri"><span style="color:black">Currency</span></span></li><li><span style="font-family:Calibri"><span style="color:black">DocumentDate</span></span></li><li><span style="font-family:Calibri"><span style="color:black">GrossAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">InvoiceNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">NetAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">TaxAmount</span></span></li><li><span style="font-family:Calibri"><span style="color:black">OrderNumber</span></span></li><li><span style="font-family:Calibri"><span style="color:black">CreditorCountry</span></span></li></ul>'
paragraph3 = '<p><strong>Benchmark observations:</strong><br />From all documents in the validation set, 60% of them had all indexes captured correctly.</p><p>Here are the results per index:</p><p style="margin-left:40px"><img alt="" src="https://s3.amazonaws.com/moonup/production/uploads/1677749023966-6335a49ceb6132ca653239a0.png" style="height:70%; width:70%" /></p><p>Some other observations:<br />- when trying with a non invoice document, it's quite reliably identified as Doctype: 'Other'<br />- validation set contained mostly same layout invoices as the train set. If it was validated against completely differently sourced invoices, the results would be different<br />- Document date is able to be recognized across different notations, however, it's often wrong because the data set was not diverse (as in time span of dates) enough</p>'
#demo = gr.Interface(fn=process_document,inputs=gr_image,outputs="json",title="Demo: Donut 🍩 for invoice header retrieval", description=description,
# article=article,enable_queue=True, examples=[["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], cache_examples=False)
paragraph4 = '<p><strong>Try it out:</strong><br />To use it, simply upload your image and click 'submit', or click one of the examples to load them.<br /><em>(because this is running on the free cpu tier, it will take about 40 secs before you see a result. On a GPU it takes less than 2 seconds)</em></p><p> </p><p>Have fun 😎</p><p>Toon Beerten</p>'
smallprint = '<p>✤ <span style="font-size:11px">To get an idea of the usage, you can opt to let me get personally notified via Telegram with the image uploaded. All data will be automatically deleted after 48 hours</span></p>'
css = "#inp {height: auto !important; width: 100% !important;}"
visit_badge = '<a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fto-be%2Finvoice_document_headers_extraction_with_donut"><img src="https://api.visitorbadge.io/api/combined?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fto-be%2Finvoice_document_headers_extraction_with_donut&labelColor=%23edd239&countColor=%23d9e3f0&style=flat" /></a>'
# css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
# css = ".output_image, .input_image {height: 600px !important}"
#css = ".image-preview {height: auto !important;}"
#css='div {margin-left: auto; margin-right: auto; width: 100%;background-image: url("background.gif"); repeat 0 0;}')
with gr.Blocks(css=css) as demo:
state = gr.State(value='start_or_clear')
gr.HTML(title)
gr.HTML(paragraph0)
gr.HTML(paragraph1)
gr.HTML(paragraph2)
gr.HTML(paragraph3)
gr.HTML(paragraph4)
with gr.Row().style():
with gr.Column(scale=1):
inp = gr.Image(label='Upload invoice here:') #.style(height=400)
with gr.Column(scale=2):
gr.Examples([["example.jpg"], ["example_2.jpg"], ["example_3.jpg"]], inputs=[inp],label='Or use one of these examples:')
with gr.Row().style(equal_height=True,height=200,rounded=False):
with gr.Column(scale=1):
img2 = gr.Image("drinking.gif",label=' ',visible=False).style(rounded=True)
with gr.Column(scale=2):
btn = gr.Button(" ↓ Extract ↓ ")
with gr.Column(scale=2):
#img3 = gr.Image("snowangel.gif",label=' ',visible=False).style(rounded=True)
sendimg = gr.Checkbox(value=True, label="Allow usage data collection for at most 48 hours ✤")
with gr.Row().style():
with gr.Column(scale=2):
imgout = gr.Image(label='Uploaded document:',elem_id="inp")
with gr.Column(scale=1):
jsonout = gr.JSON(label='Extracted information:')
#imgout.clear(fn=update_status,inputs=state,outputs=[img2,img3])
#imgout.change(fn=update_status,inputs=state,outputs=[img2,img3])
btn.click(fn=process_document, inputs=[inp,sendimg], outputs=[jsonout,imgout])
gr.HTML(smallprint)
gr.HTML(visit_badge)
demo.launch()
|