rishabhv471 commited on
Commit
ccb8696
1 Parent(s): 0599d82

back to og

Browse files
Files changed (1) hide show
  1. app.py +497 -12
app.py CHANGED
@@ -2,6 +2,7 @@ import asyncio
2
  import string
3
  from collections import Counter
4
  from itertools import count, tee
 
5
  import cv2
6
  import matplotlib.pyplot as plt
7
  import numpy as np
@@ -9,54 +10,64 @@ import pandas as pd
9
  import streamlit as st
10
  import torch
11
  from PIL import Image
12
- from transformers import (DetrImageProcessor, TableTransformerForObjectDetection)
 
13
  from vietocr.tool.config import Cfg
14
  from vietocr.tool.predictor import Predictor
15
 
16
  st.set_option('deprecation.showPyplotGlobalUse', False)
17
  st.set_page_config(layout='wide')
18
- st.title("Table Detection and Table Structure Recognition")
19
  st.write(
20
  "Implemented by MSFT team: https://github.com/microsoft/table-transformer")
21
 
22
- # Config (optional, comment out if not using)
23
  # config = Cfg.load_config_from_name('vgg_transformer')
24
- # config = Cfg.load_config_from_name('vgg_seq2seq')
25
- # config['cnn']['pretrained'] = False
26
- # config['device'] = 'cpu'
27
- # config['predictor']['beamsearch'] = False
28
- # detector = Predictor(config)
29
 
30
  table_detection_model = TableTransformerForObjectDetection.from_pretrained(
31
  "microsoft/table-transformer-detection")
 
32
  table_recognition_model = TableTransformerForObjectDetection.from_pretrained(
33
  "microsoft/table-transformer-structure-recognition")
34
 
 
35
  def PIL_to_cv(pil_img):
36
  return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
37
 
 
38
  def cv_to_PIL(cv_img):
39
  return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
40
 
 
41
  async def pytess(cell_pil_img, threshold: float = 0.5):
42
- text, prob = detector.predict(cell_pil_img, return_prob=True) # Assuming detector is defined
43
  if prob < threshold:
44
  return ""
45
  return text.strip()
46
 
 
47
  def sharpen_image(pil_img):
 
48
  img = PIL_to_cv(pil_img)
49
  sharpen_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
 
50
  sharpen = cv2.filter2D(img, -1, sharpen_kernel)
51
  pil_img = cv_to_PIL(sharpen)
52
  return pil_img
53
 
 
54
  def uniquify(seq, suffs=count(1)):
55
  """Make all the items unique by adding a suffix (1, 2, etc).
56
-
57
  Credit: https://stackoverflow.com/questions/30650474/python-rename-duplicates-in-list-with-progressive-numbers-without-sorting-list
 
 
58
  """
59
  not_unique = [k for k, v in Counter(seq).items() if v > 1]
 
60
  suff_gens = dict(zip(not_unique, tee(suffs, len(not_unique))))
61
  for idx, s in enumerate(seq):
62
  try:
@@ -65,20 +76,494 @@ def uniquify(seq, suffs=count(1)):
65
  continue
66
  else:
67
  seq[idx] += suffix
 
68
  return seq
69
 
 
70
  def binarizeBlur_image(pil_img):
71
  image = PIL_to_cv(pil_img)
72
  thresh = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY_INV)[1]
 
73
  result = cv2.GaussianBlur(thresh, (5, 5), 0)
74
  result = 255 - result
75
  return cv_to_PIL(result)
76
 
 
77
  def td_postprocess(pil_img):
78
  '''
79
  Removes gray background from tables
80
  '''
81
  img = PIL_to_cv(pil_img)
 
82
  hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
83
- mask = cv2.inRange(hsv, (0, 0, 100), (255, 5, 255)) # (0, 0, 100), (255, 5, 255)
84
- nzmask = cv2.inRange(hsv, (0, 0, 5), (255, 255, 255)) # (0, 0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import string
3
  from collections import Counter
