Spaces:
Build error
Build error
import orjson | |
import logging | |
from typing import Any, Dict, Optional, cast, Tuple | |
from typing import Sequence | |
from uuid import UUID | |
import httpx | |
import urllib.parse | |
from overrides import override | |
from chromadb.api.configuration import CollectionConfigurationInternal | |
from chromadb.api.base_http_client import BaseHTTPClient | |
from chromadb.types import Database, Tenant, Collection as CollectionModel | |
from chromadb.api import ServerAPI | |
from chromadb.api.types import ( | |
Documents, | |
Embeddings, | |
PyEmbeddings, | |
IDs, | |
Include, | |
Metadatas, | |
URIs, | |
Where, | |
WhereDocument, | |
GetResult, | |
QueryResult, | |
CollectionMetadata, | |
validate_batch, | |
convert_np_embeddings_to_list, | |
) | |
from chromadb.auth import ( | |
ClientAuthProvider, | |
) | |
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System | |
from chromadb.telemetry.opentelemetry import ( | |
OpenTelemetryClient, | |
OpenTelemetryGranularity, | |
trace_method, | |
) | |
from chromadb.telemetry.product import ProductTelemetryClient | |
logger = logging.getLogger(__name__) | |
class FastAPI(BaseHTTPClient, ServerAPI): | |
def __init__(self, system: System): | |
super().__init__(system) | |
system.settings.require("chroma_server_host") | |
system.settings.require("chroma_server_http_port") | |
self._opentelemetry_client = self.require(OpenTelemetryClient) | |
self._product_telemetry_client = self.require(ProductTelemetryClient) | |
self._settings = system.settings | |
self._api_url = FastAPI.resolve_url( | |
chroma_server_host=str(system.settings.chroma_server_host), | |
chroma_server_http_port=system.settings.chroma_server_http_port, | |
chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled, | |
default_api_path=system.settings.chroma_server_api_default_path, | |
) | |
self._session = httpx.Client(timeout=None) | |
self._header = system.settings.chroma_server_headers | |
if self._header is not None: | |
self._session.headers.update(self._header) | |
if self._settings.chroma_server_ssl_verify is not None: | |
self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify) | |
if system.settings.chroma_client_auth_provider: | |
self._auth_provider = self.require(ClientAuthProvider) | |
_headers = self._auth_provider.authenticate() | |
for header, value in _headers.items(): | |
self._session.headers[header] = value.get_secret_value() | |
def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any: | |
# If the request has json in kwargs, use orjson to serialize it, | |
# remove it from kwargs, and add it to the content parameter | |
# This is because httpx uses a slower json serializer | |
if "json" in kwargs: | |
data = orjson.dumps(kwargs.pop("json")) | |
kwargs["content"] = data | |
# Unlike requests, httpx does not automatically escape the path | |
escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) | |
url = self._api_url + escaped_path | |
response = self._session.request(method, url, **cast(Any, kwargs)) | |
BaseHTTPClient._raise_chroma_error(response) | |
return orjson.loads(response.text) | |
def heartbeat(self) -> int: | |
"""Returns the current server time in nanoseconds to check if the server is alive""" | |
resp_json = self._make_request("get", "/heartbeat") | |
return int(resp_json["nanosecond heartbeat"]) | |
def create_database( | |
self, | |
name: str, | |
tenant: str = DEFAULT_TENANT, | |
) -> None: | |
"""Creates a database""" | |
self._make_request( | |
"post", | |
"/databases", | |
json={"name": name}, | |
params={"tenant": tenant}, | |
) | |
def get_database( | |
self, | |
name: str, | |
tenant: str = DEFAULT_TENANT, | |
) -> Database: | |
"""Returns a database""" | |
resp_json = self._make_request( | |
"get", | |
"/databases/" + name, | |
params={"tenant": tenant}, | |
) | |
return Database( | |
id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"] | |
) | |
def create_tenant(self, name: str) -> None: | |
self._make_request("post", "/tenants", json={"name": name}) | |
def get_tenant(self, name: str) -> Tenant: | |
resp_json = self._make_request("get", "/tenants/" + name) | |
return Tenant(name=resp_json["name"]) | |
def list_collections( | |
self, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> Sequence[CollectionModel]: | |
"""Returns a list of all collections""" | |
json_collections = self._make_request( | |
"get", | |
"/collections", | |
params=BaseHTTPClient._clean_params( | |
{ | |
"tenant": tenant, | |
"database": database, | |
"limit": limit, | |
"offset": offset, | |
} | |
), | |
) | |
collection_models = [ | |
CollectionModel.from_json(json_collection) | |
for json_collection in json_collections | |
] | |
return collection_models | |
def count_collections( | |
self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE | |
) -> int: | |
"""Returns a count of collections""" | |
resp_json = self._make_request( | |
"get", | |
"/count_collections", | |
params={"tenant": tenant, "database": database}, | |
) | |
return cast(int, resp_json) | |
def create_collection( | |
self, | |
name: str, | |
configuration: Optional[CollectionConfigurationInternal] = None, | |
metadata: Optional[CollectionMetadata] = None, | |
get_or_create: bool = False, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> CollectionModel: | |
"""Creates a collection""" | |
resp_json = self._make_request( | |
"post", | |
"/collections", | |
json={ | |
"name": name, | |
"metadata": metadata, | |
"configuration": configuration.to_json() if configuration else None, | |
"get_or_create": get_or_create, | |
}, | |
params={"tenant": tenant, "database": database}, | |
) | |
model = CollectionModel.from_json(resp_json) | |
return model | |
def get_collection( | |
self, | |
name: str, | |
id: Optional[UUID] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> CollectionModel: | |
"""Returns a collection""" | |
if (name is None and id is None) or (name is not None and id is not None): | |
raise ValueError("Name or id must be specified, but not both") | |
_params = {"tenant": tenant, "database": database} | |
if id is not None: | |
_params["type"] = str(id) | |
resp_json = self._make_request( | |
"get", | |
"/collections/" + name if name else str(id), | |
params=_params, | |
) | |
model = CollectionModel.from_json(resp_json) | |
return model | |
def get_or_create_collection( | |
self, | |
name: str, | |
configuration: Optional[CollectionConfigurationInternal] = None, | |
metadata: Optional[CollectionMetadata] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> CollectionModel: | |
return self.create_collection( | |
name=name, | |
metadata=metadata, | |
configuration=configuration, | |
get_or_create=True, | |
tenant=tenant, | |
database=database, | |
) | |
def _modify( | |
self, | |
id: UUID, | |
new_name: Optional[str] = None, | |
new_metadata: Optional[CollectionMetadata] = None, | |
) -> None: | |
"""Updates a collection""" | |
self._make_request( | |
"put", | |
"/collections/" + str(id), | |
json={"new_metadata": new_metadata, "new_name": new_name}, | |
) | |
def delete_collection( | |
self, | |
name: str, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> None: | |
"""Deletes a collection""" | |
self._make_request( | |
"delete", | |
"/collections/" + name, | |
params={"tenant": tenant, "database": database}, | |
) | |
def _count( | |
self, | |
collection_id: UUID, | |
) -> int: | |
"""Returns the number of embeddings in the database""" | |
resp_json = self._make_request( | |
"get", | |
"/collections/" + str(collection_id) + "/count", | |
) | |
return cast(int, resp_json) | |
def _peek( | |
self, | |
collection_id: UUID, | |
n: int = 10, | |
) -> GetResult: | |
return cast( | |
GetResult, | |
self._get( | |
collection_id, | |
limit=n, | |
include=["embeddings", "documents", "metadatas"], # type: ignore[list-item] | |
), | |
) | |
def _get( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = {}, | |
sort: Optional[str] = None, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
page: Optional[int] = None, | |
page_size: Optional[int] = None, | |
where_document: Optional[WhereDocument] = {}, | |
include: Include = ["metadatas", "documents"], # type: ignore[list-item] | |
) -> GetResult: | |
if page and page_size: | |
offset = (page - 1) * page_size | |
limit = page_size | |
resp_json = self._make_request( | |
"post", | |
"/collections/" + str(collection_id) + "/get", | |
json={ | |
"ids": ids, | |
"where": where, | |
"sort": sort, | |
"limit": limit, | |
"offset": offset, | |
"where_document": where_document, | |
"include": include, | |
}, | |
) | |
return GetResult( | |
ids=resp_json["ids"], | |
embeddings=resp_json.get("embeddings", None), | |
metadatas=resp_json.get("metadatas", None), | |
documents=resp_json.get("documents", None), | |
data=None, | |
uris=resp_json.get("uris", None), | |
included=resp_json.get("included", include), | |
) | |
def _delete( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = {}, | |
where_document: Optional[WhereDocument] = {}, | |
) -> None: | |
"""Deletes embeddings from the database""" | |
self._make_request( | |
"post", | |
"/collections/" + str(collection_id) + "/delete", | |
json={ | |
"ids": ids, | |
"where": where, | |
"where_document": where_document, | |
}, | |
) | |
return None | |
def _submit_batch( | |
self, | |
batch: Tuple[ | |
IDs, | |
Optional[PyEmbeddings], | |
Optional[Metadatas], | |
Optional[Documents], | |
Optional[URIs], | |
], | |
url: str, | |
) -> None: | |
""" | |
Submits a batch of embeddings to the database | |
""" | |
self._make_request( | |
"post", | |
url, | |
json={ | |
"ids": batch[0], | |
"embeddings": batch[1], | |
"metadatas": batch[2], | |
"documents": batch[3], | |
"uris": batch[4], | |
}, | |
) | |
def _add( | |
self, | |
ids: IDs, | |
collection_id: UUID, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
""" | |
Adds a batch of embeddings to the database | |
- pass in column oriented data lists | |
""" | |
batch = ( | |
ids, | |
convert_np_embeddings_to_list(embeddings), | |
metadatas, | |
documents, | |
uris, | |
) | |
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) | |
self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") | |
return True | |
def _update( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Optional[Embeddings] = None, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
""" | |
Updates a batch of embeddings in the database | |
- pass in column oriented data lists | |
""" | |
batch = ( | |
ids, | |
convert_np_embeddings_to_list(embeddings) | |
if embeddings is not None | |
else None, | |
metadatas, | |
documents, | |
uris, | |
) | |
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) | |
self._submit_batch(batch, "/collections/" + str(collection_id) + "/update") | |
return True | |
def _upsert( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
""" | |
Upserts a batch of embeddings in the database | |
- pass in column oriented data lists | |
""" | |
batch = ( | |
ids, | |
convert_np_embeddings_to_list(embeddings), | |
metadatas, | |
documents, | |
uris, | |
) | |
validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) | |
self._submit_batch(batch, "/collections/" + str(collection_id) + "/upsert") | |
return True | |
def _query( | |
self, | |
collection_id: UUID, | |
query_embeddings: Embeddings, | |
n_results: int = 10, | |
where: Optional[Where] = {}, | |
where_document: Optional[WhereDocument] = {}, | |
include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item] | |
) -> QueryResult: | |
"""Gets the nearest neighbors of a single embedding""" | |
resp_json = self._make_request( | |
"post", | |
"/collections/" + str(collection_id) + "/query", | |
json={ | |
"query_embeddings": convert_np_embeddings_to_list(query_embeddings) | |
if query_embeddings is not None | |
else None, | |
"n_results": n_results, | |
"where": where, | |
"where_document": where_document, | |
"include": include, | |
}, | |
) | |
return QueryResult( | |
ids=resp_json["ids"], | |
distances=resp_json.get("distances", None), | |
embeddings=resp_json.get("embeddings", None), | |
metadatas=resp_json.get("metadatas", None), | |
documents=resp_json.get("documents", None), | |
uris=resp_json.get("uris", None), | |
data=None, | |
included=resp_json.get("included", include), | |
) | |
def reset(self) -> bool: | |
"""Resets the database""" | |
resp_json = self._make_request("post", "/reset") | |
return cast(bool, resp_json) | |
def get_version(self) -> str: | |
"""Returns the version of the server""" | |
resp_json = self._make_request("get", "/version") | |
return cast(str, resp_json) | |
def get_settings(self) -> Settings: | |
"""Returns the settings of the client""" | |
return self._settings | |
def get_max_batch_size(self) -> int: | |
if self._max_batch_size == -1: | |
resp_json = self._make_request("get", "/pre-flight-checks") | |
self._max_batch_size = cast(int, resp_json["max_batch_size"]) | |
return self._max_batch_size | |