from flask import Flask, request, jsonify from flask_cors import CORS import base64 import io import os from PIL import Image import logging from transformers import TrOCRProcessor, VisionEncoderDecoderModel import torch import easyocr import numpy as np # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) CORS(app) # Global variables for models trocr_processor = None trocr_model = None easyocr_reader = None def initialize_models(): """Initialize OCR models""" global trocr_processor, trocr_model, easyocr_reader try: # Initialize TrOCR for handwritten text (Microsoft's model) logger.info("Loading TrOCR model for handwritten text...") trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") # Initialize EasyOCR for printed text logger.info("Loading EasyOCR for printed text...") easyocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) logger.info("All models loaded successfully!") except Exception as e: logger.error(f"Error loading models: {str(e)}") raise e def preprocess_image(image): """Preprocess image for better OCR results""" # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Resize if image is too large max_size = 1024 if max(image.size) > max_size: ratio = max_size / max(image.size) new_size = tuple(int(dim * ratio) for dim in image.size) image = image.resize(new_size, Image.Resampling.LANCZOS) return image def extract_text_trocr(image): """Extract text using TrOCR (good for handwritten text)""" try: # Preprocess image image = preprocess_image(image) # Generate pixel values pixel_values = trocr_processor(image, return_tensors="pt").pixel_values # Generate text generated_ids = trocr_model.generate(pixel_values) generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text.strip() except Exception as e: logger.error(f"TrOCR error: {str(e)}") return "" def extract_text_easyocr(image): """Extract text using EasyOCR (good for printed text)""" try: # Convert PIL image to numpy array image_np = np.array(preprocess_image(image)) # Extract text results = easyocr_reader.readtext(image_np, detail=0) # Join all detected text extracted_text = ' '.join(results) return extracted_text.strip() except Exception as e: logger.error(f"EasyOCR error: {str(e)}") return "" def process_image_ocr(image, ocr_type="auto"): """Process image with specified OCR method""" results = {} if ocr_type in ["auto", "handwritten", "trocr"]: trocr_text = extract_text_trocr(image) results["trocr"] = trocr_text if ocr_type in ["auto", "printed", "easyocr"]: easyocr_text = extract_text_easyocr(image) results["easyocr"] = easyocr_text # For auto mode, return the longer result or combine both if ocr_type == "auto": trocr_len = len(results.get("trocr", "")) easyocr_len = len(results.get("easyocr", "")) if trocr_len > 0 and easyocr_len > 0: # If both have results, combine them intelligently if abs(trocr_len - easyocr_len) / max(trocr_len, easyocr_len) < 0.3: # If lengths are similar, prefer EasyOCR for printed text results["final"] = results["easyocr"] else: # Use the longer result results["final"] = results["trocr"] if trocr_len > easyocr_len else results["easyocr"] elif trocr_len > 0: results["final"] = results["trocr"] elif easyocr_len > 0: results["final"] = results["easyocr"] else: results["final"] = "" else: # Return the specific model result results["final"] = results.get(ocr_type.replace("handwritten", "trocr").replace("printed", "easyocr"), "") return results @app.route('/health', methods=['GET']) def health_check(): """Health check endpoint""" return jsonify({"status": "healthy", "models_loaded": True}) @app.route('/ocr', methods=['POST']) def ocr_endpoint(): """Main OCR endpoint""" try: # Check if image is provided if 'image' not in request.files and 'image_base64' not in request.json: return jsonify({"error": "No image provided"}), 400 # Get OCR type preference ocr_type = request.form.get('type', 'auto') # auto, handwritten, printed # Load image if 'image' in request.files: # File upload image_file = request.files['image'] image = Image.open(image_file.stream) else: # Base64 image image_data = request.json['image_base64'] if image_data.startswith('data:image'): # Remove data URL prefix image_data = image_data.split(',')[1] # Decode base64 image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)) # Process image results = process_image_ocr(image, ocr_type) response = { "success": True, "text": results["final"], "type_used": ocr_type, "details": { "trocr_result": results.get("trocr", ""), "easyocr_result": results.get("easyocr", "") } if ocr_type == "auto" else {} } return jsonify(response) except Exception as e: logger.error(f"OCR processing error: {str(e)}") return jsonify({"error": str(e), "success": False}), 500 @app.route('/ocr/batch', methods=['POST']) def batch_ocr_endpoint(): """Batch OCR endpoint for multiple images""" try: if 'images' not in request.files: return jsonify({"error": "No images provided"}), 400 images = request.files.getlist('images') ocr_type = request.form.get('type', 'auto') results = [] for i, image_file in enumerate(images): try: image = Image.open(image_file.stream) ocr_results = process_image_ocr(image, ocr_type) results.append({ "index": i, "filename": image_file.filename, "text": ocr_results["final"], "success": True }) except Exception as e: results.append({ "index": i, "filename": image_file.filename, "error": str(e), "success": False }) return jsonify({ "success": True, "results": results, "total_processed": len(results) }) except Exception as e: logger.error(f"Batch OCR error: {str(e)}") return jsonify({"error": str(e), "success": False}), 500 @app.route('/models/info', methods=['GET']) def models_info(): """Get information about loaded models""" return jsonify({ "models": { "trocr": { "name": "microsoft/trocr-base-handwritten", "description": "Handwritten text recognition", "loaded": trocr_model is not None }, "easyocr": { "name": "EasyOCR", "description": "Printed text recognition", "loaded": easyocr_reader is not None } }, "supported_types": ["auto", "handwritten", "printed"], "supported_formats": ["PNG", "JPEG", "JPG", "BMP", "TIFF"] }) if __name__ == '__main__': # Initialize models on startup logger.info("Starting OCR service...") initialize_models() # Run the app port = int(os.environ.get('PORT', 5000)) app.run(host='0.0.0.0', port=port, debug=False)