File size: 2,355 Bytes
1bcd3a0 2a6e4c0 1bcd3a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import re
import gradio as gr
import torch
from transformers import DonutProcessor, VisionEncoderDecoderModel
processor = DonutProcessor.from_pretrained("Travad98/donut-finetuned-sogc-trademarks-1883-2001")
model = VisionEncoderDecoderModel.from_pretrained("Travad98/donut-finetuned-sogc-trademarks-1883-2001")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
def process_document(image):
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
# prepare decoder inputs
task_prompt = "<s_cord-v2>"
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# generate answer
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# postprocess
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
return processor.token2json(sequence)
description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on extracted SOGC Trademark dataset. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
demo = gr.Interface(
fn=process_document,
inputs="image",
outputs="json",
title="Demo: Donut 🍩 for Document Parsing",
description=description,
article=article,
enable_queue=True,
examples=[["example.png"], ["example_2.png"], ["example_3.png"]],
cache_examples=False)
demo.launch() |