mike23415 commited on
Commit
e4d75fe
·
verified ·
1 Parent(s): 5309df3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -99
app.py CHANGED
@@ -1,122 +1,249 @@
1
- import os
2
- import uuid
3
  from flask import Flask, request, jsonify
4
  from flask_cors import CORS
5
- from werkzeug.utils import secure_filename
6
- import torch
7
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
8
  from PIL import Image
9
- import cv2
 
 
 
10
  import numpy as np
11
 
 
 
 
 
12
  app = Flask(__name__)
13
  CORS(app)
14
 
15
- # Configure upload folder
16
- UPLOAD_FOLDER = 'uploads'
17
- ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'pdf', 'tif', 'tiff'}
18
- app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
19
- app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 # 16MB max upload
20
-
21
- # Create uploads directory if it doesn't exist
22
- os.makedirs(UPLOAD_FOLDER, exist_ok=True)
23
-
24
- # Load OCR model - Microsoft's TrOCR model
25
- print("Loading OCR model...")
26
- processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten", cache_dir="/huggingface_cache")
27
- model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten", cache_dir="/huggingface_cache")
28
 
29
- # Move model to GPU if available
30
- device = "cuda" if torch.cuda.is_available() else "cpu"
31
- print(f"Using device: {device}")
32
- model.to(device)
33
-
34
- def allowed_file(filename):
35
- return '.' in filename and \
36
- filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
37
-
38
- def preprocess_image(image_path):
39
- # Open image with PIL
40
- image = Image.open(image_path).convert("RGB")
41
-
42
- # Basic enhancement for better OCR results
43
- # Convert to OpenCV format for preprocessing
44
- img = np.array(image)
45
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
46
 
47
- # Apply adaptive thresholding to handle varying lighting conditions
48
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
49
- thresh = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
50
- cv2.THRESH_BINARY, 11, 2)
51
-
52
- # Convert back to PIL
53
- enhanced_image = Image.fromarray(cv2.cvtColor(thresh, cv2.COLOR_GRAY2RGB))
54
- return enhanced_image
 
 
 
 
 
 
 
55
 
56
- def perform_ocr(image_path):
57
- # Preprocess the image
58
- image = preprocess_image(image_path)
59
-
60
- # Prepare image for the model
61
- pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
62
 
63
- # Generate text
64
- generated_ids = model.generate(
65
- pixel_values,
66
- max_length=64,
67
- num_beams=5,
68
- early_stopping=True
69
- )
70
 
71
- # Decode generated text
72
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
73
- return generated_text.strip()
74
 
75
- @app.route('/ocr', methods=['POST'])
76
- def ocr():
77
- # Check if a file was uploaded
78
- if 'file' not in request.files:
79
- return jsonify({'error': 'No file part'}), 400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- file = request.files['file']
 
 
82
 
83
- # Check if filename is empty
84
- if file.filename == '':
85
- return jsonify({'error': 'No selected file'}), 400
86
 
87
- # Check if file type is allowed
88
- if file and allowed_file(file.filename):
89
- # Create a unique filename
90
- filename = str(uuid.uuid4()) + '_' + secure_filename(file.filename)
91
- file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
92
-
93
- # Save the file
94
- file.save(file_path)
95
-
96
- try:
97
- # Perform OCR
98
- extracted_text = perform_ocr(file_path)
99
-
100
- # Clean up the file after processing
101
- os.remove(file_path)
102
-
103
- return jsonify({
104
- 'success': True,
105
- 'text': extracted_text
106
- })
107
- except Exception as e:
108
- # Log the error
109
- print(f"Error processing image: {str(e)}")
110
- return jsonify({
111
- 'success': False,
112
- 'error': str(e)
113
- }), 500
114
  else:
115
- return jsonify({'error': 'File type not allowed'}), 400
 
 
 
116
 
117
  @app.route('/health', methods=['GET'])
118
  def health_check():
119
- return jsonify({'status': 'healthy'}), 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  if __name__ == '__main__':
122
- app.run(host='0.0.0.0', port=5000, debug=False)
 
 
 
 
 
 
 
 
 
1
  from flask import Flask, request, jsonify
2
  from flask_cors import CORS
3
+ import base64
4
+ import io
5
+ import os
6
  from PIL import Image
7
+ import logging
8
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
9
+ import torch
10
+ import easyocr
11
  import numpy as np
