Vinay15's picture
Update app.py
d74aaed verified
raw
history blame
1.04 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForImageClassification
from PIL import Image
import torch
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0')
model = AutoModelForImageClassification.from_pretrained('ucaslcl/GOT-OCR2_0')
def perform_ocr(image):
# Ensure the image is in the right format
if isinstance(image, Image.Image):
image = image.convert("RGB")
else:
raise ValueError("Input must be a PIL Image")
# Use the model to perform OCR
inputs = tokenizer(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Get the predictions
predictions = outputs.logits.argmax(dim=1).item()
return predictions
# Create the Gradio interface
iface = gr.Interface(
fn=perform_ocr,
inputs=gr.inputs.Image(type="pil"),
outputs="text",
title="OCR with GOT-OCR2.0",
description="Upload an image for Optical Character Recognition."
)
# Launch the interface
iface.launch()