optimised-ocr / app.py
Mallisetty Siva Mahesh
added some changes
13c6ddb
raw
history blame
14.2 kB
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict
import shutil
import torch
import logging
import os
# Set Google Application Credentials
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = (
"titanium-scope-436311-t3-966373f5aa2f.json"
)
from s3_setup import s3_client
import requests
from fastapi import FastAPI, HTTPException, Request
from transformers import LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from dotenv import load_dotenv
import urllib.parse
from utils import doc_processing, extract_document_number_from_file
# Load .env file
load_dotenv()
# Access variables
dummy_key = os.getenv("dummy_key")
HUGGINGFACE_AUTH_TOKEN = dummy_key
# Hugging Face model and token
aadhar_model = "AuditEdge/doc_ocr_a" # Replace with your fine-tuned model if applicable
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the processor (tokenizer + image processor)
processor_aadhar = LayoutLMv3Processor.from_pretrained(
aadhar_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
aadhar_model = LayoutLMv3ForTokenClassification.from_pretrained(
aadhar_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
aadhar_model = aadhar_model.to(device)
# pan model
pan_model = "AuditEdge/doc_ocr_p" # Replace with your fine-tuned model if applicable
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the processor (tokenizer + image processor)
processor_pan = LayoutLMv3Processor.from_pretrained(
pan_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
pan_model = LayoutLMv3ForTokenClassification.from_pretrained(
pan_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
pan_model = pan_model.to(device)
#
# gst model
gst_model = (
"AuditEdge/doc_ocr_new_g" # Replace with your fine-tuned model if applicable
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the processor (tokenizer + image processor)
processor_gst = LayoutLMv3Processor.from_pretrained(
gst_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
gst_model = LayoutLMv3ForTokenClassification.from_pretrained(
gst_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
gst_model = gst_model.to(device)
# cheque model
cheque_model = (
"AuditEdge/doc_ocr_new_c" # Replace with your fine-tuned model if applicable
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load the processor (tokenizer + image processor)
processor_cheque = LayoutLMv3Processor.from_pretrained(
cheque_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
cheque_model = LayoutLMv3ForTokenClassification.from_pretrained(
cheque_model, use_auth_token=HUGGINGFACE_AUTH_TOKEN
)
cheque_model = cheque_model.to(device)
# Verify model and processor are loaded
print("Model and processor loaded successfully!")
print(f"Model is on device: {next(aadhar_model.parameters()).device}")
# Import inference modules
from layoutlmv3FineTuning.Layoutlm_inference.ocr import prepare_batch_for_inference
from layoutlmv3FineTuning.Layoutlm_inference.inference_handler import handle
# Create FastAPI instance
app = FastAPI(debug=True)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Configure directories
UPLOAD_FOLDER = "./uploads/"
processing_folder = "./processed_images"
os.makedirs(UPLOAD_FOLDER, exist_ok=True) # Ensure the main upload folder exists
os.makedirs(processing_folder, exist_ok=True)
UPLOAD_DIRS = {
"pan_file": "uploads/pan/",
"aadhar_file": "uploads/aadhar/",
"gst_file": "uploads/gst/",
"msme_file": "uploads/msme/",
"cin_llpin_file": "uploads/cin_llpin/",
"cheque_file": "uploads/cheque/",
}
process_dirs = {
"aadhar_file": "processed_images/aadhar/",
"pan_file": "processed_images/pan/",
"cheque_file": "processed_images/cheque/",
"gst_file": "processed_images/gst/",
"msme_file": "processed_images/msme/",
"cin_llpin_file": "processed_images/cin_llpin/",
}
# Ensure individual directories exist
for dir_path in UPLOAD_DIRS.values():
os.makedirs(dir_path, exist_ok=True)
for dir_path in process_dirs.values():
os.makedirs(dir_path, exist_ok=True)
# Logger configuration
logging.basicConfig(level=logging.INFO)
def perform_inference(file_paths: Dict[str, str], upload_to_s3: bool):
model_dirs = {
"pan_file": pan_model,
"gst_file": gst_model,
"cheque_file": cheque_model,
}
try:
inference_results = {}
for doc_type, file_path in file_paths.items():
processed_file_p = file_path.split("&&")[
0
] # Extracted document number or processed image
unprocessed_file_path = file_path.split("&&")[1] # Original file path
print(f"Processing {doc_type}: {processed_file_p}")
# Determine the attachment number based on the document type
attachment_num = {
"pan_file": 2,
"gst_file": 4,
"msme_file": 5,
"cin_llpin_file": 6,
"cheque_file": 8,
}.get(doc_type, None)
if attachment_num is None:
print(f"Skipping {doc_type}, not recognized.")
continue
# Upload file to S3 if required
if upload_to_s3:
client = s3_client()
bucket_name = "edgekycdocs"
if doc_type == "cin_llpin":
folder_name = f"{doc_type.replace('_', '')}docs"
else:
folder_name = f"{doc_type.split('_')[0]}docs"
file_name = unprocessed_file_path.split("/")[-1].replace(" ", "_")
try:
response = client.upload_file(
unprocessed_file_path, bucket_name, folder_name, file_name
)
print("The file has been uploaded to S3 bucket", response)
attachment_url = response["url"]
print(f"File uploaded to S3: {attachment_url}")
except Exception as e:
print(f"Failed to upload {file_name} to S3: {e}")
attachment_url = None
else:
attachment_url = None
# If it's an OCR-based extraction (CIN, MSME, LLPIN, PAN, Aadhaar), return the extracted number
if doc_type in ["msme_file", "cin_llpin_file", "aadhar_file"]:
result = {
"attachment_num": processed_file_p, # Extracted CIN, LLPIN, MSME, PAN, or Aadhaar number
"attachment_url": attachment_url,
"attachment_status": 200,
"detect": True,
}
else:
# If the document needs ML model inference (PAN, GST, Cheque)
if doc_type in model_dirs:
print(
f"Running ML inference for {doc_type} using {model_dirs[doc_type]}"
)
images_path = [processed_file_p]
inference_batch = prepare_batch_for_inference(images_path)
context = model_dirs[doc_type]
processor = globals()[f"processor_{doc_type.split('_')[0]}"]
name = doc_type.split("_")[0]
result = handle(inference_batch, context, processor, name)
result["attachment_url"] = attachment_url
result["detect"] = True
else:
print(f"No model found for {doc_type}, skipping inference.")
continue
inference_results[f"attachment_{attachment_num}"] = result
return inference_results
except Exception as e:
print(f"Error in perform_inference: {e}")
return {"status": "error", "message": "Text extraction failed."}
# Routes
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/api/aadhar_ocr")
async def aadhar_ocr(
aadhar_file: UploadFile = File(None),
pan_file: UploadFile = File(None),
cheque_file: UploadFile = File(None),
gst_file: UploadFile = File(None),
msme_file: UploadFile = File(None),
cin_llpin_file: UploadFile = File(None),
upload_to_s3: bool = True,
):
# try:
# Handle file uploads
file_paths = {}
for file_type, folder in UPLOAD_DIRS.items():
file = locals()[file_type] # Dynamically access the file arguments
if file:
# Save the file in the respective directory
file_path = os.path.join(folder, file.filename)
print("this is the filename", file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
file_paths[file_type] = file_path
# Log received files
logging.info(f"Received files: {list(file_paths.keys())}")
print("file_paths", file_paths)
files = {}
for key, f_path in file_paths.items():
name = os.path.splitext(os.path.basename(f_path))[0]
# Determine id_type: for cin_llpin_file, explicitly set id_type to "cin_llpin"
if key == "cin_llpin_file":
id_type = "cin_llpin"
else:
id_type = key.split("_")[0]
doc_type = os.path.splitext(f_path)[-1].lstrip(".")
if key in ["msme_file", "cin_llpin_file", "aadhar_file"]:
extracted_number = extract_document_number_from_file(f_path, id_type)
if not extracted_number:
logging.error(f"Failed to extract document number from {f_path}")
raise HTTPException(
status_code=400, detail=f"Invalid document format in {key}"
)
files[key] = extracted_number + "&&" + f_path
print("files", files[key])
else:
# For other files, use existing preprocessing.
preprocessing = doc_processing(name, id_type, doc_type, f_path)
response = preprocessing.process()
files[key] = response["output_p"] + "&&" + f_path
# Perform inference
result = perform_inference(files, upload_to_s3)
print("this is the result we got", result)
if "status" in list(result.keys()):
raise Exception("Custom error message")
# if result["status"] == "error":
return {"status": "success", "result": result}
@app.post("/api/document_ocr")
async def document_ocr_s3(request: Request):
try:
body = await request.json() # Read JSON body
logging.info(f"Received request body: {body}")
except Exception as e:
logging.error(f"Failed to parse JSON request: {e}")
raise HTTPException(status_code=400, detail="Invalid JSON payload")
# Extract file URLs
url_mapping = {
"pan_file": body.get("pan_file"),
"gst_file": body.get("gst_file"),
"msme_file": body.get("msme_file"),
"cin_llpin_file": body.get("cin_llpin_file"),
"cheque_file": body.get("cheque_file"),
}
upload_to_s3 = body.get("upload_to_s3", False)
logging.info(f"URL Mapping: {url_mapping}")
file_paths = {}
for file_type, url in url_mapping.items():
if url:
# local_filename = url.split("/")[-1]
local_filename = urllib.parse.unquote(url.split("/")[-1]).replace(" ", "_")
file_path = os.path.join(UPLOAD_DIRS[file_type], local_filename)
try:
logging.info(f"Attempting to download {url} for {file_type}...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(file_path, "wb") as buffer:
shutil.copyfileobj(response.raw, buffer)
file_paths[file_type] = file_path
logging.info(f"Successfully downloaded {file_type} to {file_path}")
except requests.exceptions.RequestException as e:
logging.error(f"Failed to download {url}: {e}")
raise HTTPException(
status_code=400, detail=f"Failed to download file from {url}"
)
logging.info(f"Downloaded files: {list(file_paths.keys())}")
files = {}
for key, f_path in file_paths.items():
name = f_path.split("/")[-1].split(".")[0]
if key == "cin_llpin_file":
id_type = "cin_llpin"
else:
id_type = key.split("_")[0]
# id_type = key.split("_")[0]
doc_type = f_path.split("/")[-1].split(".")[-1]
# For MSME and CIN/LLPIN files, extract document number via OCR and regex
if key in ["msme_file", "cin_llpin_file", "aadhar_file"]:
extracted_number = extract_document_number_from_file(f_path, id_type)
if not extracted_number:
logging.error(f"Failed to extract document number from {f_path}")
raise HTTPException(
status_code=400, detail=f"Invalid document format in {key}"
)
files[key] = extracted_number + "&&" + f_path
else:
# For other documents, use the existing ML model preprocessing
preprocessing = doc_processing(name, id_type, doc_type, f_path)
response = preprocessing.process()
files[key] = response["output_p"] + "&&" + f_path
result = perform_inference(files, upload_to_s3)
if "status" in list(result.keys()):
raise HTTPException(status_code=500, detail="Custom error message")
return {"status": "success", "result": result}