|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
This code is refer from: |
|
https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/match.py |
|
""" |
|
|
|
import os |
|
import re |
|
import cv2 |
|
import glob |
|
import copy |
|
import math |
|
import pickle |
|
import numpy as np |
|
|
|
from shapely.geometry import Polygon, MultiPoint |
|
""" |
|
Useful function in matching. |
|
""" |
|
|
|
|
|
def remove_empty_bboxes(bboxes): |
|
""" |
|
remove [0., 0., 0., 0.] in structure master bboxes. |
|
len(bboxes.shape) must be 2. |
|
:param bboxes: |
|
:return: |
|
""" |
|
new_bboxes = [] |
|
for bbox in bboxes: |
|
if sum(bbox) == 0.: |
|
continue |
|
new_bboxes.append(bbox) |
|
return np.array(new_bboxes) |
|
|
|
|
|
def xywh2xyxy(bboxes): |
|
if len(bboxes.shape) == 1: |
|
new_bboxes = np.empty_like(bboxes) |
|
new_bboxes[0] = bboxes[0] - bboxes[2] / 2 |
|
new_bboxes[1] = bboxes[1] - bboxes[3] / 2 |
|
new_bboxes[2] = bboxes[0] + bboxes[2] / 2 |
|
new_bboxes[3] = bboxes[1] + bboxes[3] / 2 |
|
return new_bboxes |
|
elif len(bboxes.shape) == 2: |
|
new_bboxes = np.empty_like(bboxes) |
|
new_bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2 |
|
new_bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2 |
|
new_bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2 |
|
new_bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2 |
|
return new_bboxes |
|
else: |
|
raise ValueError |
|
|
|
|
|
def xyxy2xywh(bboxes): |
|
if len(bboxes.shape) == 1: |
|
new_bboxes = np.empty_like(bboxes) |
|
new_bboxes[0] = bboxes[0] + (bboxes[2] - bboxes[0]) / 2 |
|
new_bboxes[1] = bboxes[1] + (bboxes[3] - bboxes[1]) / 2 |
|
new_bboxes[2] = bboxes[2] - bboxes[0] |
|
new_bboxes[3] = bboxes[3] - bboxes[1] |
|
return new_bboxes |
|
elif len(bboxes.shape) == 2: |
|
new_bboxes = np.empty_like(bboxes) |
|
new_bboxes[:, 0] = bboxes[:, 0] + (bboxes[:, 2] - bboxes[:, 0]) / 2 |
|
new_bboxes[:, 1] = bboxes[:, 1] + (bboxes[:, 3] - bboxes[:, 1]) / 2 |
|
new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] |
|
new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] |
|
return new_bboxes |
|
else: |
|
raise ValueError |
|
|
|
|
|
def pickle_load(path, prefix='end2end'): |
|
if os.path.isfile(path): |
|
data = pickle.load(open(path, 'rb')) |
|
elif os.path.isdir(path): |
|
data = dict() |
|
search_path = os.path.join(path, '{}_*.pkl'.format(prefix)) |
|
pkls = glob.glob(search_path) |
|
for pkl in pkls: |
|
this_data = pickle.load(open(pkl, 'rb')) |
|
data.update(this_data) |
|
else: |
|
raise ValueError |
|
return data |
|
|
|
|
|
def convert_coord(xyxy): |
|
""" |
|
Convert two points format to four points format. |
|
:param xyxy: |
|
:return: |
|
""" |
|
new_bbox = np.zeros([4, 2], dtype=np.float32) |
|
new_bbox[0, 0], new_bbox[0, 1] = xyxy[0], xyxy[1] |
|
new_bbox[1, 0], new_bbox[1, 1] = xyxy[2], xyxy[1] |
|
new_bbox[2, 0], new_bbox[2, 1] = xyxy[2], xyxy[3] |
|
new_bbox[3, 0], new_bbox[3, 1] = xyxy[0], xyxy[3] |
|
return new_bbox |
|
|
|
|
|
def cal_iou(bbox1, bbox2): |
|
bbox1_poly = Polygon(bbox1).convex_hull |
|
bbox2_poly = Polygon(bbox2).convex_hull |
|
union_poly = np.concatenate((bbox1, bbox2)) |
|
|
|
if not bbox1_poly.intersects(bbox2_poly): |
|
iou = 0 |
|
else: |
|
inter_area = bbox1_poly.intersection(bbox2_poly).area |
|
union_area = MultiPoint(union_poly).convex_hull.area |
|
if union_area == 0: |
|
iou = 0 |
|
else: |
|
iou = float(inter_area) / union_area |
|
return iou |
|
|
|
|
|
def cal_distance(p1, p2): |
|
delta_x = p1[0] - p2[0] |
|
delta_y = p1[1] - p2[1] |
|
d = math.sqrt((delta_x**2) + (delta_y**2)) |
|
return d |
|
|
|
|
|
def is_inside(center_point, corner_point): |
|
""" |
|
Find if center_point inside the bbox(corner_point) or not. |
|
:param center_point: center point (x, y) |
|
:param corner_point: corner point ((x1,y1),(x2,y2)) |
|
:return: |
|
""" |
|
x_flag = False |
|
y_flag = False |
|
if (center_point[0] >= corner_point[0][0]) and ( |
|
center_point[0] <= corner_point[1][0]): |
|
x_flag = True |
|
if (center_point[1] >= corner_point[0][1]) and ( |
|
center_point[1] <= corner_point[1][1]): |
|
y_flag = True |
|
if x_flag and y_flag: |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def find_no_match(match_list, all_end2end_nums, type='end2end'): |
|
""" |
|
Find out no match end2end bbox in previous match list. |
|
:param match_list: matching pairs. |
|
:param all_end2end_nums: numbers of end2end_xywh |
|
:param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1. |
|
:return: no match pse bbox index list |
|
""" |
|
if type == 'end2end': |
|
idx = 0 |
|
elif type == 'master': |
|
idx = 1 |
|
else: |
|
raise ValueError |
|
|
|
no_match_indexs = [] |
|
|
|
matched_bbox_indexs = [m[idx] for m in match_list] |
|
for n in range(all_end2end_nums): |
|
if n not in matched_bbox_indexs: |
|
no_match_indexs.append(n) |
|
return no_match_indexs |
|
|
|
|
|
def is_abs_lower_than_threshold(this_bbox, target_bbox, threshold=3): |
|
|
|
delta = abs(this_bbox[1] - target_bbox[1]) |
|
if delta < threshold: |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def sort_line_bbox(g, bg): |
|
""" |
|
Sorted the bbox in the same line(group) |
|
compare coord 'x' value, where 'y' value is closed in the same group. |
|
:param g: index in the same group |
|
:param bg: bbox in the same group |
|
:return: |
|
""" |
|
|
|
xs = [bg_item[0] for bg_item in bg] |
|
xs_sorted = sorted(xs) |
|
|
|
g_sorted = [None] * len(xs_sorted) |
|
bg_sorted = [None] * len(xs_sorted) |
|
for g_item, bg_item in zip(g, bg): |
|
idx = xs_sorted.index(bg_item[0]) |
|
bg_sorted[idx] = bg_item |
|
g_sorted[idx] = g_item |
|
|
|
return g_sorted, bg_sorted |
|
|
|
|
|
def flatten(sorted_groups, sorted_bbox_groups): |
|
idxs = [] |
|
bboxes = [] |
|
for group, bbox_group in zip(sorted_groups, sorted_bbox_groups): |
|
for g, bg in zip(group, bbox_group): |
|
idxs.append(g) |
|
bboxes.append(bg) |
|
return idxs, bboxes |
|
|
|
|
|
def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes): |
|
""" |
|
This function will group the render end2end bboxes in row. |
|
:param end2end_xywh_bboxes: |
|
:param no_match_end2end_indexes: |
|
:return: |
|
""" |
|
groups = [] |
|
bbox_groups = [] |
|
for index, end2end_xywh_bbox in zip(no_match_end2end_indexes, |
|
end2end_xywh_bboxes): |
|
this_bbox = end2end_xywh_bbox |
|
if len(groups) == 0: |
|
groups.append([index]) |
|
bbox_groups.append([this_bbox]) |
|
else: |
|
flag = False |
|
for g, bg in zip(groups, bbox_groups): |
|
|
|
if is_abs_lower_than_threshold(this_bbox, bg[0]): |
|
g.append(index) |
|
bg.append(this_bbox) |
|
flag = True |
|
break |
|
if not flag: |
|
|
|
groups.append([index]) |
|
bbox_groups.append([this_bbox]) |
|
|
|
|
|
tmp_groups, tmp_bbox_groups = [], [] |
|
for g, bg in zip(groups, bbox_groups): |
|
g_sorted, bg_sorted = sort_line_bbox(g, bg) |
|
tmp_groups.append(g_sorted) |
|
tmp_bbox_groups.append(bg_sorted) |
|
|
|
|
|
sorted_groups = [None] * len(tmp_groups) |
|
sorted_bbox_groups = [None] * len(tmp_bbox_groups) |
|
ys = [bg[0][1] for bg in tmp_bbox_groups] |
|
sorted_ys = sorted(ys) |
|
for g, bg in zip(tmp_groups, tmp_bbox_groups): |
|
idx = sorted_ys.index(bg[0][1]) |
|
sorted_groups[idx] = g |
|
sorted_bbox_groups[idx] = bg |
|
|
|
|
|
end2end_sorted_idx_list, end2end_sorted_bbox_list \ |
|
= flatten(sorted_groups, sorted_bbox_groups) |
|
|
|
return end2end_sorted_idx_list, end2end_sorted_bbox_list, sorted_groups, sorted_bbox_groups |
|
|
|
|
|
def get_bboxes_list(end2end_result, structure_master_result): |
|
""" |
|
This function is use to convert end2end results and structure master results to |
|
List of xyxy bbox format and List of xywh bbox format |
|
:param end2end_result: bbox's format is xyxy |
|
:param structure_master_result: bbox's format is xywh |
|
:return: 4 kind list of bbox () |
|
""" |
|
|
|
end2end_xyxy_list = [] |
|
end2end_xywh_list = [] |
|
for end2end_item in end2end_result: |
|
src_bbox = end2end_item['bbox'] |
|
end2end_xyxy_list.append(src_bbox) |
|
xywh_bbox = xyxy2xywh(src_bbox) |
|
end2end_xywh_list.append(xywh_bbox) |
|
end2end_xyxy_bboxes = np.array(end2end_xyxy_list) |
|
end2end_xywh_bboxes = np.array(end2end_xywh_list) |
|
|
|
|
|
src_bboxes = structure_master_result['bbox'] |
|
src_bboxes = remove_empty_bboxes(src_bboxes) |
|
structure_master_xyxy_bboxes = src_bboxes |
|
xywh_bbox = xyxy2xywh(src_bboxes) |
|
structure_master_xywh_bboxes = xywh_bbox |
|
|
|
return end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes |
|
|
|
|
|
def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes): |
|
""" |
|
Judge end2end Bbox's center point is inside structure master Bbox or not, |
|
if end2end Bbox's center is in structure master Bbox, get matching pair. |
|
:param end2end_xywh_bboxes: |
|
:param structure_master_xyxy_bboxes: |
|
:return: match pairs list, e.g. [[0,1], [1,2], ...] |
|
""" |
|
match_pairs_list = [] |
|
for i, end2end_xywh in enumerate(end2end_xywh_bboxes): |
|
for j, master_xyxy in enumerate(structure_master_xyxy_bboxes): |
|
x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1] |
|
x_master1, y_master1, x_master2, y_master2 \ |
|
= master_xyxy[0], master_xyxy[1], master_xyxy[2], master_xyxy[3] |
|
center_point_end2end = (x_end2end, y_end2end) |
|
corner_point_master = ((x_master1, y_master1), |
|
(x_master2, y_master2)) |
|
if is_inside(center_point_end2end, corner_point_master): |
|
match_pairs_list.append([i, j]) |
|
return match_pairs_list |
|
|
|
|
|
def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes, |
|
structure_master_xyxy_bboxes): |
|
""" |
|
Use iou to find matching list. |
|
choose max iou value bbox as match pair. |
|
:param end2end_xyxy_bboxes: |
|
:param end2end_xyxy_indexes: original end2end indexes. |
|
:param structure_master_xyxy_bboxes: |
|
:return: match pairs list, e.g. [[0,1], [1,2], ...] |
|
""" |
|
match_pair_list = [] |
|
for end2end_xyxy_index, end2end_xyxy in zip(end2end_xyxy_indexes, |
|
end2end_xyxy_bboxes): |
|
max_iou = 0 |
|
max_match = [None, None] |
|
for j, master_xyxy in enumerate(structure_master_xyxy_bboxes): |
|
end2end_4xy = convert_coord(end2end_xyxy) |
|
master_4xy = convert_coord(master_xyxy) |
|
iou = cal_iou(end2end_4xy, master_4xy) |
|
if iou > max_iou: |
|
max_match[0], max_match[1] = end2end_xyxy_index, j |
|
max_iou = iou |
|
|
|
if max_match[0] is None: |
|
|
|
continue |
|
match_pair_list.append(max_match) |
|
return match_pair_list |
|
|
|
|
|
def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes, |
|
master_bboxes): |
|
""" |
|
Get matching between no-match end2end bboxes and no-match master bboxes. |
|
Use min distance to match. |
|
This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0) |
|
It will Return master_bboxes_nums match-pairs. |
|
:param end2end_indexes: |
|
:param end2end_bboxes: |
|
:param master_indexes: |
|
:param master_bboxes: |
|
:return: match_pairs list, e.g. [[0,1], [1,2], ...] |
|
""" |
|
min_match_list = [] |
|
for j, master_bbox in zip(master_indexes, master_bboxes): |
|
min_distance = np.inf |
|
min_match = [0, 0] |
|
for i, end2end_bbox in zip(end2end_indexes, end2end_bboxes): |
|
x_end2end, y_end2end = end2end_bbox[0], end2end_bbox[1] |
|
x_master, y_master = master_bbox[0], master_bbox[1] |
|
end2end_point = (x_end2end, y_end2end) |
|
master_point = (x_master, y_master) |
|
dist = cal_distance(master_point, end2end_point) |
|
if dist < min_distance: |
|
min_match[0], min_match[1] = i, j |
|
min_distance = dist |
|
min_match_list.append(min_match) |
|
return min_match_list |
|
|
|
|
|
def extra_match(no_match_end2end_indexes, master_bbox_nums): |
|
""" |
|
This function will create some virtual master bboxes, |
|
and get match with the no match end2end indexes. |
|
:param no_match_end2end_indexes: |
|
:param master_bbox_nums: |
|
:return: |
|
""" |
|
end_nums = len(no_match_end2end_indexes) + master_bbox_nums |
|
extra_match_list = [] |
|
for i in range(master_bbox_nums, end_nums): |
|
end2end_index = no_match_end2end_indexes[i - master_bbox_nums] |
|
extra_match_list.append([end2end_index, i]) |
|
return extra_match_list |
|
|
|
|
|
def get_match_dict(match_list): |
|
""" |
|
Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index. |
|
:param match_list: |
|
:return: |
|
""" |
|
match_dict = dict() |
|
for match_pair in match_list: |
|
end2end_index, master_index = match_pair[0], match_pair[1] |
|
if master_index not in match_dict.keys(): |
|
match_dict[master_index] = [end2end_index] |
|
else: |
|
match_dict[master_index].append(end2end_index) |
|
return match_dict |
|
|
|
|
|
def deal_successive_space(text): |
|
""" |
|
deal successive space character for text |
|
1. Replace ' '*3 with '<space>' which is real space is text |
|
2. Remove ' ', which is split token, not true space |
|
3. Replace '<space>' with ' ', to get real text |
|
:param text: |
|
:return: |
|
""" |
|
text = text.replace(' ' * 3, '<space>') |
|
text = text.replace(' ', '') |
|
text = text.replace('<space>', ' ') |
|
return text |
|
|
|
|
|
def reduce_repeat_bb(text_list, break_token): |
|
""" |
|
convert ['<b>Local</b>', '<b>government</b>', '<b>unit</b>'] to ['<b>Local government unit</b>'] |
|
PS: maybe style <i>Local</i> is also exist, too. it can be processed like this. |
|
:param text_list: |
|
:param break_token: |
|
:return: |
|
""" |
|
count = 0 |
|
for text in text_list: |
|
if text.startswith('<b>'): |
|
count += 1 |
|
if count == len(text_list): |
|
new_text_list = [] |
|
for text in text_list: |
|
text = text.replace('<b>', '').replace('</b>', '') |
|
new_text_list.append(text) |
|
return ['<b>' + break_token.join(new_text_list) + '</b>'] |
|
else: |
|
return text_list |
|
|
|
|
|
def get_match_text_dict(match_dict, end2end_info, break_token=' '): |
|
match_text_dict = dict() |
|
for master_index, end2end_index_list in match_dict.items(): |
|
text_list = [ |
|
end2end_info[end2end_index]['text'] |
|
for end2end_index in end2end_index_list |
|
] |
|
text_list = reduce_repeat_bb(text_list, break_token) |
|
text = break_token.join(text_list) |
|
match_text_dict[master_index] = text |
|
return match_text_dict |
|
|
|
|
|
def merge_span_token(master_token_list): |
|
""" |
|
Merge the span style token (row span or col span). |
|
:param master_token_list: |
|
:return: |
|
""" |
|
new_master_token_list = [] |
|
pointer = 0 |
|
if master_token_list[-1] != '</tbody>': |
|
master_token_list.append('</tbody>') |
|
while master_token_list[pointer] != '</tbody>': |
|
try: |
|
if master_token_list[pointer] == '<td': |
|
if master_token_list[pointer + 1].startswith( |
|
' colspan=') or master_token_list[ |
|
pointer + 1].startswith(' rowspan='): |
|
""" |
|
example: |
|
pattern <td colspan="3"> |
|
'<td' + 'colspan=" "' + '>' + '</td>' |
|
""" |
|
tmp = ''.join(master_token_list[pointer:pointer + 3 + 1]) |
|
pointer += 4 |
|
new_master_token_list.append(tmp) |
|
|
|
elif master_token_list[pointer + 2].startswith( |
|
' colspan=') or master_token_list[ |
|
pointer + 2].startswith(' rowspan='): |
|
""" |
|
example: |
|
pattern <td rowspan="2" colspan="3"> |
|
'<td' + 'rowspan=" "' + 'colspan=" "' + '>' + '</td>' |
|
""" |
|
tmp = ''.join(master_token_list[pointer:pointer + 4 + 1]) |
|
pointer += 5 |
|
new_master_token_list.append(tmp) |
|
|
|
else: |
|
new_master_token_list.append(master_token_list[pointer]) |
|
pointer += 1 |
|
else: |
|
new_master_token_list.append(master_token_list[pointer]) |
|
pointer += 1 |
|
except: |
|
print("Break in merge...") |
|
break |
|
new_master_token_list.append('</tbody>') |
|
|
|
return new_master_token_list |
|
|
|
|
|
def deal_eb_token(master_token): |
|
""" |
|
post process with <eb></eb>, <eb1></eb1>, ... |
|
emptyBboxTokenDict = { |
|
"[]": '<eb></eb>', |
|
"[' ']": '<eb1></eb1>', |
|
"['<b>', ' ', '</b>']": '<eb2></eb2>', |
|
"['\\u2028', '\\u2028']": '<eb3></eb3>', |
|
"['<sup>', ' ', '</sup>']": '<eb4></eb4>', |
|
"['<b>', '</b>']": '<eb5></eb5>', |
|
"['<i>', ' ', '</i>']": '<eb6></eb6>', |
|
"['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>', |
|
"['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>', |
|
"['<i>', '</i>']": '<eb9></eb9>', |
|
"['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>', |
|
} |
|
:param master_token: |
|
:return: |
|
""" |
|
master_token = master_token.replace('<eb></eb>', '<td></td>') |
|
master_token = master_token.replace('<eb1></eb1>', '<td> </td>') |
|
master_token = master_token.replace('<eb2></eb2>', '<td><b> </b></td>') |
|
master_token = master_token.replace('<eb3></eb3>', '<td>\u2028\u2028</td>') |
|
master_token = master_token.replace('<eb4></eb4>', '<td><sup> </sup></td>') |
|
master_token = master_token.replace('<eb5></eb5>', '<td><b></b></td>') |
|
master_token = master_token.replace('<eb6></eb6>', '<td><i> </i></td>') |
|
master_token = master_token.replace('<eb7></eb7>', |
|
'<td><b><i></i></b></td>') |
|
master_token = master_token.replace('<eb8></eb8>', |
|
'<td><b><i> </i></b></td>') |
|
master_token = master_token.replace('<eb9></eb9>', '<td><i></i></td>') |
|
master_token = master_token.replace('<eb10></eb10>', |
|
'<td><b> \u2028 \u2028 </b></td>') |
|
return master_token |
|
|
|
|
|
def insert_text_to_token(master_token_list, match_text_dict): |
|
""" |
|
Insert OCR text result to structure token. |
|
:param master_token_list: |
|
:param match_text_dict: |
|
:return: |
|
""" |
|
master_token_list = merge_span_token(master_token_list) |
|
merged_result_list = [] |
|
text_count = 0 |
|
for master_token in master_token_list: |
|
if master_token.startswith('<td'): |
|
if text_count > len(match_text_dict) - 1: |
|
text_count += 1 |
|
continue |
|
elif text_count not in match_text_dict.keys(): |
|
text_count += 1 |
|
continue |
|
else: |
|
master_token = master_token.replace( |
|
'><', '>{}<'.format(match_text_dict[text_count])) |
|
text_count += 1 |
|
master_token = deal_eb_token(master_token) |
|
merged_result_list.append(master_token) |
|
|
|
return ''.join(merged_result_list) |
|
|
|
|
|
def deal_isolate_span(thead_part): |
|
""" |
|
Deal with isolate span cases in this function. |
|
It causes by wrong prediction in structure recognition model. |
|
eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>. |
|
:param thead_part: |
|
:return: |
|
""" |
|
|
|
isolate_pattern = "<td></td> rowspan=\"(\d)+\" colspan=\"(\d)+\"></b></td>|" \ |
|
"<td></td> colspan=\"(\d)+\" rowspan=\"(\d)+\"></b></td>|" \ |
|
"<td></td> rowspan=\"(\d)+\"></b></td>|" \ |
|
"<td></td> colspan=\"(\d)+\"></b></td>" |
|
isolate_iter = re.finditer(isolate_pattern, thead_part) |
|
isolate_list = [i.group() for i in isolate_iter] |
|
|
|
|
|
span_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\"|" \ |
|
" colspan=\"(\d)+\" rowspan=\"(\d)+\"|" \ |
|
" rowspan=\"(\d)+\"|" \ |
|
" colspan=\"(\d)+\"" |
|
corrected_list = [] |
|
for isolate_item in isolate_list: |
|
span_part = re.search(span_pattern, isolate_item) |
|
spanStr_in_isolateItem = span_part.group() |
|
|
|
if spanStr_in_isolateItem is not None: |
|
corrected_item = '<td{}></td>'.format(spanStr_in_isolateItem) |
|
corrected_list.append(corrected_item) |
|
else: |
|
corrected_list.append(None) |
|
|
|
|
|
for corrected_item, isolate_item in zip(corrected_list, isolate_list): |
|
if corrected_item is not None: |
|
thead_part = thead_part.replace(isolate_item, corrected_item) |
|
else: |
|
pass |
|
return thead_part |
|
|
|
|
|
def deal_duplicate_bb(thead_part): |
|
""" |
|
Deal duplicate <b> or </b> after replace. |
|
Keep one <b></b> in a <td></td> token. |
|
:param thead_part: |
|
:return: |
|
""" |
|
|
|
td_pattern = "<td rowspan=\"(\d)+\" colspan=\"(\d)+\">(.+?)</td>|" \ |
|
"<td colspan=\"(\d)+\" rowspan=\"(\d)+\">(.+?)</td>|" \ |
|
"<td rowspan=\"(\d)+\">(.+?)</td>|" \ |
|
"<td colspan=\"(\d)+\">(.+?)</td>|" \ |
|
"<td>(.*?)</td>" |
|
td_iter = re.finditer(td_pattern, thead_part) |
|
td_list = [t.group() for t in td_iter] |
|
|
|
|
|
new_td_list = [] |
|
for td_item in td_list: |
|
if td_item.count('<b>') > 1 or td_item.count('</b>') > 1: |
|
|
|
|
|
td_item = td_item.replace('<b>', '').replace('</b>', '') |
|
|
|
td_item = td_item.replace('<td>', '<td><b>').replace('</td>', |
|
'</b></td>') |
|
new_td_list.append(td_item) |
|
else: |
|
new_td_list.append(td_item) |
|
|
|
|
|
for td_item, new_td_item in zip(td_list, new_td_list): |
|
thead_part = thead_part.replace(td_item, new_td_item) |
|
return thead_part |
|
|
|
|
|
def deal_bb(result_token): |
|
""" |
|
In our opinion, <b></b> always occurs in <thead></thead> text's context. |
|
This function will find out all tokens in <thead></thead> and insert <b></b> by manual. |
|
:param result_token: |
|
:return: |
|
""" |
|
|
|
thead_pattern = '<thead>(.*?)</thead>' |
|
if re.search(thead_pattern, result_token) is None: |
|
return result_token |
|
thead_part = re.search(thead_pattern, result_token).group() |
|
origin_thead_part = copy.deepcopy(thead_part) |
|
|
|
|
|
span_pattern = "<td rowspan=\"(\d)+\" colspan=\"(\d)+\">|<td colspan=\"(\d)+\" rowspan=\"(\d)+\">|<td rowspan=\"(\d)+\">|<td colspan=\"(\d)+\">" |
|
span_iter = re.finditer(span_pattern, thead_part) |
|
span_list = [s.group() for s in span_iter] |
|
has_span_in_head = True if len(span_list) > 0 else False |
|
|
|
if not has_span_in_head: |
|
|
|
|
|
|
|
|
|
thead_part = thead_part.replace('<td>', '<td><b>')\ |
|
.replace('</td>', '</b></td>')\ |
|
.replace('<b><b>', '<b>')\ |
|
.replace('</b></b>', '</b>') |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
replaced_span_list = [] |
|
for sp in span_list: |
|
replaced_span_list.append(sp.replace('>', '><b>')) |
|
for sp, rsp in zip(span_list, replaced_span_list): |
|
thead_part = thead_part.replace(sp, rsp) |
|
|
|
|
|
thead_part = thead_part.replace('</td>', '</b></td>') |
|
|
|
|
|
mb_pattern = "(<b>)+" |
|
single_b_string = "<b>" |
|
thead_part = re.sub(mb_pattern, single_b_string, thead_part) |
|
|
|
mgb_pattern = "(</b>)+" |
|
single_gb_string = "</b>" |
|
thead_part = re.sub(mgb_pattern, single_gb_string, thead_part) |
|
|
|
|
|
thead_part = thead_part.replace('<td>', '<td><b>').replace('<b><b>', |
|
'<b>') |
|
|
|
|
|
|
|
thead_part = thead_part.replace('<td><b></b></td>', '<td></td>') |
|
|
|
thead_part = deal_duplicate_bb(thead_part) |
|
|
|
|
|
thead_part = deal_isolate_span(thead_part) |
|
|
|
result_token = result_token.replace(origin_thead_part, thead_part) |
|
return result_token |
|
|
|
|
|
class Matcher: |
|
def __init__(self, end2end_file, structure_master_file): |
|
""" |
|
This class process the end2end results and structure recognition results. |
|
:param end2end_file: end2end results predict by end2end inference. |
|
:param structure_master_file: structure recognition results predict by structure master inference. |
|
""" |
|
self.end2end_file = end2end_file |
|
self.structure_master_file = structure_master_file |
|
self.end2end_results = pickle_load(end2end_file, prefix='end2end') |
|
self.structure_master_results = pickle_load( |
|
structure_master_file, prefix='structure') |
|
|
|
def match(self): |
|
""" |
|
Match process: |
|
pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format. |
|
1. Use pseBbox is inside masterBbox judge rule |
|
2. Use iou between pseBbox and masterBbox rule |
|
3. Use min distance of center point rule |
|
:return: |
|
""" |
|
match_results = dict() |
|
for idx, (file_name, |
|
end2end_result) in enumerate(self.end2end_results.items()): |
|
match_list = [] |
|
if file_name not in self.structure_master_results: |
|
continue |
|
structure_master_result = self.structure_master_results[file_name] |
|
end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes = \ |
|
get_bboxes_list(end2end_result, structure_master_result) |
|
|
|
|
|
center_rule_match_list = \ |
|
center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes) |
|
match_list.extend(center_rule_match_list) |
|
|
|
|
|
|
|
center_no_match_end2end_indexs = \ |
|
find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end') |
|
if len(center_no_match_end2end_indexs) > 0: |
|
center_no_match_end2end_xyxy = end2end_xyxy_bboxes[ |
|
center_no_match_end2end_indexs] |
|
|
|
iou_rule_match_list = \ |
|
iou_rule_match(center_no_match_end2end_xyxy, center_no_match_end2end_indexs, structure_master_xyxy_bboxes) |
|
match_list.extend(iou_rule_match_list) |
|
|
|
|
|
|
|
|
|
|
|
centerIou_no_match_end2end_indexs = \ |
|
find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end') |
|
centerIou_no_match_master_indexs = \ |
|
find_no_match(match_list, len(structure_master_xywh_bboxes), type='master') |
|
if len(centerIou_no_match_master_indexs) > 0 and len( |
|
centerIou_no_match_end2end_indexs) > 0: |
|
centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[ |
|
centerIou_no_match_end2end_indexs] |
|
centerIou_no_match_master_xywh = structure_master_xywh_bboxes[ |
|
centerIou_no_match_master_indexs] |
|
distance_match_list = distance_rule_match( |
|
centerIou_no_match_end2end_indexs, |
|
centerIou_no_match_end2end_xywh, |
|
centerIou_no_match_master_indexs, |
|
centerIou_no_match_master_xywh) |
|
match_list.extend(distance_match_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
no_match_end2end_indexes = \ |
|
find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end') |
|
if len(no_match_end2end_indexes) > 0: |
|
no_match_end2end_xywh = end2end_xywh_bboxes[ |
|
no_match_end2end_indexes] |
|
|
|
end2end_sorted_indexes_list, end2end_sorted_bboxes_list, sorted_groups, sorted_bboxes_groups = \ |
|
sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes) |
|
|
|
extra_match_list = extra_match( |
|
end2end_sorted_indexes_list, |
|
len(structure_master_xywh_bboxes)) |
|
match_list_add_extra_match = copy.deepcopy(match_list) |
|
match_list_add_extra_match.extend(extra_match_list) |
|
else: |
|
|
|
match_list_add_extra_match = copy.deepcopy(match_list) |
|
sorted_groups = [] |
|
sorted_bboxes_groups = [] |
|
|
|
match_result_dict = { |
|
'match_list': match_list, |
|
'match_list_add_extra_match': match_list_add_extra_match, |
|
'sorted_groups': sorted_groups, |
|
'sorted_bboxes_groups': sorted_bboxes_groups |
|
} |
|
|
|
|
|
match_result_dict = self._format(match_result_dict, file_name) |
|
|
|
match_results[file_name] = match_result_dict |
|
|
|
return match_results |
|
|
|
def _format(self, match_result, file_name): |
|
""" |
|
Extend the master token(insert virtual master token), and format matching result. |
|
:param match_result: |
|
:param file_name: |
|
:return: |
|
""" |
|
end2end_info = self.end2end_results[file_name] |
|
master_info = self.structure_master_results[file_name] |
|
master_token = master_info['text'] |
|
sorted_groups = match_result['sorted_groups'] |
|
|
|
|
|
virtual_master_token_list = [] |
|
for line_group in sorted_groups: |
|
tmp_list = ['<tr>'] |
|
item_nums = len(line_group) |
|
for _ in range(item_nums): |
|
tmp_list.append('<td></td>') |
|
tmp_list.append('</tr>') |
|
virtual_master_token_list.extend(tmp_list) |
|
|
|
|
|
master_token_list = master_token.split(',') |
|
if master_token_list[-1] == '</tbody>': |
|
|
|
|
|
|
|
|
|
|
|
master_token_list[:-1].extend(virtual_master_token_list) |
|
|
|
|
|
|
|
|
|
|
|
|
|
elif master_token_list[-1] == '<td></td>': |
|
master_token_list.append('</tr>') |
|
master_token_list.extend(virtual_master_token_list) |
|
master_token_list.append('</tbody>') |
|
else: |
|
master_token_list.extend(virtual_master_token_list) |
|
master_token_list.append('</tbody>') |
|
|
|
|
|
match_result.setdefault('matched_master_token_list', master_token_list) |
|
return match_result |
|
|
|
def get_merge_result(self, match_results): |
|
""" |
|
Merge the OCR result into structure token to get final results. |
|
:param match_results: |
|
:return: |
|
""" |
|
merged_results = dict() |
|
|
|
|
|
break_token = ' ' |
|
|
|
for idx, (file_name, match_info) in enumerate(match_results.items()): |
|
end2end_info = self.end2end_results[file_name] |
|
master_token_list = match_info['matched_master_token_list'] |
|
match_list = match_info['match_list_add_extra_match'] |
|
|
|
match_dict = get_match_dict(match_list) |
|
match_text_dict = get_match_text_dict(match_dict, end2end_info, |
|
break_token) |
|
merged_result = insert_text_to_token(master_token_list, |
|
match_text_dict) |
|
merged_result = deal_bb(merged_result) |
|
|
|
merged_results[file_name] = merged_result |
|
|
|
return merged_results |
|
|
|
|
|
class TableMasterMatcher(Matcher): |
|
def __init__(self): |
|
pass |
|
|
|
def __call__(self, structure_res, dt_boxes, rec_res, img_name=1): |
|
end2end_results = {img_name: []} |
|
for dt_box, res in zip(dt_boxes, rec_res): |
|
d = dict( |
|
bbox=np.array(dt_box), |
|
text=res[0], ) |
|
end2end_results[img_name].append(d) |
|
|
|
self.end2end_results = end2end_results |
|
|
|
structure_master_result_dict = {img_name: {}} |
|
pred_structures, pred_bboxes = structure_res |
|
pred_structures = ','.join(pred_structures[3:-3]) |
|
structure_master_result_dict[img_name]['text'] = pred_structures |
|
structure_master_result_dict[img_name]['bbox'] = pred_bboxes |
|
self.structure_master_results = structure_master_result_dict |
|
|
|
|
|
match_results = self.match() |
|
merged_results = self.get_merge_result(match_results) |
|
pred_html = merged_results[img_name] |
|
pred_html = '<html><body><table>' + pred_html + '</table></body></html>' |
|
return pred_html |
|
|