File size: 2,393 Bytes
4463d10 739fe18 84982b3 739fe18 52540c8 84982b3 739fe18 84982b3 52540c8 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 739fe18 84982b3 52540c8 739fe18 4463d10 739fe18 52540c8 739fe18 52540c8 739fe18 52540c8 4463d10 52540c8 739fe18 4463d10 739fe18 84982b3 739fe18 52540c8 4463d10 84982b3 52540c8 739fe18 84982b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
# app.py
import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from model import load_model
from utils import preprocess_image, decode_predictions
import os
# Load the model (ensure the path is correct)
MODEL_PATH = "finetuned_recog_model.pth"
FONT_PATH = "NotoSansEthiopic-Regular.ttf" # Update the path to your font
# Check if model file exists
if not os.path.exists(MODEL_PATH):
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please provide the correct path.")
# Check if font file exists
if not os.path.exists(FONT_PATH):
raise FileNotFoundError(f"Font file not found at {FONT_PATH}. Please provide the correct path.")
# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(MODEL_PATH, device=device)
# Load the font for rendering Amharic text
from matplotlib import font_manager as fm
import matplotlib.pyplot as plt
ethiopic_font = fm.FontProperties(fname=FONT_PATH, size=15)
pil_font = ImageFont.truetype(FONT_PATH, size=20)
def recognize_text(image: Image.Image) -> str:
"""
Function to recognize text from an image.
"""
# Preprocess the image
input_tensor = preprocess_image(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
# Perform inference
with torch.no_grad():
log_probs = model(input_tensor) # [H*W, 1, vocab_size]
# Decode predictions
recognized_texts = decode_predictions(log_probs)
return recognized_texts[0]
def recognize_and_overlay(image: Image.Image) -> Image.Image:
"""
Function to recognize text and overlay it on the image.
"""
recognized_text = recognize_text(image)
# Overlay text on the image
draw = ImageDraw.Draw(image)
text_position = (10, 10) # Top-left corner
text_color = (255, 0, 0) # Red color
draw.text(text_position, f"Recognized: {recognized_text}", font=pil_font, fill=text_color)
return image
# Define Gradio Interface
iface = gr.Interface(
fn=recognize_and_overlay,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Image(type="pil", label="Image with Recognized Text"),
title="Amharic Text Recognition",
description="Upload an image containing Amharic text. The app will recognize and overlay the text on the image."
)
# Launch the Gradio app
if __name__ == "__main__":
iface.launch()
|