|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
import base64 |
|
import io |
|
import os |
|
from PIL import Image |
|
import logging |
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
import torch |
|
import easyocr |
|
import numpy as np |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
trocr_processor = None |
|
trocr_model = None |
|
easyocr_reader = None |
|
|
|
def initialize_models(): |
|
"""Initialize OCR models""" |
|
global trocr_processor, trocr_model, easyocr_reader |
|
|
|
try: |
|
|
|
logger.info("Loading TrOCR model for handwritten text...") |
|
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") |
|
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") |
|
|
|
|
|
logger.info("Loading EasyOCR for printed text...") |
|
easyocr_reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available()) |
|
|
|
logger.info("All models loaded successfully!") |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading models: {str(e)}") |
|
raise e |
|
|
|
def preprocess_image(image): |
|
"""Preprocess image for better OCR results""" |
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
max_size = 1024 |
|
if max(image.size) > max_size: |
|
ratio = max_size / max(image.size) |
|
new_size = tuple(int(dim * ratio) for dim in image.size) |
|
image = image.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
return image |
|
|
|
def extract_text_trocr(image): |
|
"""Extract text using TrOCR (good for handwritten text)""" |
|
try: |
|
|
|
image = preprocess_image(image) |
|
|
|
|
|
pixel_values = trocr_processor(image, return_tensors="pt").pixel_values |
|
|
|
|
|
generated_ids = trocr_model.generate(pixel_values) |
|
generated_text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
return generated_text.strip() |
|
except Exception as e: |
|
logger.error(f"TrOCR error: {str(e)}") |
|
return "" |
|
|
|
def extract_text_easyocr(image): |
|
"""Extract text using EasyOCR (good for printed text)""" |
|
try: |
|
|
|
image_np = np.array(preprocess_image(image)) |
|
|
|
|
|
results = easyocr_reader.readtext(image_np, detail=0) |
|
|
|
|
|
extracted_text = ' '.join(results) |
|
return extracted_text.strip() |
|
except Exception as e: |
|
logger.error(f"EasyOCR error: {str(e)}") |
|
return "" |
|
|
|
def process_image_ocr(image, ocr_type="auto"): |
|
"""Process image with specified OCR method""" |
|
results = {} |
|
|
|
if ocr_type in ["auto", "handwritten", "trocr"]: |
|
trocr_text = extract_text_trocr(image) |
|
results["trocr"] = trocr_text |
|
|
|
if ocr_type in ["auto", "printed", "easyocr"]: |
|
easyocr_text = extract_text_easyocr(image) |
|
results["easyocr"] = easyocr_text |
|
|
|
|
|
if ocr_type == "auto": |
|
trocr_len = len(results.get("trocr", "")) |
|
easyocr_len = len(results.get("easyocr", "")) |
|
|
|
if trocr_len > 0 and easyocr_len > 0: |
|
|
|
if abs(trocr_len - easyocr_len) / max(trocr_len, easyocr_len) < 0.3: |
|
|
|
results["final"] = results["easyocr"] |
|
else: |
|
|
|
results["final"] = results["trocr"] if trocr_len > easyocr_len else results["easyocr"] |
|
elif trocr_len > 0: |
|
results["final"] = results["trocr"] |
|
elif easyocr_len > 0: |
|
results["final"] = results["easyocr"] |
|
else: |
|
results["final"] = "" |
|
else: |
|
|
|
results["final"] = results.get(ocr_type.replace("handwritten", "trocr").replace("printed", "easyocr"), "") |
|
|
|
return results |
|
|
|
@app.route('/health', methods=['GET']) |
|
def health_check(): |
|
"""Health check endpoint""" |
|
return jsonify({"status": "healthy", "models_loaded": True}) |
|
|
|
@app.route('/ocr', methods=['POST']) |
|
def ocr_endpoint(): |
|
"""Main OCR endpoint""" |
|
try: |
|
|
|
if 'image' not in request.files and 'image_base64' not in request.json: |
|
return jsonify({"error": "No image provided"}), 400 |
|
|
|
|
|
ocr_type = request.form.get('type', 'auto') |
|
|
|
|
|
if 'image' in request.files: |
|
|
|
image_file = request.files['image'] |
|
image = Image.open(image_file.stream) |
|
else: |
|
|
|
image_data = request.json['image_base64'] |
|
if image_data.startswith('data:image'): |
|
|
|
image_data = image_data.split(',')[1] |
|
|
|
|
|
image_bytes = base64.b64decode(image_data) |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
|
|
|
|
results = process_image_ocr(image, ocr_type) |
|
|
|
response = { |
|
"success": True, |
|
"text": results["final"], |
|
"type_used": ocr_type, |
|
"details": { |
|
"trocr_result": results.get("trocr", ""), |
|
"easyocr_result": results.get("easyocr", "") |
|
} if ocr_type == "auto" else {} |
|
} |
|
|
|
return jsonify(response) |
|
|
|
except Exception as e: |
|
logger.error(f"OCR processing error: {str(e)}") |
|
return jsonify({"error": str(e), "success": False}), 500 |
|
|
|
@app.route('/ocr/batch', methods=['POST']) |
|
def batch_ocr_endpoint(): |
|
"""Batch OCR endpoint for multiple images""" |
|
try: |
|
if 'images' not in request.files: |
|
return jsonify({"error": "No images provided"}), 400 |
|
|
|
images = request.files.getlist('images') |
|
ocr_type = request.form.get('type', 'auto') |
|
|
|
results = [] |
|
for i, image_file in enumerate(images): |
|
try: |
|
image = Image.open(image_file.stream) |
|
ocr_results = process_image_ocr(image, ocr_type) |
|
|
|
results.append({ |
|
"index": i, |
|
"filename": image_file.filename, |
|
"text": ocr_results["final"], |
|
"success": True |
|
}) |
|
except Exception as e: |
|
results.append({ |
|
"index": i, |
|
"filename": image_file.filename, |
|
"error": str(e), |
|
"success": False |
|
}) |
|
|
|
return jsonify({ |
|
"success": True, |
|
"results": results, |
|
"total_processed": len(results) |
|
}) |
|
|
|
except Exception as e: |
|
logger.error(f"Batch OCR error: {str(e)}") |
|
return jsonify({"error": str(e), "success": False}), 500 |
|
|
|
@app.route('/models/info', methods=['GET']) |
|
def models_info(): |
|
"""Get information about loaded models""" |
|
return jsonify({ |
|
"models": { |
|
"trocr": { |
|
"name": "microsoft/trocr-base-handwritten", |
|
"description": "Handwritten text recognition", |
|
"loaded": trocr_model is not None |
|
}, |
|
"easyocr": { |
|
"name": "EasyOCR", |
|
"description": "Printed text recognition", |
|
"loaded": easyocr_reader is not None |
|
} |
|
}, |
|
"supported_types": ["auto", "handwritten", "printed"], |
|
"supported_formats": ["PNG", "JPEG", "JPG", "BMP", "TIFF"] |
|
}) |
|
|
|
if __name__ == '__main__': |
|
|
|
logger.info("Starting OCR service...") |
|
initialize_models() |
|
|
|
|
|
port = int(os.environ.get('PORT', 5000)) |
|
app.run(host='0.0.0.0', port=port, debug=False) |