amst / app.py
Gizachew's picture
Update app.py
84982b3 verified
raw
history blame
2.15 kB
# app.py
import gradio as gr
import torch
from PIL import Image
from model import load_model
from utils import preprocess_image, decode_predictions
import os
# Load the model (ensure the path is correct)
MODEL_PATH = "saved_models/finetuned/finetuned_recog_model.pth"
FONT_PATH = "fonts/NotoSansEthiopic-Regular.ttf" # Update the path to your font
# Check if model file exists
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please provide the correct path.")
# Check if font file exists
if not os.path.exists(FONT_PATH):
raise FileNotFoundError(f"Font file not found at {FONT_PATH}. Please provide the correct path.")
# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(MODEL_PATH, device=device)
# Load the font for rendering Amharic text
from matplotlib import font_manager as fm
import matplotlib.pyplot as plt
ethiopic_font = fm.FontProperties(fname=FONT_PATH, size=15)
def recognize_text(image: Image.Image):
"""
Function to recognize text from an image.
"""
# Preprocess the image
input_tensor = preprocess_image(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
# Perform inference
with torch.no_grad():
log_probs = model(input_tensor) # [H*W, 1, vocab_size]
# Decode predictions
recognized_texts = decode_predictions(log_probs)
return recognized_texts[0]
def display_image_with_text(image: Image.Image, recognized_text: str):
"""
Function to display the image with recognized text.
"""
plt.figure(figsize=(6,6))
plt.imshow(image)
plt.axis('off')
plt.title(f"Recognized Text: {recognized_text}", fontproperties=ethiopic_font)
plt.show()
return plt
# Define Gradio Interface
iface = gr.Interface(
fn=recognize_text,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Textbox(),
title="Amharic Text Recognition",
description="Upload an image containing Amharic text, and the model will recognize and display the text."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()