File size: 4,770 Bytes
54f5afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Split a document into semantic chunks."""

import re

import numpy as np
from scipy.optimize import linprog
from scipy.sparse import coo_matrix

from raglite._typing import FloatMatrix


def split_chunks(  # noqa: C901, PLR0915
    sentences: list[str],
    sentence_embeddings: FloatMatrix,
    sentence_window_size: int = 3,
    max_size: int = 1440,
) -> tuple[list[str], list[FloatMatrix]]:
    """Split sentences into optimal semantic chunks with corresponding sentence embeddings."""
    # Validate the input.
    sentence_length = np.asarray([len(sentence) for sentence in sentences])
    if not np.all(sentence_length <= max_size):
        error_message = "Sentence with length larger than chunk max_size detected."
        raise ValueError(error_message)
    if not np.all(np.linalg.norm(sentence_embeddings, axis=1) > 0.0):
        error_message = "Sentence embeddings with zero norm detected."
        raise ValueError(error_message)
    # Exit early if there is only one chunk to return.
    if len(sentences) <= 1 or sum(sentence_length) <= max_size:
        return ["".join(sentences)] if sentences else sentences, [sentence_embeddings]
    # Normalise the sentence embeddings to unit norm.
    X = sentence_embeddings.astype(np.float32)  # noqa: N806
    X = X / np.linalg.norm(X, axis=1, keepdims=True)  # noqa: N806
    # Select nonoutlying sentences and remove the discourse vector.
    q15, q85 = np.quantile(sentence_length, [0.15, 0.85])
    nonoutlying_sentences = (q15 <= sentence_length) & (sentence_length <= q85)
    discourse = np.mean(X[nonoutlying_sentences, :], axis=0)
    discourse = discourse / np.linalg.norm(discourse)
    if not np.any(np.linalg.norm(X - discourse[np.newaxis, :], axis=1) <= np.finfo(X.dtype).eps):
        X = X - np.outer(X @ discourse, discourse)  # noqa: N806
        X = X / np.linalg.norm(X, axis=1, keepdims=True)  # noqa: N806
    # For each partition point in the list of sentences, compute the similarity of the windows
    # before and after the partition point. Sentence embeddings are assumed to be of the sentence
    # itself and at most the (sentence_window_size - 1) sentences that preceed it.
    sentence_window_size = min(len(sentences) - 1, sentence_window_size)
    windows_before = X[:-sentence_window_size]
    windows_after = X[sentence_window_size:]
    partition_similarity = np.ones(len(sentences) - 1, dtype=X.dtype)
    partition_similarity[: len(windows_before)] = np.sum(windows_before * windows_after, axis=1)
    # Make partition similarity nonnegative before modification and optimisation.
    partition_similarity = np.maximum(
        (partition_similarity + 1) / 2, np.sqrt(np.finfo(X.dtype).eps)
    )
    # Modify the partition similarity to encourage splitting on Markdown headings.
    prev_sentence_is_heading = True
    for i, sentence in enumerate(sentences[:-1]):
        is_heading = bool(re.match(r"^#+\s", sentence.replace("\n", "").strip()))
        if is_heading:
            # Encourage splitting before a heading.
            if not prev_sentence_is_heading:
                partition_similarity[i - 1] = partition_similarity[i - 1] / 4
            # Don't split immediately after a heading.
            partition_similarity[i] = 1.0
        prev_sentence_is_heading = is_heading
    # Solve an optimisation problem to find the best partition points.
    sentence_length_cumsum = np.cumsum(sentence_length)
    row_indices = []
    col_indices = []
    data = []
    for i in range(len(sentences) - 1):
        r = sentence_length_cumsum[i - 1] if i > 0 else 0
        idx = np.searchsorted(sentence_length_cumsum - r, max_size)
        assert idx > i
        if idx == len(sentence_length_cumsum):
            break
        cols = list(range(i, idx))
        col_indices.extend(cols)
        row_indices.extend([i] * len(cols))
        data.extend([1] * len(cols))
    A = coo_matrix(  # noqa: N806
        (data, (row_indices, col_indices)),
        shape=(max(row_indices) + 1, len(sentences) - 1),
        dtype=np.float32,
    )
    b_ub = np.ones(A.shape[0], dtype=np.float32)
    res = linprog(
        partition_similarity,
        A_ub=-A,
        b_ub=-b_ub,
        bounds=(0, 1),
        integrality=[1] * A.shape[1],
    )
    if not res.success:
        error_message = "Optimization of chunk partitions failed."
        raise ValueError(error_message)
    # Split the sentences and their window embeddings into optimal chunks.
    partition_indices = (np.where(res.x)[0] + 1).tolist()
    chunks = [
        "".join(sentences[i:j])
        for i, j in zip([0, *partition_indices], [*partition_indices, len(sentences)], strict=True)
    ]
    chunk_embeddings = np.split(sentence_embeddings, partition_indices)
    return chunks, chunk_embeddings