Spaces:
Build error
Build error
from tenacity import retry, stop_after_attempt, retry_if_exception, wait_fixed | |
from chromadb.api import ServerAPI | |
from chromadb.api.configuration import CollectionConfigurationInternal | |
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System | |
from chromadb.db.system import SysDB | |
from chromadb.quota import QuotaEnforcer | |
from chromadb.rate_limit import RateLimitEnforcer | |
from chromadb.segment import SegmentManager, MetadataReader, VectorReader | |
from chromadb.telemetry.opentelemetry import ( | |
add_attributes_to_current_span, | |
OpenTelemetryClient, | |
OpenTelemetryGranularity, | |
trace_method, | |
) | |
from chromadb.telemetry.product import ProductTelemetryClient | |
from chromadb.ingest import Producer | |
from chromadb.types import Collection as CollectionModel | |
from chromadb import __version__ | |
from chromadb.errors import ( | |
InvalidDimensionException, | |
InvalidCollectionException, | |
VersionMismatchError, | |
) | |
from chromadb.api.types import ( | |
URI, | |
CollectionMetadata, | |
Document, | |
IDs, | |
Embeddings, | |
Embedding, | |
Metadatas, | |
Documents, | |
URIs, | |
Where, | |
WhereDocument, | |
Include, | |
GetResult, | |
QueryResult, | |
validate_metadata, | |
validate_update_metadata, | |
validate_where, | |
validate_where_document, | |
validate_batch, | |
) | |
from chromadb.telemetry.product.events import ( | |
CollectionAddEvent, | |
CollectionDeleteEvent, | |
CollectionGetEvent, | |
CollectionUpdateEvent, | |
CollectionQueryEvent, | |
ClientCreateCollectionEvent, | |
) | |
import chromadb.types as t | |
from typing import ( | |
Optional, | |
Sequence, | |
Generator, | |
List, | |
cast, | |
Set, | |
Any, | |
Callable, | |
TypeVar, | |
) | |
from overrides import override | |
from uuid import UUID, uuid4 | |
from functools import wraps | |
import time | |
import logging | |
import re | |
T = TypeVar("T", bound=Callable[..., Any]) | |
logger = logging.getLogger(__name__) | |
# mimics s3 bucket requirements for naming | |
def check_index_name(index_name: str) -> None: | |
msg = ( | |
"Expected collection name that " | |
"(1) contains 3-63 characters, " | |
"(2) starts and ends with an alphanumeric character, " | |
"(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), " | |
"(4) contains no two consecutive periods (..) and " | |
"(5) is not a valid IPv4 address, " | |
f"got {index_name}" | |
) | |
if len(index_name) < 3 or len(index_name) > 63: | |
raise ValueError(msg) | |
if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name): | |
raise ValueError(msg) | |
if ".." in index_name: | |
raise ValueError(msg) | |
if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name): | |
raise ValueError(msg) | |
def rate_limit(func: T) -> T: | |
def wrapper(*args: Any, **kwargs: Any) -> Any: | |
self = args[0] | |
return self._rate_limit_enforcer.rate_limit(func)(*args, **kwargs) | |
return wrapper # type: ignore | |
class SegmentAPI(ServerAPI): | |
"""API implementation utilizing the new segment-based internal architecture""" | |
_settings: Settings | |
_sysdb: SysDB | |
_manager: SegmentManager | |
_producer: Producer | |
_product_telemetry_client: ProductTelemetryClient | |
_opentelemetry_client: OpenTelemetryClient | |
_tenant_id: str | |
_topic_ns: str | |
def __init__(self, system: System): | |
super().__init__(system) | |
self._settings = system.settings | |
self._sysdb = self.require(SysDB) | |
self._manager = self.require(SegmentManager) | |
self._quota = self.require(QuotaEnforcer) | |
self._product_telemetry_client = self.require(ProductTelemetryClient) | |
self._opentelemetry_client = self.require(OpenTelemetryClient) | |
self._producer = self.require(Producer) | |
self._rate_limit_enforcer = self._system.require(RateLimitEnforcer) | |
def heartbeat(self) -> int: | |
return int(time.time_ns()) | |
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None: | |
if len(name) < 3: | |
raise ValueError("Database name must be at least 3 characters long") | |
self._sysdb.create_database( | |
id=uuid4(), | |
name=name, | |
tenant=tenant, | |
) | |
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database: | |
return self._sysdb.get_database(name=name, tenant=tenant) | |
def create_tenant(self, name: str) -> None: | |
if len(name) < 3: | |
raise ValueError("Tenant name must be at least 3 characters long") | |
self._sysdb.create_tenant( | |
name=name, | |
) | |
def get_tenant(self, name: str) -> t.Tenant: | |
return self._sysdb.get_tenant(name=name) | |
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is | |
# necessary because changing the value type from `Any` to`` `Union[str, int, float]` | |
# causes the system to somehow convert all values to strings. | |
def create_collection( | |
self, | |
name: str, | |
configuration: Optional[CollectionConfigurationInternal] = None, | |
metadata: Optional[CollectionMetadata] = None, | |
get_or_create: bool = False, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> CollectionModel: | |
if metadata is not None: | |
validate_metadata(metadata) | |
# TODO: remove backwards compatibility in naming requirements | |
check_index_name(name) | |
id = uuid4() | |
model = CollectionModel( | |
id=id, | |
name=name, | |
metadata=metadata, | |
configuration=configuration | |
if configuration is not None | |
else CollectionConfigurationInternal(), # Use default configuration if none is provided | |
tenant=tenant, | |
database=database, | |
dimension=None, | |
) | |
# TODO: Let sysdb create the collection directly from the model | |
coll, created = self._sysdb.create_collection( | |
id=model.id, | |
name=model.name, | |
configuration=model.get_configuration(), | |
metadata=model.metadata, | |
dimension=None, # This is lazily populated on the first add | |
get_or_create=get_or_create, | |
tenant=tenant, | |
database=database, | |
) | |
# TODO: wrap sysdb call in try except and log error if it fails | |
if created: | |
segments = self._manager.create_segments(coll) | |
for segment in segments: | |
self._sysdb.create_segment(segment) | |
else: | |
logger.debug( | |
f"Collection {name} already exists, returning existing collection." | |
) | |
# TODO: This event doesn't capture the get_or_create case appropriately | |
# TODO: Re-enable embedding function tracking in create_collection | |
self._product_telemetry_client.capture( | |
ClientCreateCollectionEvent( | |
collection_uuid=str(id), | |
# embedding_function=embedding_function.__class__.__name__, | |
) | |
) | |
add_attributes_to_current_span({"collection_uuid": str(id)}) | |
return coll | |
def get_or_create_collection( | |
self, | |
name: str, | |
configuration: Optional[CollectionConfigurationInternal] = None, | |
metadata: Optional[CollectionMetadata] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> CollectionModel: | |
return self.create_collection( | |
name=name, | |
metadata=metadata, | |
configuration=configuration, | |
get_or_create=True, | |
tenant=tenant, | |
database=database, | |
) | |
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is | |
# necessary because changing the value type from `Any` to`` `Union[str, int, float]` | |
# causes the system to somehow convert all values to strings | |
def get_collection( | |
self, | |
name: Optional[str] = None, | |
id: Optional[UUID] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> CollectionModel: | |
if id is None and name is None or (id is not None and name is not None): | |
raise ValueError("Name or id must be specified, but not both") | |
existing = self._sysdb.get_collections( | |
id=id, name=name, tenant=tenant, database=database | |
) | |
if existing: | |
return existing[0] | |
else: | |
raise InvalidCollectionException(f"Collection {name} does not exist.") | |
def list_collections( | |
self, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> Sequence[CollectionModel]: | |
return self._sysdb.get_collections( | |
limit=limit, offset=offset, tenant=tenant, database=database | |
) | |
def count_collections( | |
self, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> int: | |
collection_count = len( | |
self._sysdb.get_collections(tenant=tenant, database=database) | |
) | |
return collection_count | |
def _modify( | |
self, | |
id: UUID, | |
new_name: Optional[str] = None, | |
new_metadata: Optional[CollectionMetadata] = None, | |
) -> None: | |
if new_name: | |
# backwards compatibility in naming requirements (for now) | |
check_index_name(new_name) | |
if new_metadata: | |
validate_update_metadata(new_metadata) | |
# Ensure the collection exists | |
_ = self._get_collection(id) | |
# TODO eventually we'll want to use OptionalArgument and Unspecified in the | |
# signature of `_modify` but not changing the API right now. | |
if new_name and new_metadata: | |
self._sysdb.update_collection(id, name=new_name, metadata=new_metadata) | |
elif new_name: | |
self._sysdb.update_collection(id, name=new_name) | |
elif new_metadata: | |
self._sysdb.update_collection(id, metadata=new_metadata) | |
def delete_collection( | |
self, | |
name: str, | |
tenant: str = DEFAULT_TENANT, | |
database: str = DEFAULT_DATABASE, | |
) -> None: | |
existing = self._sysdb.get_collections( | |
name=name, tenant=tenant, database=database | |
) | |
if existing: | |
self._sysdb.delete_collection( | |
existing[0].id, tenant=tenant, database=database | |
) | |
for s in self._manager.delete_segments(existing[0].id): | |
self._sysdb.delete_segment(existing[0].id, s) | |
else: | |
raise ValueError(f"Collection {name} does not exist.") | |
def _add( | |
self, | |
ids: IDs, | |
collection_id: UUID, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.ADD) | |
validate_batch( | |
(ids, embeddings, metadatas, documents, uris), | |
{"max_batch_size": self.get_max_batch_size()}, | |
) | |
records_to_submit = list( | |
_records( | |
t.Operation.ADD, | |
ids=ids, | |
embeddings=embeddings, | |
metadatas=metadatas, | |
documents=documents, | |
uris=uris, | |
) | |
) | |
self._validate_embedding_record_set(coll, records_to_submit) | |
self._producer.submit_embeddings(collection_id, records_to_submit) | |
self._product_telemetry_client.capture( | |
CollectionAddEvent( | |
collection_uuid=str(collection_id), | |
add_amount=len(ids), | |
with_metadata=len(ids) if metadatas is not None else 0, | |
with_documents=len(ids) if documents is not None else 0, | |
with_uris=len(ids) if uris is not None else 0, | |
) | |
) | |
return True | |
def _update( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Optional[Embeddings] = None, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE) | |
validate_batch( | |
(ids, embeddings, metadatas, documents, uris), | |
{"max_batch_size": self.get_max_batch_size()}, | |
) | |
records_to_submit = list( | |
_records( | |
t.Operation.UPDATE, | |
ids=ids, | |
embeddings=embeddings, | |
metadatas=metadatas, | |
documents=documents, | |
uris=uris, | |
) | |
) | |
self._validate_embedding_record_set(coll, records_to_submit) | |
self._producer.submit_embeddings(collection_id, records_to_submit) | |
self._product_telemetry_client.capture( | |
CollectionUpdateEvent( | |
collection_uuid=str(collection_id), | |
update_amount=len(ids), | |
with_embeddings=len(embeddings) if embeddings else 0, | |
with_metadata=len(metadatas) if metadatas else 0, | |
with_documents=len(documents) if documents else 0, | |
with_uris=len(uris) if uris else 0, | |
) | |
) | |
return True | |
def _upsert( | |
self, | |
collection_id: UUID, | |
ids: IDs, | |
embeddings: Embeddings, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> bool: | |
self._quota.static_check(metadatas, documents, embeddings, str(collection_id)) | |
coll = self._get_collection(collection_id) | |
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT) | |
validate_batch( | |
(ids, embeddings, metadatas, documents, uris), | |
{"max_batch_size": self.get_max_batch_size()}, | |
) | |
records_to_submit = list( | |
_records( | |
t.Operation.UPSERT, | |
ids=ids, | |
embeddings=embeddings, | |
metadatas=metadatas, | |
documents=documents, | |
uris=uris, | |
) | |
) | |
self._validate_embedding_record_set(coll, records_to_submit) | |
self._producer.submit_embeddings(collection_id, records_to_submit) | |
return True | |
def _get( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = {}, | |
sort: Optional[str] = None, | |
limit: Optional[int] = None, | |
offset: Optional[int] = None, | |
page: Optional[int] = None, | |
page_size: Optional[int] = None, | |
where_document: Optional[WhereDocument] = {}, | |
include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item] | |
) -> GetResult: | |
add_attributes_to_current_span( | |
{ | |
"collection_id": str(collection_id), | |
"ids_count": len(ids) if ids else 0, | |
} | |
) | |
coll = self._get_collection(collection_id) | |
request_version_context = t.RequestVersionContext( | |
collection_version=coll.version, | |
log_position=coll.log_position, | |
) | |
where = validate_where(where) if where is not None and len(where) > 0 else None | |
where_document = ( | |
validate_where_document(where_document) | |
if where_document is not None and len(where_document) > 0 | |
else None | |
) | |
metadata_segment = self._manager.get_segment(collection_id, MetadataReader) | |
if sort is not None: | |
raise NotImplementedError("Sorting is not yet supported") | |
if page and page_size: | |
offset = (page - 1) * page_size | |
limit = page_size | |
records = metadata_segment.get_metadata( | |
where=where, | |
where_document=where_document, | |
ids=ids, | |
limit=limit, | |
offset=offset, | |
request_version_context=request_version_context, | |
) | |
if len(records) == 0: | |
# Nothing to return if there are no records | |
return GetResult( | |
ids=[], | |
embeddings=[] if "embeddings" in include else None, | |
metadatas=[] if "metadatas" in include else None, | |
documents=[] if "documents" in include else None, | |
uris=[] if "uris" in include else None, | |
data=[] if "data" in include else None, | |
included=include, | |
) | |
vectors: Sequence[t.VectorEmbeddingRecord] = [] | |
if "embeddings" in include: | |
vector_ids = [r["id"] for r in records] | |
vector_segment = self._manager.get_segment(collection_id, VectorReader) | |
vectors = vector_segment.get_vectors( | |
ids=vector_ids, request_version_context=request_version_context | |
) | |
# TODO: Fix type so we don't need to ignore | |
# It is possible to have a set of records, some with metadata and some without | |
# Same with documents | |
metadatas = [r["metadata"] for r in records] | |
if "documents" in include: | |
documents = [_doc(m) for m in metadatas] | |
if "uris" in include: | |
uris = [_uri(m) for m in metadatas] | |
ids_amount = len(ids) if ids else 0 | |
self._product_telemetry_client.capture( | |
CollectionGetEvent( | |
collection_uuid=str(collection_id), | |
ids_count=ids_amount, | |
limit=limit if limit else 0, | |
include_metadata=ids_amount if "metadatas" in include else 0, | |
include_documents=ids_amount if "documents" in include else 0, | |
include_uris=ids_amount if "uris" in include else 0, | |
) | |
) | |
return GetResult( | |
ids=[r["id"] for r in records], | |
embeddings=[r["embedding"] for r in vectors] | |
if "embeddings" in include | |
else None, | |
metadatas=_clean_metadatas(metadatas) | |
if "metadatas" in include | |
else None, # type: ignore | |
documents=documents if "documents" in include else None, # type: ignore | |
uris=uris if "uris" in include else None, # type: ignore | |
data=None, | |
included=include, | |
) | |
def _delete( | |
self, | |
collection_id: UUID, | |
ids: Optional[IDs] = None, | |
where: Optional[Where] = None, | |
where_document: Optional[WhereDocument] = None, | |
) -> None: | |
add_attributes_to_current_span( | |
{ | |
"collection_id": str(collection_id), | |
"ids_count": len(ids) if ids else 0, | |
} | |
) | |
where = validate_where(where) if where is not None and len(where) > 0 else None | |
where_document = ( | |
validate_where_document(where_document) | |
if where_document is not None and len(where_document) > 0 | |
else None | |
) | |
# You must have at least one of non-empty ids, where, or where_document. | |
if ( | |
(ids is None or (ids is not None and len(ids) == 0)) | |
and (where is None or (where is not None and len(where) == 0)) | |
and ( | |
where_document is None | |
or (where_document is not None and len(where_document) == 0) | |
) | |
): | |
raise ValueError( | |
""" | |
You must provide either ids, where, or where_document to delete. If | |
you want to delete all data in a collection you can delete the | |
collection itself using the delete_collection method. Or alternatively, | |
you can get() all the relevant ids and then delete them. | |
""" | |
) | |
coll = self._get_collection(collection_id) | |
request_version_context = t.RequestVersionContext( | |
collection_version=coll.version, | |
log_position=coll.log_position, | |
) | |
self._manager.hint_use_collection(collection_id, t.Operation.DELETE) | |
if (where or where_document) or not ids: | |
metadata_segment = self._manager.get_segment(collection_id, MetadataReader) | |
records = metadata_segment.get_metadata( | |
where=where, | |
where_document=where_document, | |
ids=ids, | |
request_version_context=request_version_context, | |
) | |
ids_to_delete = [r["id"] for r in records] | |
else: | |
ids_to_delete = ids | |
if len(ids_to_delete) == 0: | |
return | |
records_to_submit = list( | |
_records(operation=t.Operation.DELETE, ids=ids_to_delete) | |
) | |
self._validate_embedding_record_set(coll, records_to_submit) | |
self._producer.submit_embeddings(collection_id, records_to_submit) | |
self._product_telemetry_client.capture( | |
CollectionDeleteEvent( | |
collection_uuid=str(collection_id), delete_amount=len(ids_to_delete) | |
) | |
) | |
def _count(self, collection_id: UUID) -> int: | |
add_attributes_to_current_span({"collection_id": str(collection_id)}) | |
coll = self._get_collection(collection_id) | |
request_version_context = t.RequestVersionContext( | |
collection_version=coll.version, | |
log_position=coll.log_position, | |
) | |
metadata_segment = self._manager.get_segment(collection_id, MetadataReader) | |
return metadata_segment.count(request_version_context) | |
# We retry on version mismatch errors because the version of the collection | |
# may have changed between the time we got the version and the time we | |
# actually query the collection on the FE. We are fine with fixed | |
# wait time because the version mismatch error is not a error due to | |
# network issues or other transient issues. It is a result of the | |
# collection being updated between the time we got the version and | |
# the time we actually query the collection on the FE. | |
def _query( | |
self, | |
collection_id: UUID, | |
query_embeddings: Embeddings, | |
n_results: int = 10, | |
where: Where = {}, | |
where_document: WhereDocument = {}, | |
include: Include = ["documents", "metadatas", "distances"], # type: ignore[list-item] | |
) -> QueryResult: | |
add_attributes_to_current_span( | |
{ | |
"collection_id": str(collection_id), | |
"n_results": n_results, | |
"where": str(where), | |
} | |
) | |
query_amount = len(query_embeddings) | |
self._product_telemetry_client.capture( | |
CollectionQueryEvent( | |
collection_uuid=str(collection_id), | |
query_amount=query_amount, | |
n_results=n_results, | |
with_metadata_filter=query_amount if where is not None else 0, | |
with_document_filter=query_amount if where_document is not None else 0, | |
include_metadatas=query_amount if "metadatas" in include else 0, | |
include_documents=query_amount if "documents" in include else 0, | |
include_uris=query_amount if "uris" in include else 0, | |
include_distances=query_amount if "distances" in include else 0, | |
) | |
) | |
where = validate_where(where) if where is not None and len(where) > 0 else where | |
where_document = ( | |
validate_where_document(where_document) | |
if where_document is not None and len(where_document) > 0 | |
else where_document | |
) | |
allowed_ids = None | |
coll = self._get_collection(collection_id) | |
request_version_context = t.RequestVersionContext( | |
collection_version=coll.version, | |
log_position=coll.log_position, | |
) | |
for embedding in query_embeddings: | |
self._validate_dimension(coll, len(embedding), update=False) | |
if where or where_document: | |
metadata_reader = self._manager.get_segment(collection_id, MetadataReader) | |
records = metadata_reader.get_metadata( | |
where=where, | |
where_document=where_document, | |
include_metadata=False, | |
request_version_context=request_version_context, | |
) | |
allowed_ids = [r["id"] for r in records] | |
ids: List[List[str]] = [] | |
distances: List[List[float]] = [] | |
embeddings: List[Embeddings] = [] | |
documents: List[List[Document]] = [] | |
uris: List[List[URI]] = [] | |
metadatas: List[List[t.Metadata]] = [] | |
# If where conditions returned empty list then no need to proceed | |
# further and can simply return an empty result set here. | |
if allowed_ids is not None and allowed_ids == []: | |
for em in range(len(query_embeddings)): | |
ids.append([]) | |
if "distances" in include: | |
distances.append([]) | |
if "embeddings" in include: | |
embeddings.append([]) | |
if "documents" in include: | |
documents.append([]) | |
if "metadatas" in include: | |
metadatas.append([]) | |
if "uris" in include: | |
uris.append([]) | |
else: | |
query = t.VectorQuery( | |
vectors=query_embeddings, | |
k=n_results, | |
allowed_ids=allowed_ids, | |
include_embeddings="embeddings" in include, | |
options=None, | |
request_version_context=request_version_context, | |
) | |
vector_reader = self._manager.get_segment(collection_id, VectorReader) | |
results = vector_reader.query_vectors(query) | |
for result in results: | |
ids.append([r["id"] for r in result]) | |
if "distances" in include: | |
distances.append([r["distance"] for r in result]) | |
if "embeddings" in include: | |
embeddings.append([cast(Embedding, r["embedding"]) for r in result]) | |
if "documents" in include or "metadatas" in include or "uris" in include: | |
all_ids: Set[str] = set() | |
for id_list in ids: | |
all_ids.update(id_list) | |
metadata_reader = self._manager.get_segment( | |
collection_id, MetadataReader | |
) | |
records = metadata_reader.get_metadata( | |
ids=list(all_ids), | |
include_metadata=True, | |
request_version_context=request_version_context, | |
) | |
metadata_by_id = {r["id"]: r["metadata"] for r in records} | |
for id_list in ids: | |
# In the segment based architecture, it is possible for one segment | |
# to have a record that another segment does not have. This results in | |
# data inconsistency. For the case of the local segments and the | |
# local segment manager, there is a case where a thread writes | |
# a record to the vector segment but not the metadata segment. | |
# Then a query'ing thread reads from the vector segment and | |
# queries the metadata segment. The metadata segment does not have | |
# the record. In this case we choose to return potentially | |
# incorrect data in the form of None. | |
metadata_list = [metadata_by_id.get(id, None) for id in id_list] | |
if "metadatas" in include: | |
metadatas.append(_clean_metadatas(metadata_list)) # type: ignore | |
if "documents" in include: | |
doc_list = [_doc(m) for m in metadata_list] | |
documents.append(doc_list) # type: ignore | |
if "uris" in include: | |
uri_list = [_uri(m) for m in metadata_list] | |
uris.append(uri_list) # type: ignore | |
return QueryResult( | |
ids=ids, | |
distances=distances if distances else None, | |
metadatas=metadatas if metadatas else None, | |
embeddings=embeddings if embeddings else None, | |
documents=documents if documents else None, | |
uris=uris if uris else None, | |
data=None, | |
included=include, | |
) | |
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult: | |
add_attributes_to_current_span({"collection_id": str(collection_id)}) | |
return self._get(collection_id, limit=n) # type: ignore | |
def get_version(self) -> str: | |
return __version__ | |
def reset_state(self) -> None: | |
pass | |
def reset(self) -> bool: | |
self._system.reset_state() | |
return True | |
def get_settings(self) -> Settings: | |
return self._settings | |
def get_max_batch_size(self) -> int: | |
return self._producer.max_batch_size | |
# TODO: This could potentially cause race conditions in a distributed version of the | |
# system, since the cache is only local. | |
# TODO: promote collection -> topic to a base class method so that it can be | |
# used for channel assignment in the distributed version of the system. | |
def _validate_embedding_record_set( | |
self, collection: t.Collection, records: List[t.OperationRecord] | |
) -> None: | |
"""Validate the dimension of an embedding record before submitting it to the system.""" | |
add_attributes_to_current_span({"collection_id": str(collection["id"])}) | |
for record in records: | |
if record["embedding"] is not None: | |
self._validate_dimension( | |
collection, len(record["embedding"]), update=True | |
) | |
# This method is intentionally left untraced because otherwise it can emit thousands of spans for requests containing many embeddings. | |
def _validate_dimension( | |
self, collection: t.Collection, dim: int, update: bool | |
) -> None: | |
"""Validate that a collection supports records of the given dimension. If update | |
is true, update the collection if the collection doesn't already have a | |
dimension.""" | |
if collection["dimension"] is None: | |
if update: | |
id = collection.id | |
self._sysdb.update_collection(id=id, dimension=dim) | |
elif collection["dimension"] != dim: | |
raise InvalidDimensionException( | |
f"Embedding dimension {dim} does not match collection dimensionality {collection['dimension']}" | |
) | |
else: | |
return # all is well | |
def _get_collection(self, collection_id: UUID) -> t.Collection: | |
collections = self._sysdb.get_collections(id=collection_id) | |
if not collections or len(collections) == 0: | |
raise InvalidCollectionException( | |
f"Collection {collection_id} does not exist." | |
) | |
return collections[0] | |
def _records( | |
operation: t.Operation, | |
ids: IDs, | |
embeddings: Optional[Embeddings] = None, | |
metadatas: Optional[Metadatas] = None, | |
documents: Optional[Documents] = None, | |
uris: Optional[URIs] = None, | |
) -> Generator[t.OperationRecord, None, None]: | |
"""Convert parallel lists of embeddings, metadatas and documents to a sequence of | |
SubmitEmbeddingRecords""" | |
# Presumes that callers were invoked via Collection model, which means | |
# that we know that the embeddings, metadatas and documents have already been | |
# normalized and are guaranteed to be consistently named lists. | |
if embeddings == []: | |
embeddings = None | |
for i, id in enumerate(ids): | |
metadata = None | |
if metadatas: | |
metadata = metadatas[i] | |
if documents: | |
document = documents[i] | |
if metadata: | |
metadata = {**metadata, "chroma:document": document} | |
else: | |
metadata = {"chroma:document": document} | |
if uris: | |
uri = uris[i] | |
if metadata: | |
metadata = {**metadata, "chroma:uri": uri} | |
else: | |
metadata = {"chroma:uri": uri} | |
record = t.OperationRecord( | |
id=id, | |
embedding=embeddings[i] if embeddings is not None else None, | |
encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now | |
metadata=metadata, | |
operation=operation, | |
) | |
yield record | |
def _doc(metadata: Optional[t.Metadata]) -> Optional[str]: | |
"""Retrieve the document (if any) from a Metadata map""" | |
if metadata and "chroma:document" in metadata: | |
return str(metadata["chroma:document"]) | |
return None | |
def _uri(metadata: Optional[t.Metadata]) -> Optional[str]: | |
"""Retrieve the uri (if any) from a Metadata map""" | |
if metadata and "chroma:uri" in metadata: | |
return str(metadata["chroma:uri"]) | |
return None | |
def _clean_metadatas( | |
metadata: List[Optional[t.Metadata]], | |
) -> List[Optional[t.Metadata]]: | |
"""Remove any chroma-specific metadata keys that the client shouldn't see from a | |
list of metadata maps.""" | |
return [_clean_metadata(m) for m in metadata] | |
def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]: | |
"""Remove any chroma-specific metadata keys that the client shouldn't see from a | |
metadata map.""" | |
if not metadata: | |
return None | |
result = {} | |
for k, v in metadata.items(): | |
if not k.startswith("chroma:"): | |
result[k] = v | |
if len(result) == 0: | |
return None | |
return result | |