|
from marker.schema.bbox import rescale_bbox, box_intersection_pct |
|
from marker.schema.page import Page |
|
from sklearn.cluster import DBSCAN |
|
import numpy as np |
|
|
|
from marker.settings import settings |
|
|
|
|
|
def cluster_coords(coords): |
|
if len(coords) == 0: |
|
return [] |
|
coords = np.array(sorted(set(coords))).reshape(-1, 1) |
|
|
|
clustering = DBSCAN(eps=5, min_samples=1).fit(coords) |
|
clusters = clustering.labels_ |
|
|
|
separators = [] |
|
for label in set(clusters): |
|
clustered_points = coords[clusters == label] |
|
separators.append(np.mean(clustered_points)) |
|
|
|
separators = sorted(separators) |
|
return separators |
|
|
|
|
|
def find_column_separators(page: Page, table_box, round_factor=4, min_count=1): |
|
left_edges = [] |
|
right_edges = [] |
|
centers = [] |
|
|
|
line_boxes = [p.bbox for p in page.text_lines.bboxes] |
|
line_boxes = [rescale_bbox(page.text_lines.image_bbox, page.bbox, l) for l in line_boxes] |
|
line_boxes = [l for l in line_boxes if box_intersection_pct(l, table_box) > settings.BBOX_INTERSECTION_THRESH] |
|
|
|
for cell in line_boxes: |
|
left_edges.append(cell[0] / round_factor * round_factor) |
|
right_edges.append(cell[2] / round_factor * round_factor) |
|
centers.append((cell[0] + cell[2]) / 2 * round_factor / round_factor) |
|
|
|
left_edges = [l for l in left_edges if left_edges.count(l) > min_count] |
|
right_edges = [r for r in right_edges if right_edges.count(r) > min_count] |
|
centers = [c for c in centers if centers.count(c) > min_count] |
|
|
|
sorted_left = cluster_coords(left_edges) |
|
sorted_right = cluster_coords(right_edges) |
|
sorted_center = cluster_coords(centers) |
|
|
|
|
|
separators = max([sorted_left, sorted_right, sorted_center], key=len) |
|
separators.append(page.bbox[2]) |
|
separators.insert(0, page.bbox[0]) |
|
return separators |
|
|
|
|
|
def assign_cells_to_columns(page, table_box, rows, round_factor=4, tolerance=4): |
|
separators = find_column_separators(page, table_box, round_factor=round_factor) |
|
new_rows = [] |
|
additional_column_index = 0 |
|
for row in rows: |
|
new_row = {} |
|
last_col_index = -1 |
|
for cell in row: |
|
left_edge = cell[0][0] |
|
column_index = -1 |
|
for i, separator in enumerate(separators): |
|
if left_edge - tolerance < separator and last_col_index < i: |
|
column_index = i |
|
break |
|
if column_index == -1: |
|
column_index = len(separators) + additional_column_index |
|
additional_column_index += 1 |
|
new_row[column_index] = cell[1] |
|
last_col_index = column_index |
|
additional_column_index = 0 |
|
|
|
flat_row = [] |
|
for cell_idx, cell in enumerate(sorted(new_row.items())): |
|
flat_row.append(cell[1]) |
|
new_rows.append(flat_row) |
|
|
|
|
|
max_row_len = max([len(r) for r in new_rows]) |
|
for row in new_rows: |
|
while len(row) < max_row_len: |
|
row.append("") |
|
|
|
cols_to_remove = set() |
|
for idx, col in enumerate(zip(*new_rows)): |
|
col_total = sum([len(cell.strip()) > 0 for cell in col]) |
|
if col_total == 0: |
|
cols_to_remove.add(idx) |
|
|
|
rows = [] |
|
for row in new_rows: |
|
rows.append([col for idx, col in enumerate(row) if idx not in cols_to_remove]) |
|
|
|
return rows |
|
|