File size: 2,393 Bytes
4463d10
739fe18
84982b3
739fe18
52540c8
84982b3
 
 
739fe18
84982b3
52540c8
 
739fe18
84982b3
 
 
739fe18
84982b3
 
 
739fe18
84982b3
 
 
739fe18
84982b3
 
 
739fe18
84982b3
52540c8
739fe18
4463d10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739fe18
52540c8
739fe18
52540c8
739fe18
52540c8
4463d10
52540c8
 
 
739fe18
4463d10
739fe18
84982b3
739fe18
52540c8
 
4463d10
84982b3
52540c8
739fe18
 
84982b3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# app.py

import gradio as gr
import torch
from PIL import Image, ImageDraw, ImageFont
from model import load_model
from utils import preprocess_image, decode_predictions
import os

# Load the model (ensure the path is correct)
MODEL_PATH = "finetuned_recog_model.pth"
FONT_PATH = "NotoSansEthiopic-Regular.ttf"  # Update the path to your font

# Check if model file exists
if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please provide the correct path.")

# Check if font file exists
if not os.path.exists(FONT_PATH):
    raise FileNotFoundError(f"Font file not found at {FONT_PATH}. Please provide the correct path.")

# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = load_model(MODEL_PATH, device=device)

# Load the font for rendering Amharic text
from matplotlib import font_manager as fm
import matplotlib.pyplot as plt

ethiopic_font = fm.FontProperties(fname=FONT_PATH, size=15)
pil_font = ImageFont.truetype(FONT_PATH, size=20)

def recognize_text(image: Image.Image) -> str:
    """
    Function to recognize text from an image.
    """
    # Preprocess the image
    input_tensor = preprocess_image(image).unsqueeze(0).to(device)  # [1, 3, 224, 224]
    
    # Perform inference
    with torch.no_grad():
        log_probs = model(input_tensor)  # [H*W, 1, vocab_size]
    
    # Decode predictions
    recognized_texts = decode_predictions(log_probs)
    
    return recognized_texts[0]

def recognize_and_overlay(image: Image.Image) -> Image.Image:
    """
    Function to recognize text and overlay it on the image.
    """
    recognized_text = recognize_text(image)
    
    # Overlay text on the image
    draw = ImageDraw.Draw(image)
    text_position = (10, 10)  # Top-left corner
    text_color = (255, 0, 0)  # Red color
    draw.text(text_position, f"Recognized: {recognized_text}", font=pil_font, fill=text_color)
    
    return image

# Define Gradio Interface
iface = gr.Interface(
    fn=recognize_and_overlay,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Image(type="pil", label="Image with Recognized Text"),
    title="Amharic Text Recognition",
    description="Upload an image containing Amharic text. The app will recognize and overlay the text on the image."
)

# Launch the Gradio app
if __name__ == "__main__":
    iface.launch()