dungnt7 commited on
Commit
cebb0c1
1 Parent(s): 234d38f

[VietOCR] Add VietOCR

Browse files
Files changed (2) hide show
  1. app.py +559 -5
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,9 +1,563 @@
1
- import gradio as gr
 
 
 
 
2
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def greet(name):
5
- return "Hello " + name + "!!"
 
 
 
6
 
 
 
 
 
 
7
 
8
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
9
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import string
4
+ from collections import Counter
5
+ from itertools import count, tee
6
 
7
+ import cv2
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import pandas as pd
11
+ import streamlit as st
12
+ import torch
13
+ from PIL import Image
14
+ from transformers import (DetrImageProcessor,
15
+ TableTransformerForObjectDetection)
16
+ from vietocr.tool.config import Cfg
17
+ from vietocr.tool.predictor import Predictor
18
 
19
+ st.set_option('deprecation.showPyplotGlobalUse', False)
20
+ st.set_page_config(layout='wide')
21
+ st.title("Table Detection and Table Structure Recognition")
22
+ st.write(
23
+ "Implemented by MSFT team: https://github.com/microsoft/table-transformer")
24
 
25
+ config = Cfg.load_config_from_name('vgg_transformer')
26
+ config['cnn']['pretrained'] = False
27
+ config['device'] = 'cpu'
28
+ config['predictor']['beamsearch'] = True
29
+ detector = Predictor(config)
30
 
