document-ocr-ai / app.py
mayank-youdataai's picture
Upload 2 files
74f6a97 verified
import os
import json
import tempfile
from fastapi import FastAPI, UploadFile, File, HTTPException
from paddleocr import PPStructure
import logging
import paddle
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI app
app = FastAPI()
# Global variable for OCR engine
ocr_engine = None
# Function to initialize the PaddleOCR engine based on GPU availability
def init_ocr_engine():
global ocr_engine
if ocr_engine is None:
use_gpu = is_gpu_available()
if use_gpu:
logger.info("NVIDIA GPU detected, running PaddleOCR on GPU.")
else:
logger.info("No GPU detected, running PaddleOCR on CPU.")
# Initialize the OCR engine with the use_gpu variable
ocr_engine = PPStructure(
table=True,
ocr=True,
show_log=True,
layout_score_threshold=0.1,
structure_version='PP-StructureV2',
use_gpu=use_gpu
)
return ocr_engine
# Function to check for GPU availability using Paddle
def is_gpu_available():
# Check if PaddlePaddle is compiled with CUDA and if a CUDA device is available
return paddle.is_compiled_with_cuda() and paddle.device.cuda.device_count() > 0
# Function to perform OCR and save the structured result
def perform_ocr_and_save(pdf_path, save_folder='./output'):
# Initialize PaddleOCR engine
ocr_engine = init_ocr_engine()
# Directly pass the PDF to PaddleOCR
result = ocr_engine(pdf_path)
if not result:
logger.error(f"OCR failed for {pdf_path}")
return result
# Function to format results to strings and sort them
def format_to_strings_and_sort(results):
logger.info("Formatting and sorting OCR results.")
formatted_data = []
for idx, elements in enumerate(results):
for element in elements:
type = element['type']
bbox = element['bbox']
responses = element['res']
if type != 'table':
for response in responses:
y_coordinate = bbox[1] # Use y1 coordinate for sorting
formatted_data.append({
'page_num': idx + 1,
'type': type,
'text': response['text'],
'confidence': response['confidence'],
'bbox': bbox,
'y_coordinate': y_coordinate # Add y-coordinate for sorting
})
else:
formatted_data.append({
'page_num': idx + 1,
'type': type,
'html': responses['html'],
'bbox': bbox,
'y_coordinate': bbox[1] # Use bbox y1 for sorting
})
sorted_data = sorted(formatted_data, key=lambda x: (x['page_num'], x['y_coordinate']))
logger.info("Sorting completed.")
return sorted_data
# Function to save results to a JSON file
def save_to_json(data, filename):
logger.info(f"Saving sorted results to {filename}.")
with open(filename, "w") as json_file:
json.dump(data, json_file, indent=4)
# FastAPI endpoint to process uploaded PDF
@app.post("/process-ocr/")
async def process_ocr(file: UploadFile = File(...)):
try:
# Validate file type
if file.content_type != "application/pdf":
logger.warning(f"Invalid file type uploaded: {file.content_type}")
raise HTTPException(status_code=400, detail="Invalid file type. Please upload a PDF file.")
# Create a temporary file to store the uploaded PDF
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
contents = await file.read()
temp_file.write(contents)
temp_file_path = temp_file.name
logger.info(f"Temporary file created at: {temp_file_path}")
# Perform OCR and save results
result = perform_ocr_and_save(temp_file_path)
if result is None:
raise HTTPException(status_code=500, detail="OCR processing failed. Check the input file.")
# Sort and format the results
result_json = format_to_strings_and_sort(result)
# Optionally, save the result JSON to a file (for debugging)
save_to_json(result_json, 'result_json.json')
# Return sorted result as JSON
return result_json
except Exception as e:
logger.error(f"An error occurred during OCR processing: {e}")
raise HTTPException(status_code=500, detail="An error occurred during OCR processing.")
finally:
# Clean up the temporary file
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
logger.info(f"Temporary file {temp_file_path} deleted.")
# Endpoint to check if GPU is available
@app.get("/check-gpu/")
def check_gpu():
if is_gpu_available():
return {"gpu_available": True, "message": "NVIDIA GPU is available and will be used."}
else:
return {"gpu_available": False, "message": "NVIDIA GPU is not available, using CPU instead."}