from chromadb.api import API from chromadb.config import Settings, System from chromadb.db.system import SysDB from chromadb.segment import SegmentManager, MetadataReader, VectorReader from chromadb.telemetry import Telemetry from chromadb.ingest import Producer from chromadb.api.models.Collection import Collection from chromadb import __version__ from chromadb.errors import InvalidDimensionException, InvalidCollectionException import chromadb.utils.embedding_functions as ef from chromadb.api.types import ( CollectionMetadata, EmbeddingFunction, IDs, Embeddings, Embedding, Metadatas, Documents, Where, WhereDocument, Include, GetResult, QueryResult, validate_metadata, validate_update_metadata, validate_where, validate_where_document, ) from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent import chromadb.types as t from typing import Optional, Sequence, Generator, List, cast, Set, Dict from overrides import override from uuid import UUID, uuid4 import time import logging import re logger = logging.getLogger(__name__) # mimics s3 bucket requirements for naming def check_index_name(index_name: str) -> None: msg = ( "Expected collection name that " "(1) contains 3-63 characters, " "(2) starts and ends with an alphanumeric character, " "(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), " "(4) contains no two consecutive periods (..) and " "(5) is not a valid IPv4 address, " f"got {index_name}" ) if len(index_name) < 3 or len(index_name) > 63: raise ValueError(msg) if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name): raise ValueError(msg) if ".." in index_name: raise ValueError(msg) if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name): raise ValueError(msg) class SegmentAPI(API): """API implementation utilizing the new segment-based internal architecture""" _settings: Settings _sysdb: SysDB _manager: SegmentManager _producer: Producer # TODO: fire telemetry events _telemetry_client: Telemetry _tenant_id: str _topic_ns: str _collection_cache: Dict[UUID, t.Collection] def __init__(self, system: System): super().__init__(system) self._settings = system.settings self._sysdb = self.require(SysDB) self._manager = self.require(SegmentManager) self._telemetry_client = self.require(Telemetry) self._producer = self.require(Producer) self._tenant_id = system.settings.tenant_id self._topic_ns = system.settings.topic_namespace self._collection_cache = {} @override def heartbeat(self) -> int: return int(time.time_ns()) # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings. @override def create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), get_or_create: bool = False, ) -> Collection: existing = self._sysdb.get_collections(name=name) if metadata is not None: validate_metadata(metadata) if existing: if get_or_create: if metadata and existing[0]["metadata"] != metadata: self._modify(id=existing[0]["id"], new_metadata=metadata) existing = self._sysdb.get_collections(id=existing[0]["id"]) return Collection( client=self, id=existing[0]["id"], name=existing[0]["name"], metadata=existing[0]["metadata"], # type: ignore embedding_function=embedding_function, ) else: raise ValueError(f"Collection {name} already exists.") # TODO: remove backwards compatibility in naming requirements check_index_name(name) id = uuid4() coll = t.Collection( id=id, name=name, metadata=metadata, topic=self._topic(id), dimension=None ) self._producer.create_topic(coll["topic"]) segments = self._manager.create_segments(coll) self._sysdb.create_collection(coll) for segment in segments: self._sysdb.create_segment(segment) return Collection( client=self, id=id, name=name, metadata=metadata, embedding_function=embedding_function, ) @override def get_or_create_collection( self, name: str, metadata: Optional[CollectionMetadata] = None, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), ) -> Collection: return self.create_collection( name=name, metadata=metadata, embedding_function=embedding_function, get_or_create=True, ) # TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is # necessary because changing the value type from `Any` to`` `Union[str, int, float]` # causes the system to somehow convert all values to strings @override def get_collection( self, name: str, embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), ) -> Collection: existing = self._sysdb.get_collections(name=name) if existing: return Collection( client=self, id=existing[0]["id"], name=existing[0]["name"], metadata=existing[0]["metadata"], # type: ignore embedding_function=embedding_function, ) else: raise ValueError(f"Collection {name} does not exist.") @override def list_collections(self) -> Sequence[Collection]: collections = [] db_collections = self._sysdb.get_collections() for db_collection in db_collections: collections.append( Collection( client=self, id=db_collection["id"], name=db_collection["name"], metadata=db_collection["metadata"], # type: ignore ) ) return collections @override def _modify( self, id: UUID, new_name: Optional[str] = None, new_metadata: Optional[CollectionMetadata] = None, ) -> None: if new_name: # backwards compatibility in naming requirements (for now) check_index_name(new_name) if new_metadata: validate_update_metadata(new_metadata) # TODO eventually we'll want to use OptionalArgument and Unspecified in the # signature of `_modify` but not changing the API right now. if new_name and new_metadata: self._sysdb.update_collection(id, name=new_name, metadata=new_metadata) elif new_name: self._sysdb.update_collection(id, name=new_name) elif new_metadata: self._sysdb.update_collection(id, metadata=new_metadata) @override def delete_collection(self, name: str) -> None: existing = self._sysdb.get_collections(name=name) if existing: self._sysdb.delete_collection(existing[0]["id"]) for s in self._manager.delete_segments(existing[0]["id"]): self._sysdb.delete_segment(s) self._producer.delete_topic(existing[0]["topic"]) if existing and existing[0]["id"] in self._collection_cache: del self._collection_cache[existing[0]["id"]] else: raise ValueError(f"Collection {name} does not exist.") @override def _add( self, ids: IDs, collection_id: UUID, embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.ADD) for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents): self._validate_embedding_record(coll, r) self._producer.submit_embedding(coll["topic"], r) self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids))) return True @override def _update( self, collection_id: UUID, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents): self._validate_embedding_record(coll, r) self._producer.submit_embedding(coll["topic"], r) return True @override def _upsert( self, collection_id: UUID, ids: IDs, embeddings: Embeddings, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, ) -> bool: coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents): self._validate_embedding_record(coll, r) self._producer.submit_embedding(coll["topic"], r) return True @override def _get( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = {}, sort: Optional[str] = None, limit: Optional[int] = None, offset: Optional[int] = None, page: Optional[int] = None, page_size: Optional[int] = None, where_document: Optional[WhereDocument] = {}, include: Include = ["embeddings", "metadatas", "documents"], ) -> GetResult: where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( validate_where_document(where_document) if where_document is not None and len(where_document) > 0 else None ) metadata_segment = self._manager.get_segment(collection_id, MetadataReader) if sort is not None: raise NotImplementedError("Sorting is not yet supported") if page and page_size: offset = (page - 1) * page_size limit = page_size records = metadata_segment.get_metadata( where=where, where_document=where_document, ids=ids, limit=limit, offset=offset, ) vectors: Sequence[t.VectorEmbeddingRecord] = [] if "embeddings" in include: vector_ids = [r["id"] for r in records] vector_segment = self._manager.get_segment(collection_id, VectorReader) vectors = vector_segment.get_vectors(ids=vector_ids) # TODO: Fix type so we don't need to ignore # It is possible to have a set of records, some with metadata and some without # Same with documents metadatas = [r["metadata"] for r in records] if "documents" in include: documents = [_doc(m) for m in metadatas] return GetResult( ids=[r["id"] for r in records], embeddings=[r["embedding"] for r in vectors] if "embeddings" in include else None, metadatas=_clean_metadatas(metadatas) if "metadatas" in include else None, # type: ignore documents=documents if "documents" in include else None, # type: ignore ) @override def _delete( self, collection_id: UUID, ids: Optional[IDs] = None, where: Optional[Where] = None, where_document: Optional[WhereDocument] = None, ) -> IDs: where = validate_where(where) if where is not None and len(where) > 0 else None where_document = ( validate_where_document(where_document) if where_document is not None and len(where_document) > 0 else None ) coll = self._get_collection(collection_id) self._manager.hint_use_collection(collection_id, t.Operation.DELETE) # TODO: Do we want to warn the user that unrestricted _delete() is 99% of the # time a bad idea? if (where or where_document) or not ids: metadata_segment = self._manager.get_segment(collection_id, MetadataReader) records = metadata_segment.get_metadata( where=where, where_document=where_document, ids=ids ) ids_to_delete = [r["id"] for r in records] else: ids_to_delete = ids for r in _records(t.Operation.DELETE, ids_to_delete): self._validate_embedding_record(coll, r) self._producer.submit_embedding(coll["topic"], r) self._telemetry_client.capture( CollectionDeleteEvent(str(collection_id), len(ids_to_delete)) ) return ids_to_delete @override def _count(self, collection_id: UUID) -> int: metadata_segment = self._manager.get_segment(collection_id, MetadataReader) return metadata_segment.count() @override def _query( self, collection_id: UUID, query_embeddings: Embeddings, n_results: int = 10, where: Where = {}, where_document: WhereDocument = {}, include: Include = ["documents", "metadatas", "distances"], ) -> QueryResult: where = validate_where(where) if where is not None and len(where) > 0 else where where_document = ( validate_where_document(where_document) if where_document is not None and len(where_document) > 0 else where_document ) allowed_ids = None coll = self._get_collection(collection_id) for embedding in query_embeddings: self._validate_dimension(coll, len(embedding), update=False) metadata_reader = self._manager.get_segment(collection_id, MetadataReader) if where or where_document: records = metadata_reader.get_metadata( where=where, where_document=where_document ) allowed_ids = [r["id"] for r in records] query = t.VectorQuery( vectors=query_embeddings, k=n_results, allowed_ids=allowed_ids, include_embeddings="embeddings" in include, options=None, ) vector_reader = self._manager.get_segment(collection_id, VectorReader) results = vector_reader.query_vectors(query) ids: List[List[str]] = [] distances: List[List[float]] = [] embeddings: List[List[Embedding]] = [] documents: List[List[str]] = [] metadatas: List[List[t.Metadata]] = [] for result in results: ids.append([r["id"] for r in result]) if "distances" in include: distances.append([r["distance"] for r in result]) if "embeddings" in include: embeddings.append([cast(Embedding, r["embedding"]) for r in result]) if "documents" in include or "metadatas" in include: all_ids: Set[str] = set() for id_list in ids: all_ids.update(id_list) records = metadata_reader.get_metadata(ids=list(all_ids)) metadata_by_id = {r["id"]: r["metadata"] for r in records} for id_list in ids: # In the segment based architecture, it is possible for one segment # to have a record that another segment does not have. This results in # data inconsistency. For the case of the local segments and the # local segment manager, there is a case where a thread writes # a record to the vector segment but not the metadata segment. # Then a query'ing thread reads from the vector segment and # queries the metadata segment. The metadata segment does not have # the record. In this case we choose to return potentially # incorrect data in the form of None. metadata_list = [metadata_by_id.get(id, None) for id in id_list] if "metadatas" in include: metadatas.append(_clean_metadatas(metadata_list)) # type: ignore if "documents" in include: doc_list = [_doc(m) for m in metadata_list] documents.append(doc_list) # type: ignore return QueryResult( ids=ids, distances=distances if distances else None, metadatas=metadatas if metadatas else None, embeddings=embeddings if embeddings else None, documents=documents if documents else None, ) @override def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: return self._get(collection_id, limit=n) @override def get_version(self) -> str: return __version__ @override def reset_state(self) -> None: self._collection_cache = {} @override def reset(self) -> bool: self._system.reset_state() return True @override def get_settings(self) -> Settings: return self._settings def _topic(self, collection_id: UUID) -> str: return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}" # TODO: This could potentially cause race conditions in a distributed version of the # system, since the cache is only local. def _validate_embedding_record( self, collection: t.Collection, record: t.SubmitEmbeddingRecord ) -> None: """Validate the dimension of an embedding record before submitting it to the system.""" if record["embedding"]: self._validate_dimension(collection, len(record["embedding"]), update=True) def _validate_dimension( self, collection: t.Collection, dim: int, update: bool ) -> None: """Validate that a collection supports records of the given dimension. If update is true, update the collection if the collection doesn't already have a dimension.""" if collection["dimension"] is None: if update: id = collection["id"] self._sysdb.update_collection(id=id, dimension=dim) self._collection_cache[id]["dimension"] = dim elif collection["dimension"] != dim: raise InvalidDimensionException( f"Embedding dimension {dim} does not match collection dimensionality {collection['dimension']}" ) else: return # all is well def _get_collection(self, collection_id: UUID) -> t.Collection: """Read-through cache for collection data""" if collection_id not in self._collection_cache: collections = self._sysdb.get_collections(id=collection_id) if not collections: raise InvalidCollectionException( f"Collection {collection_id} does not exist." ) self._collection_cache[collection_id] = collections[0] return self._collection_cache[collection_id] def _records( operation: t.Operation, ids: IDs, embeddings: Optional[Embeddings] = None, metadatas: Optional[Metadatas] = None, documents: Optional[Documents] = None, ) -> Generator[t.SubmitEmbeddingRecord, None, None]: """Convert parallel lists of embeddings, metadatas and documents to a sequence of SubmitEmbeddingRecords""" # Presumes that callers were invoked via Collection model, which means # that we know that the embeddings, metadatas and documents have already been # normalized and are guaranteed to be consistently named lists. # TODO: Fix API types to make it explicit that they've already been normalized for i, id in enumerate(ids): metadata = None if metadatas: metadata = metadatas[i] if documents: document = documents[i] if metadata: metadata = {**metadata, "chroma:document": document} else: metadata = {"chroma:document": document} record = t.SubmitEmbeddingRecord( id=id, embedding=embeddings[i] if embeddings else None, encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now metadata=metadata, operation=operation, ) yield record def _doc(metadata: Optional[t.Metadata]) -> Optional[str]: """Retrieve the document (if any) from a Metadata map""" if metadata and "chroma:document" in metadata: return str(metadata["chroma:document"]) return None def _clean_metadatas( metadata: List[Optional[t.Metadata]], ) -> List[Optional[t.Metadata]]: """Remove any chroma-specific metadata keys that the client shouldn't see from a list of metadata maps.""" return [_clean_metadata(m) for m in metadata] def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]: """Remove any chroma-specific metadata keys that the client shouldn't see from a metadata map.""" if not metadata: return None result = {} for k, v in metadata.items(): if not k.startswith("chroma:"): result[k] = v if len(result) == 0: return None return result