Spaces:
Sleeping
Sleeping
File size: 13,545 Bytes
4ec3e55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 |
import os
import logging
from PIL import Image, ImageDraw
import traceback
import torch
from docquery import pipeline
from docquery.document import load_bytes, load_document, ImageDocument
from docquery.ocr_reader import get_ocr_reader
from pdf2image import convert_from_path
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Initialize the logger
logging.basicConfig(filename="invoice_extraction.log", level=logging.DEBUG) # Create a log file
# Checkpoint for different models
CHECKPOINTS = {
"LayoutLMv1 for Invoices 🧾": "impira/layoutlm-invoices",
}
PIPELINES = {}
class InvoiceKeyValuePair():
"""
This class provides a utility to extract key-value pairs from invoices using LayoutLM.
"""
def __init__(self):
self.fields = {
"Vendor Name": ["Vendor Name - Logo?", "Vendor Name - Address?"],
"Vendor Address": ["Vendor Address?"],
"Customer Name": ["Customer Name?"],
"Customer Address": ["Customer Address?"],
"Invoice Number": ["Invoice Number?"],
"Invoice Date": ["Invoice Date?"],
"Due Date": ["Due Date?"],
"Subtotal": ["Subtotal?"],
"Total Tax": ["Total Tax?"],
"Invoice Total": ["Invoice Total?"],
"Amount Due": ["Amount Due?"],
"Payment Terms": ["Payment Terms?"],
"Remit To Name": ["Remit To Name?"],
"Remit To Address": ["Remit To Address?"],
}
self.model = list(CHECKPOINTS.keys())[0]
def ensure_list(self, x):
try:
# Log the function entry
logging.info(f'Entering ensure_list with x={x}')
# Check if 'x' is already a list
if isinstance(x, list):
return x
else:
# If 'x' is not a list, wrap it in a list and return
return [x]
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def construct_pipeline(self, task, model):
try:
# Log the function entry
logging.info(f'Entering construct_pipeline with task={task} and model={model}')
# Global dictionary to cache pipelines based on model checkpoint names
global PIPELINES
# Check if a pipeline for the specified model already exists in the cache
if model in PIPELINES:
# If it exists, return the cached pipeline
return PIPELINES[model]
try:
# Determine the device to use for inference (GPU if available, else CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create the pipeline using the specified task and model checkpoint
ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
# Cache the created pipeline for future use
PIPELINES[model] = ret
# Return the constructed pipeline
return ret
except Exception as e:
# Handle exceptions and log the error message
logging.error("An error occurred:", exc_info=True)
return None
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return None
def run_pipeline(self, model, question, document, top_k):
try:
# Log the function entry
logging.info(f'Entering run_pipeline with model={model}, question={question}, and document={document}')
# Use the construct_pipeline method to get or create a pipeline for the specified model
pipeline = self.construct_pipeline("document-question-answering", model)
# Use the constructed pipeline to perform question-answering on the document
# Pass the question, document context, and top_k as arguments to the pipeline
return pipeline(question=question, **document.context, top_k=top_k)
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return None
def lift_word_boxes(self, document, page):
try:
# Log the function entry
logging.info(f'Entering lift_word_boxes with document={document} and page={page}')
# Extract the word boxes for the specified page from the document's context
return document.context["image"][page][1]
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def expand_bbox(self, word_boxes):
try:
# Log the function entry
logging.info(f'Entering expand_bbox with word_boxes={word_boxes}')
# Check if the input list of word boxes is empty
if len(word_boxes) == 0:
return None
# Extract the minimum and maximum coordinates of the word boxes
min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
# Calculate the overall minimum and maximum coordinates
min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
# Return the expanded bounding box as [min_x, min_y, max_x, max_y]
return [min_x, min_y, max_x, max_y]
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return None
def normalize_bbox(self, box, width, height, padding=0.005):
try:
# Log the function entry
logging.info(f'Entering normalize_bbox with box={box}, width={width}, height={height}, and padding={padding}')
# Extract the bounding box coordinates and convert them from millimeters to fractions
min_x, min_y, max_x, max_y = [c / 1000 for c in box]
# Apply padding if specified (as a fraction of image dimensions)
if padding != 0:
min_x = max(0, min_x - padding)
min_y = max(0, min_y - padding)
max_x = min(max_x + padding, 1)
max_y = min(max_y + padding, 1)
# Scale the normalized coordinates to match the image dimensions
return [min_x * width, min_y * height, max_x * width, max_y * height]
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return None
def annotate_page(self, prediction, pages, document):
try:
# Log the function entry
logging.info(f'Entering annotate_page with prediction={prediction}, pages={pages}, and document={document}')
# Check if a prediction exists and contains word_ids
if prediction is not None and "word_ids" in prediction:
# Get the image of the page where the prediction was made
image = pages[prediction["page"]]
# Create a drawing object for the image
draw = ImageDraw.Draw(image, "RGBA")
# Extract word boxes for the page
word_boxes = self.lift_word_boxes(document, prediction["page"])
# Expand and normalize the bounding box of the predicted words
x1, y1, x2, y2 = self.normalize_bbox(
self.expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
image.width,
image.height,
)
# Draw a semi-transparent green rectangle around the predicted words
draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
def process_fields(self, document, fields, model=list(CHECKPOINTS.keys())[0]):
try:
# Log the function entry
logging.info(f'Entering process_fields with document={document}, fields={fields}, and model={model}')
# Convert preview pages of the document to RGB format
pages = [x.copy().convert("RGB") for x in document.preview]
# Initialize dictionaries to store results
ret = {}
table = []
# Iterate through the fields and associated questions
for (field_name, questions) in fields.items():
# Extract answers for each question and filter based on score
answers = [
a
for q in questions
for a in self.ensure_list(self.run_pipeline(model, q, document, top_k=1))
if a.get("score", 1) > 0.5
]
# Sort answers by score (higher score first)
answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
# Get the top answer (if any)
top = answers[0] if len(answers) > 0 else None
# Annotate the page with the top answer's bounding box
self.annotate_page(top, pages, document)
# Store the top answer for the field and add it to the table
ret[field_name] = top
table.append([field_name, top.get("answer") if top is not None else None])
# Return the table of key-value pairs
return table
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def process_document(self, document, fields, model, error=None):
try:
# Log the function entry
logging.info(f'Entering process_document with document={document}, fields={fields}, model={model}, and error={error}')
# Check if the document is not None and no error occurred during processing
if document is not None and error is None:
# Process the fields in the document using the specified model
table = self.process_fields(document, fields, model)
return table
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def process_path(self, path, fields, model):
try:
# Log the function entry
logging.info(f'Entering process_path with path={path}, fields={fields}, and model={model}')
# Initialize error and document variables
error = None
document = None
# Check if a file path is provided
if path:
try:
# Load the document from the specified file path
document = load_document(path)
except Exception as e:
# Handle exceptions and store the error message
logging.error("An error occurred:", exc_info=True)
error = str(e)
# Process the loaded document and extract key-value pairs
return self.process_document(document, fields, model, error)
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def pdf_to_image(self, file_path):
try:
# Log the function entry
logging.info(f'Entering pdf_to_image with file_path={file_path}')
# Convert PDF to a list of image objects (one for each page)
images = convert_from_path(file_path)
# Loop through each image and save it
for i, image in enumerate(images):
image_path = f'page_{i + 1}.png'
return image_path
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def process_upload(self, file):
try:
# Log the function entry
logging.info(f'Entering process_upload with file={file}')
# Get the model and fields from the instance
model = self.model
fields = self.fields
# Convert the uploaded PDF file to a list of image files
image = self.pdf_to_image(file)
# Use the first generated image file as the file path for processing
file = image
# Process the document (image) and extract key-value pairs
return self.process_path(file if file else None, fields, model)
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
return []
def extract_key_value_pair(self, invoice_file):
try:
# Log the function entry
logging.info(f'Entering extract_key_value_pair with invoice_file={invoice_file}')
# Process the uploaded invoice PDF file and extract key-value pairs
data = self.process_upload(invoice_file.name)
# Iterate through the extracted key-value pairs and print them
for item in data:
key, value = item
return f'{key}: {value}'
except Exception as e:
# Log exceptions
logging.error("An error occurred:", exc_info=True)
|