rishabhv471 commited on
Commit
0599d82
1 Parent(s): 58ab6bf
Files changed (1) hide show
  1. app.py +12 -497
app.py CHANGED
@@ -2,7 +2,6 @@ import asyncio
2
  import string
3
  from collections import Counter
4
  from itertools import count, tee
5
-
6
  import cv2
7
  import matplotlib.pyplot as plt
8
  import numpy as np
@@ -10,64 +9,54 @@ import pandas as pd
10
  import streamlit as st
11
  import torch
12
  from PIL import Image
13
- from transformers import (DetrImageProcessor,
14
- TableTransformerForObjectDetection)
15
  from vietocr.tool.config import Cfg
16
  from vietocr.tool.predictor import Predictor
17
 
18
  st.set_option('deprecation.showPyplotGlobalUse', False)
19
  st.set_page_config(layout='wide')
20
- st.title("Table Detection and Table Structure Recognition By VWITS")
21
  st.write(
22
  "Implemented by MSFT team: https://github.com/microsoft/table-transformer")
23
 
 
24
  # config = Cfg.load_config_from_name('vgg_transformer')
25
- config = Cfg.load_config_from_name('vgg_seq2seq')
26
- config['cnn']['pretrained'] = False
27
- config['device'] = 'cpu'
28
- config['predictor']['beamsearch'] = False
29
- detector = Predictor(config)
30
 
31
  table_detection_model = TableTransformerForObjectDetection.from_pretrained(
32
  "microsoft/table-transformer-detection")
33
-
34
  table_recognition_model = TableTransformerForObjectDetection.from_pretrained(
35
  "microsoft/table-transformer-structure-recognition")
36
 
37
-
38
  def PIL_to_cv(pil_img):
39
  return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
40
 
41
-
42
  def cv_to_PIL(cv_img):
43
  return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
44
 
45
-
46
  async def pytess(cell_pil_img, threshold: float = 0.5):
47
- text, prob = detector.predict(cell_pil_img, return_prob=True)
48
  if prob < threshold:
49
  return ""
50
  return text.strip()
51
 
52
-
53
  def sharpen_image(pil_img):
54
-
55
  img = PIL_to_cv(pil_img)
56
  sharpen_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
57
-
58
  sharpen = cv2.filter2D(img, -1, sharpen_kernel)
59
  pil_img = cv_to_PIL(sharpen)
60
  return pil_img
61
 
62
-
63
  def uniquify(seq, suffs=count(1)):
64
  """Make all the items unique by adding a suffix (1, 2, etc).
 
65
  Credit: https://stackoverflow.com/questions/30650474/python-rename-duplicates-in-list-with-progressive-numbers-without-sorting-list
66
- `seq` is mutable sequence of strings.
67
- `suffs` is an optional alternative suffix iterable.
68
  """
69
  not_unique = [k for k, v in Counter(seq).items() if v > 1]
70
-
71
  suff_gens = dict(zip(not_unique, tee(suffs, len(not_unique))))
72
  for idx, s in enumerate(seq):
73
  try:
@@ -76,494 +65,20 @@ def uniquify(seq, suffs=count(1)):
76
  continue
77
  else:
78
  seq[idx] += suffix
79
-
80
  return seq
81
 
82
-
83
  def binarizeBlur_image(pil_img):
84
  image = PIL_to_cv(pil_img)
85
  thresh = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY_INV)[1]
86
-
87
  result = cv2.GaussianBlur(thresh, (5, 5), 0)
88
  result = 255 - result
89
  return cv_to_PIL(result)
90
 
91
-
92
  def td_postprocess(pil_img):
93
  '''
94
  Removes gray background from tables
95
  '''
96
  img = PIL_to_cv(pil_img)
97
-
98
  hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
