himanshud2611's picture
Upload folder using huggingface_hub
60e3a80 verified
from typing import Optional, Sequence
from uuid import UUID
from overrides import override
import httpx
from chromadb.api import AdminAPI, ClientAPI, ServerAPI
from chromadb.api.configuration import CollectionConfiguration
from chromadb.api.shared_system_client import SharedSystemClient
from chromadb.api.types import (
CollectionMetadata,
DataLoader,
Documents,
Embeddable,
EmbeddingFunction,
Embeddings,
GetResult,
IDs,
Include,
Loadable,
Metadatas,
QueryResult,
URIs,
)
from chromadb.config import Settings, System
from chromadb.config import DEFAULT_TENANT, DEFAULT_DATABASE
from chromadb.api.models.Collection import Collection
from chromadb.errors import ChromaError
from chromadb.types import Database, Tenant, Where, WhereDocument
import chromadb.utils.embedding_functions as ef
class Client(SharedSystemClient, ClientAPI):
"""A client for Chroma. This is the main entrypoint for interacting with Chroma.
A client internally stores its tenant and database and proxies calls to a
Server API instance of Chroma. It treats the Server API and corresponding System
as a singleton, so multiple clients connecting to the same resource will share the
same API instance.
Client implementations should be implement their own API-caching strategies.
"""
tenant: str = DEFAULT_TENANT
database: str = DEFAULT_DATABASE
_server: ServerAPI
# An internal admin client for verifying that databases and tenants exist
_admin_client: AdminAPI
# region Initialization
def __init__(
self,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
settings: Settings = Settings(),
) -> None:
super().__init__(settings=settings)
self.tenant = tenant
self.database = database
# Create an admin client for verifying that databases and tenants exist
self._admin_client = AdminClient.from_system(self._system)
self._validate_tenant_database(tenant=tenant, database=database)
# Get the root system component we want to interact with
self._server = self._system.instance(ServerAPI)
self._submit_client_start_event()
@classmethod
@override
def from_system(
cls,
system: System,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> "Client":
SharedSystemClient._populate_data_from_system(system)
instance = cls(tenant=tenant, database=database, settings=system.settings)
return instance
# endregion
# region BaseAPI Methods
# Note - we could do this in less verbose ways, but they break type checking
@override
def heartbeat(self) -> int:
return self._server.heartbeat()
@override
def list_collections(
self, limit: Optional[int] = None, offset: Optional[int] = None
) -> Sequence[Collection]:
return [
Collection(client=self._server, model=model)
for model in self._server.list_collections(
limit, offset, tenant=self.tenant, database=self.database
)
]
@override
def count_collections(self) -> int:
return self._server.count_collections(
tenant=self.tenant, database=self.database
)
@override
def create_collection(
self,
name: str,
configuration: Optional[CollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
get_or_create: bool = False,
) -> Collection:
model = self._server.create_collection(
name=name,
metadata=metadata,
tenant=self.tenant,
database=self.database,
get_or_create=get_or_create,
configuration=configuration,
)
return Collection(
client=self._server,
model=model,
embedding_function=embedding_function,
data_loader=data_loader,
)
@override
def get_collection(
self,
name: str,
id: Optional[UUID] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
model = self._server.get_collection(
id=id,
name=name,
tenant=self.tenant,
database=self.database,
)
return Collection(
client=self._server,
model=model,
embedding_function=embedding_function,
data_loader=data_loader,
)
@override
def get_or_create_collection(
self,
name: str,
configuration: Optional[CollectionConfiguration] = None,
metadata: Optional[CollectionMetadata] = None,
embedding_function: Optional[
EmbeddingFunction[Embeddable]
] = ef.DefaultEmbeddingFunction(), # type: ignore
data_loader: Optional[DataLoader[Loadable]] = None,
) -> Collection:
model = self._server.get_or_create_collection(
name=name,
metadata=metadata,
tenant=self.tenant,
database=self.database,
configuration=configuration,
)
return Collection(
client=self._server,
model=model,
embedding_function=embedding_function,
data_loader=data_loader,
)
@override
def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
) -> None:
return self._server._modify(
id=id,
new_name=new_name,
new_metadata=new_metadata,
)
@override
def delete_collection(
self,
name: str,
) -> None:
return self._server.delete_collection(
name=name,
tenant=self.tenant,
database=self.database,
)
#
# ITEM METHODS
#
@override
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return self._server._add(
ids=ids,
collection_id=collection_id,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
@override
def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return self._server._update(
collection_id=collection_id,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
@override
def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
return self._server._upsert(
collection_id=collection_id,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
@override
def _count(self, collection_id: UUID) -> int:
return self._server._count(
collection_id=collection_id,
)
@override
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
return self._server._peek(
collection_id=collection_id,
n=n,
)
@override
def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = {},
sort: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
page: Optional[int] = None,
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = {},
include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item]
) -> GetResult:
return self._server._get(
collection_id=collection_id,
ids=ids,
where=where,
sort=sort,
limit=limit,
offset=offset,
page=page,
page_size=page_size,
where_document=where_document,
include=include,
)
def _delete(
self,
collection_id: UUID,
ids: Optional[IDs],
where: Optional[Where] = {},
where_document: Optional[WhereDocument] = {},
) -> None:
self._server._delete(
collection_id=collection_id,
ids=ids,
where=where,
where_document=where_document,
)
@override
def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
n_results: int = 10,
where: Where = {},
where_document: WhereDocument = {},
include: Include = ["embeddings", "metadatas", "documents", "distances"], # type: ignore[list-item]
) -> QueryResult:
return self._server._query(
collection_id=collection_id,
query_embeddings=query_embeddings,
n_results=n_results,
where=where,
where_document=where_document,
include=include,
)
@override
def reset(self) -> bool:
return self._server.reset()
@override
def get_version(self) -> str:
return self._server.get_version()
@override
def get_settings(self) -> Settings:
return self._server.get_settings()
@override
def get_max_batch_size(self) -> int:
return self._server.get_max_batch_size()
# endregion
# region ClientAPI Methods
@override
def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
self._validate_tenant_database(tenant=tenant, database=database)
self.tenant = tenant
self.database = database
@override
def set_database(self, database: str) -> None:
self._validate_tenant_database(tenant=self.tenant, database=database)
self.database = database
def _validate_tenant_database(self, tenant: str, database: str) -> None:
try:
self._admin_client.get_tenant(name=tenant)
except httpx.ConnectError:
raise ValueError(
"Could not connect to a Chroma server. Are you sure it is running?"
)
# Propagate ChromaErrors
except ChromaError as e:
raise e
except Exception:
raise ValueError(
f"Could not connect to tenant {tenant}. Are you sure it exists?"
)
try:
self._admin_client.get_database(name=database, tenant=tenant)
except httpx.ConnectError:
raise ValueError(
"Could not connect to a Chroma server. Are you sure it is running?"
)
# endregion
class AdminClient(SharedSystemClient, AdminAPI):
_server: ServerAPI
def __init__(self, settings: Settings = Settings()) -> None:
super().__init__(settings)
self._server = self._system.instance(ServerAPI)
@override
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
return self._server.create_database(name=name, tenant=tenant)
@override
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
return self._server.get_database(name=name, tenant=tenant)
@override
def create_tenant(self, name: str) -> None:
return self._server.create_tenant(name=name)
@override
def get_tenant(self, name: str) -> Tenant:
return self._server.get_tenant(name=name)
@classmethod
@override
def from_system(
cls,
system: System,
) -> "AdminClient":
SharedSystemClient._populate_data_from_system(system)
instance = cls(settings=system.settings)
return instance