Spaces:
Sleeping
Sleeping
from chromadb.config import Settings, System | |
from chromadb.api import API | |
import chromadb.server.fastapi | |
from requests.exceptions import ConnectionError | |
import hypothesis | |
import tempfile | |
import os | |
import uvicorn | |
import time | |
import pytest | |
from typing import Generator, List, Callable, Optional, Tuple | |
import shutil | |
import logging | |
import socket | |
import multiprocessing | |
root_logger = logging.getLogger() | |
root_logger.setLevel(logging.DEBUG) # This will only run when testing | |
logger = logging.getLogger(__name__) | |
hypothesis.settings.register_profile( | |
"dev", | |
deadline=45000, | |
suppress_health_check=[ | |
hypothesis.HealthCheck.data_too_large, | |
hypothesis.HealthCheck.large_base_example, | |
hypothesis.HealthCheck.function_scoped_fixture, | |
], | |
) | |
hypothesis.settings.load_profile(os.getenv("HYPOTHESIS_PROFILE", "dev")) | |
def find_free_port() -> int: | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
s.bind(("", 0)) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
return s.getsockname()[1] # type: ignore | |
def _run_server( | |
port: int, is_persistent: bool = False, persist_directory: Optional[str] = None | |
) -> None: | |
"""Run a Chroma server locally""" | |
if is_persistent and persist_directory: | |
settings = Settings( | |
chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
is_persistent=is_persistent, | |
persist_directory=persist_directory, | |
allow_reset=True, | |
) | |
else: | |
settings = Settings( | |
chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
is_persistent=False, | |
allow_reset=True, | |
) | |
server = chromadb.server.fastapi.FastAPI(settings) | |
uvicorn.run(server.app(), host="0.0.0.0", port=port, log_level="error") | |
def _await_server(api: API, attempts: int = 0) -> None: | |
try: | |
api.heartbeat() | |
except ConnectionError as e: | |
if attempts > 15: | |
logger.error("Test server failed to start after 15 attempts") | |
raise e | |
else: | |
logger.info("Waiting for server to start...") | |
time.sleep(4) | |
_await_server(api, attempts + 1) | |
def _fastapi_fixture(is_persistent: bool = False) -> Generator[System, None, None]: | |
"""Fixture generator that launches a server in a separate process, and yields a | |
fastapi client connect to it""" | |
port = find_free_port() | |
logger.info(f"Running test FastAPI server on port {port}") | |
ctx = multiprocessing.get_context("spawn") | |
args: Tuple[int, bool, Optional[str]] = (port, False, None) | |
persist_directory = None | |
if is_persistent: | |
persist_directory = tempfile.mkdtemp() | |
args = (port, is_persistent, persist_directory) | |
proc = ctx.Process(target=_run_server, args=args, daemon=True) | |
proc.start() | |
settings = Settings( | |
chroma_api_impl="chromadb.api.fastapi.FastAPI", | |
chroma_server_host="localhost", | |
chroma_server_http_port=str(port), | |
allow_reset=True, | |
) | |
system = System(settings) | |
api = system.instance(API) | |
system.start() | |
_await_server(api) | |
yield system | |
system.stop() | |
proc.kill() | |
if is_persistent and persist_directory is not None: | |
if os.path.exists(persist_directory): | |
shutil.rmtree(persist_directory) | |
def fastapi() -> Generator[System, None, None]: | |
return _fastapi_fixture(is_persistent=False) | |
def fastapi_persistent() -> Generator[System, None, None]: | |
return _fastapi_fixture(is_persistent=True) | |
def integration() -> Generator[System, None, None]: | |
"""Fixture generator for returning a client configured via environmenet | |
variables, intended for externally configured integration tests | |
""" | |
settings = Settings(allow_reset=True) | |
system = System(settings) | |
system.start() | |
yield system | |
system.stop() | |
def sqlite() -> Generator[System, None, None]: | |
"""Fixture generator for segment-based API using in-memory Sqlite""" | |
settings = Settings( | |
chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
is_persistent=False, | |
allow_reset=True, | |
) | |
system = System(settings) | |
system.start() | |
yield system | |
system.stop() | |
def sqlite_persistent() -> Generator[System, None, None]: | |
"""Fixture generator for segment-based API using persistent Sqlite""" | |
save_path = tempfile.mkdtemp() | |
settings = Settings( | |
chroma_api_impl="chromadb.api.segment.SegmentAPI", | |
chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB", | |
chroma_segment_manager_impl="chromadb.segment.impl.manager.local.LocalSegmentManager", | |
allow_reset=True, | |
is_persistent=True, | |
persist_directory=save_path, | |
) | |
system = System(settings) | |
system.start() | |
yield system | |
system.stop() | |
if os.path.exists(save_path): | |
shutil.rmtree(save_path) | |
def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]: | |
fixtures = [fastapi, fastapi_persistent, sqlite, sqlite_persistent] | |
if "CHROMA_INTEGRATION_TEST" in os.environ: | |
fixtures.append(integration) | |
if "CHROMA_INTEGRATION_TEST_ONLY" in os.environ: | |
fixtures = [integration] | |
return fixtures | |
def system(request: pytest.FixtureRequest) -> Generator[API, None, None]: | |
yield next(request.param()) | |
def api(system: System) -> Generator[API, None, None]: | |
system.reset_state() | |
api = system.instance(API) | |
yield api | |