Spaces:
Build error
Build error
File size: 2,971 Bytes
60e3a80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import asyncio
from typing import Any, Callable, Generator, cast
from unittest.mock import patch
import chromadb
from chromadb.config import Settings
from chromadb.api import ClientAPI
import chromadb.server.fastapi
import pytest
import tempfile
@pytest.fixture
def ephemeral_api() -> Generator[ClientAPI, None, None]:
client = chromadb.EphemeralClient()
yield client
client.clear_system_cache()
@pytest.fixture
def persistent_api() -> Generator[ClientAPI, None, None]:
client = chromadb.PersistentClient(
path=tempfile.gettempdir() + "/test_server",
)
yield client
client.clear_system_cache()
HttpAPIFactory = Callable[..., ClientAPI]
@pytest.fixture(params=["sync_client", "async_client"])
def http_api_factory(
request: pytest.FixtureRequest,
) -> Generator[HttpAPIFactory, None, None]:
if request.param == "sync_client":
with patch("chromadb.api.client.Client._validate_tenant_database"):
yield chromadb.HttpClient
else:
with patch("chromadb.api.async_client.AsyncClient._validate_tenant_database"):
def factory(*args: Any, **kwargs: Any) -> Any:
cls = asyncio.get_event_loop().run_until_complete(
chromadb.AsyncHttpClient(*args, **kwargs)
)
return cls
yield cast(HttpAPIFactory, factory)
@pytest.fixture()
def http_api(http_api_factory: HttpAPIFactory) -> Generator[ClientAPI, None, None]:
client = http_api_factory()
yield client
client.clear_system_cache()
def test_ephemeral_client(ephemeral_api: ClientAPI) -> None:
settings = ephemeral_api.get_settings()
assert settings.is_persistent is False
def test_persistent_client(persistent_api: ClientAPI) -> None:
settings = persistent_api.get_settings()
assert settings.is_persistent is True
def test_http_client(http_api: ClientAPI) -> None:
settings = http_api.get_settings()
assert (
settings.chroma_api_impl == "chromadb.api.fastapi.FastAPI"
or settings.chroma_api_impl == "chromadb.api.async_fastapi.AsyncFastAPI"
)
def test_http_client_with_inconsistent_host_settings(
http_api_factory: HttpAPIFactory,
) -> None:
try:
http_api_factory(settings=Settings(chroma_server_host="127.0.0.1"))
except ValueError as e:
assert (
str(e)
== "Chroma server host provided in settings[127.0.0.1] is different to the one provided in HttpClient: [localhost]"
)
def test_http_client_with_inconsistent_port_settings(
http_api_factory: HttpAPIFactory,
) -> None:
try:
http_api_factory(
port=8002,
settings=Settings(
chroma_server_http_port=8001,
),
)
except ValueError as e:
assert (
str(e)
== "Chroma server http port provided in settings[8001] is different to the one provided in HttpClient: [8002]"
)
|