from uuid import UUID from typing import Dict, Optional, Tuple, Union, cast from chromadb.api.configuration import CollectionConfigurationInternal from chromadb.api.types import Embedding import chromadb.proto.chroma_pb2 as proto from chromadb.types import ( Collection, LogRecord, Metadata, Operation, RequestVersionContext, ScalarEncoding, Segment, SegmentScope, SeqId, OperationRecord, UpdateMetadata, Vector, VectorEmbeddingRecord, VectorQueryResult, ) import numpy as np from numpy.typing import NDArray # TODO: Unit tests for this file, handling optional states etc def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> proto.Vector: if encoding == ScalarEncoding.FLOAT32: as_bytes = np.array(vector, dtype=np.float32).tobytes() proto_encoding = proto.ScalarEncoding.FLOAT32 elif encoding == ScalarEncoding.INT32: as_bytes = np.array(vector, dtype=np.int32).tobytes() proto_encoding = proto.ScalarEncoding.INT32 else: raise ValueError( f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \ or {ScalarEncoding.INT32}" ) return proto.Vector(dimension=vector.size, vector=as_bytes, encoding=proto_encoding) def from_proto_vector(vector: proto.Vector) -> Tuple[Embedding, ScalarEncoding]: encoding = vector.encoding as_array: Union[NDArray[np.int32], NDArray[np.float32]] if encoding == proto.ScalarEncoding.FLOAT32: as_array = np.frombuffer(vector.vector, dtype=np.float32) out_encoding = ScalarEncoding.FLOAT32 elif encoding == proto.ScalarEncoding.INT32: as_array = np.frombuffer(vector.vector, dtype=np.int32) out_encoding = ScalarEncoding.INT32 else: raise ValueError( f"Unknown encoding {encoding}, expected one of \ {proto.ScalarEncoding.FLOAT32} or {proto.ScalarEncoding.INT32}" ) return (as_array, out_encoding) def from_proto_operation(operation: proto.Operation) -> Operation: if operation == proto.Operation.ADD: return Operation.ADD elif operation == proto.Operation.UPDATE: return Operation.UPDATE elif operation == proto.Operation.UPSERT: return Operation.UPSERT elif operation == proto.Operation.DELETE: return Operation.DELETE else: # TODO: full error raise RuntimeError(f"Unknown operation {operation}") def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]: return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False)) def from_proto_update_metadata( metadata: proto.UpdateMetadata, ) -> Optional[UpdateMetadata]: return cast( Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True) ) def _from_proto_metadata_handle_none( metadata: proto.UpdateMetadata, is_update: bool ) -> Optional[Union[UpdateMetadata, Metadata]]: if not metadata.metadata: return None out_metadata: Dict[str, Union[str, int, float, bool, None]] = {} for key, value in metadata.metadata.items(): if value.HasField("bool_value"): out_metadata[key] = value.bool_value elif value.HasField("string_value"): out_metadata[key] = value.string_value elif value.HasField("int_value"): out_metadata[key] = value.int_value elif value.HasField("float_value"): out_metadata[key] = value.float_value elif is_update: out_metadata[key] = None else: raise ValueError(f"Metadata key {key} value cannot be None") return out_metadata def to_proto_update_metadata(metadata: UpdateMetadata) -> proto.UpdateMetadata: return proto.UpdateMetadata( metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()} ) def from_proto_submit( operation_record: proto.OperationRecord, seq_id: SeqId ) -> LogRecord: embedding, encoding = from_proto_vector(operation_record.vector) record = LogRecord( log_offset=seq_id, record=OperationRecord( id=operation_record.id, embedding=embedding, encoding=encoding, metadata=from_proto_update_metadata(operation_record.metadata), operation=from_proto_operation(operation_record.operation), ), ) return record def from_proto_segment(segment: proto.Segment) -> Segment: return Segment( id=UUID(hex=segment.id), type=segment.type, scope=from_proto_segment_scope(segment.scope), collection=UUID(hex=segment.collection), metadata=from_proto_metadata(segment.metadata) if segment.HasField("metadata") else None, ) def to_proto_segment(segment: Segment) -> proto.Segment: return proto.Segment( id=segment["id"].hex, type=segment["type"], scope=to_proto_segment_scope(segment["scope"]), collection=segment["collection"].hex, metadata=None if segment["metadata"] is None else to_proto_update_metadata(segment["metadata"]), ) def from_proto_segment_scope(segment_scope: proto.SegmentScope) -> SegmentScope: if segment_scope == proto.SegmentScope.VECTOR: return SegmentScope.VECTOR elif segment_scope == proto.SegmentScope.METADATA: return SegmentScope.METADATA elif segment_scope == proto.SegmentScope.RECORD: return SegmentScope.RECORD else: raise RuntimeError(f"Unknown segment scope {segment_scope}") def to_proto_segment_scope(segment_scope: SegmentScope) -> proto.SegmentScope: if segment_scope == SegmentScope.VECTOR: return proto.SegmentScope.VECTOR elif segment_scope == SegmentScope.METADATA: return proto.SegmentScope.METADATA elif segment_scope == SegmentScope.RECORD: return proto.SegmentScope.RECORD else: raise RuntimeError(f"Unknown segment scope {segment_scope}") def to_proto_metadata_update_value( value: Union[str, int, float, bool, None] ) -> proto.UpdateMetadataValue: # Be careful with the order here. Since bools are a subtype of int in python, # isinstance(value, bool) and isinstance(value, int) both return true # for a value of bool type. if isinstance(value, bool): return proto.UpdateMetadataValue(bool_value=value) elif isinstance(value, str): return proto.UpdateMetadataValue(string_value=value) elif isinstance(value, int): return proto.UpdateMetadataValue(int_value=value) elif isinstance(value, float): return proto.UpdateMetadataValue(float_value=value) # None is used to delete the metadata key. elif value is None: return proto.UpdateMetadataValue() else: raise ValueError( f"Unknown metadata value type {type(value)}, expected one of str, int, \ float, or None" ) def from_proto_collection(collection: proto.Collection) -> Collection: return Collection( id=UUID(hex=collection.id), name=collection.name, configuration=CollectionConfigurationInternal.from_json_str( collection.configuration_json_str ), metadata=from_proto_metadata(collection.metadata) if collection.HasField("metadata") else None, dimension=collection.dimension if collection.HasField("dimension") and collection.dimension else None, database=collection.database, tenant=collection.tenant, version=collection.version, log_position=collection.log_position, ) def to_proto_collection(collection: Collection) -> proto.Collection: return proto.Collection( id=collection["id"].hex, name=collection["name"], configuration_json_str=collection.get_configuration().to_json_str(), metadata=None if collection["metadata"] is None else to_proto_update_metadata(collection["metadata"]), dimension=collection["dimension"], tenant=collection["tenant"], database=collection["database"], version=collection["version"], ) def to_proto_operation(operation: Operation) -> proto.Operation: if operation == Operation.ADD: return proto.Operation.ADD elif operation == Operation.UPDATE: return proto.Operation.UPDATE elif operation == Operation.UPSERT: return proto.Operation.UPSERT elif operation == Operation.DELETE: return proto.Operation.DELETE else: raise ValueError( f"Unknown operation {operation}, expected one of {Operation.ADD}, \ {Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}" ) def to_proto_submit( submit_record: OperationRecord, ) -> proto.OperationRecord: vector = None if submit_record["embedding"] is not None and submit_record["encoding"] is not None: vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"]) metadata = None if submit_record["metadata"] is not None: metadata = to_proto_update_metadata(submit_record["metadata"]) return proto.OperationRecord( id=submit_record["id"], vector=vector, metadata=metadata, operation=to_proto_operation(submit_record["operation"]), ) def from_proto_vector_embedding_record( embedding_record: proto.VectorEmbeddingRecord, ) -> VectorEmbeddingRecord: return VectorEmbeddingRecord( id=embedding_record.id, embedding=from_proto_vector(embedding_record.vector)[0], ) def to_proto_vector_embedding_record( embedding_record: VectorEmbeddingRecord, encoding: ScalarEncoding, ) -> proto.VectorEmbeddingRecord: return proto.VectorEmbeddingRecord( id=embedding_record["id"], vector=to_proto_vector(embedding_record["embedding"], encoding), ) def from_proto_vector_query_result( vector_query_result: proto.VectorQueryResult, ) -> VectorQueryResult: return VectorQueryResult( id=vector_query_result.id, distance=vector_query_result.distance, embedding=from_proto_vector(vector_query_result.vector)[0], ) def from_proto_request_version_context( request_version_context: proto.RequestVersionContext, ) -> RequestVersionContext: return RequestVersionContext( collection_version=request_version_context.collection_version, log_position=request_version_context.log_position, ) def to_proto_request_version_context( request_version_context: RequestVersionContext, ) -> proto.RequestVersionContext: return proto.RequestVersionContext( collection_version=request_version_context["collection_version"], log_position=request_version_context["log_position"], )