ORC / app.py
mike23415's picture
Update app.py
e4d75fe verified
raw
history blame
8.44 kB
from flask import Flask, request, jsonify
from flask_cors import CORS
import base64
import io
import os
from PIL import Image
import logging
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import easyocr
import numpy as np
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app)
# Global variables for models
trocr_processor = None
trocr_model = None
easyocr_reader = None
def initialize_models():
"""Initialize OCR models"""
global trocr_processor, trocr_model, easyocr_reader
try:
# Initialize TrOCR for handwritten text (Microsoft's model)
logger.info("Loading TrOCR model for handwritten text...")
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
# Initialize EasyOCR for printed text
logger.info("Loading EasyOCR for printed text...")
easyocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
logger.info("All models loaded successfully!")
except Exception as e:
logger.error(f"Error loading models: {str(e)}")
raise e
def preprocess_image(image):
"""Preprocess image for better OCR results"""
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize if image is too large
max_size = 1024
if max(image.size) > max_size:
ratio = max_size / max(image.size)
new_size = tuple(int(dim * ratio) for dim in image.size)
image = image.resize(new_size, Image.Resampling.LANCZOS)
return image
def extract_text_trocr(image):
"""Extract text using TrOCR (good for handwritten text)"""
try:
# Preprocess image
image = preprocess_image(image)
# Generate pixel values
pixel_values = trocr_processor(image, return_tensors="pt").pixel_values
# Generate text
generated_ids = trocr_model.generate(pixel_values)
generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text.strip()
except Exception as e:
logger.error(f"TrOCR error: {str(e)}")
return ""
def extract_text_easyocr(image):
"""Extract text using EasyOCR (good for printed text)"""
try:
# Convert PIL image to numpy array
image_np = np.array(preprocess_image(image))
# Extract text
results = easyocr_reader.readtext(image_np, detail=0)
# Join all detected text
extracted_text = ' '.join(results)
return extracted_text.strip()
except Exception as e:
logger.error(f"EasyOCR error: {str(e)}")
return ""
def process_image_ocr(image, ocr_type="auto"):
"""Process image with specified OCR method"""
results = {}
if ocr_type in ["auto", "handwritten", "trocr"]:
trocr_text = extract_text_trocr(image)
results["trocr"] = trocr_text
if ocr_type in ["auto", "printed", "easyocr"]:
easyocr_text = extract_text_easyocr(image)
results["easyocr"] = easyocr_text
# For auto mode, return the longer result or combine both
if ocr_type == "auto":
trocr_len = len(results.get("trocr", ""))
easyocr_len = len(results.get("easyocr", ""))
if trocr_len > 0 and easyocr_len > 0:
# If both have results, combine them intelligently
if abs(trocr_len - easyocr_len) / max(trocr_len, easyocr_len) < 0.3:
# If lengths are similar, prefer EasyOCR for printed text
results["final"] = results["easyocr"]
else:
# Use the longer result
results["final"] = results["trocr"] if trocr_len > easyocr_len else results["easyocr"]
elif trocr_len > 0:
results["final"] = results["trocr"]
elif easyocr_len > 0:
results["final"] = results["easyocr"]
else:
results["final"] = ""
else:
# Return the specific model result
results["final"] = results.get(ocr_type.replace("handwritten", "trocr").replace("printed", "easyocr"), "")
return results
@app.route('/health', methods=['GET'])
def health_check():
"""Health check endpoint"""
return jsonify({"status": "healthy", "models_loaded": True})
@app.route('/ocr', methods=['POST'])
def ocr_endpoint():
"""Main OCR endpoint"""
try:
# Check if image is provided
if 'image' not in request.files and 'image_base64' not in request.json:
return jsonify({"error": "No image provided"}), 400
# Get OCR type preference
ocr_type = request.form.get('type', 'auto') # auto, handwritten, printed
# Load image
if 'image' in request.files:
# File upload
image_file = request.files['image']
image = Image.open(image_file.stream)
else:
# Base64 image
image_data = request.json['image_base64']
if image_data.startswith('data:image'):
# Remove data URL prefix
image_data = image_data.split(',')[1]
# Decode base64
image_bytes = base64.b64decode(image_data)
image = Image.open(io.BytesIO(image_bytes))
# Process image
results = process_image_ocr(image, ocr_type)
response = {
"success": True,
"text": results["final"],
"type_used": ocr_type,
"details": {
"trocr_result": results.get("trocr", ""),
"easyocr_result": results.get("easyocr", "")
} if ocr_type == "auto" else {}
}
return jsonify(response)
except Exception as e:
logger.error(f"OCR processing error: {str(e)}")
return jsonify({"error": str(e), "success": False}), 500
@app.route('/ocr/batch', methods=['POST'])
def batch_ocr_endpoint():
"""Batch OCR endpoint for multiple images"""
try:
if 'images' not in request.files:
return jsonify({"error": "No images provided"}), 400
images = request.files.getlist('images')
ocr_type = request.form.get('type', 'auto')
results = []
for i, image_file in enumerate(images):
try:
image = Image.open(image_file.stream)
ocr_results = process_image_ocr(image, ocr_type)
results.append({
"index": i,
"filename": image_file.filename,
"text": ocr_results["final"],
"success": True
})
except Exception as e:
results.append({
"index": i,
"filename": image_file.filename,
"error": str(e),
"success": False
})
return jsonify({
"success": True,
"results": results,
"total_processed": len(results)
})
except Exception as e:
logger.error(f"Batch OCR error: {str(e)}")
return jsonify({"error": str(e), "success": False}), 500
@app.route('/models/info', methods=['GET'])
def models_info():
"""Get information about loaded models"""
return jsonify({
"models": {
"trocr": {
"name": "microsoft/trocr-base-handwritten",
"description": "Handwritten text recognition",
"loaded": trocr_model is not None
},
"easyocr": {
"name": "EasyOCR",
"description": "Printed text recognition",
"loaded": easyocr_reader is not None
}
},
"supported_types": ["auto", "handwritten", "printed"],
"supported_formats": ["PNG", "JPEG", "JPG", "BMP", "TIFF"]
})
if __name__ == '__main__':
# Initialize models on startup
logger.info("Starting OCR service...")
initialize_models()
# Run the app
port = int(os.environ.get('PORT', 5000))
app.run(host='0.0.0.0', port=port, debug=False)