Update app.py
Browse files
app.py
CHANGED
@@ -1,122 +1,249 @@
|
|
1 |
-
import os
|
2 |
-
import uuid
|
3 |
from flask import Flask, request, jsonify
|
4 |
from flask_cors import CORS
|
5 |
-
|
6 |
-
import
|
7 |
-
|
8 |
from PIL import Image
|
9 |
-
import
|
|
|
|
|
|
|
10 |
import numpy as np
|
11 |
|
|
|
|
|
|
|
|
|
12 |
app = Flask(__name__)
|
13 |
CORS(app)
|
14 |
|
15 |
-
#
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max upload
|
20 |
-
|
21 |
-
# Create uploads directory if it doesn't exist
|
22 |
-
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
23 |
-
|
24 |
-
# Load OCR model - Microsoft's TrOCR model
|
25 |
-
print("Loading OCR model...")
|
26 |
-
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten", cache_dir="/huggingface_cache")
|
27 |
-
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten", cache_dir="/huggingface_cache")
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
model.to(device)
|
33 |
-
|
34 |
-
def allowed_file(filename):
|
35 |
-
return '.' in filename and \
|
36 |
-
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
37 |
-
|
38 |
-
def preprocess_image(image_path):
|
39 |
-
# Open image with PIL
|
40 |
-
image = Image.open(image_path).convert("RGB")
|
41 |
-
|
42 |
-
# Basic enhancement for better OCR results
|
43 |
-
# Convert to OpenCV format for preprocessing
|
44 |
-
img = np.array(image)
|
45 |
-
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
46 |
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
|
56 |
-
def
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
|
62 |
|
63 |
-
#
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
)
|
70 |
|
71 |
-
|
72 |
-
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
73 |
-
return generated_text.strip()
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
-
|
|
|
|
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
|
87 |
-
#
|
88 |
-
if
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
})
|
107 |
-
except Exception as e:
|
108 |
-
# Log the error
|
109 |
-
print(f"Error processing image: {str(e)}")
|
110 |
-
return jsonify({
|
111 |
-
'success': False,
|
112 |
-
'error': str(e)
|
113 |
-
}), 500
|
114 |
else:
|
115 |
-
|
|
|
|
|
|
|
116 |
|
117 |
@app.route('/health', methods=['GET'])
|
118 |
def health_check():
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
if __name__ == '__main__':
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from flask import Flask, request, jsonify
|
2 |
from flask_cors import CORS
|
3 |
+
import base64
|
4 |
+
import io
|
5 |
+
import os
|
6 |
from PIL import Image
|
7 |
+
import logging
|
8 |
+
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
9 |
+
import torch
|
10 |
+
import easyocr
|
11 |
import numpy as np
|
12 |
|
13 |
+
# Set up logging
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
app = Flask(__name__)
|
18 |
CORS(app)
|
19 |
|
20 |
+
# Global variables for models
|
21 |
+
trocr_processor = None
|
22 |
+
trocr_model = None
|
23 |
+
easyocr_reader = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
def initialize_models():
|
26 |
+
"""Initialize OCR models"""
|
27 |
+
global trocr_processor, trocr_model, easyocr_reader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
+
try:
|
30 |
+
# Initialize TrOCR for handwritten text (Microsoft's model)
|
31 |
+
logger.info("Loading TrOCR model for handwritten text...")
|
32 |
+
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
|
33 |
+
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
|
34 |
+
|
35 |
+
# Initialize EasyOCR for printed text
|
36 |
+
logger.info("Loading EasyOCR for printed text...")
|
37 |
+
easyocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
|
38 |
+
|
39 |
+
logger.info("All models loaded successfully!")
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
logger.error(f"Error loading models: {str(e)}")
|
43 |
+
raise e
|
44 |
|
45 |
+
def preprocess_image(image):
|
46 |
+
"""Preprocess image for better OCR results"""
|
47 |
+
# Convert to RGB if needed
|
48 |
+
if image.mode != 'RGB':
|
49 |
+
image = image.convert('RGB')
|
|
|
50 |
|
51 |
+
# Resize if image is too large
|
52 |
+
max_size = 1024
|
53 |
+
if max(image.size) > max_size:
|
54 |
+
ratio = max_size / max(image.size)
|
55 |
+
new_size = tuple(int(dim * ratio) for dim in image.size)
|
56 |
+
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
|
|
57 |
|
58 |
+
return image
|
|
|
|
|
59 |
|
60 |
+
def extract_text_trocr(image):
|
61 |
+
"""Extract text using TrOCR (good for handwritten text)"""
|
62 |
+
try:
|
63 |
+
# Preprocess image
|
64 |
+
image = preprocess_image(image)
|
65 |
+
|
66 |
+
# Generate pixel values
|
67 |
+
pixel_values = trocr_processor(image, return_tensors="pt").pixel_values
|
68 |
+
|
69 |
+
# Generate text
|
70 |
+
generated_ids = trocr_model.generate(pixel_values)
|
71 |
+
generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
72 |
+
|
73 |
+
return generated_text.strip()
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"TrOCR error: {str(e)}")
|
76 |
+
return ""
|
77 |
+
|
78 |
+
def extract_text_easyocr(image):
|
79 |
+
"""Extract text using EasyOCR (good for printed text)"""
|
80 |
+
try:
|
81 |
+
# Convert PIL image to numpy array
|
82 |
+
image_np = np.array(preprocess_image(image))
|
83 |
+
|
84 |
+
# Extract text
|
85 |
+
results = easyocr_reader.readtext(image_np, detail=0)
|
86 |
+
|
87 |
+
# Join all detected text
|
88 |
+
extracted_text = ' '.join(results)
|
89 |
+
return extracted_text.strip()
|
90 |
+
except Exception as e:
|
91 |
+
logger.error(f"EasyOCR error: {str(e)}")
|
92 |
+
return ""
|
93 |
+
|
94 |
+
def process_image_ocr(image, ocr_type="auto"):
|
95 |
+
"""Process image with specified OCR method"""
|
96 |
+
results = {}
|
97 |
|
98 |
+
if ocr_type in ["auto", "handwritten", "trocr"]:
|
99 |
+
trocr_text = extract_text_trocr(image)
|
100 |
+
results["trocr"] = trocr_text
|
101 |
|
102 |
+
if ocr_type in ["auto", "printed", "easyocr"]:
|
103 |
+
easyocr_text = extract_text_easyocr(image)
|
104 |
+
results["easyocr"] = easyocr_text
|
105 |
|
106 |
+
# For auto mode, return the longer result or combine both
|
107 |
+
if ocr_type == "auto":
|
108 |
+
trocr_len = len(results.get("trocr", ""))
|
109 |
+
easyocr_len = len(results.get("easyocr", ""))
|
110 |
+
|
111 |
+
if trocr_len > 0 and easyocr_len > 0:
|
112 |
+
# If both have results, combine them intelligently
|
113 |
+
if abs(trocr_len - easyocr_len) / max(trocr_len, easyocr_len) < 0.3:
|
114 |
+
# If lengths are similar, prefer EasyOCR for printed text
|
115 |
+
results["final"] = results["easyocr"]
|
116 |
+
else:
|
117 |
+
# Use the longer result
|
118 |
+
results["final"] = results["trocr"] if trocr_len > easyocr_len else results["easyocr"]
|
119 |
+
elif trocr_len > 0:
|
120 |
+
results["final"] = results["trocr"]
|
121 |
+
elif easyocr_len > 0:
|
122 |
+
results["final"] = results["easyocr"]
|
123 |
+
else:
|
124 |
+
results["final"] = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
else:
|
126 |
+
# Return the specific model result
|
127 |
+
results["final"] = results.get(ocr_type.replace("handwritten", "trocr").replace("printed", "easyocr"), "")
|
128 |
+
|
129 |
+
return results
|
130 |
|
131 |
@app.route('/health', methods=['GET'])
|
132 |
def health_check():
|
133 |
+
"""Health check endpoint"""
|
134 |
+
return jsonify({"status": "healthy", "models_loaded": True})
|
135 |
+
|
136 |
+
@app.route('/ocr', methods=['POST'])
|
137 |
+
def ocr_endpoint():
|
138 |
+
"""Main OCR endpoint"""
|
139 |
+
try:
|
140 |
+
# Check if image is provided
|
141 |
+
if 'image' not in request.files and 'image_base64' not in request.json:
|
142 |
+
return jsonify({"error": "No image provided"}), 400
|
143 |
+
|
144 |
+
# Get OCR type preference
|
145 |
+
ocr_type = request.form.get('type', 'auto') # auto, handwritten, printed
|
146 |
+
|
147 |
+
# Load image
|
148 |
+
if 'image' in request.files:
|
149 |
+
# File upload
|
150 |
+
image_file = request.files['image']
|
151 |
+
image = Image.open(image_file.stream)
|
152 |
+
else:
|
153 |
+
# Base64 image
|
154 |
+
image_data = request.json['image_base64']
|
155 |
+
if image_data.startswith('data:image'):
|
156 |
+
# Remove data URL prefix
|
157 |
+
image_data = image_data.split(',')[1]
|
158 |
+
|
159 |
+
# Decode base64
|
160 |
+
image_bytes = base64.b64decode(image_data)
|
161 |
+
image = Image.open(io.BytesIO(image_bytes))
|
162 |
+
|
163 |
+
# Process image
|
164 |
+
results = process_image_ocr(image, ocr_type)
|
165 |
+
|
166 |
+
response = {
|
167 |
+
"success": True,
|
168 |
+
"text": results["final"],
|
169 |
+
"type_used": ocr_type,
|
170 |
+
"details": {
|
171 |
+
"trocr_result": results.get("trocr", ""),
|
172 |
+
"easyocr_result": results.get("easyocr", "")
|
173 |
+
} if ocr_type == "auto" else {}
|
174 |
+
}
|
175 |
+
|
176 |
+
return jsonify(response)
|
177 |
+
|
178 |
+
except Exception as e:
|
179 |
+
logger.error(f"OCR processing error: {str(e)}")
|
180 |
+
return jsonify({"error": str(e), "success": False}), 500
|
181 |
+
|
182 |
+
@app.route('/ocr/batch', methods=['POST'])
|
183 |
+
def batch_ocr_endpoint():
|
184 |
+
"""Batch OCR endpoint for multiple images"""
|
185 |
+
try:
|
186 |
+
if 'images' not in request.files:
|
187 |
+
return jsonify({"error": "No images provided"}), 400
|
188 |
+
|
189 |
+
images = request.files.getlist('images')
|
190 |
+
ocr_type = request.form.get('type', 'auto')
|
191 |
+
|
192 |
+
results = []
|
193 |
+
for i, image_file in enumerate(images):
|
194 |
+
try:
|
195 |
+
image = Image.open(image_file.stream)
|
196 |
+
ocr_results = process_image_ocr(image, ocr_type)
|
197 |
+
|
198 |
+
results.append({
|
199 |
+
"index": i,
|
200 |
+
"filename": image_file.filename,
|
201 |
+
"text": ocr_results["final"],
|
202 |
+
"success": True
|
203 |
+
})
|
204 |
+
except Exception as e:
|
205 |
+
results.append({
|
206 |
+
"index": i,
|
207 |
+
"filename": image_file.filename,
|
208 |
+
"error": str(e),
|
209 |
+
"success": False
|
210 |
+
})
|
211 |
+
|
212 |
+
return jsonify({
|
213 |
+
"success": True,
|
214 |
+
"results": results,
|
215 |
+
"total_processed": len(results)
|
216 |
+
})
|
217 |
+
|
218 |
+
except Exception as e:
|
219 |
+
logger.error(f"Batch OCR error: {str(e)}")
|
220 |
+
return jsonify({"error": str(e), "success": False}), 500
|
221 |
+
|
222 |
+
@app.route('/models/info', methods=['GET'])
|
223 |
+
def models_info():
|
224 |
+
"""Get information about loaded models"""
|
225 |
+
return jsonify({
|
226 |
+
"models": {
|
227 |
+
"trocr": {
|
228 |
+
"name": "microsoft/trocr-base-handwritten",
|
229 |
+
"description": "Handwritten text recognition",
|
230 |
+
"loaded": trocr_model is not None
|
231 |
+
},
|
232 |
+
"easyocr": {
|
233 |
+
"name": "EasyOCR",
|
234 |
+
"description": "Printed text recognition",
|
235 |
+
"loaded": easyocr_reader is not None
|
236 |
+
}
|
237 |
+
},
|
238 |
+
"supported_types": ["auto", "handwritten", "printed"],
|
239 |
+
"supported_formats": ["PNG", "JPEG", "JPG", "BMP", "TIFF"]
|
240 |
+
})
|
241 |
|
242 |
if __name__ == '__main__':
|
243 |
+
# Initialize models on startup
|
244 |
+
logger.info("Starting OCR service...")
|
245 |
+
initialize_models()
|
246 |
+
|
247 |
+
# Run the app
|
248 |
+
port = int(os.environ.get('PORT', 5000))
|
249 |
+
app.run(host='0.0.0.0', port=port, debug=False)
|