Spaces:
Running
Running
"""Test RAGLite's reranking functionality.""" | |
import pytest | |
from rerankers.models.ranker import BaseRanker | |
from raglite import RAGLiteConfig, hybrid_search, rerank_chunks, retrieve_chunks | |
from raglite._database import Chunk | |
from raglite._flashrank import PatchedFlashRankRanker as FlashRankRanker | |
def reranker( | |
request: pytest.FixtureRequest, | |
) -> BaseRanker | tuple[tuple[str, BaseRanker], ...] | None: | |
"""Get a reranker to test RAGLite with.""" | |
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None = request.param | |
return reranker | |
def test_reranker( | |
raglite_test_config: RAGLiteConfig, | |
reranker: BaseRanker | tuple[tuple[str, BaseRanker], ...] | None, | |
) -> None: | |
"""Test inserting a document, updating the indexes, and searching for a query.""" | |
# Update the config with the reranker. | |
raglite_test_config = RAGLiteConfig( | |
db_url=raglite_test_config.db_url, embedder=raglite_test_config.embedder, reranker=reranker | |
) | |
# Search for a query. | |
query = "What does it mean for two events to be simultaneous?" | |
chunk_ids, _ = hybrid_search(query, num_results=3, config=raglite_test_config) | |
# Retrieve the chunks. | |
chunks = retrieve_chunks(chunk_ids, config=raglite_test_config) | |
assert all(isinstance(chunk, Chunk) for chunk in chunks) | |
assert all(chunk_id == chunk.id for chunk_id, chunk in zip(chunk_ids, chunks, strict=True)) | |
# Rerank the chunks given an inverted chunk order. | |
reranked_chunks = rerank_chunks(query, chunks[::-1], config=raglite_test_config) | |
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder: | |
assert reranked_chunks[0] == chunks[0] | |
# Test that we can also rerank given the chunk_ids only. | |
reranked_chunks = rerank_chunks(query, chunk_ids[::-1], config=raglite_test_config) | |
if reranker is not None and "text-embedding-3-small" not in raglite_test_config.embedder: | |
assert reranked_chunks[0] == chunks[0] | |