Spaces:
Running
Running
from collections import defaultdict | |
from concurrent.futures import ProcessPoolExecutor | |
from typing import List, Optional | |
from PIL import Image | |
import numpy as np | |
from surya.detection import batch_detection | |
from surya.postprocessing.heatmap import keep_largest_boxes, get_and_clean_boxes, get_detected_boxes | |
from surya.schema import LayoutResult, LayoutBox, TextDetectionResult | |
from surya.settings import settings | |
def get_regions_from_detection_result(detection_result: TextDetectionResult, heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment, vertical_line_width=20) -> List[LayoutBox]: | |
logits = np.stack(heatmaps, axis=0) | |
vertical_line_bboxes = detection_result.vertical_lines | |
line_bboxes = detection_result.bboxes | |
# Scale back to processor size | |
for line in vertical_line_bboxes: | |
line.rescale_bbox(orig_size, list(reversed(heatmaps[0].shape))) | |
for line in line_bboxes: | |
line.rescale(orig_size, list(reversed(heatmaps[0].shape))) | |
for bbox in vertical_line_bboxes: | |
# Give some width to the vertical lines | |
vert_bbox = list(bbox.bbox) | |
vert_bbox[2] = min(heatmaps[0].shape[0], vert_bbox[2] + vertical_line_width) | |
logits[:, vert_bbox[1]:vert_bbox[3], vert_bbox[0]:vert_bbox[2]] = 0 # zero out where the column lines are | |
logits[:, logits[0] >= .5] = 0 # zero out where blanks are | |
# Zero out where other segments are | |
for i in range(logits.shape[0]): | |
logits[i, segment_assignment != i] = 0 | |
detected_boxes = [] | |
for heatmap_idx in range(1, len(id2label)): # Skip the blank class | |
heatmap = logits[heatmap_idx] | |
if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD: | |
continue | |
bboxes = get_detected_boxes(heatmap) | |
bboxes = [bbox for bbox in bboxes if bbox.area > 25] | |
for bb in bboxes: | |
bb.fit_to_bounds([0, 0, heatmap.shape[1] - 1, heatmap.shape[0] - 1]) | |
for bbox in bboxes: | |
detected_boxes.append(LayoutBox(polygon=bbox.polygon, label=id2label[heatmap_idx], confidence=1)) | |
detected_boxes = sorted(detected_boxes, key=lambda x: x.confidence, reverse=True) | |
# Expand bbox to cover intersecting lines | |
box_lines = defaultdict(list) | |
used_lines = set() | |
# We try 2 rounds of identifying the correct lines to snap to | |
# First round is majority intersection, second lowers the threshold | |
for thresh in [.5, .4]: | |
for bbox_idx, bbox in enumerate(detected_boxes): | |
for line_idx, line_bbox in enumerate(line_bboxes): | |
if line_bbox.intersection_pct(bbox) > thresh and line_idx not in used_lines: | |
box_lines[bbox_idx].append(line_bbox.bbox) | |
used_lines.add(line_idx) | |
new_boxes = [] | |
for bbox_idx, bbox in enumerate(detected_boxes): | |
if bbox.label == "Picture" and bbox.area < 200: # Remove very small figures | |
continue | |
# Skip if we didn't find any lines to snap to, except for Pictures and Formulas | |
if bbox_idx not in box_lines and bbox.label not in ["Picture", "Formula"]: | |
continue | |
covered_lines = box_lines[bbox_idx] | |
# Snap non-picture layout boxes to correct text boundaries | |
if len(covered_lines) > 0 and bbox.label not in ["Picture"]: | |
min_x = min([line[0] for line in covered_lines]) | |
min_y = min([line[1] for line in covered_lines]) | |
max_x = max([line[2] for line in covered_lines]) | |
max_y = max([line[3] for line in covered_lines]) | |
# Tables and formulas can contain text, but text isn't the whole area | |
if bbox.label in ["Table", "Formula"]: | |
min_x_box = min([b[0] for b in bbox.polygon]) | |
min_y_box = min([b[1] for b in bbox.polygon]) | |
max_x_box = max([b[0] for b in bbox.polygon]) | |
max_y_box = max([b[1] for b in bbox.polygon]) | |
min_x = min(min_x, min_x_box) | |
min_y = min(min_y, min_y_box) | |
max_x = max(max_x, max_x_box) | |
max_y = max(max_y, max_y_box) | |
bbox.polygon = [ | |
[min_x, min_y], | |
[max_x, min_y], | |
[max_x, max_y], | |
[min_x, max_y] | |
] | |
if bbox_idx in box_lines and bbox.label in ["Picture"]: | |
bbox.label = "Figure" | |
new_boxes.append(bbox) | |
# Merge tables together (sometimes one column is detected as a separate table) | |
mergeable_types = ["Table", "Picture", "Figure"] | |
for ftype in mergeable_types: | |
to_remove = set() | |
for bbox_idx, bbox in enumerate(new_boxes): | |
if bbox.label != ftype or bbox_idx in to_remove: | |
continue | |
for bbox_idx2, bbox2 in enumerate(new_boxes): | |
if bbox2.label != ftype or bbox_idx2 in to_remove or bbox_idx == bbox_idx2: | |
continue | |
if bbox.intersection_pct(bbox2, x_margin=.25) > .1: | |
bbox.merge(bbox2) | |
to_remove.add(bbox_idx2) | |
new_boxes = [bbox for idx, bbox in enumerate(new_boxes) if idx not in to_remove] | |
# Ensure we account for all text lines in the layout | |
unused_lines = [line for idx, line in enumerate(line_bboxes) if idx not in used_lines] | |
for bbox in unused_lines: | |
new_boxes.append(LayoutBox(polygon=bbox.polygon, label="Text", confidence=.5)) | |
for bbox in new_boxes: | |
bbox.rescale(list(reversed(heatmaps[0].shape)), orig_size) | |
detected_boxes = [bbox for bbox in new_boxes if bbox.area > 16] | |
# Remove bboxes contained inside others, unless they're captions | |
contained_bbox = [] | |
for i, bbox in enumerate(detected_boxes): | |
for j, bbox2 in enumerate(detected_boxes): | |
if i == j: | |
continue | |
if bbox2.intersection_pct(bbox) >= .95 and bbox2.label not in ["Caption"]: | |
contained_bbox.append(j) | |
detected_boxes = [bbox for idx, bbox in enumerate(detected_boxes) if idx not in contained_bbox] | |
return detected_boxes | |
def get_regions(heatmaps: List[np.ndarray], orig_size, id2label, segment_assignment) -> List[LayoutBox]: | |
bboxes = [] | |
for i in range(1, len(id2label)): # Skip the blank class | |
heatmap = heatmaps[i] | |
assert heatmap.shape == segment_assignment.shape | |
heatmap[segment_assignment != i] = 0 # zero out where another segment is | |
# Skip processing empty labels | |
if np.max(heatmap) < settings.DETECTOR_BLANK_THRESHOLD: | |
continue | |
bbox = get_and_clean_boxes(heatmap, list(reversed(heatmap.shape)), orig_size) | |
for bb in bbox: | |
bboxes.append(LayoutBox(polygon=bb.polygon, label=id2label[i])) | |
bboxes = keep_largest_boxes(bboxes) | |
return bboxes | |
def parallel_get_regions(heatmaps: List[np.ndarray], orig_size, id2label, detection_results=None) -> LayoutResult: | |
logits = np.stack(heatmaps, axis=0) | |
segment_assignment = logits.argmax(axis=0) | |
if detection_results is not None: | |
bboxes = get_regions_from_detection_result(detection_results, heatmaps, orig_size, id2label, | |
segment_assignment) | |
else: | |
bboxes = get_regions(heatmaps, orig_size, id2label, segment_assignment) | |
segmentation_img = Image.fromarray(segment_assignment.astype(np.uint8)) | |
result = LayoutResult( | |
bboxes=bboxes, | |
segmentation_map=segmentation_img, | |
heatmaps=heatmaps, | |
image_bbox=[0, 0, orig_size[0], orig_size[1]] | |
) | |
return result | |
def batch_layout_detection(images: List, model, processor, detection_results: Optional[List[TextDetectionResult]] = None, batch_size=None) -> List[LayoutResult]: | |
layout_generator = batch_detection(images, model, processor, batch_size=batch_size) | |
id2label = model.config.id2label | |
results = [] | |
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) | |
parallelize = not settings.IN_STREAMLIT and len(images) >= settings.DETECTOR_MIN_PARALLEL_THRESH | |
if parallelize: | |
with ProcessPoolExecutor(max_workers=max_workers) as executor: | |
img_idx = 0 | |
for preds, orig_sizes in layout_generator: | |
futures = [] | |
for pred, orig_size in zip(preds, orig_sizes): | |
future = executor.submit( | |
parallel_get_regions, | |
pred, | |
orig_size, | |
id2label, | |
detection_results[img_idx] if detection_results else None | |
) | |
futures.append(future) | |
img_idx += 1 | |
for future in futures: | |
results.append(future.result()) | |
else: | |
img_idx = 0 | |
for preds, orig_sizes in layout_generator: | |
for pred, orig_size in zip(preds, orig_sizes): | |
results.append(parallel_get_regions( | |
pred, | |
orig_size, | |
id2label, | |
detection_results[img_idx] if detection_results else None | |
)) | |
img_idx += 1 | |
return results |