naveenvenkatesh commited on
Commit
0e5d81d
·
1 Parent(s): 57c9413

Delete invoice_extractor.py

Browse files
Files changed (1) hide show
  1. invoice_extractor.py +0 -341
invoice_extractor.py DELETED
@@ -1,341 +0,0 @@
1
- import os
2
- import logging
3
- from PIL import Image, ImageDraw
4
- import traceback
5
- import torch
6
- from docquery import pipeline
7
- from docquery.document import load_bytes, load_document, ImageDocument
8
- from docquery.ocr_reader import get_ocr_reader
9
- from pdf2image import convert_from_path
10
-
11
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
12
-
13
- # Initialize the logger
14
- logging.basicConfig(filename="invoice_extraction.log", level=logging.DEBUG) # Create a log file
15
-
16
- # Checkpoint for different models
17
- CHECKPOINTS = {
18
- "LayoutLMv1 for Invoices 🧾": "impira/layoutlm-invoices",
19
- }
20
- PIPELINES = {}
21
-
22
-
23
- class InvoiceKeyValuePair():
24
-
25
- """
26
- This class provides a utility to extract key-value pairs from invoices using LayoutLM.
27
- """
28
-
29
- def __init__(self):
30
-
31
- self.fields = {
32
- "Vendor Name": ["Vendor Name - Logo?", "Vendor Name - Address?"],
33
- "Vendor Address": ["Vendor Address?"],
34
- "Customer Name": ["Customer Name?"],
35
- "Customer Address": ["Customer Address?"],
36
- "Invoice Number": ["Invoice Number?"],
37
- "Invoice Date": ["Invoice Date?"],
38
- "Due Date": ["Due Date?"],
39
- "Subtotal": ["Subtotal?"],
40
- "Total Tax": ["Total Tax?"],
41
- "Invoice Total": ["Invoice Total?"],
42
- "Amount Due": ["Amount Due?"],
43
- "Payment Terms": ["Payment Terms?"],
44
- "Remit To Name": ["Remit To Name?"],
45
- "Remit To Address": ["Remit To Address?"],
46
- }
47
- self.model = list(CHECKPOINTS.keys())[0]
48
-
49
- def ensure_list(self, x):
50
- try:
51
- # Log the function entry
52
- logging.info(f'Entering ensure_list with x={x}')
53
-
54
- # Check if 'x' is already a list
55
- if isinstance(x, list):
56
- return x
57
- else:
58
- # If 'x' is not a list, wrap it in a list and return
59
- return [x]
60
- except Exception as e:
61
- # Log exceptions
62
- logging.error("An error occurred:", exc_info=True)
63
- return []
64
-
65
- def construct_pipeline(self, task, model):
66
- try:
67
- # Log the function entry
68
- logging.info(f'Entering construct_pipeline with task={task} and model={model}')
69
-
70
- # Global dictionary to cache pipelines based on model checkpoint names
71
- global PIPELINES
72
-
73
- # Check if a pipeline for the specified model already exists in the cache
74
- if model in PIPELINES:
75
- # If it exists, return the cached pipeline
76
- return PIPELINES[model]
77
- try:
78
- # Determine the device to use for inference (GPU if available, else CPU)
79
- device = "cuda" if torch.cuda.is_available() else "cpu"
80
-
81
- # Create the pipeline using the specified task and model checkpoint
82
- ret = pipeline(task=task, model=CHECKPOINTS[model], device=device)
83
-
84
- # Cache the created pipeline for future use
85
- PIPELINES[model] = ret
86
-
87
- # Return the constructed pipeline
88
- return ret
89
- except Exception as e:
90
- # Handle exceptions and log the error message
91
- logging.error("An error occurred:", exc_info=True)
92
- return None
93
- except Exception as e:
94
- # Log exceptions
95
- logging.error("An error occurred:", exc_info=True)
96
- return None
97
-
98
- def run_pipeline(self, model, question, document, top_k):
99
- try:
100
- # Log the function entry
101
- logging.info(f'Entering run_pipeline with model={model}, question={question}, and document={document}')
102
-
103
- # Use the construct_pipeline method to get or create a pipeline for the specified model
104
- pipeline = self.construct_pipeline("document-question-answering", model)
105
-
106
- # Use the constructed pipeline to perform question-answering on the document
107
- # Pass the question, document context, and top_k as arguments to the pipeline
108
- return pipeline(question=question, **document.context, top_k=top_k)
109
- except Exception as e:
110
- # Log exceptions
111
- logging.error("An error occurred:", exc_info=True)
112
- return None
113
-
114
- def lift_word_boxes(self, document, page):
115
- try:
116
- # Log the function entry
117
- logging.info(f'Entering lift_word_boxes with document={document} and page={page}')
118
-
119
- # Extract the word boxes for the specified page from the document's context
120
- return document.context["image"][page][1]
121
- except Exception as e:
122
- # Log exceptions
123
- logging.error("An error occurred:", exc_info=True)
124
- return []
125
-
126
- def expand_bbox(self, word_boxes):
127
- try:
128
- # Log the function entry
129
- logging.info(f'Entering expand_bbox with word_boxes={word_boxes}')
130
-
131
- # Check if the input list of word boxes is empty
132
- if len(word_boxes) == 0:
133
- return None
134
-
135
- # Extract the minimum and maximum coordinates of the word boxes
136
- min_x, min_y, max_x, max_y = zip(*[x[1] for x in word_boxes])
137
-
138
- # Calculate the overall minimum and maximum coordinates
139
- min_x, min_y, max_x, max_y = [min(min_x), min(min_y), max(max_x), max(max_y)]
140
-
141
- # Return the expanded bounding box as [min_x, min_y, max_x, max_y]
142
- return [min_x, min_y, max_x, max_y]
143
- except Exception as e:
144
- # Log exceptions
145
- logging.error("An error occurred:", exc_info=True)
146
- return None
147
-
148
- def normalize_bbox(self, box, width, height, padding=0.005):
149
- try:
150
- # Log the function entry
151
- logging.info(f'Entering normalize_bbox with box={box}, width={width}, height={height}, and padding={padding}')
152
-
153
- # Extract the bounding box coordinates and convert them from millimeters to fractions
154
- min_x, min_y, max_x, max_y = [c / 1000 for c in box]
155
-
156
- # Apply padding if specified (as a fraction of image dimensions)
157
- if padding != 0:
158
- min_x = max(0, min_x - padding)
159
- min_y = max(0, min_y - padding)
160
- max_x = min(max_x + padding, 1)
161
- max_y = min(max_y + padding, 1)
162
-
163
- # Scale the normalized coordinates to match the image dimensions
164
- return [min_x * width, min_y * height, max_x * width, max_y * height]
165
- except Exception as e:
166
- # Log exceptions
167
- logging.error("An error occurred:", exc_info=True)
168
- return None
169
-
170
- def annotate_page(self, prediction, pages, document):
171
- try:
172
- # Log the function entry
173
- logging.info(f'Entering annotate_page with prediction={prediction}, pages={pages}, and document={document}')
174
-
175
- # Check if a prediction exists and contains word_ids
176
- if prediction is not None and "word_ids" in prediction:
177
-
178
- # Get the image of the page where the prediction was made
179
- image = pages[prediction["page"]]
180
-
181
- # Create a drawing object for the image
182
- draw = ImageDraw.Draw(image, "RGBA")
183
-
184
- # Extract word boxes for the page
185
- word_boxes = self.lift_word_boxes(document, prediction["page"])
186
-
187
- # Expand and normalize the bounding box of the predicted words
188
- x1, y1, x2, y2 = self.normalize_bbox(
189
- self.expand_bbox([word_boxes[i] for i in prediction["word_ids"]]),
190
- image.width,
191
- image.height,
192
- )
193
-
194
- # Draw a semi-transparent green rectangle around the predicted words
195
- draw.rectangle(((x1, y1), (x2, y2)), fill=(0, 255, 0, int(0.4 * 255)))
196
- except Exception as e:
197
- # Log exceptions
198
- logging.error("An error occurred:", exc_info=True)
199
-
200
- def process_fields(self, document, fields, model=list(CHECKPOINTS.keys())[0]):
201
- try:
202
- # Log the function entry
203
- logging.info(f'Entering process_fields with document={document}, fields={fields}, and model={model}')
204
-
205
- # Convert preview pages of the document to RGB format
206
- pages = [x.copy().convert("RGB") for x in document.preview]
207
-
208
- # Initialize dictionaries to store results
209
- ret = {}
210
- table = []
211
-
212
- # Iterate through the fields and associated questions
213
- for (field_name, questions) in fields.items():
214
-
215
- # Extract answers for each question and filter based on score
216
- answers = [
217
- a
218
- for q in questions
219
- for a in self.ensure_list(self.run_pipeline(model, q, document, top_k=1))
220
- if a.get("score", 1) > 0.5
221
- ]
222
-
223
- # Sort answers by score (higher score first)
224
- answers.sort(key=lambda x: -x.get("score", 0) if x else 0)
225
-
226
- # Get the top answer (if any)
227
- top = answers[0] if len(answers) > 0 else None
228
-
229
- # Annotate the page with the top answer's bounding box
230
- self.annotate_page(top, pages, document)
231
-
232
- # Store the top answer for the field and add it to the table
233
- ret[field_name] = top
234
- table.append([field_name, top.get("answer") if top is not None else None])
235
-
236
- # Return the table of key-value pairs
237
- return table
238
- except Exception as e:
239
- # Log exceptions
240
- logging.error("An error occurred:", exc_info=True)
241
- return []
242
-
243
- def process_document(self, document, fields, model, error=None):
244
- try:
245
- # Log the function entry
246
- logging.info(f'Entering process_document with document={document}, fields={fields}, model={model}, and error={error}')
247
-
248
- # Check if the document is not None and no error occurred during processing
249
- if document is not None and error is None:
250
-
251
- # Process the fields in the document using the specified model
252
- table = self.process_fields(document, fields, model)
253
- return table
254
- except Exception as e:
255
- # Log exceptions
256
- logging.error("An error occurred:", exc_info=True)
257
- return []
258
-
259
- def process_path(self, path, fields, model):
260
- try:
261
- # Log the function entry
262
- logging.info(f'Entering process_path with path={path}, fields={fields}, and model={model}')
263
-
264
- # Initialize error and document variables
265
- error = None
266
- document = None
267
-
268
- # Check if a file path is provided
269
- if path:
270
- try:
271
- # Load the document from the specified file path
272
- document = load_document(path)
273
- except Exception as e:
274
- # Handle exceptions and store the error message
275
- logging.error("An error occurred:", exc_info=True)
276
- error = str(e)
277
-
278
- # Process the loaded document and extract key-value pairs
279
- return self.process_document(document, fields, model, error)
280
- except Exception as e:
281
- # Log exceptions
282
- logging.error("An error occurred:", exc_info=True)
283
- return []
284
-
285
- def pdf_to_image(self, file_path):
286
- try:
287
- # Log the function entry
288
- logging.info(f'Entering pdf_to_image with file_path={file_path}')
289
-
290
- # Convert PDF to a list of image objects (one for each page)
291
- images = convert_from_path(file_path)
292
-
293
- # Loop through each image and save it
294
- for i, image in enumerate(images):
295
- image_path = f'page_{i + 1}.png'
296
-
297
- return image_path
298
- except Exception as e:
299
- # Log exceptions
300
- logging.error("An error occurred:", exc_info=True)
301
- return []
302
-
303
- def process_upload(self, file):
304
- try:
305
- # Log the function entry
306
- logging.info(f'Entering process_upload with file={file}')
307
-
308
- # Get the model and fields from the instance
309
- model = self.model
310
- fields = self.fields
311
-
312
- # Convert the uploaded PDF file to a list of image files
313
- image = self.pdf_to_image(file)
314
-
315
- # Use the first generated image file as the file path for processing
316
- file = image
317
-
318
- # Process the document (image) and extract key-value pairs
319
- return self.process_path(file if file else None, fields, model)
320
- except Exception as e:
321
- # Log exceptions
322
- logging.error("An error occurred:", exc_info=True)
323
- return []
324
-
325
- def extract_key_value_pair(self, invoice_file):
326
- try:
327
- # Log the function entry
328
- logging.info(f'Entering extract_key_value_pair with invoice_file={invoice_file}')
329
-
330
- # Process the uploaded invoice PDF file and extract key-value pairs
331
- data = self.process_upload(invoice_file.name)
332
-
333
- # Iterate through the extracted key-value pairs and print them
334
- for item in data:
335
- key, value = item
336
- return f'{key}: {value}'
337
-
338
- except Exception as e:
339
- # Log exceptions
340
- logging.error("An error occurred:", exc_info=True)
341
-