Gizachew commited on
Commit
52540c8
·
verified ·
1 Parent(s): 160eac3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -29
app.py CHANGED
@@ -1,15 +1,15 @@
1
- # app.py
2
 
3
  import gradio as gr
4
  import torch
5
- from PIL import Image
6
  from model import load_model
7
  from utils import preprocess_image, decode_predictions
8
  import os
9
 
10
  # Load the model (ensure the path is correct)
11
- MODEL_PATH = "./finetuned_recog_model.pth"
12
- FONT_PATH = "./NotoSansEthiopic-Regular.ttf" # Update the path to your font
13
 
14
  # Check if model file exists
15
  if not os.path.exists(MODEL_PATH):
@@ -28,41 +28,36 @@ from matplotlib import font_manager as fm
28
  import matplotlib.pyplot as plt
29
 
30
  ethiopic_font = fm.FontProperties(fname=FONT_PATH, size=15)
 
31
 
32
- def recognize_text(image: Image.Image):
33
  """
34
- Function to recognize text from an image.
 
35
  """
36
- # Preprocess the image
37
- input_tensor = preprocess_image(image).unsqueeze(0).to(device) # [1, 3, 224, 224]
38
 
39
- # Perform inference
40
- with torch.no_grad():
41
- log_probs = model(input_tensor) # [H*W, 1, vocab_size]
42
 
43
- # Decode predictions
44
- recognized_texts = decode_predictions(log_probs)
 
 
 
45
 
46
- return recognized_texts[0]
47
-
48
- def display_image_with_text(image: Image.Image, recognized_text: str):
49
- """
50
- Function to display the image with recognized text.
51
- """
52
- plt.figure(figsize=(6,6))
53
- plt.imshow(image)
54
- plt.axis('off')
55
- plt.title(f"Recognized Text: {recognized_text}", fontproperties=ethiopic_font)
56
- plt.show()
57
- return plt
58
 
59
  # Define Gradio Interface
60
  iface = gr.Interface(
61
- fn=recognize_text,
62
- inputs=gr.inputs.Image(type="pil"),
63
- outputs=gr.outputs.Textbox(),
 
 
 
64
  title="Amharic Text Recognition",
65
- description="Upload an image containing Amharic text, and the model will recognize and display the text."
66
  )
67
 
68
  # Launch the Gradio app
 
1
+ # app.py (Enhanced to display both image and text)
2
 
3
  import gradio as gr
4
  import torch
5
+ from PIL import Image, ImageDraw, ImageFont
6
  from model import load_model
7
  from utils import preprocess_image, decode_predictions
8
  import os
9
 
10
  # Load the model (ensure the path is correct)
11
+ MODEL_PATH = "finetuned_recog_model.pth"
12
+ FONT_PATH = "NotoSansEthiopic-Regular.ttf" # Update the path to your font
13
 
14
  # Check if model file exists
15
  if not os.path.exists(MODEL_PATH):
 
28
  import matplotlib.pyplot as plt
29
 
30
  ethiopic_font = fm.FontProperties(fname=FONT_PATH, size=15)
31
+ pil_font = ImageFont.truetype(FONT_PATH, size=20)
32
 
33
+ def recognize_and_overlay(image: Image.Image):
34
  """
35
+ Function to recognize text and overlay it on the image.
36
+ Returns both the modified image and the recognized text.
37
  """
38
+ recognized_text = recognize_text(image)
 
39
 
40
+ # Create a copy of the image to draw on
41
+ image_with_text = image.copy()
 
42
 
43
+ # Overlay text on the image
44
+ draw = ImageDraw.Draw(image_with_text)
45
+ text_position = (10, 10) # Top-left corner
46
+ text_color = (255, 0, 0) # Red color
47
+ draw.text(text_position, f"Recognized: {recognized_text}", font=pil_font, fill=text_color)
48
 
49
+ return image_with_text, recognized_text
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # Define Gradio Interface
52
  iface = gr.Interface(
53
+ fn=recognize_and_overlay,
54
+ inputs=gr.Image(type="pil", label="Upload Image"),
55
+ outputs=[
56
+ gr.Image(type="pil", label="Image with Recognized Text"),
57
+ gr.Textbox(label="Recognized Text")
58
+ ],
59
  title="Amharic Text Recognition",
60
+ description="Upload an image containing Amharic text. The app will recognize and overlay the text on the image."
61
  )
62
 
63
  # Launch the Gradio app