rag_lite / tests /test_search.py
EL GHAFRAOUI AYOUB
C
54f5afe
"""Test RAGLite's search functionality."""
import pytest
from raglite import (
RAGLiteConfig,
hybrid_search,
keyword_search,
retrieve_chunks,
retrieve_segments,
vector_search,
)
from raglite._database import Chunk
from raglite._typing import SearchMethod
@pytest.fixture(
params=[
pytest.param(keyword_search, id="keyword_search"),
pytest.param(vector_search, id="vector_search"),
pytest.param(hybrid_search, id="hybrid_search"),
],
)
def search_method(
request: pytest.FixtureRequest,
) -> SearchMethod:
"""Get a search method to test RAGLite with."""
search_method: SearchMethod = request.param
return search_method
def test_search(raglite_test_config: RAGLiteConfig, search_method: SearchMethod) -> None:
"""Test searching for a query."""
# Search for a query.
query = "What does it mean for two events to be simultaneous?"
num_results = 5
chunk_ids, scores = search_method(query, num_results=num_results, config=raglite_test_config)
assert len(chunk_ids) == len(scores) == num_results
assert all(isinstance(chunk_id, str) for chunk_id in chunk_ids)
assert all(isinstance(score, float) for score in scores)
# Retrieve the chunks.
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config)
assert all(isinstance(chunk, Chunk) for chunk in chunks)
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True))
assert any("Definition of Simultaneity" in str(chunk) for chunk in chunks)
# Extend the chunks with their neighbours and group them into contiguous segments.
segments = retrieve_segments(chunk_ids, neighbors=(-1, 1), config=raglite_test_config)
assert all(isinstance(segment, str) for segment in segments)