File size: 6,689 Bytes
99d12ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from collections import deque
from pathlib import Path
import logging
from typing import Iterator, Self
import re

from llama_index.core.schema import BaseNode

logger = logging.getLogger(__name__)

DOCUMENT_NODE_NUMBER: int = -1


class TreeNode:
    def __init__(self, name: str, number: int | None = None):
        self.name = name
        self.number = number
        self.children: list[Self] = []
        self.parent: Self | None = None

    def add_child(self, child: Self) -> None:
        self.children.append(child)

    def set_parent(self, parent: Self) -> None:
        if self.parent is not None:
            raise ValueError("parent has already been set")
        else:
            self.parent = parent

    def remove_parent(self) -> None:
        self.parent = None

    def __str__(self, level: int = 0) -> str:
        ret = "  " * level + self.name
        if self.number is not None:
            ret += f" [{self.number}]"
        ret += "\n"
        for child in self.children:
            ret += child.__str__(level + 1)
        return ret

    def bfs(self) -> Iterator[Self]:
        """Perform Breadth-First traversal of the tree."""
        queue = deque([self])
        while queue:
            node = queue.popleft()
            yield node
            queue.extend(node.children)

    def remove_child(self, child: Self) -> bool:
        if child in self.children:
            child.remove_parent()
            self.children.remove(child)
            return True
        return False

    def __iter__(self):
        return self.bfs()


def parse_landscape_structure(document: BaseNode) -> TreeNode:
    page_pattern = re.compile(r"^-\s*Page\s+(\d+)\s*:\s*(.+)$")
    header_pattern = re.compile(r"^(#+)\s+(.+)$")
    format = document.metadata.get("format", "")
    if format != "landscape":
        raise ValueError(f"Unsupported format {format}")
    number_pages = document.metadata.get("nb_pages", None)
    structure = document.metadata.get("structure", "")
    filename = document.metadata.get("filename", "")
    assert number_pages and structure and filename
    lines = structure.splitlines()
    filestem = Path(filename).stem
    root = TreeNode(name=filestem, number=DOCUMENT_NODE_NUMBER)
    stack = [(root, 0)]  # (node, level) pairs
    abstract_node_number = DOCUMENT_NODE_NUMBER - 1
    processed_page_numbers = set()
    for line in lines:
        line = line.strip()
        if not line:
            continue
        # Check if it's a header
        header_match = header_pattern.match(line)
        if header_match:
            level = len(header_match.group(1))
            title = header_match.group(2).strip()
            new_node = TreeNode(name=title, number=abstract_node_number)
            abstract_node_number -= 1
            # Adjust stack for header level
            while stack and stack[-1][1] >= level:
                stack.pop()
            if stack:
                stack[-1][0].add_child(new_node)
                new_node.set_parent(stack[-1][0])
            stack.append((new_node, level))
            continue
        # Check if it's a page entry
        page_match = page_pattern.match(line)
        if page_match:
            page_num = int(page_match.group(1))
            title = page_match.group(2).strip()
            if page_num in processed_page_numbers:
                logger.warning(f"Filename {filename} Page {page_num} already processed. Skipping {title}.")
            elif page_num > number_pages:
                logger.warning(
                    f"Filename {filename} Page number {page_num} is greater than the number of pages in the document. Skipping {title}."
                )
            else:
                processed_page_numbers.add(page_num)
                new_node = TreeNode(name=title, number=page_num)
                # Add to last header in stack
                if stack:
                    stack[-1][0].add_child(new_node)
                    new_node.set_parent(stack[-1][0])

    leftout_page_numbers = set(range(1, number_pages + 1)) - processed_page_numbers
    if leftout_page_numbers:
        logger.warning(f"Filename {filename} Page numbers {leftout_page_numbers} are not processed.")
        uncategorized_node = TreeNode(name="Uncategorized", number=abstract_node_number)
        abstract_node_number -= 1
        root.add_child(uncategorized_node)
        uncategorized_node.set_parent(root)
        for page_num in leftout_page_numbers:
            new_node = TreeNode(name=f"Page number {page_num}", number=page_num)
            uncategorized_node.add_child(new_node)
            new_node.set_parent(uncategorized_node)

    return root


def parse_portrait_structure(document: BaseNode) -> TreeNode:
    header_pattern = re.compile(r"(#+)\s+(.*?)\s+\[line\s+(\d+)\]")
    format = document.metadata.get("format", "")
    if format != "portrait":
        raise ValueError(f"Unsupported format {format}")
    structure = document.metadata.get("structure", "")
    filename = document.metadata.get("filename", "")
    created_toc = document.metadata.get("created_toc", "")
    assert structure and filename and created_toc
    lines = structure.splitlines()
    filestem = Path(filename).stem
    root = TreeNode(name=filestem, number=DOCUMENT_NODE_NUMBER)
    stack = [(root, 0)]  # (node, level) pairs
    processed_line_numbers = list()
    for line in lines:
        line = line.strip()
        if not line:
            continue
        # Check if it's a header
        header_match = header_pattern.match(line)
        if header_match:
            level = len(header_match.group(1))
            title = header_match.group(2).strip()
            line_number = int(header_match.group(3))
            processed_line_numbers.append(line_number)
            new_node = TreeNode(name=title, number=line_number)
            # Adjust stack for header level
            while stack and stack[-1][1] >= level:
                stack.pop()
            if stack:
                stack[-1][0].add_child(new_node)
                new_node.set_parent(stack[-1][0])
            stack.append((new_node, level))
            continue

    assert processed_line_numbers[0] == 0 and all(
        processed_line_numbers[i] <= processed_line_numbers[i + 1] for i in range(len(processed_line_numbers) - 1)
    )
    return root


def parse_structure(document: BaseNode) -> TreeNode:
    format = document.metadata.get("format", "")
    match format:
        case "landscape":
            return parse_landscape_structure(document)
        case "portrait":
            return parse_portrait_structure(document)
        case _:
            raise ValueError(f"Unsupported format {format}")