Iqra Ali commited on
Commit
e6d304f
·
1 Parent(s): 973e927

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -27
app.py CHANGED
@@ -1,39 +1,67 @@
1
 
 
2
  import gradio as gr
3
  import torch
 
 
4
  from PIL import Image
 
 
5
 
6
- from donut import DonutModel
 
7
 
8
- def demo_process(input_img):
9
- global pretrained_model, task_prompt, task_name
10
- # input_img = Image.fromarray(input_img)
11
- output = pretrained_model.inference(image=input_img, prompt=task_prompt)["predictions"][0]
12
- return output
13
 
14
- task_prompt = f"<s_cord-v2>"
 
 
15
 
16
- image = Image.open("./Binder1_Page_48_Image_0001.png")
17
- image.save("cord_sample_receipt1.png")
18
- image = Image.open("./SKMBT_75122072616550_Page_50_Image_0001.png")
19
- image.save("cord_sample_receipt2.png")
20
 
21
- pretrained_model = DonutModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
22
- pretrained_model.encoder.to(torch.bfloat16)
23
- pretrained_model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  demo = gr.Interface(
26
- fn=demo_process,
27
- inputs= gr.inputs.Image(type="pil"),
28
  outputs="json",
29
- title=f"Donut 🍩 demonstration for `cord-v2` task",
30
- description="""This model is trained with 800 Indonesian receipt images of CORD dataset. <br>
31
- Demonstrations for other types of documents/tasks are available at https://github.com/clovaai/donut <br>
32
- More CORD receipt images are available at https://huggingface.co/datasets/naver-clova-ix/cord-v2
33
- More details are available at:
34
- - Paper: https://arxiv.org/abs/2111.15664
35
- - GitHub: https://github.com/clovaai/donut""",
36
- examples=[["cord_sample_receipt1.png"], ["cord_sample_receipt2.png"]],
37
- )
38
-
39
- demo.launch()
 
1
 
2
+ import re
3
  import gradio as gr
4
  import torch
5
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
6
+ import transformers
7
  from PIL import Image
8
+ import random
9
+ import numpy as np
10
 
11
+ # hidde logs
12
+ transformers.logging.disable_default_handler()
13
 
 
 
 
 
 
14
 
15
+ # Load our model from Hugging Face
16
+ processor = DonutProcessor.from_pretrained("Iqra56/Donut_Updated")
17
+ model = VisionEncoderDecoderModel.from_pretrained("Iqra56/Donut_Updated")
18
 
19
+ # Move model to GPU
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ model.to(device)
 
22
 
23
+ # Load random document image from the test set
24
+ test_sample = processed_dataset["test"][random.randint(1,7)]
25
+
26
+ def run_prediction(sample, model=model, processor=processor):
27
+ # prepare inputs
28
+ pixel_values = torch.tensor(test_sample["pixel_values"]).unsqueeze(0)
29
+ task_prompt = "<s>"
30
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
31
+
32
+ # run inference
33
+ outputs = model.generate(
34
+ pixel_values.to(device),
35
+ decoder_input_ids=decoder_input_ids.to(device),
36
+ max_length=model.decoder.config.max_position_embeddings,
37
+ early_stopping=True,
38
+ pad_token_id=processor.tokenizer.pad_token_id,
39
+ eos_token_id=processor.tokenizer.eos_token_id,
40
+ use_cache=True,
41
+ num_beams=1,
42
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
43
+ return_dict_in_generate=True,
44
+ )
45
+
46
+ # postprocess
47
+ sequence = processor.batch_decode(outputs.sequences)[0]
48
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
49
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
50
+
51
+ return processor.token2json(sequence)
52
+
53
+ description = "Gradio Demo for Donut, an instance of `VisionEncoderDecoderModel` fine-tuned on DocVQA (document visual question answering). To use it, simply upload your image and type a question and click 'submit', or click one of the examples to load them. Read more at the links below."
54
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2111.15664' target='_blank'>Donut: OCR-free Document Understanding Transformer</a> | <a href='https://github.com/clovaai/donut' target='_blank'>Github Repo</a></p>"
55
 
56
  demo = gr.Interface(
57
+ fn=process_document,
58
+ inputs=["image", "text"],
59
  outputs="json",
60
+ title="Demo: Donut 🍩 for DocVQA",
61
+ description=description,
62
+ article=article,
63
+ enable_queue=True,
64
+ examples=[["example_1.png", "When is the coffee break?"], ["example_2.jpeg", "What's the population of Stoddard?"]],
65
+ cache_examples=False)
66
+
67
+ demo.launch()