|
|
|
|
|
import gradio as gr |
|
import torch |
|
from PIL import Image, ImageDraw, ImageFont |
|
from model import load_model |
|
from utils import preprocess_image, decode_predictions |
|
import os |
|
|
|
|
|
MODEL_PATH = "finetuned_recog_model.pth" |
|
FONT_PATH = "NotoSansEthiopic-Regular.ttf" |
|
|
|
|
|
if not os.path.exists(MODEL_PATH): |
|
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please provide the correct path.") |
|
|
|
|
|
if not os.path.exists(FONT_PATH): |
|
raise FileNotFoundError(f"Font file not found at {FONT_PATH}. Please provide the correct path.") |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
model = load_model(MODEL_PATH, device=device) |
|
|
|
|
|
from matplotlib import font_manager as fm |
|
import matplotlib.pyplot as plt |
|
|
|
ethiopic_font = fm.FontProperties(fname=FONT_PATH, size=15) |
|
pil_font = ImageFont.truetype(FONT_PATH, size=20) |
|
|
|
def recognize_and_overlay(image: Image.Image): |
|
""" |
|
Function to recognize text and overlay it on the image. |
|
Returns both the modified image and the recognized text. |
|
""" |
|
recognized_text = recognize_text(image) |
|
|
|
|
|
image_with_text = image.copy() |
|
|
|
|
|
draw = ImageDraw.Draw(image_with_text) |
|
text_position = (10, 10) |
|
text_color = (255, 0, 0) |
|
draw.text(text_position, f"Recognized: {recognized_text}", font=pil_font, fill=text_color) |
|
|
|
return image_with_text, recognized_text |
|
|
|
|
|
iface = gr.Interface( |
|
fn=recognize_and_overlay, |
|
inputs=gr.Image(type="pil", label="Upload Image"), |
|
outputs=[ |
|
gr.Image(type="pil", label="Image with Recognized Text"), |
|
gr.Textbox(label="Recognized Text") |
|
], |
|
title="Amharic Text Recognition", |
|
description="Upload an image containing Amharic text. The app will recognize and overlay the text on the image." |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|