from chromadb.config import Settings, System from chromadb.api import API import chromadb.server.fastapi from requests.exceptions import ConnectionError import hypothesis import tempfile import os import uvicorn import time import pytest from typing import Generator, List, Callable, Optional, Tuple import shutil import logging import socket import multiprocessing root_logger = logging.getLogger() root_logger.setLevel(logging.DEBUG) # This will only run when testing logger = logging.getLogger(__name__) hypothesis.settings.register_profile( "dev", deadline=45000, suppress_health_check=[ hypothesis.HealthCheck.data_too_large, hypothesis.HealthCheck.large_base_example, hypothesis.HealthCheck.function_scoped_fixture, ], ) hypothesis.settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "dev")) def find_free_port() -> int: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] # type: ignore def _run_server( port: int, is_persistent: bool = False, persist_directory: Optional[str] = None ) -> None: """Run a Chroma server locally""" if is_persistent and persist_directory: settings = Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", is_persistent=is_persistent, persist_directory=persist_directory, allow_reset=True, ) else: settings = Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", is_persistent=False, allow_reset=True, ) server = chromadb.server.fastapi.FastAPI(settings) uvicorn.run(server.app(), host="0.0.0.0", port=port, log_level="error") def _await_server(api: API, attempts: int = 0) -> None: try: api.heartbeat() except ConnectionError as e: if attempts > 15: logger.error("Test server failed to start after 15 attempts") raise e else: logger.info("Waiting for server to start...") time.sleep(4) _await_server(api, attempts + 1) def _fastapi_fixture(is_persistent: bool = False) -> Generator[System, None, None]: """Fixture generator that launches a server in a separate process, and yields a fastapi client connect to it""" port = find_free_port() logger.info(f"Running test FastAPI server on port {port}") ctx = multiprocessing.get_context("spawn") args: Tuple[int, bool, Optional[str]] = (port, False, None) persist_directory = None if is_persistent: persist_directory = tempfile.mkdtemp() args = (port, is_persistent, persist_directory) proc = ctx.Process(target=_run_server, args=args, daemon=True) proc.start() settings = Settings( chroma_api_impl="chromadb.api.fastapi.FastAPI", chroma_server_host="localhost", chroma_server_http_port=str(port), allow_reset=True, ) system = System(settings) api = system.instance(API) system.start() _await_server(api) yield system system.stop() proc.kill() if is_persistent and persist_directory is not None: if os.path.exists(persist_directory): shutil.rmtree(persist_directory) def fastapi() -> Generator[System, None, None]: return _fastapi_fixture(is_persistent=False) def fastapi_persistent() -> Generator[System, None, None]: return _fastapi_fixture(is_persistent=True) def integration() -> Generator[System, None, None]: """Fixture generator for returning a client configured via environmenet variables, intended for externally configured integration tests """ settings = Settings(allow_reset=True) system = System(settings) system.start() yield system system.stop() def sqlite() -> Generator[System, None, None]: """Fixture generator for segment-based API using in-memory Sqlite""" settings = Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", is_persistent=False, allow_reset=True, ) system = System(settings) system.start() yield system system.stop() def sqlite_persistent() -> Generator[System, None, None]: """Fixture generator for segment-based API using persistent Sqlite""" save_path = tempfile.mkdtemp() settings = Settings( chroma_api_impl="chromadb.api.segment.SegmentAPI", chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", allow_reset=True, is_persistent=True, persist_directory=save_path, ) system = System(settings) system.start() yield system system.stop() if os.path.exists(save_path): shutil.rmtree(save_path) def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: fixtures = [fastapi, fastapi_persistent, sqlite, sqlite_persistent] if "CHROMA_INTEGRATION_TEST" in os.environ: fixtures.append(integration) if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ: fixtures = [integration] return fixtures @pytest.fixture(scope="module", params=system_fixtures()) def system(request: pytest.FixtureRequest) -> Generator[API, None, None]: yield next(request.param()) @pytest.fixture(scope="function") def api(system: System) -> Generator[API, None, None]: system.reset_state() api = system.instance(API) yield api