99
- mask = cv2.inRange(hsv, (0, 0, 100),
100
- (255, 5, 255)) # (0, 0, 100), (255, 5, 255)
101
- nzmask = cv2.inRange(hsv, (0, 0, 5),
102
- (255, 255, 255)) # (0, 0, 5), (255, 255, 255))
103
- nzmask = cv2.erode(nzmask, np.ones((3, 3))) # (3,3)
104
- mask = mask & nzmask
105
-
106
- new_img = img.copy()
107
- new_img[np.where(mask)] = 255
108
-
109
- return cv_to_PIL(new_img)
110
-
111
-
112
- # def super_res(pil_img):
113
- # # requires opencv-contrib-python installed without the opencv-python
114
- # sr = dnn_superres.DnnSuperResImpl_create()
115
- # image = PIL_to_cv(pil_img)
116
- # model_path = "./LapSRN_x8.pb"
117
- # model_name = model_path.split('/')[1].split('_')[0].lower()
118
- # model_scale = int(model_path.split('/')[1].split('_')[1].split('.')[0][1])
119
-
120
- # sr.readModel(model_path)
121
- # sr.setModel(model_name, model_scale)
122
- # final_img = sr.upsample(image)
123
- # final_img = cv_to_PIL(final_img)
124
-
125
- # return final_img
126
-
127
-
128
- def table_detector(image, THRESHOLD_PROBA):
129
- '''
130
- Table detection using DEtect-object TRansformer pre-trained on 1 million tables
131
-
132
- '''
133
-
134
- feature_extractor = DetrImageProcessor(do_resize=True,
135
- size=800,
136
- max_size=800)
137
- encoding = feature_extractor(image, return_tensors="pt")
138
-
139
- with torch.no_grad():
140
- outputs = table_detection_model(**encoding)
141
-
142
- probas = outputs.logits.softmax(-1)[0, :, :-1]
143
- keep = probas.max(-1).values > THRESHOLD_PROBA
144
-
145
- target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
146
- postprocessed_outputs = feature_extractor.post_process(
147
- outputs, target_sizes)
148
- bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
149
-
150
- return (probas[keep], bboxes_scaled)
151
-
152
-
153
- def table_struct_recog(image, THRESHOLD_PROBA):
154
- '''
155
- Table structure recognition using DEtect-object TRansformer pre-trained on 1 million tables
156
- '''
157
-
158
- feature_extractor = DetrImageProcessor(do_resize=True,
159
- size=1000,
160
- max_size=1000)
161
- encoding = feature_extractor(image, return_tensors="pt")
162
-
163
- with torch.no_grad():
164
- outputs = table_recognition_model(**encoding)
165
-
166
- probas = outputs.logits.softmax(-1)[0, :, :-1]
167
- keep = probas.max(-1).values > THRESHOLD_PROBA
168
-
169
- target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
170
- postprocessed_outputs = feature_extractor.post_process(
171
- outputs, target_sizes)
172
- bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
173
-
174
- return (probas[keep], bboxes_scaled)
175
-
176
-
177
- class TableExtractionPipeline():
178
-
179
- colors = ["red", "blue", "green", "yellow", "orange", "violet"]
180
-
181
- # colors = ["red", "blue", "green", "red", "red", "red"]
182
-
183
- def add_padding(self,
184
- pil_img,
185
- top,
186
- right,
187
- bottom,
188
- left,
189
- color=(255, 255, 255)):
190
- '''
191
- Image padding as part of TSR pre-processing to prevent missing table edges
192
- '''
193
- width, height = pil_img.size
194
- new_width = width + right + left
195
- new_height = height + top + bottom
196
- result = Image.new(pil_img.mode, (new_width, new_height), color)
197
- result.paste(pil_img, (left, top))
198
- return result
199
-
200
- def plot_results_detection(self, c1, model, pil_img, prob, boxes,
201
- delta_xmin, delta_ymin, delta_xmax, delta_ymax):
202
- '''
203
- crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
204
- '''
205
- # st.write('img_obj')
206
- # st.write(pil_img)
207
- plt.imshow(pil_img)
208
- ax = plt.gca()
209
-
210
- for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
211
- cl = p.argmax()
212
- xmin, ymin, xmax, ymax = xmin - delta_xmin, ymin - delta_ymin, xmax + delta_xmax, ymax + delta_ymax
213
- ax.add_patch(
214
- plt.Rectangle((xmin, ymin),
215
- xmax - xmin,
216
- ymax - ymin,
217
- fill=False,
218
- color='red',
219
- linewidth=3))
220
- text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
221
- ax.text(xmin - 20,
222
- ymin - 50,
223
- text,
224
- fontsize=10,
225
- bbox=dict(facecolor='yellow', alpha=0.5))
226
- plt.axis('off')
227
- c1.pyplot()
228
-
229
- def crop_tables(self, pil_img, prob, boxes, delta_xmin, delta_ymin,
230
- delta_xmax, delta_ymax):
231
- '''
232
- crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
233
- '''
234
- cropped_img_list = []
235
-
236
- for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
237
-
238
- xmin, ymin, xmax, ymax = xmin - delta_xmin, ymin - delta_ymin, xmax + delta_xmax, ymax + delta_ymax
239
- cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
240
- cropped_img_list.append(cropped_img)
241
-
242
- return cropped_img_list
243
-
244
- def generate_structure(self, c2, model, pil_img, prob, boxes,
245
- expand_rowcol_bbox_top, expand_rowcol_bbox_bottom):
246
- '''
247
- Co-ordinates are adjusted here by 3 'pixels'
248
- To plot table pillow image and the TSR bounding boxes on the table
249
- '''
250
- # st.write('img_obj')
251
- # st.write(pil_img)
252
- plt.figure(figsize=(32, 20))
253
- plt.imshow(pil_img)
254
- ax = plt.gca()
255
- rows = {}
256
- cols = {}
257
- idx = 0
258
-
259
- for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
260
-
261
- xmin, ymin, xmax, ymax = xmin, ymin, xmax, ymax
262
- cl = p.argmax()
263
- class_text = model.config.id2label[cl.item()]
264
- text = f'{class_text}: {p[cl]:0.2f}'
265
- # or (class_text == 'table column')
266
- if (class_text
267
- == 'table row') or (class_text
268
- == 'table projected row header') or (
269
- class_text == 'table column'):
270
- ax.add_patch(
271
- plt.Rectangle((xmin, ymin),
272
- xmax - xmin,
273
- ymax - ymin,
274
- fill=False,
275
- color=self.colors[cl.item()],
276
- linewidth=2))
277
- ax.text(xmin - 10,
278
- ymin - 10,
279
- text,
280
- fontsize=5,
281
- bbox=dict(facecolor='yellow', alpha=0.5))
282
-
283
- if class_text == 'table row':
284
- rows['table row.' +
285
- str(idx)] = (xmin, ymin - expand_rowcol_bbox_top, xmax,
286
- ymax + expand_rowcol_bbox_bottom)
287
- if class_text == 'table column':
288
- cols['table column.' +
289
- str(idx)] = (xmin, ymin - expand_rowcol_bbox_top, xmax,
290
- ymax + expand_rowcol_bbox_bottom)
291
-
292
- idx += 1
293
-
294
- plt.axis('on')
295
- c2.pyplot()
296
- return rows, cols
297
-
298
- def sort_table_featuresv2(self, rows: dict, cols: dict):
299
- # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox
300
- rows_ = {
301
- table_feature: (xmin, ymin, xmax, ymax)
302
- for table_feature, (
303
- xmin, ymin, xmax,
304
- ymax) in sorted(rows.items(), key=lambda tup: tup[1][1])
305
- }
306
- cols_ = {
307
- table_feature: (xmin, ymin, xmax, ymax)
308
- for table_feature, (
309
- xmin, ymin, xmax,
310
- ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])
311
- }
312
-
313
- return rows_, cols_
314
-
315
- def individual_table_featuresv2(self, pil_img, rows: dict, cols: dict):
316
-
317
- for k, v in rows.items():
318
- xmin, ymin, xmax, ymax = v
319
- cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
320
- rows[k] = xmin, ymin, xmax, ymax, cropped_img
321
-
322
- for k, v in cols.items():
323
- xmin, ymin, xmax, ymax = v
324
- cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
325
- cols[k] = xmin, ymin, xmax, ymax, cropped_img
326
-
327
- return rows, cols
328
-
329
- def object_to_cellsv2(self, master_row: dict, cols: dict,
330
- expand_rowcol_bbox_top, expand_rowcol_bbox_bottom,
331
- padd_left):
332
- '''Removes redundant bbox for rows&columns and divides each row into cells from columns
333
- Args:
334
-
335
- Returns:
336
-
337
-
338
- '''
339
- cells_img = {}
340
- header_idx = 0
341
- row_idx = 0
342
- previous_xmax_col = 0
343
- new_cols = {}
344
- new_master_row = {}
345
- previous_ymin_row = 0
346
- new_cols = cols
347
- new_master_row = master_row
348
- ## Below 2 for loops remove redundant bounding boxes ###
349
- # for k_col, v_col in cols.items():
350
- # xmin_col, _, xmax_col, _, col_img = v_col
351
- # if (np.isclose(previous_xmax_col, xmax_col, atol=5)) or (xmin_col >= xmax_col):
352
- # print('Found a column with double bbox')
353
- # continue
354
- # previous_xmax_col = xmax_col
355
- # new_cols[k_col] = v_col
356
-
357
- # for k_row, v_row in master_row.items():
358
- # _, ymin_row, _, ymax_row, row_img = v_row
359
- # if (np.isclose(previous_ymin_row, ymin_row, atol=5)) or (ymin_row >= ymax_row):
360
- # print('Found a row with double bbox')
361
- # continue
362
- # previous_ymin_row = ymin_row
363
- # new_master_row[k_row] = v_row
364
- ######################################################
365
- for k_row, v_row in new_master_row.items():
366
-
367
- _, _, _, _, row_img = v_row
368
- xmax, ymax = row_img.size
369
- xa, ya, xb, yb = 0, 0, 0, ymax
370
- row_img_list = []
371
- # plt.imshow(row_img)
372
- # st.pyplot()
373
- for idx, kv in enumerate(new_cols.items()):
374
- k_col, v_col = kv
375
- xmin_col, _, xmax_col, _, col_img = v_col
376
- xmin_col, xmax_col = xmin_col - padd_left - 10, xmax_col - padd_left
377
- xa = xmin_col
378
- xb = xmax_col
379
- if idx == 0:
380
- xa = 0
381
- if idx == len(new_cols) - 1:
382
- xb = xmax
383
- xa, ya, xb, yb = xa, ya, xb, yb
384
-
385
- row_img_cropped = row_img.crop((xa, ya, xb, yb))
386
- row_img_list.append(row_img_cropped)
387
-
388
- cells_img[k_row + '.' + str(row_idx)] = row_img_list
389
- row_idx += 1
390
-
391
- return cells_img, len(new_cols), len(new_master_row) - 1
392
-
393
- def clean_dataframe(self, df):
394
- '''
395
- Remove irrelevant symbols that appear with tesseractOCR
396
- '''
397
- # df.columns = [col.replace('|', '') for col in df.columns]
398
-
399
- for col in df.columns:
400
-
401
- df[col] = df[col].str.replace("'", '', regex=True)
402
- df[col] = df[col].str.replace('"', '', regex=True)
403
- df[col] = df[col].str.replace(']', '', regex=True)
404
- df[col] = df[col].str.replace('[', '', regex=True)
405
- df[col] = df[col].str.replace('{', '', regex=True)
406
- df[col] = df[col].str.replace('}', '', regex=True)
407
- return df
408
-
409
- @st.cache
410
- def convert_df(self, df):
411
- return df.to_csv().encode('utf-8')
412
-
413
- def create_dataframe(self, c3, cell_ocr_res: list, max_cols: int,
414
- max_rows: int):
415
- '''Create dataframe using list of cell values of the table, also checks for valid header of dataframe
416
- Args:
417
- cell_ocr_res: list of strings, each element representing a cell in a table
418
- max_cols, max_rows: number of columns and rows
419
- Returns:
420
- dataframe : final dataframe after all pre-processing
421
- '''
422
-
423
- headers = cell_ocr_res[:max_cols]
424
- new_headers = uniquify(headers,
425
- (f' {x!s}' for x in string.ascii_lowercase))
426
- counter = 0
427
-
428
- cells_list = cell_ocr_res[max_cols:]
429
- df = pd.DataFrame("", index=range(0, max_rows), columns=new_headers)
430
-
431
- cell_idx = 0
432
- for nrows in range(max_rows):
433
- for ncols in range(max_cols):
434
- df.iat[nrows, ncols] = str(cells_list[cell_idx])
435
- cell_idx += 1
436
-
437
- ## To check if there are duplicate headers if result of uniquify+col == col
438
- ## This check removes headers when all headers are empty or if median of header word count is less than 6
439
- for x, col in zip(string.ascii_lowercase, new_headers):
440
- if f' {x!s}' == col:
441
- counter += 1
442
- header_char_count = [len(col) for col in new_headers]
443
-
444
- # if (counter == len(new_headers)) or (statistics.median(header_char_count) < 6):
445
- # st.write('woooot')
446
- # df.columns = uniquify(df.iloc[0], (f' {x!s}' for x in string.ascii_lowercase))
447
- # df = df.iloc[1:,:]
448
-
449
- df = self.clean_dataframe(df)
450
-
451
- c3.dataframe(df)
452
- csv = self.convert_df(df)
453
- c3.download_button("Download table",
454
- csv,
455
- "file.csv",
456
- "text/csv",
457
- key='download-csv')
458
-
459
- return df
460
-
461
- async def start_process(self, image_path: str, TD_THRESHOLD, TSR_THRESHOLD,
462
- OCR_THRESHOLD, padd_top, padd_left, padd_bottom,
463
- padd_right, delta_xmin, delta_ymin, delta_xmax,
464
- delta_ymax, expand_rowcol_bbox_top,
465
- expand_rowcol_bbox_bottom):
466
- '''
467
- Initiates process of generating pandas dataframes from raw pdf-page images
468
-
469
- '''
470
- image = Image.open(image_path).convert("RGB")
471
- probas, bboxes_scaled = table_detector(image,
472
- THRESHOLD_PROBA=TD_THRESHOLD)
473
-
474
- if bboxes_scaled.nelement() == 0:
475
- st.write('No table found in the pdf-page image')
476
- return ''
477
-
478
- # try:
479
- # st.write('Document: '+image_path.split('/')[-1])
480
- c1, c2, c3 = st.columns((1, 1, 1))
481
-
482
- self.plot_results_detection(c1, table_detection_model, image, probas,
483
- bboxes_scaled, delta_xmin, delta_ymin,
484
- delta_xmax, delta_ymax)
485
- cropped_img_list = self.crop_tables(image, probas, bboxes_scaled,
486
- delta_xmin, delta_ymin, delta_xmax,
487
- delta_ymax)
488
-
489
- for unpadded_table in cropped_img_list:
490
-
491
- table = self.add_padding(unpadded_table, padd_top, padd_right,
492
- padd_bottom, padd_left)
493
- # table = super_res(table)
494
- # table = binarizeBlur_image(table)
495
- # table = sharpen_image(table) # Test sharpen image next
496
- # table = td_postprocess(table)
497
-
498
- probas, bboxes_scaled = table_struct_recog(
499
- table, THRESHOLD_PROBA=TSR_THRESHOLD)
500
- rows, cols = self.generate_structure(c2, table_recognition_model,
501
- table, probas, bboxes_scaled,
502
- expand_rowcol_bbox_top,
503
- expand_rowcol_bbox_bottom)
504
- # st.write(len(rows), len(cols))
505
- rows, cols = self.sort_table_featuresv2(rows, cols)
506
- master_row, cols = self.individual_table_featuresv2(
507
- table, rows, cols)
508
-
509
- cells_img, max_cols, max_rows = self.object_to_cellsv2(
510
- master_row, cols, expand_rowcol_bbox_top,
511
- expand_rowcol_bbox_bottom, padd_left)
512
-
513
- sequential_cell_img_list = []
514
- for k, img_list in cells_img.items():
515
- for img in img_list:
516
- # img = super_res(img)
517
- # img = sharpen_image(img) # Test sharpen image next
518
- # img = binarizeBlur_image(img)
519
- # img = self.add_padding(img, 10,10,10,10)
520
- # plt.imshow(img)
521
- # c3.pyplot()
522
- sequential_cell_img_list.append(
523
- pytess(cell_pil_img=img, threshold=OCR_THRESHOLD))
524
-
525
- cell_ocr_res = await asyncio.gather(*sequential_cell_img_list)
526
-
527
- self.create_dataframe(c3, cell_ocr_res, max_cols, max_rows)
528
- st.write(
529
- 'Errors in OCR is due to either quality of the image or performance of the OCR'
530
- )
531
- # except:
532
- # st.write('Either incorrectly identified table or no table, to debug remove try/except')
533
- # break
534
- # break
535
-
536
-
537
- if __name__ == "__main__":
538
-
539
- img_name = st.file_uploader("Upload an image with table(s)")
540
- st1, st2, st3 = st.columns((1, 1, 1))
541
- TD_th = st1.slider('Table detection threshold', 0.0, 1.0, 0.8)
542
- TSR_th = st2.slider('Table structure recognition threshold', 0.0, 1.0, 0.8)
543
- OCR_th = st3.slider("Text Probs Threshold", 0.0, 1.0, 0.5)
544
-
545
- st1, st2, st3, st4 = st.columns((1, 1, 1, 1))
546
-
547
- padd_top = st1.slider('Padding top', 0, 200, 40)
548
- padd_left = st2.slider('Padding left', 0, 200, 40)
549
- padd_right = st3.slider('Padding right', 0, 200, 40)
550
- padd_bottom = st4.slider('Padding bottom', 0, 200, 40)
551
-
552
- te = TableExtractionPipeline()
553
- # for img in image_list:
554
- if img_name is not None:
555
- asyncio.run(
556
- te.start_process(img_name,
557
- TD_THRESHOLD=TD_th,
558
- TSR_THRESHOLD=TSR_th,
559
- OCR_THRESHOLD=OCR_th,
560
- padd_top=padd_top,
561
- padd_left=padd_left,
562
- padd_bottom=padd_bottom,
563
- padd_right=padd_right,
564
- delta_xmin=0,
565
- delta_ymin=0,
566
- delta_xmax=0,
567
- delta_ymax=0,
568
- expand_rowcol_bbox_top=0,
569
- expand_rowcol_bbox_bottom=0))
 
