hmarzan's picture
Fixes to Gradio App. Calling half when using CUDA devices.
421d199
import torch
from sconf import Config
from PIL import Image, ImageOps
from donut import DonutConfig, DonutModel
import warnings
warnings.filterwarnings("ignore")
from transformers import logging
logging.set_verbosity_warning()
config = Config(default="./config.yaml")
has_cuda = torch.cuda.is_available()
print(f"Is CUDA available: {has_cuda}")
if has_cuda:
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
model = DonutModel.from_pretrained(
config.pretrained_model_name_or_path,
input_size=config.input_size,
max_length=config.max_position_embeddings, #self.config.max_length,
align_long_axis=config.align_long_axis,
ignore_mismatched_sizes=True,
)
#device = "cuda" if torch.cuda.is_available() else "cpu"
if has_cuda:
model.half()
model = model.to("cuda")
task_name = "matricula"
task_prompt = f"<s_{task_name}>"
def predict_matricula(model, task_name, image):
image = ImageOps.exif_transpose(image)
image = image.resize(size=(1280, 960),
resample=Image.Resampling.NEAREST)
model.eval()
result = model.inference(image=image, prompt=f"<s_{task_name}>")["predictions"][0]
return result
import gradio as gr
demo = gr.Interface(
fn=lambda x:predict_matricula(model, task_name="matricula", image=x),
title="Demo: Donut 🍩 for DR Matriculas",
description="Dominican Vehicle **Matriculas OCR** Infering",
inputs=gr.Image(label="Matricula", sources="upload", type="pil", show_label=True),
outputs=[gr.JSON(label="Matricula JSON", show_label=True, value={})]
)
demo.launch(share=True)