Contract_Management / invoice_extractor.py
naveenvenkatesh's picture
Upload 5 files
4ec3e55
raw
history blame
13.5 kB
import os
import logging
from PIL import Image, ImageDraw
import traceback
import torch
from docquery import pipeline
from docquery.document import load_bytes, load_document, ImageDocument
from docquery.ocr_reader import get_ocr_reader
from pdf2image import convert_from_path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Initialize the logger
logging.basicConfig(filename="invoice_extraction.log", level=logging.DEBUG) # Create a log file
# Checkpoint for different models
CHECKPOINTS = {
"LayoutLMv1 for Invoices 🧾": "impira/layoutlm-invoices",
}
PIPELINES = {}
class InvoiceKeyValuePair():
"""
This class provides a utility to extract key-value pairs from invoices using LayoutLM.
"""
def __init__(self):
self.fields = {
"Vendor Name": ["Vendor Name - Logo?", "Vendor Name - Address?"],
"Vendor Address": ["Vendor Address?"],
"Customer Name": ["Customer Name?"],
"Customer Address": ["Customer Address?"],
"Invoice Number": ["Invoice Number?"],
"Invoice Date": ["Invoice Date?"],
"Due Date": ["Due Date?"],
"Subtotal": ["Subtotal?"],
"Total Tax": ["Total Tax?"],
"Invoice Total": ["Invoice Total?"],
"Amount Due": ["Amount Due?"],
"Payment Terms": ["Payment Terms?"],
"Remit To Name": ["Remit To Name?"],
"Remit To Address": ["Remit To Address?"],
}
self.model = list(CHECKPOINTS.keys())[0]
def ensure_list(self, x):
try:
# Log the function entry
logging.info(f'Entering ensure_list with x={x}')
# Check if 'x' is already a list
if isinstance(x, list):
return x
else:
# If 'x' is not a list, wrap it in a list and return
return [x]
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def construct_pipeline(self, task, model):
try:
# Log the function entry
logging.info(f'Entering construct_pipeline with task={task} and model={model}')
# Global dictionary to cache pipelines based on model checkpoint names
global PIPELINES
# Check if a pipeline for the specified model already exists in the cache
if model in PIPELINES:
# If it exists, return the cached pipeline
return PIPELINES[model]
try:
# Determine the device to use for inference (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create the pipeline using the specified task and model checkpoint
ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
# Cache the created pipeline for future use
PIPELINES[model] = ret
# Return the constructed pipeline
return ret
except Exception as e:
# Handle exceptions and log the error message
logging.error("An error occurred:", exc_info=True)
return None
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return None
def run_pipeline(self, model, question, document, top_k):
try:
# Log the function entry
logging.info(f'Entering run_pipeline with model={model}, question={question}, and document={document}')
# Use the construct_pipeline method to get or create a pipeline for the specified model
pipeline = self.construct_pipeline("document-question-answering", model)
# Use the constructed pipeline to perform question-answering on the document
# Pass the question, document context, and top_k as arguments to the pipeline
return pipeline(question=question, **document.context, top_k=top_k)
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return None
def lift_word_boxes(self, document, page):
try:
# Log the function entry
logging.info(f'Entering lift_word_boxes with document={document} and page={page}')
# Extract the word boxes for the specified page from the document's context
return document.context["image"][page][1]
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def expand_bbox(self, word_boxes):
try:
# Log the function entry
logging.info(f'Entering expand_bbox with word_boxes={word_boxes}')
# Check if the input list of word boxes is empty
if len(word_boxes) == 0:
return None
# Extract the minimum and maximum coordinates of the word boxes
min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
# Calculate the overall minimum and maximum coordinates
min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
# Return the expanded bounding box as [min_x, min_y, max_x, max_y]
return [min_x, min_y, max_x, max_y]
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return None
def normalize_bbox(self, box, width, height, padding=0.005):
try:
# Log the function entry
logging.info(f'Entering normalize_bbox with box={box}, width={width}, height={height}, and padding={padding}')
# Extract the bounding box coordinates and convert them from millimeters to fractions
min_x, min_y, max_x, max_y = [c / 1000 for c in box]
# Apply padding if specified (as a fraction of image dimensions)
if padding != 0:
min_x = max(0, min_x - padding)
min_y = max(0, min_y - padding)
max_x = min(max_x + padding, 1)
max_y = min(max_y + padding, 1)
# Scale the normalized coordinates to match the image dimensions
return [min_x * width, min_y * height, max_x * width, max_y * height]
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return None
def annotate_page(self, prediction, pages, document):
try:
# Log the function entry
logging.info(f'Entering annotate_page with prediction={prediction}, pages={pages}, and document={document}')
# Check if a prediction exists and contains word_ids
if prediction is not None and "word_ids" in prediction:
# Get the image of the page where the prediction was made
image = pages[prediction["page"]]
# Create a drawing object for the image
draw = ImageDraw.Draw(image, "RGBA")
# Extract word boxes for the page
word_boxes = self.lift_word_boxes(document, prediction["page"])
# Expand and normalize the bounding box of the predicted words
x1, y1, x2, y2 = self.normalize_bbox(
self.expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
image.width,
image.height,
)
# Draw a semi-transparent green rectangle around the predicted words
draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
def process_fields(self, document, fields, model=list(CHECKPOINTS.keys())[0]):
try:
# Log the function entry
logging.info(f'Entering process_fields with document={document}, fields={fields}, and model={model}')
# Convert preview pages of the document to RGB format
pages = [x.copy().convert("RGB") for x in document.preview]
# Initialize dictionaries to store results
ret = {}
table = []
# Iterate through the fields and associated questions
for (field_name, questions) in fields.items():
# Extract answers for each question and filter based on score
answers = [
a
for q in questions
for a in self.ensure_list(self.run_pipeline(model, q, document, top_k=1))
if a.get("score", 1) > 0.5
]
# Sort answers by score (higher score first)
answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
# Get the top answer (if any)
top = answers[0] if len(answers) > 0 else None
# Annotate the page with the top answer's bounding box
self.annotate_page(top, pages, document)
# Store the top answer for the field and add it to the table
ret[field_name] = top
table.append([field_name, top.get("answer") if top is not None else None])
# Return the table of key-value pairs
return table
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def process_document(self, document, fields, model, error=None):
try:
# Log the function entry
logging.info(f'Entering process_document with document={document}, fields={fields}, model={model}, and error={error}')
# Check if the document is not None and no error occurred during processing
if document is not None and error is None:
# Process the fields in the document using the specified model
table = self.process_fields(document, fields, model)
return table
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def process_path(self, path, fields, model):
try:
# Log the function entry
logging.info(f'Entering process_path with path={path}, fields={fields}, and model={model}')
# Initialize error and document variables
error = None
document = None
# Check if a file path is provided
if path:
try:
# Load the document from the specified file path
document = load_document(path)
except Exception as e:
# Handle exceptions and store the error message
logging.error("An error occurred:", exc_info=True)
error = str(e)
# Process the loaded document and extract key-value pairs
return self.process_document(document, fields, model, error)
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def pdf_to_image(self, file_path):
try:
# Log the function entry
logging.info(f'Entering pdf_to_image with file_path={file_path}')
# Convert PDF to a list of image objects (one for each page)
images = convert_from_path(file_path)
# Loop through each image and save it
for i, image in enumerate(images):
image_path = f'page_{i + 1}.png'
return image_path
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def process_upload(self, file):
try:
# Log the function entry
logging.info(f'Entering process_upload with file={file}')
# Get the model and fields from the instance
model = self.model
fields = self.fields
# Convert the uploaded PDF file to a list of image files
image = self.pdf_to_image(file)
# Use the first generated image file as the file path for processing
file = image
# Process the document (image) and extract key-value pairs
return self.process_path(file if file else None, fields, model)
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def extract_key_value_pair(self, invoice_file):
try:
# Log the function entry
logging.info(f'Entering extract_key_value_pair with invoice_file={invoice_file}')
# Process the uploaded invoice PDF file and extract key-value pairs
data = self.process_upload(invoice_file.name)
# Iterate through the extracted key-value pairs and print them
for item in data:
key, value = item
return f'{key}: {value}'
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)