mithenks commited on
Commit
ebd3f73
·
1 Parent(s): a5f2584

first commit

Browse files
Files changed (3) hide show
  1. app.py +59 -0
  2. packages.txt +0 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import re
3
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
4
+ import torch
5
+ from PIL import Image
6
+
7
+ def process_filename(filename, question):
8
+ print(f"Image file: {filename}")
9
+ print(f"Question: {question}")
10
+ image = Image.open(filename).convert("RGB")
11
+ return process_image(image)
12
+
13
+
14
+ def process_image(image, question):
15
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
16
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-docvqa")
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ model.to(device)
20
+
21
+ # prepare decoder inputs
22
+ prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
23
+ decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
24
+
25
+ pixel_values = processor(image, return_tensors="pt").pixel_values
26
+
27
+ outputs = model.generate(
28
+ pixel_values.to(device),
29
+ decoder_input_ids=decoder_input_ids.to(device),
30
+ max_length=model.decoder.config.max_position_embeddings,
31
+ pad_token_id=processor.tokenizer.pad_token_id,
32
+ eos_token_id=processor.tokenizer.eos_token_id,
33
+ use_cache=False,
34
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
35
+ return_dict_in_generate=True,
36
+ )
37
+
38
+ sequence = processor.batch_decode(outputs.sequences)[0]
39
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
40
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
41
+ print(processor.token2json(sequence))
42
+
43
+ return [True, processor.token2json(sequence)['answer'], ""]
44
+
45
+ def process_document(image, question):
46
+ ret = process_image(image, question)
47
+ return ret[1]
48
+
49
+ description = "DocVQA (document visual question answering)"
50
+
51
+ demo = gr.Interface(
52
+ fn=process_document,
53
+ inputs=["image", gr.Textbox(label = "Question" )],
54
+ outputs=gr.Textbox(label = "Response" ),
55
+ title="Extract data from image",
56
+ description=description,
57
+ cache_examples=True)
58
+
59
+ demo.launch()
packages.txt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch