rag_lite / tests /test_rag.py
EL GHAFRAOUI AYOUB
C
54f5afe
"""Test RAGLite's RAG functionality."""
import os
from typing import TYPE_CHECKING
import pytest
from llama_cpp import llama_supports_gpu_offload
from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks
if TYPE_CHECKING:
from raglite._database import Chunk
from raglite._typing import SearchMethod
def is_accelerator_available() -> bool:
"""Check if an accelerator is available."""
return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004
@pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available")
def test_rag(raglite_test_config: RAGLiteConfig) -> None:
"""Test Retrieval-Augmented Generation."""
# Assemble different types of search inputs for RAG.
prompt = "What does it mean for two events to be simultaneous?"
search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [
hybrid_search, # A search method as input.
hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input.
retrieve_chunks( # Chunks as input.
hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config
),
]
# Answer a question with RAG.
for search_input in search_inputs:
stream = rag(prompt, search=search_input, config=raglite_test_config)
answer = ""
for update in stream:
assert isinstance(update, str)
answer += update
assert "simultaneous" in answer.lower()