Update app.py
Browse filesEnsuring the model runs in spce
app.py
CHANGED
@@ -1,3 +1,56 @@
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
import gradio as gr
|
3 |
|
4 |
+
import torch
|
5 |
+
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
6 |
+
|
7 |
+
processor = DonutProcessor.from_pretrained("Travad98/donut-finetuned-sogc-trademarks-1883-2001")
|
8 |
+
model = VisionEncoderDecoderModel.from_pretrained("Travad98/donut-finetuned-sogc-trademarks-1883-2001")
|
9 |
+
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
model.to(device)
|
12 |
+
|
13 |
+
def process_document(image):
|
14 |
+
# prepare encoder inputs
|
15 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values
|
16 |
+
|
17 |
+
# prepare decoder inputs
|
18 |
+
task_prompt = "<s_cord-v2>"
|
19 |
+
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
|
20 |
+
|
21 |
+
# generate answer
|
22 |
+
outputs = model.generate(
|
23 |
+
pixel_values.to(device),
|
24 |
+
decoder_input_ids=decoder_input_ids.to(device),
|
25 |
+
max_length=model.decoder.config.max_position_embeddings,
|
26 |
+
early_stopping=True,
|
27 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
28 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
29 |
+
use_cache=True,
|
30 |
+
num_beams=1,
|
31 |
+
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
32 |
+
return_dict_in_generate=True,
|
33 |
+
)
|
34 |
+
|
35 |
+
# postprocess
|
36 |
+
sequence = processor.batch_decode(outputs.sequences)[0]
|
37 |
+
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
|
38 |
+
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
|
39 |
+
|
40 |
+
return processor.token2json(sequence)
|
41 |
+
|
42 |
+
description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on extracted SOGC Trademark dataset. To use it, simply upload your image and click 'submit', or click one of the examples to load them. Read more at the links below."
|
43 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
|
44 |
+
|
45 |
+
demo = gr.Interface(
|
46 |
+
fn=process_document,
|
47 |
+
inputs="image",
|
48 |
+
outputs="json",
|
49 |
+
title="Demo: Donut 🍩 for Document Parsing",
|
50 |
+
description=description,
|
51 |
+
article=article,
|
52 |
+
enable_queue=True,
|
53 |
+
examples=[["example.png"], ["example_2.png"], ["example_3.png"]],
|
54 |
+
cache_examples=False)
|
55 |
+
|
56 |
+
demo.launch()
|