File size: 3,439 Bytes
54f5afe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Fixtures for the tests."""

import os
import socket
import tempfile
from collections.abc import Generator
from pathlib import Path

import pytest
from sqlalchemy import create_engine, text

from raglite import RAGLiteConfig, insert_document

POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres"


def is_postgres_running() -> bool:
    """Check if PostgreSQL is running."""
    try:
        with socket.create_connection(("postgres", 5432), timeout=1):
            return True
    except OSError:
        return False


def is_openai_available() -> bool:
    """Check if an OpenAI API key is set."""
    return bool(os.environ.get("OPENAI_API_KEY"))


def pytest_sessionstart(session: pytest.Session) -> None:
    """Reset the PostgreSQL and SQLite databases."""
    if is_postgres_running():
        engine = create_engine(POSTGRES_URL, isolation_level="AUTOCOMMIT")
        with engine.connect() as conn:
            for variant in ["local", "remote"]:
                conn.execute(text(f"DROP DATABASE IF EXISTS raglite_test_{variant}"))
                conn.execute(text(f"CREATE DATABASE raglite_test_{variant}"))


@pytest.fixture(scope="session")
def sqlite_url() -> Generator[str, None, None]:
    """Create a temporary SQLite database file and return the database URL."""
    with tempfile.TemporaryDirectory() as temp_dir:
        db_file = Path(temp_dir) / "raglite_test.sqlite"
        yield f"sqlite:///{db_file}"


@pytest.fixture(
    scope="session",
    params=[
        pytest.param("sqlite", id="sqlite"),
        pytest.param(
            POSTGRES_URL,
            id="postgres",
            marks=pytest.mark.skipif(not is_postgres_running(), reason="PostgreSQL is not running"),
        ),
    ],
)
def database(request: pytest.FixtureRequest) -> str:
    """Get a database URL to test RAGLite with."""
    db_url: str = (
        request.getfixturevalue("sqlite_url") if request.param == "sqlite" else request.param
    )
    return db_url


@pytest.fixture(
    scope="session",
    params=[
        pytest.param(
            "llama-cpp-python/lm-kit/bge-m3-gguf/*Q4_K_M.gguf",
            id="bge_m3",
        ),
        pytest.param(
            "text-embedding-3-small",
            id="openai_text_embedding_3_small",
            marks=pytest.mark.skipif(not is_openai_available(), reason="OpenAI API key is not set"),
        ),
    ],
)
def embedder(request: pytest.FixtureRequest) -> str:
    """Get an embedder model URL to test RAGLite with."""
    embedder: str = request.param
    return embedder


@pytest.fixture(scope="session")
def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig:
    """Create a lightweight in-memory config for testing SQLite and PostgreSQL."""
    # Select the database based on the embedder.
    variant = "local" if embedder.startswith("llama-cpp-python") else "remote"
    if "postgres" in database:
        database = database.replace("/postgres", f"/raglite_test_{variant}")
    elif "sqlite" in database:
        database = database.replace(".sqlite", f"_{variant}.sqlite")
    # Create a RAGLite config for the given database and embedder.
    db_config = RAGLiteConfig(db_url=database, embedder=embedder)
    # Insert a document and update the index.
    doc_path = Path(__file__).parent / "specrel.pdf"  # Einstein's special relativity paper.
    insert_document(doc_path, config=db_config)
    return db_config