File size: 6,689 Bytes
99d12ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from collections import deque
from pathlib import Path
import logging
from typing import Iterator, Self
import re
from llama_index.core.schema import BaseNode
logger = logging.getLogger(__name__)
DOCUMENT_NODE_NUMBER: int = -1
class TreeNode:
def __init__(self, name: str, number: int | None = None):
self.name = name
self.number = number
self.children: list[Self] = []
self.parent: Self | None = None
def add_child(self, child: Self) -> None:
self.children.append(child)
def set_parent(self, parent: Self) -> None:
if self.parent is not None:
raise ValueError("parent has already been set")
else:
self.parent = parent
def remove_parent(self) -> None:
self.parent = None
def __str__(self, level: int = 0) -> str:
ret = " " * level + self.name
if self.number is not None:
ret += f" [{self.number}]"
ret += "\n"
for child in self.children:
ret += child.__str__(level + 1)
return ret
def bfs(self) -> Iterator[Self]:
"""Perform Breadth-First traversal of the tree."""
queue = deque([self])
while queue:
node = queue.popleft()
yield node
queue.extend(node.children)
def remove_child(self, child: Self) -> bool:
if child in self.children:
child.remove_parent()
self.children.remove(child)
return True
return False
def __iter__(self):
return self.bfs()
def parse_landscape_structure(document: BaseNode) -> TreeNode:
page_pattern = re.compile(r"^-\s*Page\s+(\d+)\s*:\s*(.+)$")
header_pattern = re.compile(r"^(#+)\s+(.+)$")
format = document.metadata.get("format", "")
if format != "landscape":
raise ValueError(f"Unsupported format {format}")
number_pages = document.metadata.get("nb_pages", None)
structure = document.metadata.get("structure", "")
filename = document.metadata.get("filename", "")
assert number_pages and structure and filename
lines = structure.splitlines()
filestem = Path(filename).stem
root = TreeNode(name=filestem, number=DOCUMENT_NODE_NUMBER)
stack = [(root, 0)] # (node, level) pairs
abstract_node_number = DOCUMENT_NODE_NUMBER - 1
processed_page_numbers = set()
for line in lines:
line = line.strip()
if not line:
continue
# Check if it's a header
header_match = header_pattern.match(line)
if header_match:
level = len(header_match.group(1))
title = header_match.group(2).strip()
new_node = TreeNode(name=title, number=abstract_node_number)
abstract_node_number -= 1
# Adjust stack for header level
while stack and stack[-1][1] >= level:
stack.pop()
if stack:
stack[-1][0].add_child(new_node)
new_node.set_parent(stack[-1][0])
stack.append((new_node, level))
continue
# Check if it's a page entry
page_match = page_pattern.match(line)
if page_match:
page_num = int(page_match.group(1))
title = page_match.group(2).strip()
if page_num in processed_page_numbers:
logger.warning(f"Filename {filename} Page {page_num} already processed. Skipping {title}.")
elif page_num > number_pages:
logger.warning(
f"Filename {filename} Page number {page_num} is greater than the number of pages in the document. Skipping {title}."
)
else:
processed_page_numbers.add(page_num)
new_node = TreeNode(name=title, number=page_num)
# Add to last header in stack
if stack:
stack[-1][0].add_child(new_node)
new_node.set_parent(stack[-1][0])
leftout_page_numbers = set(range(1, number_pages + 1)) - processed_page_numbers
if leftout_page_numbers:
logger.warning(f"Filename {filename} Page numbers {leftout_page_numbers} are not processed.")
uncategorized_node = TreeNode(name="Uncategorized", number=abstract_node_number)
abstract_node_number -= 1
root.add_child(uncategorized_node)
uncategorized_node.set_parent(root)
for page_num in leftout_page_numbers:
new_node = TreeNode(name=f"Page number {page_num}", number=page_num)
uncategorized_node.add_child(new_node)
new_node.set_parent(uncategorized_node)
return root
def parse_portrait_structure(document: BaseNode) -> TreeNode:
header_pattern = re.compile(r"(#+)\s+(.*?)\s+\[line\s+(\d+)\]")
format = document.metadata.get("format", "")
if format != "portrait":
raise ValueError(f"Unsupported format {format}")
structure = document.metadata.get("structure", "")
filename = document.metadata.get("filename", "")
created_toc = document.metadata.get("created_toc", "")
assert structure and filename and created_toc
lines = structure.splitlines()
filestem = Path(filename).stem
root = TreeNode(name=filestem, number=DOCUMENT_NODE_NUMBER)
stack = [(root, 0)] # (node, level) pairs
processed_line_numbers = list()
for line in lines:
line = line.strip()
if not line:
continue
# Check if it's a header
header_match = header_pattern.match(line)
if header_match:
level = len(header_match.group(1))
title = header_match.group(2).strip()
line_number = int(header_match.group(3))
processed_line_numbers.append(line_number)
new_node = TreeNode(name=title, number=line_number)
# Adjust stack for header level
while stack and stack[-1][1] >= level:
stack.pop()
if stack:
stack[-1][0].add_child(new_node)
new_node.set_parent(stack[-1][0])
stack.append((new_node, level))
continue
assert processed_line_numbers[0] == 0 and all(
processed_line_numbers[i] <= processed_line_numbers[i + 1] for i in range(len(processed_line_numbers) - 1)
)
return root
def parse_structure(document: BaseNode) -> TreeNode:
format = document.metadata.get("format", "")
match format:
case "landscape":
return parse_landscape_structure(document)
case "portrait":
return parse_portrait_structure(document)
case _:
raise ValueError(f"Unsupported format {format}")
|