rag_lite / tests /test_split_chunks.py
EL GHAFRAOUI AYOUB
C
54f5afe
"""Test RAGLite's chunk splitting functionality."""
import numpy as np
import pytest
from raglite._split_chunks import split_chunks
@pytest.mark.parametrize(
"sentences",
[
pytest.param([], id="one_chunk:no_sentences"),
pytest.param(["Hello world"], id="one_chunk:one_sentence"),
pytest.param(["Hello world"] * 2, id="one_chunk:two_sentences"),
pytest.param(["Hello world"] * 3, id="one_chunk:three_sentences"),
pytest.param(["Hello world"] * 100, id="one_chunk:many_sentences"),
pytest.param(["Hello world", "X" * 1000], id="n_chunks:two_sentences_a"),
pytest.param(["X" * 1000, "Hello world"], id="n_chunks:two_sentences_b"),
pytest.param(["Hello world", "X" * 1000, "X" * 1000], id="n_chunks:three_sentences_a"),
pytest.param(["X" * 1000, "Hello world", "X" * 1000], id="n_chunks:three_sentences_b"),
pytest.param(["X" * 1000, "X" * 1000, "Hello world"], id="n_chunks:three_sentences_c"),
pytest.param(["X" * 1000] * 100, id="n_chunks:many_sentences_a"),
pytest.param(["X" * 100] * 1000, id="n_chunks:many_sentences_b"),
],
)
def test_edge_cases(sentences: list[str]) -> None:
"""Test chunk splitting edge cases."""
sentence_embeddings = np.ones((len(sentences), 768)).astype(np.float16)
chunks, chunk_embeddings = split_chunks(
sentences, sentence_embeddings, sentence_window_size=3, max_size=1440
)
assert isinstance(chunks, list)
assert isinstance(chunk_embeddings, list)
assert len(chunk_embeddings) == (len(chunks) if sentences else 1)
assert all(isinstance(chunk, str) for chunk in chunks)
assert all(isinstance(chunk_embedding, np.ndarray) for chunk_embedding in chunk_embeddings)
assert all(ce.dtype == sentence_embeddings.dtype for ce in chunk_embeddings)
assert sum(ce.shape[0] for ce in chunk_embeddings) == sentence_embeddings.shape[0]
assert all(ce.shape[1] == sentence_embeddings.shape[1] for ce in chunk_embeddings)
@pytest.mark.parametrize(
"sentences",
[
pytest.param(["Hello world" * 1000] + ["X"] * 100, id="first"),
pytest.param(["X"] * 50 + ["Hello world" * 1000] + ["X"] * 50, id="middle"),
pytest.param(["X"] * 100 + ["Hello world" * 1000], id="last"),
],
)
def test_long_sentence(sentences: list[str]) -> None:
"""Test chunking on sentences that are too long."""
sentence_embeddings = np.ones((len(sentences), 768)).astype(np.float16)
with pytest.raises(
ValueError, match="Sentence with length larger than chunk max_size detected."
):
_ = split_chunks(sentences, sentence_embeddings, sentence_window_size=3, max_size=1440)