Spaces:
Build error
Build error
from typing import Optional, Sequence | |
from uuid import UUID | |
from overrides import override | |
import httpx | |
from chromadb.api import AdminAPI, ClientAPI, ServerAPI | |
from chromadb.api.configuration import CollectionConfiguration | |
from chromadb.api.shared_system_client import SharedSystemClient | |
from chromadb.api.types import ( | |
CollectionMetadata, | |
DataLoader, | |
Documents, | |
Embeddable, | |
EmbeddingFunction, | |
Embeddings, | |
GetResult, | |
IDs, | |
Include, | |
Loadable, | |
Metadatas, | |
QueryResult, | |
URIs, | |
) | |
from chromadb.config import Settings, System | |
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE | |
from chromadb.api.models.Collection import Collection | |
from chromadb.errors import ChromaError | |
from chromadb.types import Database, Tenant, Where, WhereDocument | |
import chromadb.utils.embedding_functions as ef | |
class Client(SharedSystemClient, ClientAPI): | |
"""A client for Chroma. This is the main entrypoint for interacting with Chroma. | |
A client internally stores its tenant and database and proxies calls to a | |
Server API instance of Chroma. It treats the Server API and corresponding System | |
as a singleton, so multiple clients connecting to the same resource will share the | |
same API instance. | |
Client implementations should be implement their own API-caching strategies. | |
""" | |
tenant: str = DEFAULT_TENANT | |
database: str = DEFAULT_DATABASE | |
_server: ServerAPI | |
# An internal admin client for verifying that databases and tenants exist | |
_admin_client: AdminAPI | |
# region Initialization | |
def __init__( | |
self, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
settings: Settings = Settings(), | |
) -> None: | |
super().__init__(settings=settings) | |
self.tenant = tenant | |
self.database = database | |
# Create an admin client for verifying that databases and tenants exist | |
self._admin_client = AdminClient.from_system(self._system) | |
self._validate_tenant_database(tenant=tenant, database=database) | |
# Get the root system component we want to interact with | |
self._server = self._system.instance(ServerAPI) | |
self._submit_client_start_event() | |
def from_system( | |
cls, | |
system: System, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> "Client": | |
SharedSystemClient._populate_data_from_system(system) | |
instance = cls(tenant=tenant, database=database, settings=system.settings) | |
return instance | |
# endregion | |
# region BaseAPI Methods | |
# Note - we could do this in less verbose ways, but they break type checking | |
def heartbeat(self) -> int: | |
return self._server.heartbeat() | |
def list_collections( | |
self, limit: Optional[int] = None, offset: Optional[int] = None | |
) -> Sequence[Collection]: | |
return [ | |
Collection(client=self._server, model=model) | |
for model in self._server.list_collections( | |
limit, offset, tenant=self.tenant, database=self.database | |
) | |
] | |
def count_collections(self) -> int: | |
return self._server.count_collections( | |
tenant=self.tenant, database=self.database | |
) | |
def create_collection( | |
self, | |
name: str, | |
configuration: Optional[CollectionConfiguration] = None, | |
metadata: Optional[CollectionMetadata] = None, | |
embedding_function: Optional[ | |
EmbeddingFunction[Embeddable] | |
] = ef.DefaultEmbeddingFunction(), # type: ignore | |
data_loader: Optional[DataLoader[Loadable]] = None, | |
get_or_create: bool = False, | |
) -> Collection: | |
model = self._server.create_collection( | |
name=name, | |
metadata=metadata, | |
tenant=self.tenant, | |
database=self.database, | |
get_or_create=get_or_create, | |
configuration=configuration, | |
) | |
return Collection( | |
client=self._server, | |
model=model, | |
embedding_function=embedding_function, | |
data_loader=data_loader, | |
) | |
def get_collection( | |
self, | |
name: str, | |
id: Optional[UUID] = None, | |
embedding_function: Optional[ | |
EmbeddingFunction[Embeddable] | |
] = ef.DefaultEmbeddingFunction(), # type: ignore | |
data_loader: Optional[DataLoader[Loadable]] = None, | |
) -> Collection: | |
model = self._server.get_collection( | |
id=id, | |
name=name, | |
tenant=self.tenant, | |
database=self.database, | |
) | |
return Collection( | |
client=self._server, | |
model=model, | |
embedding_function=embedding_function, | |
data_loader=data_loader, | |
) | |
def get_or_create_collection( | |
self, | |
name: str, | |
configuration: Optional[CollectionConfiguration] = None, | |
metadata: Optional[CollectionMetadata] = None, | |
embedding_function: Optional[ | |
EmbeddingFunction[Embeddable] | |
] = ef.DefaultEmbeddingFunction(), # type: ignore | |
data_loader: Optional[DataLoader[Loadable]] = None, | |
) -> Collection: | |
model = self._server.get_or_create_collection( | |
name=name, | |
metadata=metadata, | |
tenant=self.tenant, | |
database=self.database, | |
configuration=configuration, | |
) | |
return Collection( | |
client=self._server, | |
model=model, | |
embedding_function=embedding_function, | |
data_loader=data_loader, | |
) | |
def _modify( | |
self, | |
id: UUID, | |
new_name: Optional[str] = None, | |
new_metadata: Optional[CollectionMetadata] = None, | |
) -> None: | |
return self._server._modify( | |
id=id, | |
new_name=new_name, | |
new_metadata=new_metadata, | |
) | |
def delete_collection( | |
self, | |
name: str, | |
) -> None: | |
return self._server.delete_collection( | |
name=name, | |
tenant=self.tenant, | |
database=self.database, | |
) | |
# | |
# ITEM METHODS | |
# | |
def _add( | |
self, | |
ids: IDs, | |
collection_id: UUID, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
return self._server._add( | |
ids=ids, | |
collection_id=collection_id, | |
embeddings=embeddings, | |
metadatas=metadatas, | |
documents=documents, | |
uris=uris, | |
) | |
def _update( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Optional[Embeddings] = None, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
return self._server._update( | |
collection_id=collection_id, | |
ids=ids, | |
embeddings=embeddings, | |
metadatas=metadatas, | |
documents=documents, | |
uris=uris, | |
) | |
def _upsert( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
return self._server._upsert( | |
collection_id=collection_id, | |
ids=ids, | |
embeddings=embeddings, | |
metadatas=metadatas, | |
documents=documents, | |
uris=uris, | |
) | |
def _count(self, collection_id: UUID) -> int: | |
return self._server._count( | |
collection_id=collection_id, | |
) | |
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: | |
return self._server._peek( | |
collection_id=collection_id, | |
n=n, | |
) | |
def _get( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = {}, | |
sort: Optional[str] = None, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
page: Optional[int] = None, | |
page_size: Optional[int] = None, | |
where_document: Optional[WhereDocument] = {}, | |
include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item] | |
) -> GetResult: | |
return self._server._get( | |
collection_id=collection_id, | |
ids=ids, | |
where=where, | |
sort=sort, | |
limit=limit, | |
offset=offset, | |
page=page, | |
page_size=page_size, | |
where_document=where_document, | |
include=include, | |
) | |
def _delete( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs], | |
where: Optional[Where] = {}, | |
where_document: Optional[WhereDocument] = {}, | |
) -> None: | |
self._server._delete( | |
collection_id=collection_id, | |
ids=ids, | |
where=where, | |
where_document=where_document, | |
) | |
def _query( | |
self, | |
collection_id: UUID, | |
query_embeddings: Embeddings, | |
n_results: int = 10, | |
where: Where = {}, | |
where_document: WhereDocument = {}, | |
include: Include = ["embeddings", "metadatas", "documents", "distances"], # type: ignore[list-item] | |
) -> QueryResult: | |
return self._server._query( | |
collection_id=collection_id, | |
query_embeddings=query_embeddings, | |
n_results=n_results, | |
where=where, | |
where_document=where_document, | |
include=include, | |
) | |
def reset(self) -> bool: | |
return self._server.reset() | |
def get_version(self) -> str: | |
return self._server.get_version() | |
def get_settings(self) -> Settings: | |
return self._server.get_settings() | |
def get_max_batch_size(self) -> int: | |
return self._server.get_max_batch_size() | |
# endregion | |
# region ClientAPI Methods | |
def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None: | |
self._validate_tenant_database(tenant=tenant, database=database) | |
self.tenant = tenant | |
self.database = database | |
def set_database(self, database: str) -> None: | |
self._validate_tenant_database(tenant=self.tenant, database=database) | |
self.database = database | |
def _validate_tenant_database(self, tenant: str, database: str) -> None: | |
try: | |
self._admin_client.get_tenant(name=tenant) | |
except httpx.ConnectError: | |
raise ValueError( | |
"Could not connect to a Chroma server. Are you sure it is running?" | |
) | |
# Propagate ChromaErrors | |
except ChromaError as e: | |
raise e | |
except Exception: | |
raise ValueError( | |
f"Could not connect to tenant {tenant}. Are you sure it exists?" | |
) | |
try: | |
self._admin_client.get_database(name=database, tenant=tenant) | |
except httpx.ConnectError: | |
raise ValueError( | |
"Could not connect to a Chroma server. Are you sure it is running?" | |
) | |
# endregion | |
class AdminClient(SharedSystemClient, AdminAPI): | |
_server: ServerAPI | |
def __init__(self, settings: Settings = Settings()) -> None: | |
super().__init__(settings) | |
self._server = self._system.instance(ServerAPI) | |
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: | |
return self._server.create_database(name=name, tenant=tenant) | |
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database: | |
return self._server.get_database(name=name, tenant=tenant) | |
def create_tenant(self, name: str) -> None: | |
return self._server.create_tenant(name=name) | |
def get_tenant(self, name: str) -> Tenant: | |
return self._server.get_tenant(name=name) | |
def from_system( | |
cls, | |
system: System, | |
) -> "AdminClient": | |
SharedSystemClient._populate_data_from_system(system) | |
instance = cls(settings=system.settings) | |
return instance | |