2
  import string
3
  from collections import Counter
4
  from itertools import count, tee
 
5
  import cv2
6
  import matplotlib.pyplot as plt
7
  import numpy as np
 
9
  import streamlit as st
10
  import torch
11
  from PIL import Image
12
+ from transformers import (DetrImageProcessor, TableTransformerForObjectDetection)
 
13
  from vietocr.tool.config import Cfg
14
  from vietocr.tool.predictor import Predictor
15
 
16
  st.set_option('deprecation.showPyplotGlobalUse', False)
17
  st.set_page_config(layout='wide')
18
+ st.title("Table Detection and Table Structure Recognition")
19
  st.write(
20
  "Implemented by MSFT team: https://github.com/microsoft/table-transformer")
21
 
22
+ # Config (optional, comment out if not using)
23
  # config = Cfg.load_config_from_name('vgg_transformer')
24
+ # config = Cfg.load_config_from_name('vgg_seq2seq')
25
+ # config['cnn']['pretrained'] = False
26
+ # config['device'] = 'cpu'
27
+ # config['predictor']['beamsearch'] = False
28
+ # detector = Predictor(config)
29
 
30
  table_detection_model = TableTransformerForObjectDetection.from_pretrained(
31
  "microsoft/table-transformer-detection")
 
32
  table_recognition_model = TableTransformerForObjectDetection.from_pretrained(
33
  "microsoft/table-transformer-structure-recognition")
