hmarzan's picture
Fixes to Gradio App. Check if CUDA is present, before asking for CUDA device name
4954695
raw
history blame
1.66 kB
import torch
from sconf import Config
from PIL import Image, ImageOps
from donut import DonutConfig, DonutModel
import warnings
warnings.filterwarnings("ignore")
from transformers import logging
logging.set_verbosity_warning()
config = Config(default="./config.yaml")
has_cuda = torch.cuda.is_available()
print(f"Is CUDA available: {has_cuda}")
if has_cuda:
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
model = DonutModel.from_pretrained(
config.pretrained_model_name_or_path,
input_size=config.input_size,
max_length=config.max_position_embeddings, #self.config.max_length,
align_long_axis=config.align_long_axis,
ignore_mismatched_sizes=True,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
task_name = "matricula"
task_prompt = f"<s_{task_name}>"
def predict_matricula(model, task_name, image):
image = ImageOps.exif_transpose(image)
image = image.resize(size=(1280, 960),
resample=Image.Resampling.NEAREST)
result = model.inference(image=image, prompt=f"<s_{task_name}>")["predictions"][0]
return result
import gradio as gr
demo = gr.Interface(
fn=lambda x:predict_matricula(model, task_name="matricula", image=x),
title="Demo: Donut 🍩 for DR Matriculas",
description="Dominican Vehicle **Matriculas OCR** Infering",
inputs=gr.Image(label="Matricula", sources="upload", type="pil", show_label=True),
outputs=[gr.JSON(label="Matricula JSON", show_label=True, value={})]
)
demo.launch(share=True)