File size: 7,884 Bytes
c8a32e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
from collections import defaultdict
from copy import deepcopy
from typing import List
from marker.debug.data import dump_equation_debug_data
from marker.equations.inference import get_total_texify_tokens, get_latex_batched
from marker.pdf.images import render_bbox_image
from marker.schema.bbox import rescale_bbox
from marker.schema.page import Page
from marker.schema.block import Line, Span, Block, bbox_from_lines, split_block_lines, find_insert_block
from marker.settings import settings
def find_equation_blocks(page, processor):
equation_blocks = []
equation_regions = [l.bbox for l in page.layout.bboxes if l.label in ["Formula"]]
equation_regions = [rescale_bbox(page.layout.image_bbox, page.bbox, b) for b in equation_regions]
lines_to_remove = defaultdict(list)
insert_points = {}
equation_lines = defaultdict(list)
for region_idx, region in enumerate(equation_regions):
for block_idx, block in enumerate(page.blocks):
for line_idx, line in enumerate(block.lines):
if line.intersection_pct(region) > settings.BBOX_INTERSECTION_THRESH:
# We will remove this line from the block
lines_to_remove[region_idx].append((block_idx, line_idx))
equation_lines[region_idx].append(line)
if region_idx not in insert_points:
insert_points[region_idx] = (block_idx, line_idx)
# Account for regions where the lines were not detected
for region_idx, region in enumerate(equation_regions):
if region_idx in insert_points:
continue
insert_points[region_idx] = (find_insert_block(page.blocks, region), 0)
block_lines_to_remove = defaultdict(set)
for region_idx, equation_region in enumerate(equation_regions):
if region_idx not in equation_lines or len(equation_lines[region_idx]) == 0:
block_text = ""
total_tokens = 0
else:
equation_block = equation_lines[region_idx]
block_text = " ".join([line.prelim_text for line in equation_block])
total_tokens = get_total_texify_tokens(block_text, processor)
equation_insert = insert_points[region_idx]
equation_insert_line_idx = equation_insert[1]
equation_insert_line_idx -= len(
[x for x in lines_to_remove[region_idx] if x[0] == equation_insert[0] and x[1] < equation_insert[1]])
selected_blocks = [equation_insert[0], equation_insert_line_idx, total_tokens, block_text, equation_region]
if total_tokens < settings.TEXIFY_MODEL_MAX:
# Account for the lines we're about to remove
for item in lines_to_remove[region_idx]:
block_lines_to_remove[item[0]].add(item[1])
equation_blocks.append(selected_blocks)
# Remove the lines from the blocks
for block_idx, bad_lines in block_lines_to_remove.items():
block = page.blocks[block_idx]
block.lines = [line for idx, line in enumerate(block.lines) if idx not in bad_lines]
return equation_blocks
def increment_insert_points(page_equation_blocks, insert_block_idx, insert_count):
for idx, (block_idx, line_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks):
if block_idx >= insert_block_idx:
page_equation_blocks[idx][0] += insert_count
def insert_latex_block(page_blocks: Page, page_equation_blocks, predictions, pnum, processor):
converted_spans = []
idx = 0
success_count = 0
fail_count = 0
for block_number, (insert_block_idx, insert_line_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks):
latex_text = predictions[block_number]
conditions = [
get_total_texify_tokens(latex_text, processor) < settings.TEXIFY_MODEL_MAX, # Make sure we didn't get to the overall token max, indicates run-on
len(latex_text) > len(block_text) * .7,
len(latex_text.strip()) > 0
]
new_block = Block(
lines=[Line(
spans=[
Span(
text=block_text.replace("\n", " "),
bbox=equation_bbox,
span_id=f"{pnum}_{idx}_fixeq",
font="Latex",
font_weight=0,
font_size=0
)
],
bbox=equation_bbox
)],
bbox=equation_bbox,
block_type="Formula",
pnum=pnum
)
if not all(conditions):
fail_count += 1
else:
success_count += 1
new_block.lines[0].spans[0].text = latex_text.replace("\n", " ")
converted_spans.append(deepcopy(new_block.lines[0].spans[0]))
# Add in the new LaTeX block
if insert_line_idx == 0:
page_blocks.blocks.insert(insert_block_idx, new_block)
increment_insert_points(page_equation_blocks, insert_block_idx, 1)
elif insert_line_idx >= len(page_blocks.blocks[insert_block_idx].lines):
page_blocks.blocks.insert(insert_block_idx + 1, new_block)
increment_insert_points(page_equation_blocks, insert_block_idx + 1, 1)
else:
new_blocks = []
for block_idx, block in enumerate(page_blocks.blocks):
if block_idx == insert_block_idx:
split_block = split_block_lines(block, insert_line_idx)
new_blocks.append(split_block[0])
new_blocks.append(new_block)
new_blocks.append(split_block[1])
increment_insert_points(page_equation_blocks, insert_block_idx, 2)
else:
new_blocks.append(block)
page_blocks.blocks = new_blocks
return success_count, fail_count, converted_spans
def replace_equations(doc, pages: List[Page], texify_model, batch_multiplier=1):
unsuccessful_ocr = 0
successful_ocr = 0
# Find potential equation regions, and length of text in each region
equation_blocks = []
for pnum, page in enumerate(pages):
equation_blocks.append(find_equation_blocks(page, texify_model.processor))
eq_count = sum([len(x) for x in equation_blocks])
images = []
token_counts = []
for page_idx, page_equation_blocks in enumerate(equation_blocks):
page_obj = doc[page_idx]
for equation_idx, (insert_block_idx, insert_line_idx, token_count, block_text, equation_bbox) in enumerate(page_equation_blocks):
png_image = render_bbox_image(page_obj, pages[page_idx], equation_bbox)
images.append(png_image)
token_counts.append(token_count)
# Make batched predictions
predictions = get_latex_batched(images, token_counts, texify_model, batch_multiplier=batch_multiplier)
# Replace blocks with predictions
page_start = 0
converted_spans = []
for page_idx, page_equation_blocks in enumerate(equation_blocks):
page_equation_count = len(page_equation_blocks)
page_predictions = predictions[page_start:page_start + page_equation_count]
success_count, fail_count, converted_span = insert_latex_block(
pages[page_idx],
page_equation_blocks,
page_predictions,
page_idx,
texify_model.processor
)
converted_spans.extend(converted_span)
page_start += page_equation_count
successful_ocr += success_count
unsuccessful_ocr += fail_count
# If debug mode is on, dump out conversions for comparison
dump_equation_debug_data(doc, images, converted_spans)
return pages, {"successful_ocr": successful_ocr, "unsuccessful_ocr": unsuccessful_ocr, "equations": eq_count}
|