4
  from itertools import count, tee
5
+
6
  import cv2
7
  import matplotlib.pyplot as plt
8
  import numpy as np
 
10
  import streamlit as st
11
  import torch
12
  from PIL import Image
13
+ from transformers import (DetrImageProcessor,
14
+ TableTransformerForObjectDetection)
15
  from vietocr.tool.config import Cfg
16
  from vietocr.tool.predictor import Predictor
17
 
18
  st.set_option('deprecation.showPyplotGlobalUse', False)
19
  st.set_page_config(layout='wide')
20
+ st.title("Table Detection and Table Structure Recognition By VWITS")
21
  st.write(
22
  "Implemented by MSFT team: https://github.com/microsoft/table-transformer")
23
 
 
24
  # config = Cfg.load_config_from_name('vgg_transformer')
25
+ config = Cfg.load_config_from_name('vgg_seq2seq')
26
+ config['cnn']['pretrained'] = False
27
+ config['device'] = 'cpu'
28
+ config['predictor']['beamsearch'] = False
29
+ detector = Predictor(config)
30
 
31
  table_detection_model = TableTransformerForObjectDetection.from_pretrained(
32
  "microsoft/table-transformer-detection")
33
+
34
  table_recognition_model = TableTransformerForObjectDetection.from_pretrained(
35
  "microsoft/table-transformer-structure-recognition")
36
 
37
+
38
  def PIL_to_cv(pil_img):
39
  return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
40
 
41
+
42
  def cv_to_PIL(cv_img):
43
  return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
44
 
45
+
46
  async def pytess(cell_pil_img, threshold: float = 0.5):
47
+ text, prob = detector.predict(cell_pil_img, return_prob=True)
48
  if prob < threshold:
49
  return ""
50
  return text.strip()
51
 
52
+
53
  def sharpen_image(pil_img):
54
+
55
  img = PIL_to_cv(pil_img)
56
  sharpen_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
57
+
58
  sharpen = cv2.filter2D(img, -1, sharpen_kernel)
59
  pil_img = cv_to_PIL(sharpen)
60
  return pil_img
61
 
62
+
63
  def uniquify(seq, suffs=count(1)):
64
  """Make all the items unique by adding a suffix (1, 2, etc).
 
65
  Credit: https://stackoverflow.com/questions/30650474/python-rename-duplicates-in-list-with-progressive-numbers-without-sorting-list
66
+ `seq` is mutable sequence of strings.
67
+ `suffs` is an optional alternative suffix iterable.
68
  """
69
  not_unique = [k for k, v in Counter(seq).items() if v > 1]
70
+
71
  suff_gens = dict(zip(not_unique, tee(suffs, len(not_unique))))
72
  for idx, s in enumerate(seq):
73
  try:
 
76
  continue
77
  else:
78
  seq[idx] += suffix
79
+
80
  return seq
81
 
82
+
83
  def binarizeBlur_image(pil_img):
84
  image = PIL_to_cv(pil_img)
85
  thresh = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY_INV)[1]
86
+
87
  result = cv2.GaussianBlur(thresh, (5, 5), 0)
88
  result = 255 - result
89
  return cv_to_PIL(result)
90
 
91
+
92
  def td_postprocess(pil_img):
93
  '''
94
  Removes gray background from tables
95
  '''
96
  img = PIL_to_cv(pil_img)
97
+
98
  hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
