himanshud2611's picture
Upload folder using huggingface_hub
60e3a80 verified
from tenacity import retry, stop_after_attempt, retry_if_exception, wait_fixed
from chromadb.api import ServerAPI
from chromadb.api.configuration import CollectionConfigurationInternal
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.db.system import SysDB
from chromadb.quota import QuotaEnforcer
from chromadb.rate_limit import RateLimitEnforcer
from chromadb.segment import SegmentManager, MetadataReader, VectorReader
from chromadb.telemetry.opentelemetry import (
add_attributes_to_current_span,
OpenTelemetryClient,
OpenTelemetryGranularity,
trace_method,
)
from chromadb.telemetry.product import ProductTelemetryClient
from chromadb.ingest import Producer
from chromadb.types import Collection as CollectionModel
from chromadb import __version__
from chromadb.errors import (
InvalidDimensionException,
InvalidCollectionException,
VersionMismatchError,
)
from chromadb.api.types import (
URI,
CollectionMetadata,
Document,
IDs,
Embeddings,
Embedding,
Metadatas,
Documents,
URIs,
Where,
WhereDocument,
Include,
GetResult,
QueryResult,
validate_metadata,
validate_update_metadata,
validate_where,
validate_where_document,
validate_batch,
)
from chromadb.telemetry.product.events import (
CollectionAddEvent,
CollectionDeleteEvent,
CollectionGetEvent,
CollectionUpdateEvent,
CollectionQueryEvent,
ClientCreateCollectionEvent,
)
import chromadb.types as t
from typing import (
Optional,
Sequence,
Generator,
List,
cast,
Set,
Any,
Callable,
TypeVar,
)
from overrides import override
from uuid import UUID, uuid4
from functools import wraps
import time
import logging
import re
T = TypeVar("T", bound=Callable[..., Any])
logger = logging.getLogger(__name__)
# mimics s3 bucket requirements for naming
def check_index_name(index_name: str) -> None:
msg = (
"Expected collection name that "
"(1) contains 3-63 characters, "
"(2) starts and ends with an alphanumeric character, "
"(3) otherwise contains only alphanumeric characters, underscores or hyphens (-), "
"(4) contains no two consecutive periods (..) and "
"(5) is not a valid IPv4 address, "
f"got {index_name}"
)
if len(index_name) < 3 or len(index_name) > 63:
raise ValueError(msg)
if not re.match("^[a-zA-Z0-9][a-zA-Z0-9._-]*[a-zA-Z0-9]$", index_name):
raise ValueError(msg)
if ".." in index_name:
raise ValueError(msg)
if re.match("^[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}\\.[0-9]{1,3}$", index_name):
raise ValueError(msg)
def rate_limit(func: T) -> T:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
self = args[0]
return self._rate_limit_enforcer.rate_limit(func)(*args, **kwargs)
return wrapper # type: ignore
class SegmentAPI(ServerAPI):
"""API implementation utilizing the new segment-based internal architecture"""
_settings: Settings
_sysdb: SysDB
_manager: SegmentManager
_producer: Producer
_product_telemetry_client: ProductTelemetryClient
_opentelemetry_client: OpenTelemetryClient
_tenant_id: str
_topic_ns: str
def __init__(self, system: System):
super().__init__(system)
self._settings = system.settings
self._sysdb = self.require(SysDB)
self._manager = self.require(SegmentManager)
self._quota = self.require(QuotaEnforcer)
self._product_telemetry_client = self.require(ProductTelemetryClient)
self._opentelemetry_client = self.require(OpenTelemetryClient)
self._producer = self.require(Producer)
self._rate_limit_enforcer = self._system.require(RateLimitEnforcer)
@override
def heartbeat(self) -> int:
return int(time.time_ns())
@trace_method("SegmentAPI.create_database", OpenTelemetryGranularity.OPERATION)
@override
def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
if len(name) < 3:
raise ValueError("Database name must be at least 3 characters long")
self._sysdb.create_database(
id=uuid4(),
name=name,
tenant=tenant,
)
@trace_method("SegmentAPI.get_database", OpenTelemetryGranularity.OPERATION)
@override
def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> t.Database:
return self._sysdb.get_database(name=name, tenant=tenant)
@trace_method("SegmentAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
@override
def create_tenant(self, name: str) -> None:
if len(name) < 3:
raise ValueError("Tenant name must be at least 3 characters long")
self._sysdb.create_tenant(
name=name,
)
@trace_method("SegmentAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
@override
def get_tenant(self, name: str) -> t.Tenant:
return self._sysdb.get_tenant(name=name)
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is
# necessary because changing the value type from `Any` to`` `Union[str, int, float]`
# causes the system to somehow convert all values to strings.
@trace_method("SegmentAPI.create_collection", OpenTelemetryGranularity.OPERATION)
@override
def create_collection(
self,
name: str,
configuration: Optional[CollectionConfigurationInternal] = None,
metadata: Optional[CollectionMetadata] = None,
get_or_create: bool = False,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
if metadata is not None:
validate_metadata(metadata)
# TODO: remove backwards compatibility in naming requirements
check_index_name(name)
id = uuid4()
model = CollectionModel(
id=id,
name=name,
metadata=metadata,
configuration=configuration
if configuration is not None
else CollectionConfigurationInternal(), # Use default configuration if none is provided
tenant=tenant,
database=database,
dimension=None,
)
# TODO: Let sysdb create the collection directly from the model
coll, created = self._sysdb.create_collection(
id=model.id,
name=model.name,
configuration=model.get_configuration(),
metadata=model.metadata,
dimension=None, # This is lazily populated on the first add
get_or_create=get_or_create,
tenant=tenant,
database=database,
)
# TODO: wrap sysdb call in try except and log error if it fails
if created:
segments = self._manager.create_segments(coll)
for segment in segments:
self._sysdb.create_segment(segment)
else:
logger.debug(
f"Collection {name} already exists, returning existing collection."
)
# TODO: This event doesn't capture the get_or_create case appropriately
# TODO: Re-enable embedding function tracking in create_collection
self._product_telemetry_client.capture(
ClientCreateCollectionEvent(
collection_uuid=str(id),
# embedding_function=embedding_function.__class__.__name__,
)
)
add_attributes_to_current_span({"collection_uuid": str(id)})
return coll
@trace_method(
"SegmentAPI.get_or_create_collection", OpenTelemetryGranularity.OPERATION
)
@override
def get_or_create_collection(
self,
name: str,
configuration: Optional[CollectionConfigurationInternal] = None,
metadata: Optional[CollectionMetadata] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
return self.create_collection(
name=name,
metadata=metadata,
configuration=configuration,
get_or_create=True,
tenant=tenant,
database=database,
)
# TODO: Actually fix CollectionMetadata type to remove type: ignore flags. This is
# necessary because changing the value type from `Any` to`` `Union[str, int, float]`
# causes the system to somehow convert all values to strings
@trace_method("SegmentAPI.get_collection", OpenTelemetryGranularity.OPERATION)
@override
def get_collection(
self,
name: Optional[str] = None,
id: Optional[UUID] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> CollectionModel:
if id is None and name is None or (id is not None and name is not None):
raise ValueError("Name or id must be specified, but not both")
existing = self._sysdb.get_collections(
id=id, name=name, tenant=tenant, database=database
)
if existing:
return existing[0]
else:
raise InvalidCollectionException(f"Collection {name} does not exist.")
@trace_method("SegmentAPI.list_collection", OpenTelemetryGranularity.OPERATION)
@override
def list_collections(
self,
limit: Optional[int] = None,
offset: Optional[int] = None,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> Sequence[CollectionModel]:
return self._sysdb.get_collections(
limit=limit, offset=offset, tenant=tenant, database=database
)
@trace_method("SegmentAPI.count_collections", OpenTelemetryGranularity.OPERATION)
@override
def count_collections(
self,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> int:
collection_count = len(
self._sysdb.get_collections(tenant=tenant, database=database)
)
return collection_count
@trace_method("SegmentAPI._modify", OpenTelemetryGranularity.OPERATION)
@override
def _modify(
self,
id: UUID,
new_name: Optional[str] = None,
new_metadata: Optional[CollectionMetadata] = None,
) -> None:
if new_name:
# backwards compatibility in naming requirements (for now)
check_index_name(new_name)
if new_metadata:
validate_update_metadata(new_metadata)
# Ensure the collection exists
_ = self._get_collection(id)
# TODO eventually we'll want to use OptionalArgument and Unspecified in the
# signature of `_modify` but not changing the API right now.
if new_name and new_metadata:
self._sysdb.update_collection(id, name=new_name, metadata=new_metadata)
elif new_name:
self._sysdb.update_collection(id, name=new_name)
elif new_metadata:
self._sysdb.update_collection(id, metadata=new_metadata)
@trace_method("SegmentAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
@override
def delete_collection(
self,
name: str,
tenant: str = DEFAULT_TENANT,
database: str = DEFAULT_DATABASE,
) -> None:
existing = self._sysdb.get_collections(
name=name, tenant=tenant, database=database
)
if existing:
self._sysdb.delete_collection(
existing[0].id, tenant=tenant, database=database
)
for s in self._manager.delete_segments(existing[0].id):
self._sysdb.delete_segment(existing[0].id, s)
else:
raise ValueError(f"Collection {name} does not exist.")
@trace_method("SegmentAPI._add", OpenTelemetryGranularity.OPERATION)
@override
def _add(
self,
ids: IDs,
collection_id: UUID,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.ADD)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},
)
records_to_submit = list(
_records(
t.Operation.ADD,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)
self._product_telemetry_client.capture(
CollectionAddEvent(
collection_uuid=str(collection_id),
add_amount=len(ids),
with_metadata=len(ids) if metadatas is not None else 0,
with_documents=len(ids) if documents is not None else 0,
with_uris=len(ids) if uris is not None else 0,
)
)
return True
@trace_method("SegmentAPI._update", OpenTelemetryGranularity.OPERATION)
@override
def _update(
self,
collection_id: UUID,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPDATE)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},
)
records_to_submit = list(
_records(
t.Operation.UPDATE,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)
self._product_telemetry_client.capture(
CollectionUpdateEvent(
collection_uuid=str(collection_id),
update_amount=len(ids),
with_embeddings=len(embeddings) if embeddings else 0,
with_metadata=len(metadatas) if metadatas else 0,
with_documents=len(documents) if documents else 0,
with_uris=len(uris) if uris else 0,
)
)
return True
@trace_method("SegmentAPI._upsert", OpenTelemetryGranularity.OPERATION)
@override
def _upsert(
self,
collection_id: UUID,
ids: IDs,
embeddings: Embeddings,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> bool:
self._quota.static_check(metadatas, documents, embeddings, str(collection_id))
coll = self._get_collection(collection_id)
self._manager.hint_use_collection(collection_id, t.Operation.UPSERT)
validate_batch(
(ids, embeddings, metadatas, documents, uris),
{"max_batch_size": self.get_max_batch_size()},
)
records_to_submit = list(
_records(
t.Operation.UPSERT,
ids=ids,
embeddings=embeddings,
metadatas=metadatas,
documents=documents,
uris=uris,
)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)
return True
@trace_method("SegmentAPI._get", OpenTelemetryGranularity.OPERATION)
@retry( # type: ignore[misc]
retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)),
wait=wait_fixed(2),
stop=stop_after_attempt(5),
reraise=True,
)
@override
def _get(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = {},
sort: Optional[str] = None,
limit: Optional[int] = None,
offset: Optional[int] = None,
page: Optional[int] = None,
page_size: Optional[int] = None,
where_document: Optional[WhereDocument] = {},
include: Include = ["embeddings", "metadatas", "documents"], # type: ignore[list-item]
) -> GetResult:
add_attributes_to_current_span(
{
"collection_id": str(collection_id),
"ids_count": len(ids) if ids else 0,
}
)
coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)
where = validate_where(where) if where is not None and len(where) > 0 else None
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
else None
)
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
if sort is not None:
raise NotImplementedError("Sorting is not yet supported")
if page and page_size:
offset = (page - 1) * page_size
limit = page_size
records = metadata_segment.get_metadata(
where=where,
where_document=where_document,
ids=ids,
limit=limit,
offset=offset,
request_version_context=request_version_context,
)
if len(records) == 0:
# Nothing to return if there are no records
return GetResult(
ids=[],
embeddings=[] if "embeddings" in include else None,
metadatas=[] if "metadatas" in include else None,
documents=[] if "documents" in include else None,
uris=[] if "uris" in include else None,
data=[] if "data" in include else None,
included=include,
)
vectors: Sequence[t.VectorEmbeddingRecord] = []
if "embeddings" in include:
vector_ids = [r["id"] for r in records]
vector_segment = self._manager.get_segment(collection_id, VectorReader)
vectors = vector_segment.get_vectors(
ids=vector_ids, request_version_context=request_version_context
)
# TODO: Fix type so we don't need to ignore
# It is possible to have a set of records, some with metadata and some without
# Same with documents
metadatas = [r["metadata"] for r in records]
if "documents" in include:
documents = [_doc(m) for m in metadatas]
if "uris" in include:
uris = [_uri(m) for m in metadatas]
ids_amount = len(ids) if ids else 0
self._product_telemetry_client.capture(
CollectionGetEvent(
collection_uuid=str(collection_id),
ids_count=ids_amount,
limit=limit if limit else 0,
include_metadata=ids_amount if "metadatas" in include else 0,
include_documents=ids_amount if "documents" in include else 0,
include_uris=ids_amount if "uris" in include else 0,
)
)
return GetResult(
ids=[r["id"] for r in records],
embeddings=[r["embedding"] for r in vectors]
if "embeddings" in include
else None,
metadatas=_clean_metadatas(metadatas)
if "metadatas" in include
else None, # type: ignore
documents=documents if "documents" in include else None, # type: ignore
uris=uris if "uris" in include else None, # type: ignore
data=None,
included=include,
)
@trace_method("SegmentAPI._delete", OpenTelemetryGranularity.OPERATION)
@override
def _delete(
self,
collection_id: UUID,
ids: Optional[IDs] = None,
where: Optional[Where] = None,
where_document: Optional[WhereDocument] = None,
) -> None:
add_attributes_to_current_span(
{
"collection_id": str(collection_id),
"ids_count": len(ids) if ids else 0,
}
)
where = validate_where(where) if where is not None and len(where) > 0 else None
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
else None
)
# You must have at least one of non-empty ids, where, or where_document.
if (
(ids is None or (ids is not None and len(ids) == 0))
and (where is None or (where is not None and len(where) == 0))
and (
where_document is None
or (where_document is not None and len(where_document) == 0)
)
):
raise ValueError(
"""
You must provide either ids, where, or where_document to delete. If
you want to delete all data in a collection you can delete the
collection itself using the delete_collection method. Or alternatively,
you can get() all the relevant ids and then delete them.
"""
)
coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)
self._manager.hint_use_collection(collection_id, t.Operation.DELETE)
if (where or where_document) or not ids:
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
records = metadata_segment.get_metadata(
where=where,
where_document=where_document,
ids=ids,
request_version_context=request_version_context,
)
ids_to_delete = [r["id"] for r in records]
else:
ids_to_delete = ids
if len(ids_to_delete) == 0:
return
records_to_submit = list(
_records(operation=t.Operation.DELETE, ids=ids_to_delete)
)
self._validate_embedding_record_set(coll, records_to_submit)
self._producer.submit_embeddings(collection_id, records_to_submit)
self._product_telemetry_client.capture(
CollectionDeleteEvent(
collection_uuid=str(collection_id), delete_amount=len(ids_to_delete)
)
)
@trace_method("SegmentAPI._count", OpenTelemetryGranularity.OPERATION)
@retry( # type: ignore[misc]
retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)),
wait=wait_fixed(2),
stop=stop_after_attempt(5),
reraise=True,
)
@override
def _count(self, collection_id: UUID) -> int:
add_attributes_to_current_span({"collection_id": str(collection_id)})
coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)
metadata_segment = self._manager.get_segment(collection_id, MetadataReader)
return metadata_segment.count(request_version_context)
@trace_method("SegmentAPI._query", OpenTelemetryGranularity.OPERATION)
# We retry on version mismatch errors because the version of the collection
# may have changed between the time we got the version and the time we
# actually query the collection on the FE. We are fine with fixed
# wait time because the version mismatch error is not a error due to
# network issues or other transient issues. It is a result of the
# collection being updated between the time we got the version and
# the time we actually query the collection on the FE.
@retry( # type: ignore[misc]
retry=retry_if_exception(lambda e: isinstance(e, VersionMismatchError)),
wait=wait_fixed(2),
stop=stop_after_attempt(5),
reraise=True,
)
@override
def _query(
self,
collection_id: UUID,
query_embeddings: Embeddings,
n_results: int = 10,
where: Where = {},
where_document: WhereDocument = {},
include: Include = ["documents", "metadatas", "distances"], # type: ignore[list-item]
) -> QueryResult:
add_attributes_to_current_span(
{
"collection_id": str(collection_id),
"n_results": n_results,
"where": str(where),
}
)
query_amount = len(query_embeddings)
self._product_telemetry_client.capture(
CollectionQueryEvent(
collection_uuid=str(collection_id),
query_amount=query_amount,
n_results=n_results,
with_metadata_filter=query_amount if where is not None else 0,
with_document_filter=query_amount if where_document is not None else 0,
include_metadatas=query_amount if "metadatas" in include else 0,
include_documents=query_amount if "documents" in include else 0,
include_uris=query_amount if "uris" in include else 0,
include_distances=query_amount if "distances" in include else 0,
)
)
where = validate_where(where) if where is not None and len(where) > 0 else where
where_document = (
validate_where_document(where_document)
if where_document is not None and len(where_document) > 0
else where_document
)
allowed_ids = None
coll = self._get_collection(collection_id)
request_version_context = t.RequestVersionContext(
collection_version=coll.version,
log_position=coll.log_position,
)
for embedding in query_embeddings:
self._validate_dimension(coll, len(embedding), update=False)
if where or where_document:
metadata_reader = self._manager.get_segment(collection_id, MetadataReader)
records = metadata_reader.get_metadata(
where=where,
where_document=where_document,
include_metadata=False,
request_version_context=request_version_context,
)
allowed_ids = [r["id"] for r in records]
ids: List[List[str]] = []
distances: List[List[float]] = []
embeddings: List[Embeddings] = []
documents: List[List[Document]] = []
uris: List[List[URI]] = []
metadatas: List[List[t.Metadata]] = []
# If where conditions returned empty list then no need to proceed
# further and can simply return an empty result set here.
if allowed_ids is not None and allowed_ids == []:
for em in range(len(query_embeddings)):
ids.append([])
if "distances" in include:
distances.append([])
if "embeddings" in include:
embeddings.append([])
if "documents" in include:
documents.append([])
if "metadatas" in include:
metadatas.append([])
if "uris" in include:
uris.append([])
else:
query = t.VectorQuery(
vectors=query_embeddings,
k=n_results,
allowed_ids=allowed_ids,
include_embeddings="embeddings" in include,
options=None,
request_version_context=request_version_context,
)
vector_reader = self._manager.get_segment(collection_id, VectorReader)
results = vector_reader.query_vectors(query)
for result in results:
ids.append([r["id"] for r in result])
if "distances" in include:
distances.append([r["distance"] for r in result])
if "embeddings" in include:
embeddings.append([cast(Embedding, r["embedding"]) for r in result])
if "documents" in include or "metadatas" in include or "uris" in include:
all_ids: Set[str] = set()
for id_list in ids:
all_ids.update(id_list)
metadata_reader = self._manager.get_segment(
collection_id, MetadataReader
)
records = metadata_reader.get_metadata(
ids=list(all_ids),
include_metadata=True,
request_version_context=request_version_context,
)
metadata_by_id = {r["id"]: r["metadata"] for r in records}
for id_list in ids:
# In the segment based architecture, it is possible for one segment
# to have a record that another segment does not have. This results in
# data inconsistency. For the case of the local segments and the
# local segment manager, there is a case where a thread writes
# a record to the vector segment but not the metadata segment.
# Then a query'ing thread reads from the vector segment and
# queries the metadata segment. The metadata segment does not have
# the record. In this case we choose to return potentially
# incorrect data in the form of None.
metadata_list = [metadata_by_id.get(id, None) for id in id_list]
if "metadatas" in include:
metadatas.append(_clean_metadatas(metadata_list)) # type: ignore
if "documents" in include:
doc_list = [_doc(m) for m in metadata_list]
documents.append(doc_list) # type: ignore
if "uris" in include:
uri_list = [_uri(m) for m in metadata_list]
uris.append(uri_list) # type: ignore
return QueryResult(
ids=ids,
distances=distances if distances else None,
metadatas=metadatas if metadatas else None,
embeddings=embeddings if embeddings else None,
documents=documents if documents else None,
uris=uris if uris else None,
data=None,
included=include,
)
@trace_method("SegmentAPI._peek", OpenTelemetryGranularity.OPERATION)
@override
def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
add_attributes_to_current_span({"collection_id": str(collection_id)})
return self._get(collection_id, limit=n) # type: ignore
@override
def get_version(self) -> str:
return __version__
@override
def reset_state(self) -> None:
pass
@override
def reset(self) -> bool:
self._system.reset_state()
return True
@override
def get_settings(self) -> Settings:
return self._settings
@override
def get_max_batch_size(self) -> int:
return self._producer.max_batch_size
# TODO: This could potentially cause race conditions in a distributed version of the
# system, since the cache is only local.
# TODO: promote collection -> topic to a base class method so that it can be
# used for channel assignment in the distributed version of the system.
@trace_method(
"SegmentAPI._validate_embedding_record_set", OpenTelemetryGranularity.ALL
)
def _validate_embedding_record_set(
self, collection: t.Collection, records: List[t.OperationRecord]
) -> None:
"""Validate the dimension of an embedding record before submitting it to the system."""
add_attributes_to_current_span({"collection_id": str(collection["id"])})
for record in records:
if record["embedding"] is not None:
self._validate_dimension(
collection, len(record["embedding"]), update=True
)
# This method is intentionally left untraced because otherwise it can emit thousands of spans for requests containing many embeddings.
def _validate_dimension(
self, collection: t.Collection, dim: int, update: bool
) -> None:
"""Validate that a collection supports records of the given dimension. If update
is true, update the collection if the collection doesn't already have a
dimension."""
if collection["dimension"] is None:
if update:
id = collection.id
self._sysdb.update_collection(id=id, dimension=dim)
elif collection["dimension"] != dim:
raise InvalidDimensionException(
f"Embedding dimension {dim} does not match collection dimensionality {collection['dimension']}"
)
else:
return # all is well
@trace_method("SegmentAPI._get_collection", OpenTelemetryGranularity.ALL)
def _get_collection(self, collection_id: UUID) -> t.Collection:
collections = self._sysdb.get_collections(id=collection_id)
if not collections or len(collections) == 0:
raise InvalidCollectionException(
f"Collection {collection_id} does not exist."
)
return collections[0]
def _records(
operation: t.Operation,
ids: IDs,
embeddings: Optional[Embeddings] = None,
metadatas: Optional[Metadatas] = None,
documents: Optional[Documents] = None,
uris: Optional[URIs] = None,
) -> Generator[t.OperationRecord, None, None]:
"""Convert parallel lists of embeddings, metadatas and documents to a sequence of
SubmitEmbeddingRecords"""
# Presumes that callers were invoked via Collection model, which means
# that we know that the embeddings, metadatas and documents have already been
# normalized and are guaranteed to be consistently named lists.
if embeddings == []:
embeddings = None
for i, id in enumerate(ids):
metadata = None
if metadatas:
metadata = metadatas[i]
if documents:
document = documents[i]
if metadata:
metadata = {**metadata, "chroma:document": document}
else:
metadata = {"chroma:document": document}
if uris:
uri = uris[i]
if metadata:
metadata = {**metadata, "chroma:uri": uri}
else:
metadata = {"chroma:uri": uri}
record = t.OperationRecord(
id=id,
embedding=embeddings[i] if embeddings is not None else None,
encoding=t.ScalarEncoding.FLOAT32, # Hardcode for now
metadata=metadata,
operation=operation,
)
yield record
def _doc(metadata: Optional[t.Metadata]) -> Optional[str]:
"""Retrieve the document (if any) from a Metadata map"""
if metadata and "chroma:document" in metadata:
return str(metadata["chroma:document"])
return None
def _uri(metadata: Optional[t.Metadata]) -> Optional[str]:
"""Retrieve the uri (if any) from a Metadata map"""
if metadata and "chroma:uri" in metadata:
return str(metadata["chroma:uri"])
return None
def _clean_metadatas(
metadata: List[Optional[t.Metadata]],
) -> List[Optional[t.Metadata]]:
"""Remove any chroma-specific metadata keys that the client shouldn't see from a
list of metadata maps."""
return [_clean_metadata(m) for m in metadata]
def _clean_metadata(metadata: Optional[t.Metadata]) -> Optional[t.Metadata]:
"""Remove any chroma-specific metadata keys that the client shouldn't see from a
metadata map."""
if not metadata:
return None
result = {}
for k, v in metadata.items():
if not k.startswith("chroma:"):
result[k] = v
if len(result) == 0:
return None
return result