donut-test / app.py
mithenks's picture
add selectors for model parameters
3b5e294
raw
history blame
2.97 kB
import gradio as gr
import re
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
from PIL import Image
def process_filename(filename, question):
print(f"Image file: {filename}")
print(f"Question: {question}")
image = Image.open(filename).convert("RGB")
return process_image(image)
def process_image(set_use_cache, set_return_dict_in_generate, set_early_stopping, set_output_scores, image, question):
repo_id = "naver-clova-ix/donut-base-finetuned-docvqa"
print(f"Model repo: {repo_id}")
processor = DonutProcessor.from_pretrained(repo_id)
model = VisionEncoderDecoderModel.from_pretrained(repo_id)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device used: {device}")
model.to(device)
# prepare decoder inputs
prompt = f"<s_docvqa><s_question>{question}</s_question><s_answer>"
decoder_input_ids = processor.tokenizer(prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=set_use_cache=="True",
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=set_return_dict_in_generate=="True",
early_stopping=set_early_stopping=="True",
output_scores=set_output_scores=="True"
)
print(outputs)
sequence_data = processor.batch_decode(outputs.sequences)
print(f"Sequence data: {sequence_data}")
sequence = sequence_data[0]
print(f"Sequence: {sequence}")
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
print(processor.token2json(sequence))
return processor.token2json(sequence)['answer']
description = "DocVQA (document visual question answering)"
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Radio(["True", "False"], value="True", label="Use cache", info="Define model.generate() use_cache value"),
gr.Radio(["True", "False"], value="True", label="Dict in generate", info="Define model.generate() return_dict_in_generate value"),
gr.Radio(["True", "False"], value="True", label="Early stopping", info="Define model.generate() early_stopping value"),
gr.Radio(["True", "False"], value="True", label="Output scores", info="Define model.generate() output_scores value"),
"image",
gr.Textbox(label = "Question" )
],
outputs=gr.Textbox(label = "Response" ),
title="Extract data from image",
description=description,
cache_examples=True)
demo.launch()