File size: 1,501 Bytes
54f5afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Test RAGLite's RAG functionality."""

import os
from typing import TYPE_CHECKING

import pytest
from llama_cpp import llama_supports_gpu_offload

from raglite import RAGLiteConfig, hybrid_search, rag, retrieve_chunks

if TYPE_CHECKING:
    from raglite._database import Chunk
    from raglite._typing import SearchMethod


def is_accelerator_available() -> bool:
    """Check if an accelerator is available."""
    return llama_supports_gpu_offload() or (os.cpu_count() or 1) >= 8  # noqa: PLR2004


@pytest.mark.skipif(not is_accelerator_available(), reason="No accelerator available")
def test_rag(raglite_test_config: RAGLiteConfig) -> None:
    """Test Retrieval-Augmented Generation."""
    # Assemble different types of search inputs for RAG.
    prompt = "What does it mean for two events to be simultaneous?"
    search_inputs: list[SearchMethod | list[str] | list[Chunk]] = [
        hybrid_search,  # A search method as input.
        hybrid_search(prompt, config=raglite_test_config)[0],  # Chunk ids as input.
        retrieve_chunks(  # Chunks as input.
            hybrid_search(prompt, config=raglite_test_config)[0], config=raglite_test_config
        ),
    ]
    # Answer a question with RAG.
    for search_input in search_inputs:
        stream = rag(prompt, search=search_input, config=raglite_test_config)
        answer = ""
        for update in stream:
            assert isinstance(update, str)
            answer += update
        assert "simultaneous" in answer.lower()