34
 
 
35
  def PIL_to_cv(pil_img):
36
  return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
37
 
 
38
  def cv_to_PIL(cv_img):
39
  return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
40
 
 
41
  async def pytess(cell_pil_img, threshold: float = 0.5):
42
+ text, prob = detector.predict(cell_pil_img, return_prob=True) # Assuming detector is defined
43
  if prob < threshold:
44
  return ""
45
  return text.strip()
46
 
 
47
  def sharpen_image(pil_img):
 
48
  img = PIL_to_cv(pil_img)
49
  sharpen_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
 
50
  sharpen = cv2.filter2D(img, -1, sharpen_kernel)
51
  pil_img = cv_to_PIL(sharpen)
52
  return pil_img
53
 
 
54
  def uniquify(seq, suffs=count(1)):
55
  """Make all the items unique by adding a suffix (1, 2, etc).
56
+
57
  Credit: https://stackoverflow.com/questions/30650474/python-rename-duplicates-in-list-with-progressive-numbers-without-sorting-list
 
 
58
  """
59
  not_unique = [k for k, v in Counter(seq).items() if v > 1]
 
60
  suff_gens = dict(zip(not_unique, tee(suffs, len(not_unique))))
61
  for idx, s in enumerate(seq):
62
  try:
 
65
  continue
66
  else:
67
  seq[idx] += suffix
 
68
  return seq
69
 
 
70
  def binarizeBlur_image(pil_img):
71
  image = PIL_to_cv(pil_img)
72
  thresh = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY_INV)[1]
 
73
  result = cv2.GaussianBlur(thresh, (5, 5), 0)
74
  result = 255 - result
75
  return cv_to_PIL(result)
76
 
 
77
  def td_postprocess(pil_img):
78
  '''
79
  Removes gray background from tables
80
  '''
81
  img = PIL_to_cv(pil_img)
 
82
  hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
83
+ mask = cv2.inRange(hsv, (0, 0, 100), (255, 5, 255)) # (0, 0, 100), (255, 5, 255)
84
+ nzmask = cv2.inRange(hsv, (0, 0, 5), (255, 255, 255)) # (0, 0,