12
 
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
  app = Flask(__name__)
18
  CORS(app)
19
 
20
+ # Global variables for models
21
+ trocr_processor = None
22
+ trocr_model = None
23
+ easyocr_reader = None
 
 
 
 
 
 
 
 
 
24
 
25
+ def initialize_models():
26
+ """Initialize OCR models"""
27
+ global trocr_processor, trocr_model, easyocr_reader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ try:
30
+ # Initialize TrOCR for handwritten text (Microsoft's model)
31
+ logger.info("Loading TrOCR model for handwritten text...")
32
+ trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
33
+ trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
34
+
35
+ # Initialize EasyOCR for printed text
36
+ logger.info("Loading EasyOCR for printed text...")
37
+ easyocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
38
+
39
+ logger.info("All models loaded successfully!")
40
+
41
+ except Exception as e:
42
+ logger.error(f"Error loading models: {str(e)}")
43
+ raise e
44
 
45
+ def preprocess_image(image):
46
+ """Preprocess image for better OCR results"""
47
+ # Convert to RGB if needed
48
+ if image.mode != 'RGB':
49
+ image = image.convert('RGB')
 
50
 
51
+ # Resize if image is too large
52
+ max_size = 1024
53
+ if max(image.size) > max_size:
54
+ ratio = max_size / max(image.size)
55
+ new_size = tuple(int(dim * ratio) for dim in image.size)
56
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
 
57
 
58
+ return image
 
 
59
 
60
+ def extract_text_trocr(image):
61
+ """Extract text using TrOCR (good for handwritten text)"""
62
+ try:
63
+ # Preprocess image
64
+ image = preprocess_image(image)
65
+
66
+ # Generate pixel values
67
+ pixel_values = trocr_processor(image, return_tensors="pt").pixel_values
68
+
69
+ # Generate text
70
+ generated_ids = trocr_model.generate(pixel_values)
71
+ generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
72
+
73
+ return generated_text.strip()
74
+ except Exception as e:
75
+ logger.error(f"TrOCR error: {str(e)}")
76
+ return ""
77
+
78
+ def extract_text_easyocr(image):
79
+ """Extract text using EasyOCR (good for printed text)"""
80
+ try:
81
+ # Convert PIL image to numpy array
82
+ image_np = np.array(preprocess_image(image))
83
+
84
+ # Extract text
85
+ results = easyocr_reader.readtext(image_np, detail=0)
86
+
87
+ # Join all detected text
88
+ extracted_text = ' '.join(results)
89
+ return extracted_text.strip()
90
+ except Exception as e:
91
+ logger.error(f"EasyOCR error: {str(e)}")
92
+ return ""
93
+
94
+ def process_image_ocr(image, ocr_type="auto"):
95
+ """Process image with specified OCR method"""
96
+ results = {}
97
 
98
+ if ocr_type in ["auto", "handwritten", "trocr"]:
99
+ trocr_text = extract_text_trocr(image)
100
+ results["trocr"] = trocr_text
101
 
102
+ if ocr_type in ["auto", "printed", "easyocr"]:
103
+ easyocr_text = extract_text_easyocr(image)
104
+ results["easyocr"] = easyocr_text
105
 
106
+ # For auto mode, return the longer result or combine both
107
+ if ocr_type == "auto":
108
+ trocr_len = len(results.get("trocr", ""))
109
+ easyocr_len = len(results.get("easyocr", ""))
110
+
111
+ if trocr_len > 0 and easyocr_len > 0:
112
+ # If both have results, combine them intelligently
113
+ if abs(trocr_len - easyocr_len) / max(trocr_len, easyocr_len) < 0.3:
114
+ # If lengths are similar, prefer EasyOCR for printed text
115
+ results["final"] = results["easyocr"]
116
+ else:
117
+ # Use the longer result
118
+ results["final"] = results["trocr"] if trocr_len > easyocr_len else results["easyocr"]
119
+ elif trocr_len > 0:
120
+ results["final"] = results["trocr"]
121
+ elif easyocr_len > 0:
122
+ results["final"] = results["easyocr"]
123
+ else:
124
+ results["final"] = ""
 
 
 
 
 
 
 
 
125
  else:
126
+ # Return the specific model result
127
+ results["final"] = results.get(ocr_type.replace("handwritten", "trocr").replace("printed", "easyocr"), "")
128
+
129
+ return results
130
 
