File size: 2,689 Bytes
54f5afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
"""Test RAGLite's chunk splitting functionality."""

import numpy as np
import pytest

from raglite._split_chunks import split_chunks


@pytest.mark.parametrize(
    "sentences",
    [
        pytest.param([], id="one_chunk:no_sentences"),
        pytest.param(["Hello world"], id="one_chunk:one_sentence"),
        pytest.param(["Hello world"] * 2, id="one_chunk:two_sentences"),
        pytest.param(["Hello world"] * 3, id="one_chunk:three_sentences"),
        pytest.param(["Hello world"] * 100, id="one_chunk:many_sentences"),
        pytest.param(["Hello world", "X" * 1000], id="n_chunks:two_sentences_a"),
        pytest.param(["X" * 1000, "Hello world"], id="n_chunks:two_sentences_b"),
        pytest.param(["Hello world", "X" * 1000, "X" * 1000], id="n_chunks:three_sentences_a"),
        pytest.param(["X" * 1000, "Hello world", "X" * 1000], id="n_chunks:three_sentences_b"),
        pytest.param(["X" * 1000, "X" * 1000, "Hello world"], id="n_chunks:three_sentences_c"),
        pytest.param(["X" * 1000] * 100, id="n_chunks:many_sentences_a"),
        pytest.param(["X" * 100] * 1000, id="n_chunks:many_sentences_b"),
    ],
)
def test_edge_cases(sentences: list[str]) -> None:
    """Test chunk splitting edge cases."""
    sentence_embeddings = np.ones((len(sentences), 768)).astype(np.float16)
    chunks, chunk_embeddings = split_chunks(
        sentences, sentence_embeddings, sentence_window_size=3, max_size=1440
    )
    assert isinstance(chunks, list)
    assert isinstance(chunk_embeddings, list)
    assert len(chunk_embeddings) == (len(chunks) if sentences else 1)
    assert all(isinstance(chunk, str) for chunk in chunks)
    assert all(isinstance(chunk_embedding, np.ndarray) for chunk_embedding in chunk_embeddings)
    assert all(ce.dtype == sentence_embeddings.dtype for ce in chunk_embeddings)
    assert sum(ce.shape[0] for ce in chunk_embeddings) == sentence_embeddings.shape[0]
    assert all(ce.shape[1] == sentence_embeddings.shape[1] for ce in chunk_embeddings)


@pytest.mark.parametrize(
    "sentences",
    [
        pytest.param(["Hello world" * 1000] + ["X"] * 100, id="first"),
        pytest.param(["X"] * 50 + ["Hello world" * 1000] + ["X"] * 50, id="middle"),
        pytest.param(["X"] * 100 + ["Hello world" * 1000], id="last"),
    ],
)
def test_long_sentence(sentences: list[str]) -> None:
    """Test chunking on sentences that are too long."""
    sentence_embeddings = np.ones((len(sentences), 768)).astype(np.float16)
    with pytest.raises(
        ValueError, match="Sentence with length larger than chunk max_size detected."
    ):
        _ = split_chunks(sentences, sentence_embeddings, sentence_window_size=3, max_size=1440)