Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,63 +1,129 @@
|
|
1 |
-
import torch
|
2 |
-
import cv2
|
3 |
-
import json
|
4 |
-
import xml.etree.ElementTree as ET
|
5 |
import gradio as gr
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
# Load GPT-2 model
|
13 |
-
GPT2_model = GPT2LMHeadModel.from_pretrained("gpt2")
|
14 |
-
GPT2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
15 |
-
|
16 |
-
# Image preprocessing
|
17 |
-
def preprocess_image(image_path):
|
18 |
-
image = cv2.imread(image_path)
|
19 |
-
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
20 |
-
return gray
|
21 |
|
22 |
-
#
|
23 |
-
|
24 |
-
image = preprocess_image(image_path)
|
25 |
-
pixel_values = processor(image, return_tensors="pt").pixel_values
|
26 |
-
generated_ids = model.generate(pixel_values)
|
27 |
-
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
28 |
-
return text
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
if
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
-
# Gradio
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import pytesseract
|
5 |
+
from PIL import Image
|
6 |
+
import io
|
7 |
+
import matplotlib.pyplot as plt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
# Configure pytesseract path (adjust this based on your installation)
|
10 |
+
# pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' # Uncomment and modify for Windows
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
def preprocess_image(image):
|
13 |
+
"""Preprocess the image to improve OCR accuracy for handwritten text"""
|
14 |
+
# Convert to grayscale if it's a color image
|
15 |
+
if len(image.shape) == 3:
|
16 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
17 |
+
else:
|
18 |
+
gray = image.copy()
|
19 |
+
|
20 |
+
# Apply adaptive thresholding
|
21 |
+
thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
22 |
+
cv2.THRESH_BINARY_INV, 11, 2)
|
23 |
+
|
24 |
+
# Noise removal using morphological operations
|
25 |
+
kernel = np.ones((1, 1), np.uint8)
|
26 |
+
opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
|
27 |
+
|
28 |
+
# Dilate to connect components
|
29 |
+
kernel = np.ones((2, 2), np.uint8)
|
30 |
+
dilated = cv2.dilate(opening, kernel, iterations=1)
|
31 |
+
|
32 |
+
return dilated
|
33 |
|
34 |
+
def perform_ocr(input_image):
|
35 |
+
"""Process the image and perform OCR"""
|
36 |
+
if input_image is None:
|
37 |
+
return "No image provided", None
|
38 |
+
|
39 |
+
# Convert from RGB to BGR (OpenCV format)
|
40 |
+
image_np = np.array(input_image)
|
41 |
+
if len(image_np.shape) == 3:
|
42 |
+
image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
|
43 |
+
|
44 |
+
# Preprocess the image
|
45 |
+
preprocessed = preprocess_image(image_np)
|
46 |
+
|
47 |
+
# Convert back to PIL for visualization
|
48 |
+
pil_preprocessed = Image.fromarray(preprocessed)
|
49 |
+
|
50 |
+
# Use pytesseract with specific configurations for handwritten text
|
51 |
+
custom_config = r'--oem 3 --psm 6 -l eng -c preserve_interword_spaces=1 tessedit_char_whitelist="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,;:\'\"()[]{}!?-+*/=><_%$#@&|~^`\\ "'
|
52 |
+
|
53 |
+
# Perform OCR
|
54 |
+
extracted_text = pytesseract.image_to_string(pil_preprocessed, config=custom_config)
|
55 |
+
|
56 |
+
# Return the extracted text and the preprocessed image for visualization
|
57 |
+
return extracted_text, pil_preprocessed
|
58 |
|
59 |
+
def ocr_pipeline(input_image):
|
60 |
+
"""Complete OCR pipeline with visualization"""
|
61 |
+
|
62 |
+
extracted_text, preprocessed_image = perform_ocr(input_image)
|
63 |
+
|
64 |
+
# Create visualization
|
65 |
+
if input_image is not None and preprocessed_image is not None:
|
66 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
|
67 |
+
ax1.imshow(input_image)
|
68 |
+
ax1.set_title("Original Image")
|
69 |
+
ax1.axis("off")
|
70 |
+
|
71 |
+
ax2.imshow(preprocessed_image, cmap='gray')
|
72 |
+
ax2.set_title("Preprocessed Image")
|
73 |
+
ax2.axis("off")
|
74 |
+
|
75 |
+
plt.tight_layout()
|
76 |
+
|
77 |
+
# Convert plot to image
|
78 |
+
buf = io.BytesIO()
|
79 |
+
plt.savefig(buf, format='png')
|
80 |
+
buf.seek(0)
|
81 |
+
viz_img = Image.open(buf)
|
82 |
+
plt.close(fig)
|
83 |
+
|
84 |
+
return extracted_text, viz_img
|
85 |
+
|
86 |
+
return extracted_text, None
|
87 |
|
88 |
+
# Create the Gradio interface
|
89 |
+
with gr.Blocks(title="Handwritten OCR App") as app:
|
90 |
+
gr.Markdown("# Handwritten Text OCR Extraction")
|
91 |
+
gr.Markdown("""
|
92 |
+
This app extracts text from handwritten notes.
|
93 |
+
Upload an image containing handwritten text and the app will convert it to digital text.
|
94 |
+
""")
|
95 |
+
|
96 |
+
with gr.Row():
|
97 |
+
with gr.Column():
|
98 |
+
input_image = gr.Image(type="pil", label="Upload Handwritten Image")
|
99 |
+
run_button = gr.Button("Extract Text")
|
100 |
+
|
101 |
+
with gr.Column():
|
102 |
+
output_text = gr.Textbox(label="Extracted Text", lines=15)
|
103 |
+
processed_image = gr.Image(label="Preprocessing Visualization")
|
104 |
+
|
105 |
+
run_button.click(
|
106 |
+
fn=ocr_pipeline,
|
107 |
+
inputs=input_image,
|
108 |
+
outputs=[output_text, processed_image]
|
109 |
+
)
|
110 |
+
|
111 |
+
gr.Markdown("""
|
112 |
+
## Tips for better results:
|
113 |
+
- Ensure good lighting and contrast in the image
|
114 |
+
- Try to keep the text as horizontal as possible
|
115 |
+
- Clear handwriting works best
|
116 |
+
- For better results, you may need to crop the image to focus on specific sections
|
117 |
+
""")
|
118 |
+
|
119 |
+
# Add example images
|
120 |
+
gr.Examples(
|
121 |
+
examples=[
|
122 |
+
"handwritten_sample.jpg", # Replace with your example image paths
|
123 |
+
],
|
124 |
+
inputs=input_image,
|
125 |
+
)
|
126 |
|
127 |
+
# Launch the app
|
128 |
+
if __name__ == "__main__":
|
129 |
+
app.launch()
|