99
+ mask = cv2.inRange(hsv, (0, 0, 100),
100
+ (255, 5, 255)) # (0, 0, 100), (255, 5, 255)
101
+ nzmask = cv2.inRange(hsv, (0, 0, 5),
102
+ (255, 255, 255)) # (0, 0, 5), (255, 255, 255))
103
+ nzmask = cv2.erode(nzmask, np.ones((3, 3))) # (3,3)
104
+ mask = mask & nzmask
105
+
106
+ new_img = img.copy()
107
+ new_img[np.where(mask)] = 255
108
+
109
+ return cv_to_PIL(new_img)
110
+
111
+
112
+ # def super_res(pil_img):
113
+ # # requires opencv-contrib-python installed without the opencv-python
114
+ # sr = dnn_superres.DnnSuperResImpl_create()
115
+ # image = PIL_to_cv(pil_img)
116
+ # model_path = "./LapSRN_x8.pb"
117
+ # model_name = model_path.split('/')[1].split('_')[0].lower()
118
+ # model_scale = int(model_path.split('/')[1].split('_')[1].split('.')[0][1])
119
+
120
+ # sr.readModel(model_path)
121
+ # sr.setModel(model_name, model_scale)
122
+ # final_img = sr.upsample(image)
123
+ # final_img = cv_to_PIL(final_img)
124
+
125
+ # return final_img
126
+
127
+
128
+ def table_detector(image, THRESHOLD_PROBA):
129
+ '''
130
+ Table detection using DEtect-object TRansformer pre-trained on 1 million tables
131
+
132
+ '''
133
+
134
+ feature_extractor = DetrImageProcessor(do_resize=True,
135
+ size=800,
136
+ max_size=800)
137
+ encoding = feature_extractor(image, return_tensors="pt")
138
+
139
+ with torch.no_grad():
140
+ outputs = table_detection_model(**encoding)
141
+
142
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
143
+ keep = probas.max(-1).values > THRESHOLD_PROBA
144
+
145
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
146
+ postprocessed_outputs = feature_extractor.post_process(
147
+ outputs, target_sizes)
148
+ bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
149
+
150
+ return (probas[keep], bboxes_scaled)
151
+
152
+
153
+ def table_struct_recog(image, THRESHOLD_PROBA):
154
+ '''
155
+ Table structure recognition using DEtect-object TRansformer pre-trained on 1 million tables
156
+ '''
157
+
158
+ feature_extractor = DetrImageProcessor(do_resize=True,
159
+ size=1000,
160
+ max_size=1000)
161
+ encoding = feature_extractor(image, return_tensors="pt")
162
+
163
+ with torch.no_grad():
164
+ outputs = table_recognition_model(**encoding)
165
+
166
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
167
+ keep = probas.max(-1).values > THRESHOLD_PROBA
168
+
169
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
170
+ postprocessed_outputs = feature_extractor.post_process(
171
+ outputs, target_sizes)
172
+ bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
173
+
174
+ return (probas[keep], bboxes_scaled)
175
+
176
+
177
+ class TableExtractionPipeline():
178
+
179
+ colors = ["red", "blue", "green", "yellow", "orange", "violet"]
180
+
181
+ # colors = ["red", "blue", "green", "red", "red", "red"]
182
+
183
+ def add_padding(self,
184
+ pil_img,
185
+ top,
186
+ right,
187
+ bottom,
188
+ left,
189
+ color=(255, 255, 255)):
190
+ '''
191
+ Image padding as part of TSR pre-processing to prevent missing table edges
192
+ '''
193
+ width, height = pil_img.size
194
+ new_width = width + right + left
195
+ new_height = height + top + bottom
196
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
197
+ result.paste(pil_img, (left, top))
198
+ return result
199
+
200
+ def plot_results_detection(self, c1, model, pil_img, prob, boxes,
201
+ delta_xmin, delta_ymin, delta_xmax, delta_ymax):
202
+ '''
203
+ crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
204
+ '''
205
+ # st.write('img_obj')
206
+ # st.write(pil_img)
207
+ plt.imshow(pil_img)
208
+ ax = plt.gca()
209
+
210
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
211
+ cl = p.argmax()
212
+ xmin, ymin, xmax, ymax = xmin - delta_xmin, ymin - delta_ymin, xmax + delta_xmax, ymax + delta_ymax
213
+ ax.add_patch(
214
+ plt.Rectangle((xmin, ymin),
215
+ xmax - xmin,
216
+ ymax - ymin,
217
+ fill=False,
218
+ color='red',
219
+ linewidth=3))
220
+ text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
221
+ ax.text(xmin - 20,
222
+ ymin - 50,
223
+ text,
224
+ fontsize=10,
225
+ bbox=dict(facecolor='yellow', alpha=0.5))
226
+ plt.axis('off')
227
+ c1.pyplot()
228
+
229
+ def crop_tables(self, pil_img, prob, boxes, delta_xmin, delta_ymin,
230
+ delta_xmax, delta_ymax):
231
+ '''
232
+ crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
233
+ '''
234
+ cropped_img_list = []
235
+
236
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
237
+
238
+ xmin, ymin, xmax, ymax = xmin - delta_xmin, ymin - delta_ymin, xmax + delta_xmax, ymax + delta_ymax
239
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
240
+ cropped_img_list.append(cropped_img)
241
+
242
+ return cropped_img_list
243
+
244
+ def generate_structure(self, c2, model, pil_img, prob, boxes,
245
+ expand_rowcol_bbox_top, expand_rowcol_bbox_bottom):
246
+ '''
247
+ Co-ordinates are adjusted here by 3 'pixels'
248
+ To plot table pillow image and the TSR bounding boxes on the table
249
+ '''
250
+ # st.write('img_obj')
251
+ # st.write(pil_img)
252
+ plt.figure(figsize=(32, 20))
253
+ plt.imshow(pil_img)
254
+ ax = plt.gca()
255
+ rows = {}
256
+ cols = {}
257
+ idx = 0
258
+
259
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
260
+
261
+ xmin, ymin, xmax, ymax = xmin, ymin, xmax, ymax
262
+ cl = p.argmax()
263
+ class_text = model.config.id2label[cl.item()]
264
+ text = f'{class_text}: {p[cl]:0.2f}'
265
+ # or (class_text == 'table column')
266
+ if (class_text
267
+ == 'table row') or (class_text
268
+ == 'table projected row header') or (
269
+ class_text == 'table column'):
270
+ ax.add_patch(
271
+ plt.Rectangle((xmin, ymin),
272
+ xmax - xmin,
273
+ ymax - ymin,
274
+ fill=False,
275
+ color=self.colors[cl.item()],
276
+ linewidth=2))
277
+ ax.text(xmin - 10,
278
+ ymin - 10,
279
+ text,
280
+ fontsize=5,
281
+ bbox=dict(facecolor='yellow', alpha=0.5))
282
+
283
+ if class_text == 'table row':
284
+ rows['table row.' +
285
+ str(idx)] = (xmin, ymin - expand_rowcol_bbox_top, xmax,
286
+ ymax + expand_rowcol_bbox_bottom)
287
+ if class_text == 'table column':
288
+ cols['table column.' +
289
+ str(idx)] = (xmin, ymin - expand_rowcol_bbox_top, xmax,
290
+ ymax + expand_rowcol_bbox_bottom)
291
+
292
+ idx += 1
293
+
294
+ plt.axis('on')
295
+ c2.pyplot()
296
+ return rows, cols
297
+
298
+ def sort_table_featuresv2(self, rows: dict, cols: dict):
299
+ # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox
300
+ rows_ = {
301
+ table_feature: (xmin, ymin, xmax, ymax)
302
+ for table_feature, (
303
+ xmin, ymin, xmax,
304
+ ymax) in sorted(rows.items(), key=lambda tup: tup[1][1])
305
+ }
306
+ cols_ = {
307
+ table_feature: (xmin, ymin, xmax, ymax)
308
+ for table_feature, (
309
+ xmin, ymin, xmax,
310
+ ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])
311
+ }
312
+
313
+ return rows_, cols_
314
+
315
+ def individual_table_featuresv2(self, pil_img, rows: dict, cols: dict):
316
+
317
+ for k, v in rows.items():
318
+ xmin, ymin, xmax, ymax = v
319
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
320
+ rows[k] = xmin, ymin, xmax, ymax, cropped_img
321
+
322
+ for k, v in cols.items():
323
+ xmin, ymin, xmax, ymax = v
324
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
325
+ cols[k] = xmin, ymin, xmax, ymax, cropped_img
326
+
327
+ return rows, cols
328
+
329
+ def object_to_cellsv2(self, master_row: dict, cols: dict,
330
+ expand_rowcol_bbox_top, expand_rowcol_bbox_bottom,
331
+ padd_left):
332
+ '''Removes redundant bbox for rows&columns and divides each row into cells from columns
333
+ Args:
334
+
335
+ Returns:
336
+
337
+
338
+ '''
339
+ cells_img = {}
340
+ header_idx = 0
341
+ row_idx = 0
342
+ previous_xmax_col = 0
343
+ new_cols = {}
344
+ new_master_row = {}
345
+ previous_ymin_row = 0
346
+ new_cols = cols
347
+ new_master_row = master_row
348
+ ## Below 2 for loops remove redundant bounding boxes ###
349
+ # for k_col, v_col in cols.items():
350
+ # xmin_col, _, xmax_col, _, col_img = v_col
351
+ # if (np.isclose(previous_xmax_col, xmax_col, atol=5)) or (xmin_col >= xmax_col):
352
+ # print('Found a column with double bbox')
353
+ # continue
354
+ # previous_xmax_col = xmax_col
355
+ # new_cols[k_col] = v_col
356
+
357
+ # for k_row, v_row in master_row.items():
358
+ # _, ymin_row, _, ymax_row, row_img = v_row
359
+ # if (np.isclose(previous_ymin_row, ymin_row, atol=5)) or (ymin_row >= ymax_row):
360
+ # print('Found a row with double bbox')
361
+ # continue
362
+ # previous_ymin_row = ymin_row
363
+ # new_master_row[k_row] = v_row
364
+ ######################################################
365
+ for k_row, v_row in new_master_row.items():
366
+
367
+ _, _, _, _, row_img = v_row
368
+ xmax, ymax = row_img.size
369
+ xa, ya, xb, yb = 0, 0, 0, ymax
370
+ row_img_list = []
371
+ # plt.imshow(row_img)
372
+ # st.pyplot()
373
+ for idx, kv in enumerate(new_cols.items()):
374
+ k_col, v_col = kv
375
+ xmin_col, _, xmax_col, _, col_img = v_col
376
+ xmin_col, xmax_col = xmin_col - padd_left - 10, xmax_col - padd_left
377
+ xa = xmin_col
378
+ xb = xmax_col
379
+ if idx == 0:
380
+ xa = 0
381
+ if idx == len(new_cols) - 1:
382
+ xb = xmax
383
+ xa, ya, xb, yb = xa, ya, xb, yb
384
+
385
+ row_img_cropped = row_img.crop((xa, ya, xb, yb))
386
+ row_img_list.append(row_img_cropped)
387
+
388
+ cells_img[k_row + '.' + str(row_idx)] = row_img_list
389
+ row_idx += 1
390
+
391
+ return cells_img, len(new_cols), len(new_master_row) - 1
392
+
393
+ def clean_dataframe(self, df):
394
+ '''
395
+ Remove irrelevant symbols that appear with tesseractOCR
396
+ '''
397
+ # df.columns = [col.replace('|', '') for col in df.columns]
398
+
399
+ for col in df.columns:
400
+
401
+ df[col] = df[col].str.replace("'", '', regex=True)
402
+ df[col] = df[col].str.replace('"', '', regex=True)
403
+ df[col] = df[col].str.replace(']', '', regex=True)
404
+ df[col] = df[col].str.replace('[', '', regex=True)
405
+ df[col] = df[col].str.replace('{', '', regex=True)
406
+ df[col] = df[col].str.replace('}', '', regex=True)
407
+ return df
408
+
409
+ @st.cache
410
+ def convert_df(self, df):
411
+ return df.to_csv().encode('utf-8')
412
+
413
+ def create_dataframe(self, c3, cell_ocr_res: list, max_cols: int,
414
+ max_rows: int):
415
+ '''Create dataframe using list of cell values of the table, also checks for valid header of dataframe
416
+ Args:
417
+ cell_ocr_res: list of strings, each element representing a cell in a table
418
+ max_cols, max_rows: number of columns and rows
419
+ Returns:
420
+ dataframe : final dataframe after all pre-processing
421
+ '''
422
+
423
+ headers = cell_ocr_res[:max_cols]
424
+ new_headers = uniquify(headers,
425
+ (f' {x!s}' for x in string.ascii_lowercase))
426
+ counter = 0
427
+
428
+ cells_list = cell_ocr_res[max_cols:]
429
+ df = pd.DataFrame("", index=range(0, max_rows), columns=new_headers)
430
+
431
+ cell_idx = 0
432
+ for nrows in range(max_rows):
433
+ for ncols in range(max_cols):
434
+ df.iat[nrows, ncols] = str(cells_list[cell_idx])
435
+ cell_idx += 1
436
+
437
+ ## To check if there are duplicate headers if result of uniquify+col == col
438
+ ## This check removes headers when all headers are empty or if median of header word count is less than 6
439
+ for x, col in zip(string.ascii_lowercase, new_headers):
440
+ if f' {x!s}' == col:
441
+ counter += 1
442
+ header_char_count = [len(col) for col in new_headers]
443
+
444
+ # if (counter == len(new_headers)) or (statistics.median(header_char_count) < 6):
445
+ # st.write('woooot')
446
+ # df.columns = uniquify(df.iloc[0], (f' {x!s}' for x in string.ascii_lowercase))
447
+ # df = df.iloc[1:,:]
448
+
449
+ df = self.clean_dataframe(df)
450
+
451
+ c3.dataframe(df)
452
+ csv = self.convert_df(df)
453
+ c3.download_button("Download table",
454
+ csv,
455
+ "file.csv",
456
+ "text/csv",
457
+ key='download-csv')
458
+
459
+ return df
460
+
461
+ async def start_process(self, image_path: str, TD_THRESHOLD, TSR_THRESHOLD,
462
+ OCR_THRESHOLD, padd_top, padd_left, padd_bottom,
463
+ padd_right, delta_xmin, delta_ymin, delta_xmax,
464
+ delta_ymax, expand_rowcol_bbox_top,
465
+ expand_rowcol_bbox_bottom):
466
+ '''
467
+ Initiates process of generating pandas dataframes from raw pdf-page images
468
+
469
+ '''
470
+ image = Image.open(image_path).convert("RGB")
471
+ probas, bboxes_scaled = table_detector(image,
472
+ THRESHOLD_PROBA=TD_THRESHOLD)
473
+
474
+ if bboxes_scaled.nelement() == 0:
475
+ st.write('No table found in the pdf-page image')
476
+ return ''
477
+
478
+ # try:
479
+ # st.write('Document: '+image_path.split('/')[-1])
480
+ c1, c2, c3 = st.columns((1, 1, 1))
481
+
482
+ self.plot_results_detection(c1, table_detection_model, image, probas,
483
+ bboxes_scaled, delta_xmin, delta_ymin,
484
+ delta_xmax, delta_ymax)
485
+ cropped_img_list = self.crop_tables(image, probas, bboxes_scaled,
486
+ delta_xmin, delta_ymin, delta_xmax,
487
+ delta_ymax)
488
+
489
+ for unpadded_table in cropped_img_list:
490
+
491
+ table = self.add_padding(unpadded_table, padd_top, padd_right,
492
+ padd_bottom, padd_left)
493
+ # table = super_res(table)
494
+ # table = binarizeBlur_image(table)
495
+ # table = sharpen_image(table) # Test sharpen image next
496
+ # table = td_postprocess(table)
497
+
498
+ probas, bboxes_scaled = table_struct_recog(
499
+ table, THRESHOLD_PROBA=TSR_THRESHOLD)
500
+ rows, cols = self.generate_structure(c2, table_recognition_model,
501
+ table, probas, bboxes_scaled,
502
+ expand_rowcol_bbox_top,
503
+ expand_rowcol_bbox_bottom)
504
+ # st.write(len(rows), len(cols))
505
+ rows, cols = self.sort_table_featuresv2(rows, cols)
506
+ master_row, cols = self.individual_table_featuresv2(
507
+ table, rows, cols)
508
+
509
+ cells_img, max_cols, max_rows = self.object_to_cellsv2(
510
+ master_row, cols, expand_rowcol_bbox_top,
511
+ expand_rowcol_bbox_bottom, padd_left)
512
+
513
+ sequential_cell_img_list = []
514
+ for k, img_list in cells_img.items():
515
+ for img in img_list:
516
+ # img = super_res(img)
517
+ # img = sharpen_image(img) # Test sharpen image next
518
+ # img = binarizeBlur_image(img)
519
+ # img = self.add_padding(img, 10,10,10,10)
520
+ # plt.imshow(img)
521
+ # c3.pyplot()
522
+ sequential_cell_img_list.append(
523
+ pytess(cell_pil_img=img, threshold=OCR_THRESHOLD))
524
+
525
+ cell_ocr_res = await asyncio.gather(*sequential_cell_img_list)
526
+
527
+ self.create_dataframe(c3, cell_ocr_res, max_cols, max_rows)
528
+ st.write(
529
+ 'Errors in OCR is due to either quality of the image or performance of the OCR'
530
+ )
531
+ # except:
532
+ # st.write('Either incorrectly identified table or no table, to debug remove try/except')
533
+ # break
534
+ # break
535
+
536
+
537
+ if __name__ == "__main__":
538
+
539
+ img_name = st.file_uploader("Upload an image with table(s)")
540
+ st1, st2, st3 = st.columns((1, 1, 1))
541
+ TD_th = st1.slider('Table detection threshold', 0.0, 1.0, 0.8)
542
+ TSR_th = st2.slider('Table structure recognition threshold', 0.0, 1.0, 0.8)
543
+ OCR_th = st3.slider("Text Probs Threshold", 0.0, 1.0, 0.5)
544
+
545
+ st1, st2, st3, st4 = st.columns((1, 1, 1, 1))
546
+
547
+ padd_top = st1.slider('Padding top', 0, 200, 40)
548
+ padd_left = st2.slider('Padding left', 0, 200, 40)
549
+ padd_right = st3.slider('Padding right', 0, 200, 40)
550
+ padd_bottom = st4.slider('Padding bottom', 0, 200, 40)
551
+
552
+ te = TableExtractionPipeline()
553
+ # for img in image_list:
554
+ if img_name is not None:
555
+ asyncio.run(
556
+ te.start_process(img_name,
557
+ TD_THRESHOLD=TD_th,
558
+ TSR_THRESHOLD=TSR_th,
559
+ OCR_THRESHOLD=OCR_th,
560
+ padd_top=padd_top,
561
+ padd_left=padd_left,
562
+ padd_bottom=padd_bottom,
563
+ padd_right=padd_right,
564
+ delta_xmin=0,
565
+ delta_ymin=0,
566
+ delta_xmax=0,
567
+ delta_ymax=0,
568
+ expand_rowcol_bbox_top=0,
569
+ expand_rowcol_bbox_bottom=0))