DeepDiveDev commited on
Commit
e638a74
·
verified ·
1 Parent(s): c3163b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -56
app.py CHANGED
@@ -1,63 +1,129 @@
1
- import torch
2
- import cv2
3
- import json
4
- import xml.etree.ElementTree as ET
5
  import gradio as gr
6
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel, GPT2LMHeadModel, GPT2Tokenizer
7
-
8
- # Load OCR model (TrOCR)
9
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
10
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
11
-
12
- # Load GPT-2 model
13
- GPT2_model = GPT2LMHeadModel.from_pretrained("gpt2")
14
- GPT2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
15
-
16
- # Image preprocessing
17
- def preprocess_image(image_path):
18
- image = cv2.imread(image_path)
19
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
20
- return gray
21
 
22
- # Extract text using TrOCR (instead of Tesseract)
23
- def extract_text(image_path):
24
- image = preprocess_image(image_path)
25
- pixel_values = processor(image, return_tensors="pt").pixel_values
26
- generated_ids = model.generate(pixel_values)
27
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
28
- return text
29
 
30
- # Generate structured format (JSON/XML)
31
- def generate_machine_readable_format(text, format_type='json'):
32
- if format_type == 'json':
33
- return json.dumps({"content": text})
34
- elif format_type == 'xml':
35
- root = ET.Element("Document")
36
- content = ET.SubElement(root, "Content")
37
- content.text = text
38
- return ET.tostring(root, encoding='unicode')
39
- return text
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- # GPT-2 for structured output
42
- def generate_structured_output(text):
43
- inputs = GPT2_tokenizer.encode(text, return_tensors="pt")
44
- outputs = GPT2_model.generate(inputs, max_length=500)
45
- return GPT2_tokenizer.decode(outputs[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # Convert document
48
- def convert_document(image, output_format='json'):
49
- text = extract_text(image)
50
- structured_output = generate_structured_output(text)
51
- machine_readable_output = generate_machine_readable_format(structured_output, format_type=output_format)
52
- return machine_readable_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Gradio UI
55
- iface = gr.Interface(
56
- fn=convert_document,
57
- inputs=[gr.Image(type="filepath"), gr.Radio(["json", "xml"], label="Output Format")],
58
- outputs="text",
59
- title="Document OCR and Conversion",
60
- description="Extracts text from images and converts it into structured JSON/XML format."
61
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- iface.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import cv2
3
+ import numpy as np
4
+ import pytesseract
5
+ from PIL import Image
6
+ import io
7
+ import matplotlib.pyplot as plt
 
 
 
 
 
 
 
 
 
8
 
9
+ # Configure pytesseract path (adjust this based on your installation)
10
+ # pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe' # Uncomment and modify for Windows
 
 
 
 
 
11
 
12
+ def preprocess_image(image):
13
+ """Preprocess the image to improve OCR accuracy for handwritten text"""
14
+ # Convert to grayscale if it's a color image
15
+ if len(image.shape) == 3:
16
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
17
+ else:
18
+ gray = image.copy()
19
+
20
+ # Apply adaptive thresholding
21
+ thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
22
+ cv2.THRESH_BINARY_INV, 11, 2)
23
+
24
+ # Noise removal using morphological operations
25
+ kernel = np.ones((1, 1), np.uint8)
26
+ opening = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
27
+
28
+ # Dilate to connect components
29
+ kernel = np.ones((2, 2), np.uint8)
30
+ dilated = cv2.dilate(opening, kernel, iterations=1)
31
+
32
+ return dilated
33
 
34
+ def perform_ocr(input_image):
35
+ """Process the image and perform OCR"""
36
+ if input_image is None:
37
+ return "No image provided", None
38
+
39
+ # Convert from RGB to BGR (OpenCV format)
40
+ image_np = np.array(input_image)
41
+ if len(image_np.shape) == 3:
42
+ image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
43
+
44
+ # Preprocess the image
45
+ preprocessed = preprocess_image(image_np)
46
+
47
+ # Convert back to PIL for visualization
48
+ pil_preprocessed = Image.fromarray(preprocessed)
49
+
50
+ # Use pytesseract with specific configurations for handwritten text
51
+ custom_config = r'--oem 3 --psm 6 -l eng -c preserve_interword_spaces=1 tessedit_char_whitelist="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789.,;:\'\"()[]{}!?-+*/=><_%$#@&|~^`\\ "'
52
+
53
+ # Perform OCR
54
+ extracted_text = pytesseract.image_to_string(pil_preprocessed, config=custom_config)
55
+
56
+ # Return the extracted text and the preprocessed image for visualization
57
+ return extracted_text, pil_preprocessed
58
 
59
+ def ocr_pipeline(input_image):
60
+ """Complete OCR pipeline with visualization"""
61
+
62
+ extracted_text, preprocessed_image = perform_ocr(input_image)
63
+
64
+ # Create visualization
65
+ if input_image is not None and preprocessed_image is not None:
66
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
67
+ ax1.imshow(input_image)
68
+ ax1.set_title("Original Image")
69
+ ax1.axis("off")
70
+
71
+ ax2.imshow(preprocessed_image, cmap='gray')
72
+ ax2.set_title("Preprocessed Image")
73
+ ax2.axis("off")
74
+
75
+ plt.tight_layout()
76
+
77
+ # Convert plot to image
78
+ buf = io.BytesIO()
79
+ plt.savefig(buf, format='png')
80
+ buf.seek(0)
81
+ viz_img = Image.open(buf)
82
+ plt.close(fig)
83
+
84
+ return extracted_text, viz_img
85
+
86
+ return extracted_text, None
87
 
88
+ # Create the Gradio interface
89
+ with gr.Blocks(title="Handwritten OCR App") as app:
90
+ gr.Markdown("# Handwritten Text OCR Extraction")
91
+ gr.Markdown("""
92
+ This app extracts text from handwritten notes.
93
+ Upload an image containing handwritten text and the app will convert it to digital text.
94
+ """)
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ input_image = gr.Image(type="pil", label="Upload Handwritten Image")
99
+ run_button = gr.Button("Extract Text")
100
+
101
+ with gr.Column():
102
+ output_text = gr.Textbox(label="Extracted Text", lines=15)
103
+ processed_image = gr.Image(label="Preprocessing Visualization")
104
+
105
+ run_button.click(
106
+ fn=ocr_pipeline,
107
+ inputs=input_image,
108
+ outputs=[output_text, processed_image]
109
+ )
110
+
111
+ gr.Markdown("""
112
+ ## Tips for better results:
113
+ - Ensure good lighting and contrast in the image
114
+ - Try to keep the text as horizontal as possible
115
+ - Clear handwriting works best
116
+ - For better results, you may need to crop the image to focus on specific sections
117
+ """)
118
+
119
+ # Add example images
120
+ gr.Examples(
121
+ examples=[
122
+ "handwritten_sample.jpg", # Replace with your example image paths
123
+ ],
124
+ inputs=input_image,
125
+ )
126
 
127
+ # Launch the app
128
+ if __name__ == "__main__":
129
+ app.launch()