|
|
|
"""OCR template.""" |
|
|
|
from __future__ import annotations |
|
|
|
import logging |
|
import os |
|
from enum import Enum |
|
from pathlib import Path |
|
from queue import SimpleQueue |
|
from typing import Any, Final, Iterable, Optional, TypeAlias |
|
|
|
import cv2 as cv2 |
|
import numpy as np |
|
import numpy.typing as npt |
|
import pandas as pd |
|
import pdf2image |
|
import rerun as rr |
|
import rerun.blueprint as rrb |
|
from paddleocr import PPStructure |
|
from paddleocr.ppstructure.recovery.recovery_to_doc import sorted_layout_boxes |
|
|
|
EXAMPLE_DIR: Final = Path(os.path.dirname(__file__)) |
|
DATASET_DIR: Final = EXAMPLE_DIR / "dataset" |
|
|
|
SAMPLE_IMAGE_URLs = ["https://storage.googleapis.com/rerun-example-datasets/ocr/paper.png"] |
|
|
|
PAGE_LIMIT = 10 |
|
|
|
LayoutStructure: TypeAlias = tuple[ |
|
list[str], list[str], list[rrb.Spatial2DView], list[rrb.Spatial2DView], list[rrb.Spatial2DView] |
|
] |
|
|
|
|
|
|
|
|
|
class Color: |
|
Red = (255, 0, 0) |
|
Green = (0, 255, 0) |
|
Blue = (0, 0, 255) |
|
Yellow = (255, 255, 0) |
|
Cyan = (0, 255, 255) |
|
Magenta = (255, 0, 255) |
|
Purple = (128, 0, 128) |
|
Orange = (255, 165, 0) |
|
|
|
|
|
""" |
|
LayoutType: |
|
Defines an enumeration for different types of document layout elements, each associated with a unique number, name, |
|
and color. Types: |
|
- UNKNOWN: Default type for undefined or unrecognized elements, represented by purple. |
|
- TITLE: Represents the title of a document, represented by red. |
|
- TEXT: Represents plain text content within the document, represented by green. |
|
- FIGURE: Represents graphical or image content, represented by blue. |
|
- FIGURE_CAPTION: Represents captions for figures, represented by yellow. |
|
- TABLE: Represents tabular data, represented by cyan. |
|
- TABLE_CAPTION: Represents captions for tables, represented by magenta. |
|
- REFERENCE: Represents citation references within the document, also represented by purple. |
|
- Footer: Represents footer of the document, represented as orange. |
|
""" |
|
|
|
|
|
class LayoutType(Enum): |
|
UNKNOWN = (0, "unknown", Color.Purple) |
|
TITLE = (1, "title", Color.Red) |
|
TEXT = (2, "text", Color.Green) |
|
FIGURE = (3, "figure", Color.Blue) |
|
FIGURE_CAPTION = (4, "figure_caption", Color.Yellow) |
|
TABLE = (5, "table", Color.Cyan) |
|
TABLE_CAPTION = (6, "table_caption", Color.Magenta) |
|
REFERENCE = (7, "reference", Color.Purple) |
|
FOOTER = (8, "footer", Color.Orange) |
|
|
|
def __str__(self) -> str: |
|
return str(self.value[1]) |
|
|
|
@property |
|
def number(self) -> int: |
|
return self.value[0] |
|
|
|
@property |
|
def type(self) -> str: |
|
return self.value[1] |
|
|
|
@property |
|
def color(self) -> tuple[int, int, int]: |
|
return self.value[2] |
|
|
|
@staticmethod |
|
def get_class_id(text: str) -> int: |
|
try: |
|
return LayoutType[text.upper()].number |
|
except KeyError: |
|
logging.warning(f"Invalid layout type {text}") |
|
return 0 |
|
|
|
@staticmethod |
|
def get_type(text: str) -> LayoutType: |
|
try: |
|
return LayoutType[text.upper()] |
|
except KeyError: |
|
logging.warning(f"Invalid layout type {text}") |
|
return LayoutType.UNKNOWN |
|
|
|
@classmethod |
|
def get_annotation(cls) -> list[tuple[int, str, tuple[int, int, int]]]: |
|
return [(layout.number, layout.type, layout.color) for layout in cls] |
|
|
|
|
|
""" |
|
Layout Class: |
|
The main purpose of this class is to: |
|
1. Keep track of the layout types (including type, numbering) |
|
2. Save the detections for each layout (text, img or table) |
|
3. Save the bounding box of each detected layout |
|
4. Generate the recovery text document |
|
""" |
|
|
|
|
|
class Layout: |
|
def __init__(self, show_unknown: bool = False): |
|
self.counts = {layout_type: 0 for layout_type in LayoutType} |
|
self.records: dict[LayoutType, Any] = {layout_type: [] for layout_type in LayoutType} |
|
self.recovery = """""" |
|
self.show_unknown = show_unknown |
|
|
|
def add( |
|
self, |
|
layout_type: LayoutType, |
|
bounding_box: list[int], |
|
detections: Optional[Iterable[dict[str, Any]]] = None, |
|
table: Optional[str] = None, |
|
figure: Optional[dict[str, Any]] = None, |
|
) -> None: |
|
if layout_type in LayoutType: |
|
self.counts[layout_type] += 1 |
|
name = f"{layout_type}{self.counts[layout_type]}" |
|
logging.info(f"Saved layout type {layout_type} with name: {name}") |
|
self.records[layout_type].append({ |
|
"type": layout_type, |
|
"name": name, |
|
"bounding_box": bounding_box, |
|
"detections": detections, |
|
"table": table, |
|
}) |
|
if layout_type != LayoutType.UNKNOWN or self.show_unknown: |
|
path = f"recording://Image/{layout_type.type.title()}/{name.title()}" |
|
self.recovery += f"\n\n## [{name.title()}]({path})\n\n" |
|
|
|
if layout_type == LayoutType.TABLE: |
|
if table: |
|
self.recovery += table |
|
elif detections: |
|
for index, detection in enumerate(detections): |
|
path_text = f"recording://Image/{layout_type.type.title()}/{name.title()}/Detections/{index}" |
|
self.recovery += f' [{detection["text"]}]({path_text})' |
|
else: |
|
logging.warning(f"Invalid layout type detected: {layout_type}") |
|
|
|
def get_count(self, layout_type: LayoutType) -> int: |
|
if layout_type in LayoutType: |
|
return self.counts[layout_type] |
|
else: |
|
raise ValueError("Invalid layout type") |
|
|
|
def get_records(self) -> dict[LayoutType, list[dict[str, Any]]]: |
|
return self.records |
|
|
|
def save_all_layouts(self, results: list[dict[str, Any]]) -> None: |
|
for line in results: |
|
self.save_layout_data(line) |
|
for layout_type in LayoutType: |
|
logging.info(f"Number of detections for type {layout_type}: {self.counts[layout_type]}") |
|
|
|
def save_layout_data(self, line: dict[str, Any]) -> None: |
|
type = line.get("type", "empty") |
|
box = line.get("bbox", [0, 0, 0, 0]) |
|
layout_type = LayoutType.get_type(type) |
|
detections, table, img = [], None, None |
|
if layout_type == LayoutType.TABLE: |
|
table = self.get_table_markdown(line) |
|
elif layout_type == LayoutType.FIGURE: |
|
detections = self.get_detections(line) |
|
img = line.get("img") |
|
else: |
|
detections = self.get_detections(line) |
|
self.add(layout_type, box, detections=detections, table=table, figure=img) |
|
|
|
@staticmethod |
|
def get_detections(line: dict[str, Any]) -> list[dict[str, Any]]: |
|
detections = [] |
|
results = line.get("res") |
|
if results is not None: |
|
for i, result in enumerate(results): |
|
text = result.get("text") |
|
confidence = result.get("confidence") |
|
box = result.get("text_region") |
|
x_min, y_min = box[0] |
|
x_max, y_max = box[2] |
|
new_box = [x_min, y_min, x_max, y_max] |
|
detections.append({"id": i, "text": text, "confidence": confidence, "box": new_box}) |
|
return detections |
|
|
|
|
|
@staticmethod |
|
def get_table_markdown(line: dict[str, Any]) -> str: |
|
try: |
|
html_table = line.get("res", {}).get("html") |
|
if not html_table: |
|
return "No table found." |
|
|
|
dataframes = pd.read_html(html_table) |
|
if not dataframes: |
|
return "No data extracted from the table." |
|
|
|
markdown_table = dataframes[0].to_markdown() |
|
return markdown_table |
|
|
|
except Exception as e: |
|
return f"Error processing the table: {str(e)}" |
|
|
|
|
|
def process_layout_records(log_queue: SimpleQueue[Any], layout: Layout, page_path: str) -> LayoutStructure: |
|
paths, detections_paths = [], [] |
|
zoom_paths: list[rrb.Spatial2DView] = [] |
|
zoom_paths_figures: list[rrb.Spatial2DView] = [] |
|
zoom_paths_tables: list[rrb.Spatial2DView] = [] |
|
zoom_paths_texts: list[rrb.Spatial2DView] = [] |
|
|
|
for layout_type in LayoutType: |
|
for record in layout.records[layout_type]: |
|
record_name = record["name"].title() |
|
record_base_path = f"{page_path}/Image/{layout_type.type.title()}/{record_name}" |
|
paths.append(f"-{record_base_path}/**") |
|
detections_paths.append(f"-{record_base_path}/Detections/**") |
|
|
|
|
|
log_queue.put([ |
|
"log", |
|
record_base_path, |
|
[ |
|
rr.Boxes2D( |
|
array=record["bounding_box"], |
|
array_format=rr.Box2DFormat.XYXY, |
|
labels=[str(layout_type.type)], |
|
class_ids=[str(layout_type.number)], |
|
), |
|
rr.AnyValues(name=record_name), |
|
], |
|
]) |
|
|
|
log_detections(log_queue, layout_type, record, record_base_path) |
|
|
|
|
|
update_zoom_paths( |
|
layout, |
|
layout_type, |
|
record, |
|
paths, |
|
page_path, |
|
zoom_paths, |
|
zoom_paths_figures, |
|
zoom_paths_tables, |
|
zoom_paths_texts, |
|
) |
|
|
|
return paths, detections_paths, zoom_paths_figures, zoom_paths_tables, zoom_paths_texts |
|
|
|
|
|
def log_detections(log_queue: SimpleQueue, layout_type: LayoutType, record: dict[str, Any], page_path: str) -> None: |
|
if layout_type == LayoutType.TABLE: |
|
log_queue.put([ |
|
"log", |
|
f"Extracted{record['name']}", |
|
[rr.TextDocument(record["table"], media_type=rr.MediaType.MARKDOWN)], |
|
]) |
|
else: |
|
for detection in record.get("detections", []): |
|
log_queue.put([ |
|
"log", |
|
f"{page_path}/Detections/{detection['id']}", |
|
[ |
|
rr.Boxes2D( |
|
array=detection["box"], array_format=rr.Box2DFormat.XYXY, class_ids=[str(layout_type.number)] |
|
), |
|
rr.AnyValues( |
|
DetectionID=detection["id"], Text=detection["text"], Confidence=detection["confidence"] |
|
), |
|
], |
|
]) |
|
|
|
|
|
def update_zoom_paths( |
|
layout: Layout, |
|
layout_type: LayoutType, |
|
record: dict[str, Any], |
|
paths: list[str], |
|
page_path: str, |
|
zoom_paths: list[rrb.Spatial2DView], |
|
zoom_paths_figures: list[rrb.Spatial2DView], |
|
zoom_paths_tables: list[rrb.Spatial2DView], |
|
zoom_paths_texts: list[rrb.Spatial2DView], |
|
) -> None: |
|
if layout_type in [LayoutType.FIGURE, LayoutType.TABLE, LayoutType.TEXT]: |
|
current_paths = paths.copy() |
|
current_paths.remove(f"-{page_path}/Image/{layout_type.type.title()}/{record['name'].title()}/**") |
|
bounds = rrb.VisualBounds2D( |
|
x_range=[record["bounding_box"][0] - 10, record["bounding_box"][2] + 10], |
|
y_range=[record["bounding_box"][1] - 10, record["bounding_box"][3] + 10], |
|
) |
|
|
|
|
|
view = rrb.Spatial2DView( |
|
name=record["name"].title(), contents=[f"{page_path}/Image/**"] + current_paths, visual_bounds=bounds |
|
) |
|
zoom_paths.append(view) |
|
|
|
|
|
if layout_type == LayoutType.FIGURE: |
|
zoom_paths_figures.append(view) |
|
elif layout_type == LayoutType.TABLE: |
|
zoom_paths_tables.append(view) |
|
elif layout_type != LayoutType.UNKNOWN or layout.show_unknown: |
|
zoom_paths_texts.append(view) |
|
|
|
|
|
def generate_blueprint( |
|
layouts: list[Layout], |
|
page_paths: list[str], |
|
processed_layouts: list[LayoutStructure], |
|
) -> rrb.Blueprint: |
|
page_tabs = [] |
|
for layout, (page_path, processed_layout) in zip(layouts, zip(page_paths, processed_layouts)): |
|
paths, detections_paths, zoom_paths_figures, zoom_paths_tables, zoom_paths_texts = processed_layout |
|
|
|
section_tabs = [] |
|
content_data: dict[str, Any] = { |
|
"Figures": zoom_paths_figures, |
|
"Tables": zoom_paths_tables, |
|
"Texts": zoom_paths_texts, |
|
} |
|
|
|
for name, paths in content_data.items(): |
|
if paths: |
|
section_tabs.append(rrb.Tabs(*paths, name=name)) |
|
|
|
page_tabs.append( |
|
rrb.Vertical( |
|
rrb.Horizontal( |
|
rrb.Spatial2DView( |
|
name="Layout", |
|
origin=f"{page_path}/Image/", |
|
contents=[f"{page_path}/Image/**"] + detections_paths, |
|
), |
|
rrb.Spatial2DView(name="Detections", contents=[f"{page_path}/Image/**"]), |
|
rrb.Vertical( |
|
rrb.TextDocumentView(name="Progress", contents=["progress/**"]), |
|
rrb.TextDocumentView(name="Recovery", contents=f"{page_path}/Recovery"), |
|
row_shares=[1, 4] |
|
) |
|
), |
|
rrb.Horizontal(*section_tabs), |
|
name=page_path, |
|
row_shares=[4, 3], |
|
) |
|
) |
|
|
|
return rrb.Blueprint( |
|
rrb.Tabs(*page_tabs), |
|
collapse_panels=True, |
|
) |
|
|
|
|
|
def detect_and_log_layouts(log_queue: SimpleQueue[Any], file_path: str, start_page: int = 1, end_page: int | None = -1) -> None: |
|
if end_page == -1: |
|
end_page = start_page + PAGE_LIMIT |
|
if end_page < start_page: |
|
end_page = start_page |
|
|
|
images: list[npt.NDArray[np.uint8]] = [] |
|
if file_path.endswith(".pdf"): |
|
|
|
images.extend(np.array(img, dtype=np.uint8) for img in pdf2image.convert_from_path(file_path, first_page=start_page, last_page=end_page)) |
|
if len(images) > PAGE_LIMIT: |
|
log_queue.put([ |
|
"log", |
|
"error", |
|
[rr.TextLog(f"Too many pages requsted: {len(images)} requested but the limit is {PAGE_LIMIT}")], |
|
]) |
|
else: |
|
|
|
img = cv2.imread(file_path) |
|
coloured_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
images.append(coloured_image.astype(np.uint8)) |
|
|
|
|
|
layouts: list[Layout] = [] |
|
page_paths = [f"page_{i + start_page}" for i in range(len(images))] |
|
processed_layouts: list[LayoutStructure] = [] |
|
for i, (image, page_path) in enumerate(zip(images, page_paths)): |
|
layouts.append(detect_and_log_layout(log_queue, image, page_path)) |
|
|
|
|
|
processed_layouts.append( |
|
process_layout_records( |
|
log_queue, |
|
layouts[-1], |
|
page_path, |
|
) |
|
) |
|
logging.info("Sending blueprint...") |
|
blueprint = generate_blueprint(layouts, page_paths, processed_layouts) |
|
log_queue.put(["blueprint", blueprint]) |
|
logging.info("Blueprint sent...") |
|
|
|
|
|
def detect_and_log_layout(log_queue: SimpleQueue, coloured_image: npt.NDArray[np.uint8], page_path: str = "") -> Layout: |
|
|
|
layout = Layout() |
|
|
|
|
|
log_queue.put([ |
|
"log", |
|
f"{page_path}/Image", |
|
[rr.Image(coloured_image)], |
|
]) |
|
log_queue.put([ |
|
"log", |
|
f"{page_path}/Image", |
|
|
|
[rr.AnnotationContext(LayoutType.get_annotation())], |
|
{ |
|
"static": True, |
|
}, |
|
]) |
|
|
|
|
|
logging.info("Start detection... (It usually takes more than 10-20 seconds per page)") |
|
ocr_model_pp = PPStructure(show_log=False, recovery=True) |
|
logging.info("model loaded") |
|
result_pp = ocr_model_pp(coloured_image) |
|
_, w, _ = coloured_image.shape |
|
result_pp = sorted_layout_boxes(result_pp, w) |
|
logging.info("Detection finished...") |
|
|
|
|
|
layout.save_all_layouts(result_pp) |
|
logging.info("All results are saved...") |
|
|
|
|
|
log_queue.put([ |
|
"log", |
|
f"{page_path}/Recovery", |
|
[rr.TextDocument(layout.recovery, media_type=rr.MediaType.MARKDOWN)], |
|
]) |
|
|
|
return layout |
|
|