import orjson import logging from typing import Any, Dict, Optional, cast, Tuple from typing import Sequence from uuid import UUID import httpx import urllib.parse from overrides import override from chromadb.api.configuration import CollectionConfigurationInternal from chromadb.api.base_http_client import BaseHTTPClient from chromadb.types import Database, Tenant, Collection as CollectionModel from chromadb.api import ServerAPI from chromadb.api.types import ( Documents, Embeddings, PyEmbeddings, IDs, Include, Metadatas, URIs, Where, WhereDocument, GetResult, QueryResult, CollectionMetadata, validate_batch, convert_np_embeddings_to_list, ) from chromadb.auth import ( ClientAuthProvider, ) from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System from chromadb.telemetry.opentelemetry import ( OpenTelemetryClient, OpenTelemetryGranularity, trace_method, ) from chromadb.telemetry.product import ProductTelemetryClient logger = logging.getLogger(__name__) class FastAPI(BaseHTTPClient, ServerAPI): def __init__(self, system: System): super().__init__(system) system.settings.require("chroma_server_host") system.settings.require("chroma_server_http_port") self._opentelemetry_client = self.require(OpenTelemetryClient) self._product_telemetry_client = self.require(ProductTelemetryClient) self._settings = system.settings self._api_url = FastAPI.resolve_url( chroma_server_host=str(system.settings.chroma_server_host), chroma_server_http_port=system.settings.chroma_server_http_port, chroma_server_ssl_enabled=system.settings.chroma_server_ssl_enabled, default_api_path=system.settings.chroma_server_api_default_path, ) self._session = httpx.Client(timeout=None) self._header = system.settings.chroma_server_headers if self._header is not None: self._session.headers.update(self._header) if self._settings.chroma_server_ssl_verify is not None: self._session = httpx.Client(verify=self._settings.chroma_server_ssl_verify) if system.settings.chroma_client_auth_provider: self._auth_provider = self.require(ClientAuthProvider) _headers = self._auth_provider.authenticate() for header, value in _headers.items(): self._session.headers[header] = value.get_secret_value() def _make_request(self, method: str, path: str, **kwargs: Dict[str, Any]) -> Any: # If the request has json in kwargs, use orjson to serialize it, # remove it from kwargs, and add it to the content parameter # This is because httpx uses a slower json serializer if "json" in kwargs: data = orjson.dumps(kwargs.pop("json")) kwargs["content"] = data # Unlike requests, httpx does not automatically escape the path escaped_path = urllib.parse.quote(path, safe="/", encoding=None, errors=None) url = self._api_url + escaped_path response = self._session.request(method, url, **cast(Any, kwargs)) BaseHTTPClient._raise_chroma_error(response) return orjson.loads(response.text) @trace_method("FastAPI.heartbeat", OpenTelemetryGranularity.OPERATION) @override def heartbeat(self) -> int: """Returns the current server time in nanoseconds to check if the server is alive""" resp_json = self._make_request("get", "/heartbeat") return int(resp_json["nanosecond heartbeat"]) @trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION) @override def create_database( self, name: str, tenant: str = DEFAULT_TENANT, ) -> None: """Creates a database""" self._make_request( "post", "/databases", json={"name": name}, params={"tenant": tenant}, ) @trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION) @override def get_database( self, name: str, tenant: str = DEFAULT_TENANT, ) -> Database: """Returns a database""" resp_json = self._make_request( "get", "/databases/" + name, params={"tenant": tenant}, ) return Database( id=resp_json["id"], name=resp_json["name"], tenant=resp_json["tenant"] ) @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION) @override def create_tenant(self, name: str) -> None: self._make_request("post", "/tenants", json={"name": name}) @trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION) @override def get_tenant(self, name: str) -> Tenant: resp_json = self._make_request("get", "/tenants/" + name) return Tenant(name=resp_json["name"]) @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION) @override def list_collections( self, limit: Optional[int] = None, offset: Optional[int] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> Sequence[CollectionModel]: """Returns a list of all collections""" json_collections = self._make_request( "get", "/collections", params=BaseHTTPClient._clean_params( { "tenant": tenant, "database": database, "limit": limit, "offset": offset, } ), ) collection_models = [ CollectionModel.from_json(json_collection) for json_collection in json_collections ] return collection_models @trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION) @override def count_collections( self, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE ) -> int: """Returns a count of collections""" resp_json = self._make_request( "get", "/count_collections", params={"tenant": tenant, "database": database}, ) return cast(int, resp_json) @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION) @override def create_collection( self, name: str, configuration: Optional[CollectionConfigurationInternal] = None, metadata: Optional[CollectionMetadata] = None, get_or_create: bool = False, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: """Creates a collection""" resp_json = self._make_request( "post", "/collections", json={ "name": name, "metadata": metadata, "configuration": configuration.to_json() if configuration else None, "get_or_create": get_or_create, }, params={"tenant": tenant, "database": database}, ) model = CollectionModel.from_json(resp_json) return model @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION) @override def get_collection( self, name: str, id: Optional[UUID] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: """Returns a collection""" if (name is None and id is None) or (name is not None and id is not None): raise ValueError("Name or id must be specified, but not both") _params = {"tenant": tenant, "database": database} if id is not None: _params["type"] = str(id) resp_json = self._make_request( "get", "/collections/" + name if name else str(id), params=_params, ) model = CollectionModel.from_json(resp_json) return model @trace_method( "FastAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION ) @override def get_or_create_collection( self, name: str, configuration: Optional[CollectionConfigurationInternal] = None, metadata: Optional[CollectionMetadata] = None, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> CollectionModel: return self.create_collection( name=name, metadata=metadata, configuration=configuration, get_or_create=True, tenant=tenant, database=database, ) @trace_method("FastAPI._modify", OpenTelemetryGranularity.OPERATION) @override def _modify( self, id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, ) -> None: """Updates a collection""" self._make_request( "put", "/collections/" + str(id), json={"new_metadata": new_metadata, "new_name": new_name}, ) @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION) @override def delete_collection( self, name: str, tenant: str = DEFAULT_TENANT, database: str = DEFAULT_DATABASE, ) -> None: """Deletes a collection""" self._make_request( "delete", "/collections/" + name, params={"tenant": tenant, "database": database}, ) @trace_method("FastAPI._count", OpenTelemetryGranularity.OPERATION) @override def _count( self, collection_id: UUID, ) -> int: """Returns the number of embeddings in the database""" resp_json = self._make_request( "get", "/collections/" + str(collection_id) + "/count", ) return cast(int, resp_json) @trace_method("FastAPI._peek", OpenTelemetryGranularity.OPERATION) @override def _peek( self, collection_id: UUID, n: int = 10, ) -> GetResult: return cast( GetResult, self._get( collection_id, limit=n, include=["embeddings", "documents", "metadatas"], # type: ignore[list-item] ), ) @trace_method("FastAPI._get", OpenTelemetryGranularity.OPERATION) @override def _get( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = {}, sort: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, page: Optional[int] = None, page_size: Optional[int] = None, where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents"], # type: ignore[list-item] ) -> GetResult: if page and page_size: offset = (page - 1) * page_size limit = page_size resp_json = self._make_request( "post", "/collections/" + str(collection_id) + "/get", json={ "ids": ids, "where": where, "sort": sort, "limit": limit, "offset": offset, "where_document": where_document, "include": include, }, ) return GetResult( ids=resp_json["ids"], embeddings=resp_json.get("embeddings", None), metadatas=resp_json.get("metadatas", None), documents=resp_json.get("documents", None), data=None, uris=resp_json.get("uris", None), included=resp_json.get("included", include), ) @trace_method("FastAPI._delete", OpenTelemetryGranularity.OPERATION) @override def _delete( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, ) -> None: """Deletes embeddings from the database""" self._make_request( "post", "/collections/" + str(collection_id) + "/delete", json={ "ids": ids, "where": where, "where_document": where_document, }, ) return None @trace_method("FastAPI._submit_batch", OpenTelemetryGranularity.ALL) def _submit_batch( self, batch: Tuple[ IDs, Optional[PyEmbeddings], Optional[Metadatas], Optional[Documents], Optional[URIs], ], url: str, ) -> None: """ Submits a batch of embeddings to the database """ self._make_request( "post", url, json={ "ids": batch[0], "embeddings": batch[1], "metadatas": batch[2], "documents": batch[3], "uris": batch[4], }, ) @trace_method("FastAPI._add", OpenTelemetryGranularity.ALL) @override def _add( self, ids: IDs, collection_id: UUID, embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: """ Adds a batch of embeddings to the database - pass in column oriented data lists """ batch = ( ids, convert_np_embeddings_to_list(embeddings), metadatas, documents, uris, ) validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) self._submit_batch(batch, "/collections/" + str(collection_id) + "/add") return True @trace_method("FastAPI._update", OpenTelemetryGranularity.ALL) @override def _update( self, collection_id: UUID, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: """ Updates a batch of embeddings in the database - pass in column oriented data lists """ batch = ( ids, convert_np_embeddings_to_list(embeddings) if embeddings is not None else None, metadatas, documents, uris, ) validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) self._submit_batch(batch, "/collections/" + str(collection_id) + "/update") return True @trace_method("FastAPI._upsert", OpenTelemetryGranularity.ALL) @override def _upsert( self, collection_id: UUID, ids: IDs, embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, uris: Optional[URIs] = None, ) -> bool: """ Upserts a batch of embeddings in the database - pass in column oriented data lists """ batch = ( ids, convert_np_embeddings_to_list(embeddings), metadatas, documents, uris, ) validate_batch(batch, {"max_batch_size": self.get_max_batch_size()}) self._submit_batch(batch, "/collections/" + str(collection_id) + "/upsert") return True @trace_method("FastAPI._query", OpenTelemetryGranularity.ALL) @override def _query( self, collection_id: UUID, query_embeddings: Embeddings, n_results: int = 10, where: Optional[Where] = {}, where_document: Optional[WhereDocument] = {}, include: Include = ["metadatas", "documents", "distances"], # type: ignore[list-item] ) -> QueryResult: """Gets the nearest neighbors of a single embedding""" resp_json = self._make_request( "post", "/collections/" + str(collection_id) + "/query", json={ "query_embeddings": convert_np_embeddings_to_list(query_embeddings) if query_embeddings is not None else None, "n_results": n_results, "where": where, "where_document": where_document, "include": include, }, ) return QueryResult( ids=resp_json["ids"], distances=resp_json.get("distances", None), embeddings=resp_json.get("embeddings", None), metadatas=resp_json.get("metadatas", None), documents=resp_json.get("documents", None), uris=resp_json.get("uris", None), data=None, included=resp_json.get("included", include), ) @trace_method("FastAPI.reset", OpenTelemetryGranularity.ALL) @override def reset(self) -> bool: """Resets the database""" resp_json = self._make_request("post", "/reset") return cast(bool, resp_json) @trace_method("FastAPI.get_version", OpenTelemetryGranularity.OPERATION) @override def get_version(self) -> str: """Returns the version of the server""" resp_json = self._make_request("get", "/version") return cast(str, resp_json) @override def get_settings(self) -> Settings: """Returns the settings of the client""" return self._settings @trace_method("FastAPI.get_max_batch_size", OpenTelemetryGranularity.OPERATION) @override def get_max_batch_size(self) -> int: if self._max_batch_size == -1: resp_json = self._make_request("get", "/pre-flight-checks") self._max_batch_size = cast(int, resp_json["max_batch_size"]) return self._max_batch_size