import os import logging from PIL import Image, ImageDraw import traceback import torch from docquery import pipeline from docquery.document import load_bytes, load_document, ImageDocument from docquery.ocr_reader import get_ocr_reader from pdf2image import convert_from_path os.environ["TOKENIZERS_PARALLELISM"] = "false" # Initialize the logger logging.basicConfig(filename="invoice_extraction.log", level=logging.DEBUG) # Create a log file # Checkpoint for different models CHECKPOINTS = { "LayoutLMv1 for Invoices 🧾": "impira/layoutlm-invoices", } PIPELINES = {} class InvoiceKeyValuePair(): """ This class provides a utility to extract key-value pairs from invoices using LayoutLM. """ def __init__(self): self.fields = { "Vendor Name": ["Vendor Name - Logo?", "Vendor Name - Address?"], "Vendor Address": ["Vendor Address?"], "Customer Name": ["Customer Name?"], "Customer Address": ["Customer Address?"], "Invoice Number": ["Invoice Number?"], "Invoice Date": ["Invoice Date?"], "Due Date": ["Due Date?"], "Subtotal": ["Subtotal?"], "Total Tax": ["Total Tax?"], "Invoice Total": ["Invoice Total?"], "Amount Due": ["Amount Due?"], "Payment Terms": ["Payment Terms?"], "Remit To Name": ["Remit To Name?"], "Remit To Address": ["Remit To Address?"], } self.model = list(CHECKPOINTS.keys())[0] def ensure_list(self, x): try: # Log the function entry logging.info(f'Entering ensure_list with x={x}') # Check if 'x' is already a list if isinstance(x, list): return x else: # If 'x' is not a list, wrap it in a list and return return [x] except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return [] def construct_pipeline(self, task, model): try: # Log the function entry logging.info(f'Entering construct_pipeline with task={task} and model={model}') # Global dictionary to cache pipelines based on model checkpoint names global PIPELINES # Check if a pipeline for the specified model already exists in the cache if model in PIPELINES: # If it exists, return the cached pipeline return PIPELINES[model] try: # Determine the device to use for inference (GPU if available, else CPU) device = "cuda" if torch.cuda.is_available() else "cpu" # Create the pipeline using the specified task and model checkpoint ret = pipeline(task=task, model=CHECKPOINTS[model], device=device) # Cache the created pipeline for future use PIPELINES[model] = ret # Return the constructed pipeline return ret except Exception as e: # Handle exceptions and log the error message logging.error("An error occurred:", exc_info=True) return None except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return None def run_pipeline(self, model, question, document, top_k): try: # Log the function entry logging.info(f'Entering run_pipeline with model={model}, question={question}, and document={document}') # Use the construct_pipeline method to get or create a pipeline for the specified model pipeline = self.construct_pipeline("document-question-answering", model) # Use the constructed pipeline to perform question-answering on the document # Pass the question, document context, and top_k as arguments to the pipeline return pipeline(question=question, **document.context, top_k=top_k) except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return None def lift_word_boxes(self, document, page): try: # Log the function entry logging.info(f'Entering lift_word_boxes with document={document} and page={page}') # Extract the word boxes for the specified page from the document's context return document.context["image"][page][1] except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return [] def expand_bbox(self, word_boxes): try: # Log the function entry logging.info(f'Entering expand_bbox with word_boxes={word_boxes}') # Check if the input list of word boxes is empty if len(word_boxes) == 0: return None # Extract the minimum and maximum coordinates of the word boxes min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes]) # Calculate the overall minimum and maximum coordinates min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)] # Return the expanded bounding box as [min_x, min_y, max_x, max_y] return [min_x, min_y, max_x, max_y] except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return None def normalize_bbox(self, box, width, height, padding=0.005): try: # Log the function entry logging.info(f'Entering normalize_bbox with box={box}, width={width}, height={height}, and padding={padding}') # Extract the bounding box coordinates and convert them from millimeters to fractions min_x, min_y, max_x, max_y = [c / 1000 for c in box] # Apply padding if specified (as a fraction of image dimensions) if padding != 0: min_x = max(0, min_x - padding) min_y = max(0, min_y - padding) max_x = min(max_x + padding, 1) max_y = min(max_y + padding, 1) # Scale the normalized coordinates to match the image dimensions return [min_x * width, min_y * height, max_x * width, max_y * height] except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return None def annotate_page(self, prediction, pages, document): try: # Log the function entry logging.info(f'Entering annotate_page with prediction={prediction}, pages={pages}, and document={document}') # Check if a prediction exists and contains word_ids if prediction is not None and "word_ids" in prediction: # Get the image of the page where the prediction was made image = pages[prediction["page"]] # Create a drawing object for the image draw = ImageDraw.Draw(image, "RGBA") # Extract word boxes for the page word_boxes = self.lift_word_boxes(document, prediction["page"]) # Expand and normalize the bounding box of the predicted words x1, y1, x2, y2 = self.normalize_bbox( self.expand_bbox([word_boxes[i] for i in prediction["word_ids"]]), image.width, image.height, ) # Draw a semi-transparent green rectangle around the predicted words draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255))) except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) def process_fields(self, document, fields, model=list(CHECKPOINTS.keys())[0]): try: # Log the function entry logging.info(f'Entering process_fields with document={document}, fields={fields}, and model={model}') # Convert preview pages of the document to RGB format pages = [x.copy().convert("RGB") for x in document.preview] # Initialize dictionaries to store results ret = {} table = [] # Iterate through the fields and associated questions for (field_name, questions) in fields.items(): # Extract answers for each question and filter based on score answers = [ a for q in questions for a in self.ensure_list(self.run_pipeline(model, q, document, top_k=1)) if a.get("score", 1) > 0.5 ] # Sort answers by score (higher score first) answers.sort(key=lambda x: -x.get("score", 0) if x else 0) # Get the top answer (if any) top = answers[0] if len(answers) > 0 else None # Annotate the page with the top answer's bounding box self.annotate_page(top, pages, document) # Store the top answer for the field and add it to the table ret[field_name] = top table.append([field_name, top.get("answer") if top is not None else None]) # Return the table of key-value pairs return table except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return [] def process_document(self, document, fields, model, error=None): try: # Log the function entry logging.info(f'Entering process_document with document={document}, fields={fields}, model={model}, and error={error}') # Check if the document is not None and no error occurred during processing if document is not None and error is None: # Process the fields in the document using the specified model table = self.process_fields(document, fields, model) return table except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return [] def process_path(self, path, fields, model): try: # Log the function entry logging.info(f'Entering process_path with path={path}, fields={fields}, and model={model}') # Initialize error and document variables error = None document = None # Check if a file path is provided if path: try: # Load the document from the specified file path document = load_document(path) except Exception as e: # Handle exceptions and store the error message logging.error("An error occurred:", exc_info=True) error = str(e) # Process the loaded document and extract key-value pairs return self.process_document(document, fields, model, error) except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return [] def pdf_to_image(self, file_path): try: # Log the function entry logging.info(f'Entering pdf_to_image with file_path={file_path}') # Convert PDF to a list of image objects (one for each page) images = convert_from_path(file_path) # Loop through each image and save it for i, image in enumerate(images): image_path = f'page_{i + 1}.png' return image_path except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return [] def process_upload(self, file): try: # Log the function entry logging.info(f'Entering process_upload with file={file}') # Get the model and fields from the instance model = self.model fields = self.fields # Convert the uploaded PDF file to a list of image files image = self.pdf_to_image(file) # Use the first generated image file as the file path for processing file = image # Process the document (image) and extract key-value pairs return self.process_path(file if file else None, fields, model) except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True) return [] def extract_key_value_pair(self, invoice_file): try: # Log the function entry logging.info(f'Entering extract_key_value_pair with invoice_file={invoice_file}') # Process the uploaded invoice PDF file and extract key-value pairs data = self.process_upload(invoice_file.name) # Iterate through the extracted key-value pairs and print them for item in data: key, value = item return f'{key}: {value}' except Exception as e: # Log exceptions logging.error("An error occurred:", exc_info=True)