UniChart / app.py
ahmed-masry's picture
Update app.py
f63a11e verified
raw
history blame
2.39 kB
import gradio as gr
from transformers import DonutProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image
import torch, os, re
import spaces
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/test/png/74801584018932.png', 'chart_example_1.png')
torch.hub.download_url_to_file('https://raw.githubusercontent.com/vis-nlp/ChartQA/main/ChartQA%20Dataset/val/png/multi_col_1229.png', 'chart_example_2.png')
model_name = "ahmed-masry/unichart-base-960"
model = VisionEncoderDecoderModel.from_pretrained(model_name)
processor = DonutProcessor.from_pretrained(model_name)
@spaces.GPU
def predict(image, input_prompt):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_prompt += " <s_answer>"
decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=4,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = processor.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=2).strip()
return sequence
image = gr.components.Image(type="pil", label="Chart Image")
input_prompt = gr.components.Textbox(label="Input Prompt")
model_output = gr.components.Textbox(label="Model Output")
examples = [["chart_example_1.png", "<summarize_chart>"],
["chart_example_2.png", "<extract_data_table>"]]
title = "Interactive Gradio Demo for UniChart-base-960 model"
interface = gr.Interface(fn=predict,
inputs=[image, input_prompt],
outputs=model_output,
examples=examples,
title=title,
theme='gradio/soft')
interface.launch()