optimised-ocr / app.py
Mallisetty Siva Mahesh
added some changes
13c6ddb
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict
import shutil
import torch
import logging
import os
# Set Google Application Credentials
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = (
"titanium-scope-436311-t3-966373f5aa2f.json"
)
from s3_setup import s3_client
import requests
from fastapi import FastAPI, HTTPException, Request
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from dotenv import load_dotenv
import urllib.parse
from utils import doc_processing, extract_document_number_from_file
# Load .env file
load_dotenv()
# Access variables
dummy_key = os.getenv("dummy_key")
HUGGINGFACE_AUTH_TOKEN = dummy_key
# Hugging Face model and token
aadhar_model = "AuditEdge/doc_ocr_a" # Replace with your fine-tuned model if applicable
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the processor (tokenizer + image processor)
processor_aadhar = LayoutLMv3Processor.from_pretrained(
aadhar_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
aadhar_model = LayoutLMv3ForTokenClassification.from_pretrained(
aadhar_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
aadhar_model = aadhar_model.to(device)
# pan model
pan_model = "AuditEdge/doc_ocr_p" # Replace with your fine-tuned model if applicable
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the processor (tokenizer + image processor)
processor_pan = LayoutLMv3Processor.from_pretrained(
pan_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
pan_model = LayoutLMv3ForTokenClassification.from_pretrained(
pan_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
pan_model = pan_model.to(device)
#
# gst model
gst_model = (
"AuditEdge/doc_ocr_new_g" # Replace with your fine-tuned model if applicable
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the processor (tokenizer + image processor)
processor_gst = LayoutLMv3Processor.from_pretrained(
gst_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
gst_model = LayoutLMv3ForTokenClassification.from_pretrained(
gst_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
gst_model = gst_model.to(device)
# cheque model
cheque_model = (
"AuditEdge/doc_ocr_new_c" # Replace with your fine-tuned model if applicable
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the processor (tokenizer + image processor)
processor_cheque = LayoutLMv3Processor.from_pretrained(
cheque_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
cheque_model = LayoutLMv3ForTokenClassification.from_pretrained(
cheque_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
cheque_model = cheque_model.to(device)
# Verify model and processor are loaded
print("Model and processor loaded successfully!")
print(f"Model is on device: {next(aadhar_model.parameters()).device}")
# Import inference modules
from layoutlmv3FineTuning.Layoutlm_inference.ocr import prepare_batch_for_inference
from layoutlmv3FineTuning.Layoutlm_inference.inference_handler import handle
# Create FastAPI instance
app = FastAPI(debug=True)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configure directories
UPLOAD_FOLDER = "./uploads/"
processing_folder = "./processed_images"
os.makedirs(UPLOAD_FOLDER, exist_ok=True) # Ensure the main upload folder exists
os.makedirs(processing_folder, exist_ok=True)
UPLOAD_DIRS = {
"pan_file": "uploads/pan/",
"aadhar_file": "uploads/aadhar/",
"gst_file": "uploads/gst/",
"msme_file": "uploads/msme/",
"cin_llpin_file": "uploads/cin_llpin/",
"cheque_file": "uploads/cheque/",
}
process_dirs = {
"aadhar_file": "processed_images/aadhar/",
"pan_file": "processed_images/pan/",
"cheque_file": "processed_images/cheque/",
"gst_file": "processed_images/gst/",
"msme_file": "processed_images/msme/",
"cin_llpin_file": "processed_images/cin_llpin/",
}
# Ensure individual directories exist
for dir_path in UPLOAD_DIRS.values():
os.makedirs(dir_path, exist_ok=True)
for dir_path in process_dirs.values():
os.makedirs(dir_path, exist_ok=True)
# Logger configuration
logging.basicConfig(level=logging.INFO)
def perform_inference(file_paths: Dict[str, str], upload_to_s3: bool):
model_dirs = {
"pan_file": pan_model,
"gst_file": gst_model,
"cheque_file": cheque_model,
}
try:
inference_results = {}
for doc_type, file_path in file_paths.items():
processed_file_p = file_path.split("&&")[
0
] # Extracted document number or processed image
unprocessed_file_path = file_path.split("&&")[1] # Original file path
print(f"Processing {doc_type}: {processed_file_p}")
# Determine the attachment number based on the document type
attachment_num = {
"pan_file": 2,
"gst_file": 4,
"msme_file": 5,
"cin_llpin_file": 6,
"cheque_file": 8,
}.get(doc_type, None)
if attachment_num is None:
print(f"Skipping {doc_type}, not recognized.")
continue
# Upload file to S3 if required
if upload_to_s3:
client = s3_client()
bucket_name = "edgekycdocs"
if doc_type == "cin_llpin":
folder_name = f"{doc_type.replace('_', '')}docs"
else:
folder_name = f"{doc_type.split('_')[0]}docs"
file_name = unprocessed_file_path.split("/")[-1].replace(" ", "_")
try:
response = client.upload_file(
unprocessed_file_path, bucket_name, folder_name, file_name
)
print("The file has been uploaded to S3 bucket", response)
attachment_url = response["url"]
print(f"File uploaded to S3: {attachment_url}")
except Exception as e:
print(f"Failed to upload {file_name} to S3: {e}")
attachment_url = None
else:
attachment_url = None
# If it's an OCR-based extraction (CIN, MSME, LLPIN, PAN, Aadhaar), return the extracted number
if doc_type in ["msme_file", "cin_llpin_file", "aadhar_file"]:
result = {
"attachment_num": processed_file_p, # Extracted CIN, LLPIN, MSME, PAN, or Aadhaar number
"attachment_url": attachment_url,
"attachment_status": 200,
"detect": True,
}
else:
# If the document needs ML model inference (PAN, GST, Cheque)
if doc_type in model_dirs:
print(
f"Running ML inference for {doc_type} using {model_dirs[doc_type]}"
)
images_path = [processed_file_p]
inference_batch = prepare_batch_for_inference(images_path)
context = model_dirs[doc_type]
processor = globals()[f"processor_{doc_type.split('_')[0]}"]
name = doc_type.split("_")[0]
result = handle(inference_batch, context, processor, name)
result["attachment_url"] = attachment_url
result["detect"] = True
else:
print(f"No model found for {doc_type}, skipping inference.")
continue
inference_results[f"attachment_{attachment_num}"] = result
return inference_results
except Exception as e:
print(f"Error in perform_inference: {e}")
return {"status": "error", "message": "Text extraction failed."}
# Routes
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/api/aadhar_ocr")
async def aadhar_ocr(
aadhar_file: UploadFile = File(None),
pan_file: UploadFile = File(None),
cheque_file: UploadFile = File(None),
gst_file: UploadFile = File(None),
msme_file: UploadFile = File(None),
cin_llpin_file: UploadFile = File(None),
upload_to_s3: bool = True,
):
# try:
# Handle file uploads
file_paths = {}
for file_type, folder in UPLOAD_DIRS.items():
file = locals()[file_type] # Dynamically access the file arguments
if file:
# Save the file in the respective directory
file_path = os.path.join(folder, file.filename)
print("this is the filename", file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
file_paths[file_type] = file_path
# Log received files
logging.info(f"Received files: {list(file_paths.keys())}")
print("file_paths", file_paths)
files = {}
for key, f_path in file_paths.items():
name = os.path.splitext(os.path.basename(f_path))[0]
# Determine id_type: for cin_llpin_file, explicitly set id_type to "cin_llpin"
if key == "cin_llpin_file":
id_type = "cin_llpin"
else:
id_type = key.split("_")[0]
doc_type = os.path.splitext(f_path)[-1].lstrip(".")
if key in ["msme_file", "cin_llpin_file", "aadhar_file"]:
extracted_number = extract_document_number_from_file(f_path, id_type)
if not extracted_number:
logging.error(f"Failed to extract document number from {f_path}")
raise HTTPException(
status_code=400, detail=f"Invalid document format in {key}"
)
files[key] = extracted_number + "&&" + f_path
print("files", files[key])
else:
# For other files, use existing preprocessing.
preprocessing = doc_processing(name, id_type, doc_type, f_path)
response = preprocessing.process()
files[key] = response["output_p"] + "&&" + f_path
# Perform inference
result = perform_inference(files, upload_to_s3)
print("this is the result we got", result)
if "status" in list(result.keys()):
raise Exception("Custom error message")
# if result["status"] == "error":
return {"status": "success", "result": result}
@app.post("/api/document_ocr")
async def document_ocr_s3(request: Request):
try:
body = await request.json() # Read JSON body
logging.info(f"Received request body: {body}")
except Exception as e:
logging.error(f"Failed to parse JSON request: {e}")
raise HTTPException(status_code=400, detail="Invalid JSON payload")
# Extract file URLs
url_mapping = {
"pan_file": body.get("pan_file"),
"gst_file": body.get("gst_file"),
"msme_file": body.get("msme_file"),
"cin_llpin_file": body.get("cin_llpin_file"),
"cheque_file": body.get("cheque_file"),
}
upload_to_s3 = body.get("upload_to_s3", False)
logging.info(f"URL Mapping: {url_mapping}")
file_paths = {}
for file_type, url in url_mapping.items():
if url:
# local_filename = url.split("/")[-1]
local_filename = urllib.parse.unquote(url.split("/")[-1]).replace(" ", "_")
file_path = os.path.join(UPLOAD_DIRS[file_type], local_filename)
try:
logging.info(f"Attempting to download {url} for {file_type}...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(file_path, "wb") as buffer:
shutil.copyfileobj(response.raw, buffer)
file_paths[file_type] = file_path
logging.info(f"Successfully downloaded {file_type} to {file_path}")
except requests.exceptions.RequestException as e:
logging.error(f"Failed to download {url}: {e}")
raise HTTPException(
status_code=400, detail=f"Failed to download file from {url}"
)
logging.info(f"Downloaded files: {list(file_paths.keys())}")
files = {}
for key, f_path in file_paths.items():
name = f_path.split("/")[-1].split(".")[0]
if key == "cin_llpin_file":
id_type = "cin_llpin"
else:
id_type = key.split("_")[0]
# id_type = key.split("_")[0]
doc_type = f_path.split("/")[-1].split(".")[-1]
# For MSME and CIN/LLPIN files, extract document number via OCR and regex
if key in ["msme_file", "cin_llpin_file", "aadhar_file"]:
extracted_number = extract_document_number_from_file(f_path, id_type)
if not extracted_number:
logging.error(f"Failed to extract document number from {f_path}")
raise HTTPException(
status_code=400, detail=f"Invalid document format in {key}"
)
files[key] = extracted_number + "&&" + f_path
else:
# For other documents, use the existing ML model preprocessing
preprocessing = doc_processing(name, id_type, doc_type, f_path)
response = preprocessing.process()
files[key] = response["output_p"] + "&&" + f_path
result = perform_inference(files, upload_to_s3)
if "status" in list(result.keys()):
raise HTTPException(status_code=500, detail="Custom error message")
return {"status": "success", "result": result}