Spaces:
Running
Running
"""Test RAGLite's RAG functionality.""" | |
import os | |
from typing import TYPE_CHECKING | |
import pytest | |
from llama_cpp import llama_supports_gpu_offload | |
from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks | |
if TYPE_CHECKING: | |
from raglite._database import Chunk | |
from raglite._typing import SearchMethod | |
def is_accelerator_available() -> bool: | |
"""Check if an accelerator is available.""" | |
return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8 # noqa: PLR2004 | |
def test_rag(raglite_test_config: RAGLiteConfig) -> None: | |
"""Test Retrieval-Augmented Generation.""" | |
# Assemble different types of search inputs for RAG. | |
prompt = "What does it mean for two events to be simultaneous?" | |
search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [ | |
hybrid_search, # A search method as input. | |
hybrid_search(prompt, config=raglite_test_config)[0], # Chunk ids as input. | |
retrieve_chunks( # Chunks as input. | |
hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config | |
), | |
] | |
# Answer a question with RAG. | |
for search_input in search_inputs: | |
stream = rag(prompt, search=search_input, config=raglite_test_config) | |
answer = "" | |
for update in stream: | |
assert isinstance(update, str) | |
answer += update | |
assert "simultaneous" in answer.lower() | |