Spaces:
Running
Running
"""Fixtures for the tests.""" | |
import os | |
import socket | |
import tempfile | |
from collections.abc import Generator | |
from pathlib import Path | |
import pytest | |
from sqlalchemy import create_engine, text | |
from raglite import RAGLiteConfig, insert_document | |
POSTGRES_URL = "postgresql+pg8000://raglite_user:raglite_password@postgres:5432/postgres" | |
def is_postgres_running() -> bool: | |
"""Check if PostgreSQL is running.""" | |
try: | |
with socket.create_connection(("postgres", 5432), timeout=1): | |
return True | |
except OSError: | |
return False | |
def is_openai_available() -> bool: | |
"""Check if an OpenAI API key is set.""" | |
return bool(os.environ.get("OPENAI_API_KEY")) | |
def pytest_sessionstart(session: pytest.Session) -> None: | |
"""Reset the PostgreSQL and SQLite databases.""" | |
if is_postgres_running(): | |
engine = create_engine(POSTGRES_URL, isolation_level="AUTOCOMMIT") | |
with engine.connect() as conn: | |
for variant in ["local", "remote"]: | |
conn.execute(text(f"DROP DATABASE IF EXISTS raglite_test_{variant}")) | |
conn.execute(text(f"CREATE DATABASE raglite_test_{variant}")) | |
def sqlite_url() -> Generator[str, None, None]: | |
"""Create a temporary SQLite database file and return the database URL.""" | |
with tempfile.TemporaryDirectory() as temp_dir: | |
db_file = Path(temp_dir) / "raglite_test.sqlite" | |
yield f"sqlite:///{db_file}" | |
def database(request: pytest.FixtureRequest) -> str: | |
"""Get a database URL to test RAGLite with.""" | |
db_url: str = ( | |
request.getfixturevalue("sqlite_url") if request.param == "sqlite" else request.param | |
) | |
return db_url | |
def embedder(request: pytest.FixtureRequest) -> str: | |
"""Get an embedder model URL to test RAGLite with.""" | |
embedder: str = request.param | |
return embedder | |
def raglite_test_config(database: str, embedder: str) -> RAGLiteConfig: | |
"""Create a lightweight in-memory config for testing SQLite and PostgreSQL.""" | |
# Select the database based on the embedder. | |
variant = "local" if embedder.startswith("llama-cpp-python") else "remote" | |
if "postgres" in database: | |
database = database.replace("/postgres", f"/raglite_test_{variant}") | |
elif "sqlite" in database: | |
database = database.replace(".sqlite", f"_{variant}.sqlite") | |
# Create a RAGLite config for the given database and embedder. | |
db_config = RAGLiteConfig(db_url=database, embedder=embedder) | |
# Insert a document and update the index. | |
doc_path = Path(__file__).parent / "specrel.pdf" # Einstein's special relativity paper. | |
insert_document(doc_path, config=db_config) | |
return db_config | |