Spaces:
Running
Running
"""Split a document into semantic chunks.""" | |
import re | |
import numpy as np | |
from scipy.optimize import linprog | |
from scipy.sparse import coo_matrix | |
from raglite._typing import FloatMatrix | |
def split_chunks( # noqa: C901, PLR0915 | |
sentences: list[str], | |
sentence_embeddings: FloatMatrix, | |
sentence_window_size: int = 3, | |
max_size: int = 1440, | |
) -> tuple[list[str], list[FloatMatrix]]: | |
"""Split sentences into optimal semantic chunks with corresponding sentence embeddings.""" | |
# Validate the input. | |
sentence_length = np.asarray([len(sentence) for sentence in sentences]) | |
if not np.all(sentence_length <= max_size): | |
error_message = "Sentence with length larger than chunk max_size detected." | |
raise ValueError(error_message) | |
if not np.all(np.linalg.norm(sentence_embeddings, axis=1) > 0.0): | |
error_message = "Sentence embeddings with zero norm detected." | |
raise ValueError(error_message) | |
# Exit early if there is only one chunk to return. | |
if len(sentences) <= 1 or sum(sentence_length) <= max_size: | |
return ["".join(sentences)] if sentences else sentences, [sentence_embeddings] | |
# Normalise the sentence embeddings to unit norm. | |
X = sentence_embeddings.astype(np.float32) # noqa: N806 | |
X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806 | |
# Select nonoutlying sentences and remove the discourse vector. | |
q15, q85 = np.quantile(sentence_length, [0.15, 0.85]) | |
nonoutlying_sentences = (q15 <= sentence_length) & (sentence_length <= q85) | |
discourse = np.mean(X[nonoutlying_sentences, :], axis=0) | |
discourse = discourse / np.linalg.norm(discourse) | |
if not np.any(np.linalg.norm(X - discourse[np.newaxis, :], axis=1) <= np.finfo(X.dtype).eps): | |
X = X - np.outer(X @ discourse, discourse) # noqa: N806 | |
X = X / np.linalg.norm(X, axis=1, keepdims=True) # noqa: N806 | |
# For each partition point in the list of sentences, compute the similarity of the windows | |
# before and after the partition point. Sentence embeddings are assumed to be of the sentence | |
# itself and at most the (sentence_window_size - 1) sentences that preceed it. | |
sentence_window_size = min(len(sentences) - 1, sentence_window_size) | |
windows_before = X[:-sentence_window_size] | |
windows_after = X[sentence_window_size:] | |
partition_similarity = np.ones(len(sentences) - 1, dtype=X.dtype) | |
partition_similarity[: len(windows_before)] = np.sum(windows_before * windows_after, axis=1) | |
# Make partition similarity nonnegative before modification and optimisation. | |
partition_similarity = np.maximum( | |
(partition_similarity + 1) / 2, np.sqrt(np.finfo(X.dtype).eps) | |
) | |
# Modify the partition similarity to encourage splitting on Markdown headings. | |
prev_sentence_is_heading = True | |
for i, sentence in enumerate(sentences[:-1]): | |
is_heading = bool(re.match(r"^#+\s", sentence.replace("\n", "").strip())) | |
if is_heading: | |
# Encourage splitting before a heading. | |
if not prev_sentence_is_heading: | |
partition_similarity[i - 1] = partition_similarity[i - 1] / 4 | |
# Don't split immediately after a heading. | |
partition_similarity[i] = 1.0 | |
prev_sentence_is_heading = is_heading | |
# Solve an optimisation problem to find the best partition points. | |
sentence_length_cumsum = np.cumsum(sentence_length) | |
row_indices = [] | |
col_indices = [] | |
data = [] | |
for i in range(len(sentences) - 1): | |
r = sentence_length_cumsum[i - 1] if i > 0 else 0 | |
idx = np.searchsorted(sentence_length_cumsum - r, max_size) | |
assert idx > i | |
if idx == len(sentence_length_cumsum): | |
break | |
cols = list(range(i, idx)) | |
col_indices.extend(cols) | |
row_indices.extend([i] * len(cols)) | |
data.extend([1] * len(cols)) | |
A = coo_matrix( # noqa: N806 | |
(data, (row_indices, col_indices)), | |
shape=(max(row_indices) + 1, len(sentences) - 1), | |
dtype=np.float32, | |
) | |
b_ub = np.ones(A.shape[0], dtype=np.float32) | |
res = linprog( | |
partition_similarity, | |
A_ub=-A, | |
b_ub=-b_ub, | |
bounds=(0, 1), | |
integrality=[1] * A.shape[1], | |
) | |
if not res.success: | |
error_message = "Optimization of chunk partitions failed." | |
raise ValueError(error_message) | |
# Split the sentences and their window embeddings into optimal chunks. | |
partition_indices = (np.where(res.x)[0] + 1).tolist() | |
chunks = [ | |
"".join(sentences[i:j]) | |
for i, j in zip([0, *partition_indices], [*partition_indices, len(sentences)], strict=True) | |
] | |
chunk_embeddings = np.split(sentence_embeddings, partition_indices) | |
return chunks, chunk_embeddings | |