Spaces:
Sleeping
Sleeping
from chromadb.api import API | |
from chromadb.config import Settings, System | |
from chromadb.db.system import SysDB | |
from chromadb.segment import SegmentManager, MetadataReader, VectorReader | |
from chromadb.telemetry import Telemetry | |
from chromadb.ingest import Producer | |
from chromadb.api.models.Collection import Collection | |
from chromadb import __version__ | |
from chromadb.errors import InvalidDimensionException, InvalidCollectionException | |
import chromadb.utils.embedding_functions as ef | |
from chromadb.api.types import ( | |
CollectionMetadata, | |
EmbeddingFunction, | |
IDs, | |
Embeddings, | |
Embedding, | |
Metadatas, | |
Documents, | |
Where, | |
WhereDocument, | |
Include, | |
GetResult, | |
QueryResult, | |
validate_metadata, | |
validate_update_metadata, | |
validate_where, | |
validate_where_document, | |
) | |
from chromadb.telemetry.events import CollectionAddEvent, CollectionDeleteEvent | |
import chromadb.types as t | |
from typing import Optional, Sequence, Generator, List, cast, Set, Dict | |
from overrides import override | |
from uuid import UUID, uuid4 | |
import time | |
import logging | |
import re | |
logger = logging.getLogger(__name__) | |
# mimics s3 bucket requirements for naming | |
def check_index_name(index_name: str) -> None: | |
msg = ( | |
"Expected collection name that " | |
"(1) contains 3-63 characters, " | |
"(2) starts and ends with an alphanumeric character, " | |
"(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), " | |
"(4) contains no two consecutive periods (..) and " | |
"(5) is not a valid IPv4 address, " | |
f"got {index_name}" | |
) | |
if len(index_name) < 3 or len(index_name) > 63: | |
raise ValueError(msg) | |
if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name): | |
raise ValueError(msg) | |
if ".." in index_name: | |
raise ValueError(msg) | |
if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name): | |
raise ValueError(msg) | |
class SegmentAPI(API): | |
"""API implementation utilizing the new segment-based internal architecture""" | |
_settings: Settings | |
_sysdb: SysDB | |
_manager: SegmentManager | |
_producer: Producer | |
# TODO: fire telemetry events | |
_telemetry_client: Telemetry | |
_tenant_id: str | |
_topic_ns: str | |
_collection_cache: Dict[UUID, t.Collection] | |
def __init__(self, system: System): | |
super().__init__(system) | |
self._settings = system.settings | |
self._sysdb = self.require(SysDB) | |
self._manager = self.require(SegmentManager) | |
self._telemetry_client = self.require(Telemetry) | |
self._producer = self.require(Producer) | |
self._tenant_id = system.settings.tenant_id | |
self._topic_ns = system.settings.topic_namespace | |
self._collection_cache = {} | |
def heartbeat(self) -> int: | |
return int(time.time_ns()) | |
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is | |
# necessary because changing the value type from `Any` to`` `Union[str, int, float]` | |
# causes the system to somehow convert all values to strings. | |
def create_collection( | |
self, | |
name: str, | |
metadata: Optional[CollectionMetadata] = None, | |
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), | |
get_or_create: bool = False, | |
) -> Collection: | |
existing = self._sysdb.get_collections(name=name) | |
if metadata is not None: | |
validate_metadata(metadata) | |
if existing: | |
if get_or_create: | |
if metadata and existing[0]["metadata"] != metadata: | |
self._modify(id=existing[0]["id"], new_metadata=metadata) | |
existing = self._sysdb.get_collections(id=existing[0]["id"]) | |
return Collection( | |
client=self, | |
id=existing[0]["id"], | |
name=existing[0]["name"], | |
metadata=existing[0]["metadata"], # type: ignore | |
embedding_function=embedding_function, | |
) | |
else: | |
raise ValueError(f"Collection {name} already exists.") | |
# TODO: remove backwards compatibility in naming requirements | |
check_index_name(name) | |
id = uuid4() | |
coll = t.Collection( | |
id=id, name=name, metadata=metadata, topic=self._topic(id), dimension=None | |
) | |
self._producer.create_topic(coll["topic"]) | |
segments = self._manager.create_segments(coll) | |
self._sysdb.create_collection(coll) | |
for segment in segments: | |
self._sysdb.create_segment(segment) | |
return Collection( | |
client=self, | |
id=id, | |
name=name, | |
metadata=metadata, | |
embedding_function=embedding_function, | |
) | |
def get_or_create_collection( | |
self, | |
name: str, | |
metadata: Optional[CollectionMetadata] = None, | |
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), | |
) -> Collection: | |
return self.create_collection( | |
name=name, | |
metadata=metadata, | |
embedding_function=embedding_function, | |
get_or_create=True, | |
) | |
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is | |
# necessary because changing the value type from `Any` to`` `Union[str, int, float]` | |
# causes the system to somehow convert all values to strings | |
def get_collection( | |
self, | |
name: str, | |
embedding_function: Optional[EmbeddingFunction] = ef.DefaultEmbeddingFunction(), | |
) -> Collection: | |
existing = self._sysdb.get_collections(name=name) | |
if existing: | |
return Collection( | |
client=self, | |
id=existing[0]["id"], | |
name=existing[0]["name"], | |
metadata=existing[0]["metadata"], # type: ignore | |
embedding_function=embedding_function, | |
) | |
else: | |
raise ValueError(f"Collection {name} does not exist.") | |
def list_collections(self) -> Sequence[Collection]: | |
collections = [] | |
db_collections = self._sysdb.get_collections() | |
for db_collection in db_collections: | |
collections.append( | |
Collection( | |
client=self, | |
id=db_collection["id"], | |
name=db_collection["name"], | |
metadata=db_collection["metadata"], # type: ignore | |
) | |
) | |
return collections | |
def _modify( | |
self, | |
id: UUID, | |
new_name: Optional[str] = None, | |
new_metadata: Optional[CollectionMetadata] = None, | |
) -> None: | |
if new_name: | |
# backwards compatibility in naming requirements (for now) | |
check_index_name(new_name) | |
if new_metadata: | |
validate_update_metadata(new_metadata) | |
# TODO eventually we'll want to use OptionalArgument and Unspecified in the | |
# signature of `_modify` but not changing the API right now. | |
if new_name and new_metadata: | |
self._sysdb.update_collection(id, name=new_name, metadata=new_metadata) | |
elif new_name: | |
self._sysdb.update_collection(id, name=new_name) | |
elif new_metadata: | |
self._sysdb.update_collection(id, metadata=new_metadata) | |
def delete_collection(self, name: str) -> None: | |
existing = self._sysdb.get_collections(name=name) | |
if existing: | |
self._sysdb.delete_collection(existing[0]["id"]) | |
for s in self._manager.delete_segments(existing[0]["id"]): | |
self._sysdb.delete_segment(s) | |
self._producer.delete_topic(existing[0]["topic"]) | |
if existing and existing[0]["id"] in self._collection_cache: | |
del self._collection_cache[existing[0]["id"]] | |
else: | |
raise ValueError(f"Collection {name} does not exist.") | |
def _add( | |
self, | |
ids: IDs, | |
collection_id: UUID, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
) -> bool: | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.ADD) | |
for r in _records(t.Operation.ADD, ids, embeddings, metadatas, documents): | |
self._validate_embedding_record(coll, r) | |
self._producer.submit_embedding(coll["topic"], r) | |
self._telemetry_client.capture(CollectionAddEvent(str(collection_id), len(ids))) | |
return True | |
def _update( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Optional[Embeddings] = None, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
) -> bool: | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) | |
for r in _records(t.Operation.UPDATE, ids, embeddings, metadatas, documents): | |
self._validate_embedding_record(coll, r) | |
self._producer.submit_embedding(coll["topic"], r) | |
return True | |
def _upsert( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
) -> bool: | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) | |
for r in _records(t.Operation.UPSERT, ids, embeddings, metadatas, documents): | |
self._validate_embedding_record(coll, r) | |
self._producer.submit_embedding(coll["topic"], r) | |
return True | |
def _get( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = {}, | |
sort: Optional[str] = None, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
page: Optional[int] = None, | |
page_size: Optional[int] = None, | |
where_document: Optional[WhereDocument] = {}, | |
include: Include = ["embeddings", "metadatas", "documents"], | |
) -> GetResult: | |
where = validate_where(where) if where is not None and len(where) > 0 else None | |
where_document = ( | |
validate_where_document(where_document) | |
if where_document is not None and len(where_document) > 0 | |
else None | |
) | |
metadata_segment = self._manager.get_segment(collection_id, MetadataReader) | |
if sort is not None: | |
raise NotImplementedError("Sorting is not yet supported") | |
if page and page_size: | |
offset = (page - 1) * page_size | |
limit = page_size | |
records = metadata_segment.get_metadata( | |
where=where, | |
where_document=where_document, | |
ids=ids, | |
limit=limit, | |
offset=offset, | |
) | |
vectors: Sequence[t.VectorEmbeddingRecord] = [] | |
if "embeddings" in include: | |
vector_ids = [r["id"] for r in records] | |
vector_segment = self._manager.get_segment(collection_id, VectorReader) | |
vectors = vector_segment.get_vectors(ids=vector_ids) | |
# TODO: Fix type so we don't need to ignore | |
# It is possible to have a set of records, some with metadata and some without | |
# Same with documents | |
metadatas = [r["metadata"] for r in records] | |
if "documents" in include: | |
documents = [_doc(m) for m in metadatas] | |
return GetResult( | |
ids=[r["id"] for r in records], | |
embeddings=[r["embedding"] for r in vectors] | |
if "embeddings" in include | |
else None, | |
metadatas=_clean_metadatas(metadatas) if "metadatas" in include else None, # type: ignore | |
documents=documents if "documents" in include else None, # type: ignore | |
) | |
def _delete( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = None, | |
where_document: Optional[WhereDocument] = None, | |
) -> IDs: | |
where = validate_where(where) if where is not None and len(where) > 0 else None | |
where_document = ( | |
validate_where_document(where_document) | |
if where_document is not None and len(where_document) > 0 | |
else None | |
) | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.DELETE) | |
# TODO: Do we want to warn the user that unrestricted _delete() is 99% of the | |
# time a bad idea? | |
if (where or where_document) or not ids: | |
metadata_segment = self._manager.get_segment(collection_id, MetadataReader) | |
records = metadata_segment.get_metadata( | |
where=where, where_document=where_document, ids=ids | |
) | |
ids_to_delete = [r["id"] for r in records] | |
else: | |
ids_to_delete = ids | |
for r in _records(t.Operation.DELETE, ids_to_delete): | |
self._validate_embedding_record(coll, r) | |
self._producer.submit_embedding(coll["topic"], r) | |
self._telemetry_client.capture( | |
CollectionDeleteEvent(str(collection_id), len(ids_to_delete)) | |
) | |
return ids_to_delete | |
def _count(self, collection_id: UUID) -> int: | |
metadata_segment = self._manager.get_segment(collection_id, MetadataReader) | |
return metadata_segment.count() | |
def _query( | |
self, | |
collection_id: UUID, | |
query_embeddings: Embeddings, | |
n_results: int = 10, | |
where: Where = {}, | |
where_document: WhereDocument = {}, | |
include: Include = ["documents", "metadatas", "distances"], | |
) -> QueryResult: | |
where = validate_where(where) if where is not None and len(where) > 0 else where | |
where_document = ( | |
validate_where_document(where_document) | |
if where_document is not None and len(where_document) > 0 | |
else where_document | |
) | |
allowed_ids = None | |
coll = self._get_collection(collection_id) | |
for embedding in query_embeddings: | |
self._validate_dimension(coll, len(embedding), update=False) | |
metadata_reader = self._manager.get_segment(collection_id, MetadataReader) | |
if where or where_document: | |
records = metadata_reader.get_metadata( | |
where=where, where_document=where_document | |
) | |
allowed_ids = [r["id"] for r in records] | |
query = t.VectorQuery( | |
vectors=query_embeddings, | |
k=n_results, | |
allowed_ids=allowed_ids, | |
include_embeddings="embeddings" in include, | |
options=None, | |
) | |
vector_reader = self._manager.get_segment(collection_id, VectorReader) | |
results = vector_reader.query_vectors(query) | |
ids: List[List[str]] = [] | |
distances: List[List[float]] = [] | |
embeddings: List[List[Embedding]] = [] | |
documents: List[List[str]] = [] | |
metadatas: List[List[t.Metadata]] = [] | |
for result in results: | |
ids.append([r["id"] for r in result]) | |
if "distances" in include: | |
distances.append([r["distance"] for r in result]) | |
if "embeddings" in include: | |
embeddings.append([cast(Embedding, r["embedding"]) for r in result]) | |
if "documents" in include or "metadatas" in include: | |
all_ids: Set[str] = set() | |
for id_list in ids: | |
all_ids.update(id_list) | |
records = metadata_reader.get_metadata(ids=list(all_ids)) | |
metadata_by_id = {r["id"]: r["metadata"] for r in records} | |
for id_list in ids: | |
# In the segment based architecture, it is possible for one segment | |
# to have a record that another segment does not have. This results in | |
# data inconsistency. For the case of the local segments and the | |
# local segment manager, there is a case where a thread writes | |
# a record to the vector segment but not the metadata segment. | |
# Then a query'ing thread reads from the vector segment and | |
# queries the metadata segment. The metadata segment does not have | |
# the record. In this case we choose to return potentially | |
# incorrect data in the form of None. | |
metadata_list = [metadata_by_id.get(id, None) for id in id_list] | |
if "metadatas" in include: | |
metadatas.append(_clean_metadatas(metadata_list)) # type: ignore | |
if "documents" in include: | |
doc_list = [_doc(m) for m in metadata_list] | |
documents.append(doc_list) # type: ignore | |
return QueryResult( | |
ids=ids, | |
distances=distances if distances else None, | |
metadatas=metadatas if metadatas else None, | |
embeddings=embeddings if embeddings else None, | |
documents=documents if documents else None, | |
) | |
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: | |
return self._get(collection_id, limit=n) | |
def get_version(self) -> str: | |
return __version__ | |
def reset_state(self) -> None: | |
self._collection_cache = {} | |
def reset(self) -> bool: | |
self._system.reset_state() | |
return True | |
def get_settings(self) -> Settings: | |
return self._settings | |
def _topic(self, collection_id: UUID) -> str: | |
return f"persistent://{self._tenant_id}/{self._topic_ns}/{collection_id}" | |
# TODO: This could potentially cause race conditions in a distributed version of the | |
# system, since the cache is only local. | |
def _validate_embedding_record( | |
self, collection: t.Collection, record: t.SubmitEmbeddingRecord | |
) -> None: | |
"""Validate the dimension of an embedding record before submitting it to the system.""" | |
if record["embedding"]: | |
self._validate_dimension(collection, len(record["embedding"]), update=True) | |
def _validate_dimension( | |
self, collection: t.Collection, dim: int, update: bool | |
) -> None: | |
"""Validate that a collection supports records of the given dimension. If update | |
is true, update the collection if the collection doesn't already have a | |
dimension.""" | |
if collection["dimension"] is None: | |
if update: | |
id = collection["id"] | |
self._sysdb.update_collection(id=id, dimension=dim) | |
self._collection_cache[id]["dimension"] = dim | |
elif collection["dimension"] != dim: | |
raise InvalidDimensionException( | |
f"Embedding dimension {dim} does not match collection dimensionality {collection['dimension']}" | |
) | |
else: | |
return # all is well | |
def _get_collection(self, collection_id: UUID) -> t.Collection: | |
"""Read-through cache for collection data""" | |
if collection_id not in self._collection_cache: | |
collections = self._sysdb.get_collections(id=collection_id) | |
if not collections: | |
raise InvalidCollectionException( | |
f"Collection {collection_id} does not exist." | |
) | |
self._collection_cache[collection_id] = collections[0] | |
return self._collection_cache[collection_id] | |
def _records( | |
operation: t.Operation, | |
ids: IDs, | |
embeddings: Optional[Embeddings] = None, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
) -> Generator[t.SubmitEmbeddingRecord, None, None]: | |
"""Convert parallel lists of embeddings, metadatas and documents to a sequence of | |
SubmitEmbeddingRecords""" | |
# Presumes that callers were invoked via Collection model, which means | |
# that we know that the embeddings, metadatas and documents have already been | |
# normalized and are guaranteed to be consistently named lists. | |
# TODO: Fix API types to make it explicit that they've already been normalized | |
for i, id in enumerate(ids): | |
metadata = None | |
if metadatas: | |
metadata = metadatas[i] | |
if documents: | |
document = documents[i] | |
if metadata: | |
metadata = {**metadata, "chroma:document": document} | |
else: | |
metadata = {"chroma:document": document} | |
record = t.SubmitEmbeddingRecord( | |
id=id, | |
embedding=embeddings[i] if embeddings else None, | |
encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now | |
metadata=metadata, | |
operation=operation, | |
) | |
yield record | |
def _doc(metadata: Optional[t.Metadata]) -> Optional[str]: | |
"""Retrieve the document (if any) from a Metadata map""" | |
if metadata and "chroma:document" in metadata: | |
return str(metadata["chroma:document"]) | |
return None | |
def _clean_metadatas( | |
metadata: List[Optional[t.Metadata]], | |
) -> List[Optional[t.Metadata]]: | |
"""Remove any chroma-specific metadata keys that the client shouldn't see from a | |
list of metadata maps.""" | |
return [_clean_metadata(m) for m in metadata] | |
def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]: | |
"""Remove any chroma-specific metadata keys that the client shouldn't see from a | |
metadata map.""" | |
if not metadata: | |
return None | |
result = {} | |
for k, v in metadata.items(): | |
if not k.startswith("chroma:"): | |
result[k] = v | |
if len(result) == 0: | |
return None | |
return result | |