KoonJamesZ commited on
Commit
e57a5af
·
verified ·
1 Parent(s): cbde093

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +243 -0
app.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import TableTransformerForObjectDetection
4
+ import matplotlib.pyplot as plt
5
+ from transformers import DetrFeatureExtractor
6
+ import pandas as pd
7
+ import uuid
8
+ from surya.ocr import run_ocr
9
+ # from surya.model.detection.segformer import load_model as load_det_model, load_processor as load_det_processor
10
+ from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
11
+ from surya.model.recognition.model import load_model as load_rec_model
12
+ from surya.model.recognition.processor import load_processor as load_rec_processor
13
+ from PIL import ImageDraw, Image
14
+ import os
15
+ from pdf2image import convert_from_path
16
+ import tempfile
17
+ from ultralyticsplus import YOLO, render_result
18
+ import cv2
19
+ import numpy as np
20
+ from fpdf import FPDF
21
+
22
+ def convert_pdf_images(pdf_path):
23
+ # Convert PDF to images
24
+ images = convert_from_path(pdf_path)
25
+
26
+ # Save each page as a temporary image and collect file paths
27
+ temp_file_paths = []
28
+ for i, page in enumerate(images):
29
+ # Create a temporary file with a unique name
30
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
31
+ page.save(temp_file.name, 'PNG') # Save the image to the temporary file
32
+ temp_file_paths.append(temp_file.name) # Add file path to the list
33
+
34
+ return temp_file_paths[0] # Return the list of temporary file paths
35
+
36
+
37
+ # Load model
38
+ model_yolo = YOLO('keremberke/yolov8m-table-extraction')
39
+
40
+ # Set model parameters
41
+ model_yolo.overrides['conf'] = 0.25 # NMS confidence threshold
42
+ model_yolo.overrides['iou'] = 0.45 # NMS IoU threshold
43
+ model_yolo.overrides['agnostic_nms'] = False # NMS class-agnostic
44
+ model_yolo.overrides['max_det'] = 1000 # maximum number of detections per image
45
+ def resize_image(image, max_dimension=4200, min_dimension=50):
46
+ width, height = image.size
47
+ # Check if the dimensions are within range
48
+ if width > max_dimension or height > max_dimension or width < min_dimension or height < min_dimension:
49
+ scaling_factor = min(max_dimension / max(width, height), min_dimension / min(width, height))
50
+ new_width = int(width * scaling_factor)
51
+ new_height = int(height * scaling_factor)
52
+ return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
53
+ return image
54
+ def crop_table(filename):
55
+ # Set image
56
+ image_path = filename
57
+ image = Image.open(image_path)
58
+ image_np = np.array(image)
59
+
60
+ # Perform inference
61
+ results = model_yolo.predict(image_path)
62
+
63
+ # Extract the first bounding box (assuming there's only one table)
64
+ bbox = results[0].boxes[0]
65
+ x1, y1, x2, y2 = map(int, bbox.xyxy[0]) # Get the bounding box coordinates
66
+
67
+ # Crop the image using the bounding box coordinates
68
+ cropped_image = image_np[y1:y2, x1:x2]
69
+
70
+ # Convert the cropped image to RGB (if it's not already in RGB)
71
+ cropped_image_rgb = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)
72
+
73
+ # Save the cropped image as a PDF
74
+ cropped_image_pil = Image.fromarray(cropped_image_rgb)
75
+ # Save the cropped image to a temporary file
76
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
77
+ cropped_image_pil.save(temp_file.name)
78
+
79
+ return temp_file.name
80
+
81
+ # new v1.1 checkpoints require no timm anymore
82
+ device = "cuda" if torch.cuda.is_available() else "cpu"
83
+ langs = ["en","th"] # Replace with your languages - optional but recommended
84
+ det_processor, det_model = load_det_processor(), load_det_model()
85
+ rec_model, rec_processor = load_rec_model(), load_rec_processor()
86
+
87
+
88
+
89
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
90
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
91
+ feature_extractor = DetrFeatureExtractor()
92
+
93
+ model = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition-v1.1-all")
94
+
95
+
96
+
97
+ def compute_boxes(image_path):
98
+ image = Image.open(image_path).convert("RGB")
99
+ width, height = image.size
100
+
101
+ encoding = feature_extractor(image, return_tensors="pt")
102
+
103
+ with torch.no_grad():
104
+ outputs = model(**encoding)
105
+
106
+ results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]
107
+ boxes = results['boxes'].tolist()
108
+ labels = results['labels'].tolist()
109
+
110
+ return boxes,labels
111
+
112
+ def extract_table(image_path):
113
+ image = Image.open(image_path)
114
+ boxes,labels = compute_boxes(image_path)
115
+
116
+
117
+ cropped_table_visualized = image.copy()
118
+ draw = ImageDraw.Draw(cropped_table_visualized)
119
+
120
+ for cell in boxes:
121
+ draw.rectangle(cell, outline="red")
122
+ bbox_table = f"{str(uuid.uuid4())}.png"
123
+ cropped_table_visualized.save(bbox_table)
124
+ cell_locations = []
125
+
126
+ for box_row, label_row in zip(boxes, labels):
127
+ if label_row == 2:
128
+ for box_col, label_col in zip(boxes, labels):
129
+ if label_col == 1:
130
+ cell_box = (box_col[0], box_row[1], box_col[2], box_row[3])
131
+ cell_locations.append(cell_box)
132
+
133
+ cell_locations.sort(key=lambda x: (x[1], x[0]))
134
+
135
+ num_columns = 0
136
+ box_old = cell_locations[0]
137
+
138
+ for box in cell_locations[1:]:
139
+ x1, y1, x2, y2 = box
140
+ x1_old, y1_old, x2_old, y2_old = box_old
141
+ num_columns += 1
142
+ if y1 > y1_old:
143
+ break
144
+
145
+ box_old = box
146
+
147
+ headers = []
148
+ for box in cell_locations[:num_columns]:
149
+ x1, y1, x2, y2 = box
150
+ cell_image = resize_image(image.crop((x1, y1, x2, y2)))
151
+ # new_width = cell_image.width *4
152
+ # new_height = cell_image.height *4
153
+ # cell_image = cell_image.resize((new_width, new_height), resample=Image.LANCZOS)
154
+ # cell_text = pytesseract.image_to_string(cell_image, lang='tha+eng')
155
+ # print(cell_text)
156
+
157
+ plt.figure()
158
+ plt.imshow(cell_image)
159
+ plt.axis("off")
160
+ plt.title("Cropped Cell Image")
161
+ plt.show()
162
+
163
+ predictions = run_ocr([cell_image], [langs], det_model, det_processor, rec_model, rec_processor)
164
+ texts = [line.text for line in predictions[0].text_lines]
165
+ all_text = ' '.join(texts)
166
+ print(all_text)
167
+ if all_text:
168
+ headers.append(all_text)
169
+ else:
170
+ headers.append('')
171
+
172
+
173
+ df = pd.DataFrame(columns=headers)
174
+
175
+ row = []
176
+ for box in cell_locations[num_columns:]:
177
+ x1, y1, x2, y2 = box
178
+ cell_image = resize_image(image.crop((x1, y1, x2, y2)))
179
+ # new_width = cell_image.width * 4
180
+ # new_height = cell_image.height * 4
181
+ # cell_image = cell_image.resize((new_width, new_height), resample=Image.LANCZOS)
182
+ # cell_text = pytesseract.image_to_string(cell_image, lang='tha+eng')
183
+ # print(cell_text)
184
+
185
+ plt.figure()
186
+ plt.imshow(cell_image)
187
+ plt.axis("off")
188
+ plt.title("Cropped Cell Image")
189
+ plt.show()
190
+ predictions = run_ocr([cell_image], [langs], det_model, det_processor, rec_model, rec_processor)
191
+ texts = [line.text for line in predictions[0].text_lines]
192
+ all_text = ''.join(texts)
193
+ print(all_text)
194
+ if all_text:
195
+ headers.append(all_text)
196
+ else:
197
+ headers.append('')
198
+
199
+ row.append(all_text)
200
+
201
+ if len(row) == num_columns:
202
+ df.loc[len(df)] = row
203
+ print(row)
204
+ row = []
205
+ filepath = f"{str(uuid.uuid4())}.csv"
206
+ df.to_csv(filepath, index=False)
207
+ return filepath, bbox_table
208
+
209
+ # Function to process the uploaded file
210
+ def process_file(uploaded_file):
211
+ images_table = convert_pdf_images(uploaded_file)
212
+ croped_table = crop_table(images_table)
213
+ filepath, bbox_table = extract_table(croped_table)
214
+ os.remove(images_table)
215
+ os.remove(croped_table)
216
+ return filepath, bbox_table # Return the file path for download
217
+
218
+ # Function to clear the inputs and outputs
219
+ def clear_inputs():
220
+ return None, None, None # Clear both input and output
221
+
222
+ # Define the Gradio interface
223
+ with gr.Blocks() as demo:
224
+ gr.Markdown("## Upload a PDF, Process it, and Download the Processed File")
225
+
226
+ with gr.Row():
227
+ upload = gr.File(label="Upload PDF", type="filepath", file_types=[".pdf"])
228
+ download = gr.File(label="Download Processed PDF")
229
+ with gr.Row():
230
+ process_button = gr.Button("Process")
231
+ clear_button = gr.Button("Clear") # Custom clear button
232
+ image_display = gr.Image(label="Processed Image")
233
+
234
+ # Trigger the file processing with the button click
235
+ process_button.click(process_file, inputs=upload, outputs=[download, image_display])
236
+
237
+ # Trigger clearing inputs and outputs
238
+ clear_button.click(clear_inputs, inputs=None, outputs=[upload, download, image_display])
239
+
240
+ # Launch the interface
241
+ demo.launch()
242
+
243
+ # print(process_file("/content/give me a example table - give me a example table.pdf"))