131
  @app.route('/health', methods=['GET'])
132
  def health_check():
133
+ """Health check endpoint"""
134
+ return jsonify({"status": "healthy", "models_loaded": True})
135
+
136
+ @app.route('/ocr', methods=['POST'])
137
+ def ocr_endpoint():
138
+ """Main OCR endpoint"""
139
+ try:
140
+ # Check if image is provided
141
+ if 'image' not in request.files and 'image_base64' not in request.json:
142
+ return jsonify({"error": "No image provided"}), 400
143
+
144
+ # Get OCR type preference
145
+ ocr_type = request.form.get('type', 'auto') # auto, handwritten, printed
146
+
147
+ # Load image
148
+ if 'image' in request.files:
149
+ # File upload
150
+ image_file = request.files['image']
151
+ image = Image.open(image_file.stream)
152
+ else:
153
+ # Base64 image
154
+ image_data = request.json['image_base64']
155
+ if image_data.startswith('data:image'):
156
+ # Remove data URL prefix
157
+ image_data = image_data.split(',')[1]
158
+
159
+ # Decode base64
160
+ image_bytes = base64.b64decode(image_data)
161
+ image = Image.open(io.BytesIO(image_bytes))
162
+
163
+ # Process image
164
+ results = process_image_ocr(image, ocr_type)
165
+
166
+ response = {
167
+ "success": True,
168
+ "text": results["final"],
169
+ "type_used": ocr_type,
170
+ "details": {
171
+ "trocr_result": results.get("trocr", ""),
172
+ "easyocr_result": results.get("easyocr", "")
173
+ } if ocr_type == "auto" else {}
174
+ }
175
+
176
+ return jsonify(response)
177
+
178
+ except Exception as e:
179
+ logger.error(f"OCR processing error: {str(e)}")
180
+ return jsonify({"error": str(e), "success": False}), 500
181
+
182
+ @app.route('/ocr/batch', methods=['POST'])
183
+ def batch_ocr_endpoint():
184
+ """Batch OCR endpoint for multiple images"""
185
+ try:
186
+ if 'images' not in request.files:
187
+ return jsonify({"error": "No images provided"}), 400
188
+
189
+ images = request.files.getlist('images')
190
+ ocr_type = request.form.get('type', 'auto')
191
+
192
+ results = []
193
+ for i, image_file in enumerate(images):
194
+ try:
195
+ image = Image.open(image_file.stream)
196
+ ocr_results = process_image_ocr(image, ocr_type)
197
+
198
+ results.append({
199
+ "index": i,
200
+ "filename": image_file.filename,
201
+ "text": ocr_results["final"],
202
+ "success": True
203
+ })
204
+ except Exception as e:
205
+ results.append({
206
+ "index": i,
207
+ "filename": image_file.filename,
208
+ "error": str(e),
209
+ "success": False
210
+ })
211
+
212
+ return jsonify({
213
+ "success": True,
214
+ "results": results,
215
+ "total_processed": len(results)
216
+ })
217
+
218
+ except Exception as e:
219
+ logger.error(f"Batch OCR error: {str(e)}")
220
+ return jsonify({"error": str(e), "success": False}), 500
221
+
222
+ @app.route('/models/info', methods=['GET'])
223
+ def models_info():
224
+ """Get information about loaded models"""
225
+ return jsonify({
226
+ "models": {
227
+ "trocr": {
228
+ "name": "microsoft/trocr-base-handwritten",
229
+ "description": "Handwritten text recognition",
230
+ "loaded": trocr_model is not None
231
+ },
232
+ "easyocr": {
233
+ "name": "EasyOCR",
234
+ "description": "Printed text recognition",
235
+ "loaded": easyocr_reader is not None
236
+ }
237
+ },
238
+ "supported_types": ["auto", "handwritten", "printed"],
239
+ "supported_formats": ["PNG", "JPEG", "JPG", "BMP", "TIFF"]
240
+ })
241
 
242
  if __name__ == '__main__':
243
+ # Initialize models on startup
244
+ logger.info("Starting OCR service...")
245
+ initialize_models()
246
+
247
+ # Run the app
248
+ port = int(os.environ.get('PORT', 5000))
249
+ app.run(host='0.0.0.0', port=port, debug=False)