File size: 2,147 Bytes
739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# app.py
import gradio as gr
import torch
from PIL import Image
from model import load_model
from utils import preprocess_image, decode_predictions
import os
# Load the model (ensure the path is correct)
MODEL_PATH = "saved_models/finetuned/finetuned_recog_model.pth"
FONT_PATH = "fonts/NotoSansEthiopic-Regular.ttf" # Update the path to your font
# Check if model file exists
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please provide the correct path.")
# Check if font file exists
if not os.path.exists(FONT_PATH):
raise FileNotFoundError(f"Font file not found at {FONT_PATH}. Please provide the correct path.")
# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(MODEL_PATH, device=device)
# Load the font for rendering Amharic text
from matplotlib import font_manager as fm
import matplotlib.pyplot as plt
ethiopic_font = fm.FontProperties(fname=FONT_PATH, size=15)
def recognize_text(image: Image.Image):
"""
Function to recognize text from an image.
"""
# Preprocess the image
input_tensor = preprocess_image(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
# Perform inference
with torch.no_grad():
log_probs = model(input_tensor) # [H*W, 1, vocab_size]
# Decode predictions
recognized_texts = decode_predictions(log_probs)
return recognized_texts[0]
def display_image_with_text(image: Image.Image, recognized_text: str):
"""
Function to display the image with recognized text.
"""
plt.figure(figsize=(6,6))
plt.imshow(image)
plt.axis('off')
plt.title(f"Recognized Text: {recognized_text}", fontproperties=ethiopic_font)
plt.show()
return plt
# Define Gradio Interface
iface = gr.Interface(
fn=recognize_text,
inputs=gr.inputs.Image(type="pil"),
outputs=gr.outputs.Textbox(),
title="Amharic Text Recognition",
description="Upload an image containing Amharic text, and the model will recognize and display the text."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()
|