31
+ table_detection_model = TableTransformerForObjectDetection.from_pretrained(
32
+ "microsoft/table-transformer-detection")
33
+
34
+ table_recognition_model = TableTransformerForObjectDetection.from_pretrained(
35
+ "microsoft/table-transformer-structure-recognition")
36
+
37
+
38
+ def PIL_to_cv(pil_img):
39
+ return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
40
+
41
+
42
+ def cv_to_PIL(cv_img):
43
+ return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
44
+
45
+
46
+ async def pytess(cell_pil_img):
47
+ text = detector.predict(cell_pil_img)
48
+ return text.strip()
49
+
50
+
51
+ def sharpen_image(pil_img):
52
+
53
+ img = PIL_to_cv(pil_img)
54
+ sharpen_kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
55
+
56
+ sharpen = cv2.filter2D(img, -1, sharpen_kernel)
57
+ pil_img = cv_to_PIL(sharpen)
58
+ return pil_img
59
+
60
+
61
+ def uniquify(seq, suffs=count(1)):
62
+ """Make all the items unique by adding a suffix (1, 2, etc).
63
+ Credit: https://stackoverflow.com/questions/30650474/python-rename-duplicates-in-list-with-progressive-numbers-without-sorting-list
64
+ `seq` is mutable sequence of strings.
65
+ `suffs` is an optional alternative suffix iterable.
66
+ """
67
+ not_unique = [k for k, v in Counter(seq).items() if v > 1]
68
+
69
+ suff_gens = dict(zip(not_unique, tee(suffs, len(not_unique))))
70
+ for idx, s in enumerate(seq):
71
+ try:
72
+ suffix = str(next(suff_gens[s]))
73
+ except KeyError:
74
+ continue
75
+ else:
76
+ seq[idx] += suffix
77
+
78
+ return seq
79
+
80
+
81
+ def binarizeBlur_image(pil_img):
82
+ image = PIL_to_cv(pil_img)
83
+ thresh = cv2.threshold(image, 150, 255, cv2.THRESH_BINARY_INV)[1]
84
+
85
+ result = cv2.GaussianBlur(thresh, (5, 5), 0)
86
+ result = 255 - result
87
+ return cv_to_PIL(result)
88
+
89
+
90
+ def td_postprocess(pil_img):
91
+ '''
92
+ Removes gray background from tables
93
+ '''
94
+ img = PIL_to_cv(pil_img)
95
+
96
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
97
+ mask = cv2.inRange(hsv, (0, 0, 100),
98
+ (255, 5, 255)) # (0, 0, 100), (255, 5, 255)
99
+ nzmask = cv2.inRange(hsv, (0, 0, 5),
100
+ (255, 255, 255)) # (0, 0, 5), (255, 255, 255))
101
+ nzmask = cv2.erode(nzmask, np.ones((3, 3))) # (3,3)
102
+ mask = mask & nzmask
103
+
104
+ new_img = img.copy()
105
+ new_img[np.where(mask)] = 255
106
+
107
+ return cv_to_PIL(new_img)
108
+
109
+
110
+ # def super_res(pil_img):
111
+ # # requires opencv-contrib-python installed without the opencv-python
112
+ # sr = dnn_superres.DnnSuperResImpl_create()
113
+ # image = PIL_to_cv(pil_img)
114
+ # model_path = "./LapSRN_x8.pb"
115
+ # model_name = model_path.split('/')[1].split('_')[0].lower()
116
+ # model_scale = int(model_path.split('/')[1].split('_')[1].split('.')[0][1])
117
+
118
+ # sr.readModel(model_path)
119
+ # sr.setModel(model_name, model_scale)
120
+ # final_img = sr.upsample(image)
121
+ # final_img = cv_to_PIL(final_img)
122
+
123
+ # return final_img
124
+
125
+
126
+ def table_detector(image, THRESHOLD_PROBA):
127
+ '''
128
+ Table detection using DEtect-object TRansformer pre-trained on 1 million tables
129
+
130
+ '''
131
+
132
+ feature_extractor = DetrImageProcessor(do_resize=True,
133
+ size=800,
134
+ max_size=800)
135
+ encoding = feature_extractor(image, return_tensors="pt")
136
+
137
+ with torch.no_grad():
138
+ outputs = table_detection_model(**encoding)
139
+
140
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
141
+ keep = probas.max(-1).values > THRESHOLD_PROBA
142
+
143
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
144
+ postprocessed_outputs = feature_extractor.post_process(
145
+ outputs, target_sizes)
146
+ bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
147
+
148
+ return (probas[keep], bboxes_scaled)
149
+
150
+
151
+ def table_struct_recog(image, THRESHOLD_PROBA):
152
+ '''
153
+ Table structure recognition using DEtect-object TRansformer pre-trained on 1 million tables
154
+ '''
155
+
156
+ feature_extractor = DetrImageProcessor(do_resize=True,
157
+ size=1000,
158
+ max_size=1000)
159
+ encoding = feature_extractor(image, return_tensors="pt")
160
+
161
+ with torch.no_grad():
162
+ outputs = table_recognition_model(**encoding)
163
+
164
+ probas = outputs.logits.softmax(-1)[0, :, :-1]
165
+ keep = probas.max(-1).values > THRESHOLD_PROBA
166
+
167
+ target_sizes = torch.tensor(image.size[::-1]).unsqueeze(0)
168
+ postprocessed_outputs = feature_extractor.post_process(
169
+ outputs, target_sizes)
170
+ bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]
171
+
172
+ return (probas[keep], bboxes_scaled)
173
+
174
+
175
+ class TableExtractionPipeline():
176
+
177
+ colors = ["red", "blue", "green", "yellow", "orange", "violet"]
178
+
179
+ # colors = ["red", "blue", "green", "red", "red", "red"]
180
+
181
+ def add_padding(self,
182
+ pil_img,
183
+ top,
184
+ right,
185
+ bottom,
186
+ left,
187
+ color=(255, 255, 255)):
188
+ '''
189
+ Image padding as part of TSR pre-processing to prevent missing table edges
190
+ '''
191
+ width, height = pil_img.size
192
+ new_width = width + right + left
193
+ new_height = height + top + bottom
194
+ result = Image.new(pil_img.mode, (new_width, new_height), color)
195
+ result.paste(pil_img, (left, top))
196
+ return result
197
+
198
+ def plot_results_detection(self, c1, model, pil_img, prob, boxes,
199
+ delta_xmin, delta_ymin, delta_xmax, delta_ymax):
200
+ '''
201
+ crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
202
+ '''
203
+ # st.write('img_obj')
204
+ # st.write(pil_img)
205
+ plt.imshow(pil_img)
206
+ ax = plt.gca()
207
+
208
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
209
+ cl = p.argmax()
210
+ xmin, ymin, xmax, ymax = xmin - delta_xmin, ymin - delta_ymin, xmax + delta_xmax, ymax + delta_ymax
211
+ ax.add_patch(
212
+ plt.Rectangle((xmin, ymin),
213
+ xmax - xmin,
214
+ ymax - ymin,
215
+ fill=False,
216
+ color='red',
217
+ linewidth=3))
218
+ text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
219
+ ax.text(xmin - 20,
220
+ ymin - 50,
221
+ text,
222
+ fontsize=10,
223
+ bbox=dict(facecolor='yellow', alpha=0.5))
224
+ plt.axis('off')
225
+ c1.pyplot()
226
+
227
+ def crop_tables(self, pil_img, prob, boxes, delta_xmin, delta_ymin,
228
+ delta_xmax, delta_ymax):
229
+ '''
230
+ crop_tables and plot_results_detection must have same co-ord shifts because 1 only plots the other one updates co-ordinates
231
+ '''
232
+ cropped_img_list = []
233
+
234
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
235
+
236
+ xmin, ymin, xmax, ymax = xmin - delta_xmin, ymin - delta_ymin, xmax + delta_xmax, ymax + delta_ymax
237
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
238
+ cropped_img_list.append(cropped_img)
239
+
240
+ return cropped_img_list
241
+
242
+ def generate_structure(self, c2, model, pil_img, prob, boxes,
243
+ expand_rowcol_bbox_top, expand_rowcol_bbox_bottom):
244
+ '''
245
+ Co-ordinates are adjusted here by 3 'pixels'
246
+ To plot table pillow image and the TSR bounding boxes on the table
247
+ '''
248
+ # st.write('img_obj')
249
+ # st.write(pil_img)
250
+ plt.figure(figsize=(32, 20))
251
+ plt.imshow(pil_img)
252
+ ax = plt.gca()
253
+ rows = {}
254
+ cols = {}
255
+ idx = 0
256
+
257
+ for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes.tolist()):
258
+
259
+ xmin, ymin, xmax, ymax = xmin, ymin, xmax, ymax
260
+ cl = p.argmax()
261
+ class_text = model.config.id2label[cl.item()]
262
+ text = f'{class_text}: {p[cl]:0.2f}'
263
+ # or (class_text == 'table column')
264
+ if (class_text
265
+ == 'table row') or (class_text
266
+ == 'table projected row header') or (
267
+ class_text == 'table column'):
268
+ ax.add_patch(
269
+ plt.Rectangle((xmin, ymin),
270
+ xmax - xmin,
271
+ ymax - ymin,
272
+ fill=False,
273
+ color=self.colors[cl.item()],
274
+ linewidth=2))
275
+ ax.text(xmin - 10,
276
+ ymin - 10,
277
+ text,
278
+ fontsize=5,
279
+ bbox=dict(facecolor='yellow', alpha=0.5))
280
+
281
+ if class_text == 'table row':
282
+ rows['table row.' +
283
+ str(idx)] = (xmin, ymin - expand_rowcol_bbox_top, xmax,
284
+ ymax + expand_rowcol_bbox_bottom)
285
+ if class_text == 'table column':
286
+ cols['table column.' +
287
+ str(idx)] = (xmin, ymin - expand_rowcol_bbox_top, xmax,
288
+ ymax + expand_rowcol_bbox_bottom)
289
+
290
+ idx += 1
291
+
292
+ plt.axis('on')
293
+ c2.pyplot()
294
+ return rows, cols
295
+
296
+ def sort_table_featuresv2(self, rows: dict, cols: dict):
297
+ # Sometimes the header and first row overlap, and we need the header bbox not to have first row's bbox inside the headers bbox
298
+ rows_ = {
299
+ table_feature: (xmin, ymin, xmax, ymax)
300
+ for table_feature, (
301
+ xmin, ymin, xmax,
302
+ ymax) in sorted(rows.items(), key=lambda tup: tup[1][1])
303
+ }
304
+ cols_ = {
305
+ table_feature: (xmin, ymin, xmax, ymax)
306
+ for table_feature, (
307
+ xmin, ymin, xmax,
308
+ ymax) in sorted(cols.items(), key=lambda tup: tup[1][0])
309
+ }
310
+
311
+ return rows_, cols_
312
+
313
+ def individual_table_featuresv2(self, pil_img, rows: dict, cols: dict):
314
+
315
+ for k, v in rows.items():
316
+ xmin, ymin, xmax, ymax = v
317
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
318
+ rows[k] = xmin, ymin, xmax, ymax, cropped_img
319
+
320
+ for k, v in cols.items():
321
+ xmin, ymin, xmax, ymax = v
322
+ cropped_img = pil_img.crop((xmin, ymin, xmax, ymax))
323
+ cols[k] = xmin, ymin, xmax, ymax, cropped_img
324
+
325
+ return rows, cols
326
+
327
+ def object_to_cellsv2(self, master_row: dict, cols: dict,
328
+ expand_rowcol_bbox_top, expand_rowcol_bbox_bottom,
329
+ padd_left):
330
+ '''Removes redundant bbox for rows&columns and divides each row into cells from columns
331
+ Args:
332
+
333
+ Returns:
334
+
335
+
336
+ '''
337
+ cells_img = {}
338
+ header_idx = 0
339
+ row_idx = 0
340
+ previous_xmax_col = 0
341
+ new_cols = {}
342
+ new_master_row = {}
343
+ previous_ymin_row = 0
344
+ new_cols = cols
345
+ new_master_row = master_row
346
+ ## Below 2 for loops remove redundant bounding boxes ###
347
+ # for k_col, v_col in cols.items():
348
+ # xmin_col, _, xmax_col, _, col_img = v_col
349
+ # if (np.isclose(previous_xmax_col, xmax_col, atol=5)) or (xmin_col >= xmax_col):
350
+ # print('Found a column with double bbox')
351
+ # continue
352
+ # previous_xmax_col = xmax_col
353
+ # new_cols[k_col] = v_col
354
+
355
+ # for k_row, v_row in master_row.items():
356
+ # _, ymin_row, _, ymax_row, row_img = v_row
357
+ # if (np.isclose(previous_ymin_row, ymin_row, atol=5)) or (ymin_row >= ymax_row):
358
+ # print('Found a row with double bbox')
359
+ # continue
360
+ # previous_ymin_row = ymin_row
361
+ # new_master_row[k_row] = v_row
362
+ ######################################################
363
+ for k_row, v_row in new_master_row.items():
364
+
365
+ _, _, _, _, row_img = v_row
366
+ xmax, ymax = row_img.size
367
+ xa, ya, xb, yb = 0, 0, 0, ymax
368
+ row_img_list = []
369
+ # plt.imshow(row_img)
370
+ # st.pyplot()
371
+ for idx, kv in enumerate(new_cols.items()):
372
+ k_col, v_col = kv
373
+ xmin_col, _, xmax_col, _, col_img = v_col
374
+ xmin_col, xmax_col = xmin_col - padd_left - 10, xmax_col - padd_left
375
+ xa = xmin_col
376
+ xb = xmax_col
377
+ if idx == 0:
378
+ xa = 0
379
+ if idx == len(new_cols) - 1:
380
+ xb = xmax
381
+ xa, ya, xb, yb = xa, ya, xb, yb
382
+
383
+ row_img_cropped = row_img.crop((xa, ya, xb, yb))
384
+ row_img_list.append(row_img_cropped)
385
+
386
+ cells_img[k_row + '.' + str(row_idx)] = row_img_list
387
+ row_idx += 1
388
+
389
+ return cells_img, len(new_cols), len(new_master_row) - 1
390
+
391
+ def clean_dataframe(self, df):
392
+ '''
393
+ Remove irrelevant symbols that appear with tesseractOCR
394
+ '''
395
+ # df.columns = [col.replace('|', '') for col in df.columns]
396
+
397
+ for col in df.columns:
398
+
399
+ df[col] = df[col].str.replace("'", '', regex=True)
400
+ df[col] = df[col].str.replace('"', '', regex=True)
401
+ df[col] = df[col].str.replace(']', '', regex=True)
402
+ df[col] = df[col].str.replace('[', '', regex=True)
403
+ df[col] = df[col].str.replace('{', '', regex=True)
404
+ df[col] = df[col].str.replace('}', '', regex=True)
405
+ return df
406
+
407
+ # @st.cache
408
+ def convert_df(self, df):
409
+ return df.to_csv().encode('utf-8')
410
+
411
+ def create_dataframe(self, c3, cell_ocr_res: list, max_cols: int,
412
+ max_rows: int):
413
+ '''Create dataframe using list of cell values of the table, also checks for valid header of dataframe
414
+ Args:
415
+ cell_ocr_res: list of strings, each element representing a cell in a table
416
+ max_cols, max_rows: number of columns and rows
417
+ Returns:
418
+ dataframe : final dataframe after all pre-processing
419
+ '''
420
+
421
+ headers = cell_ocr_res[:max_cols]
422
+ new_headers = uniquify(headers,
423
+ (f' {x!s}' for x in string.ascii_lowercase))
424
+ counter = 0
425
+
426
+ cells_list = cell_ocr_res[max_cols:]
427
+ df = pd.DataFrame("", index=range(0, max_rows), columns=new_headers)
428
+
429
+ cell_idx = 0
430
+ for nrows in range(max_rows):
431
+ for ncols in range(max_cols):
432
+ df.iat[nrows, ncols] = str(cells_list[cell_idx])
433
+ cell_idx += 1
434
+
435
+ ## To check if there are duplicate headers if result of uniquify+col == col
436
+ ## This check removes headers when all headers are empty or if median of header word count is less than 6
437
+ for x, col in zip(string.ascii_lowercase, new_headers):
438
+ if f' {x!s}' == col:
439
+ counter += 1
440
+ header_char_count = [len(col) for col in new_headers]
441
+
442
+ # if (counter == len(new_headers)) or (statistics.median(header_char_count) < 6):
443
+ # st.write('woooot')
444
+ # df.columns = uniquify(df.iloc[0], (f' {x!s}' for x in string.ascii_lowercase))
445
+ # df = df.iloc[1:,:]
446
+
447
+ df = self.clean_dataframe(df)
448
+
449
+ c3.dataframe(df)
450
+ csv = self.convert_df(df)
451
+ c3.download_button("Download table",
452
+ csv,
453
+ "file.csv",
454
+ "text/csv",
455
+ key='download-csv')
456
+
457
+ return df
458
+
459
+ async def start_process(self, image_path: str, TD_THRESHOLD, TSR_THRESHOLD,
460
+ padd_top, padd_left, padd_bottom, padd_right,
461
+ delta_xmin, delta_ymin, delta_xmax, delta_ymax,
462
+ expand_rowcol_bbox_top, expand_rowcol_bbox_bottom):
463
+ '''
464
+ Initiates process of generating pandas dataframes from raw pdf-page images
465
+
466
+ '''
467
+ image = Image.open(image_path).convert("RGB")
468
+ probas, bboxes_scaled = table_detector(image,
469
+ THRESHOLD_PROBA=TD_THRESHOLD)
470
+
471
+ if bboxes_scaled.nelement() == 0:
472
+ st.write('No table found in the pdf-page image')
473
+ return ''
474
+
475
+ # try:
476
+ # st.write('Document: '+image_path.split('/')[-1])
477
+ c1, c2, c3 = st.columns((1, 1, 1))
478
+
479
+ self.plot_results_detection(c1, table_detection_model, image, probas,
480
+ bboxes_scaled, delta_xmin, delta_ymin,
481
+ delta_xmax, delta_ymax)
482
+ cropped_img_list = self.crop_tables(image, probas, bboxes_scaled,
483
+ delta_xmin, delta_ymin, delta_xmax,
484
+ delta_ymax)
485
+
486
+ for unpadded_table in cropped_img_list:
487
+
488
+ table = self.add_padding(unpadded_table, padd_top, padd_right,
489
+ padd_bottom, padd_left)
490
+ # table = super_res(table)
491
+ # table = binarizeBlur_image(table)
492
+ # table = sharpen_image(table) # Test sharpen image next
493
+ # table = td_postprocess(table)
494
+
495
+ probas, bboxes_scaled = table_struct_recog(
496
+ table, THRESHOLD_PROBA=TSR_THRESHOLD)
497
+ rows, cols = self.generate_structure(c2, table_recognition_model,
498
+ table, probas, bboxes_scaled,
499
+ expand_rowcol_bbox_top,
500
+ expand_rowcol_bbox_bottom)
501
+ # st.write(len(rows), len(cols))
502
+ rows, cols = self.sort_table_featuresv2(rows, cols)
503
+ master_row, cols = self.individual_table_featuresv2(
504
+ table, rows, cols)
505
+
506
+ cells_img, max_cols, max_rows = self.object_to_cellsv2(
507
+ master_row, cols, expand_rowcol_bbox_top,
508
+ expand_rowcol_bbox_bottom, padd_left)
509
+
510
+ sequential_cell_img_list = []
511
+ for k, img_list in cells_img.items():
512
+ for img in img_list:
513
+ # img = super_res(img)
514
+ # img = sharpen_image(img) # Test sharpen image next
515
+ # img = binarizeBlur_image(img)
516
+ # img = self.add_padding(img, 10,10,10,10)
517
+ # plt.imshow(img)
518
+ # c3.pyplot()
519
+ sequential_cell_img_list.append(pytess(img))
520
+
521
+ cell_ocr_res = await asyncio.gather(*sequential_cell_img_list)
522
+
523
+ self.create_dataframe(c3, cell_ocr_res, max_cols, max_rows)
524
+ st.write(
525
+ 'Errors in OCR is due to either quality of the image or performance of the OCR'
526
+ )
527
+ # except:
528
+ # st.write('Either incorrectly identified table or no table, to debug remove try/except')
529
+ # break
530
+ # break
531
+
532
+
533
+ if __name__ == "__main__":
534
+
535
+ img_name = st.file_uploader("Upload an image with table(s)")
536
+ st1, st2 = st.columns((1, 1))
537
+ TD_th = st1.slider('Table detection threshold', 0.0, 1.0, 0.8)
538
+ TSR_th = st2.slider('Table structure recognition threshold', 0.0, 1.0, 0.8)
539
+
540
+ st1, st2, st3, st4 = st.columns((1, 1, 1, 1))
541
+
542
+ padd_top = st1.slider('Padding top', 0, 200, 40)
543
+ padd_left = st2.slider('Padding left', 0, 200, 40)
544
+ padd_right = st3.slider('Padding right', 0, 200, 40)
545
+ padd_bottom = st4.slider('Padding bottom', 0, 200, 40)
546
+
547
+ te = TableExtractionPipeline()
548
+ # for img in image_list:
549
+ if img_name is not None:
550
+ asyncio.run(
551
+ te.start_process(img_name,
552
+ TD_THRESHOLD=TD_th,
553
+ TSR_THRESHOLD=TSR_th,
554
+ padd_top=padd_top,
555
+ padd_left=padd_left,
556
+ padd_bottom=padd_bottom,
557
+ padd_right=padd_right,
558
+ delta_xmin=0,
559
+ delta_ymin=0,
560
+ delta_xmax=0,
561
+ delta_ymax=0,
562
+ expand_rowcol_bbox_top=0,
563
+ expand_rowcol_bbox_bottom=0))
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  dynaconf==3.1.12
2
- pytorch-lightning==2.0.0
 
 
1
  dynaconf==3.1.12
2
+ pytorch-lightning==2.0.0
3
+ vietocr==0.36