himanshud2611's picture
Upload folder using huggingface_hub
60e3a80 verified
from uuid import UUID
from typing import Dict, Optional, Tuple, Union, cast
from chromadb.api.configuration import CollectionConfigurationInternal
from chromadb.api.types import Embedding
import chromadb.proto.chroma_pb2 as proto
from chromadb.types import (
Collection,
LogRecord,
Metadata,
Operation,
RequestVersionContext,
ScalarEncoding,
Segment,
SegmentScope,
SeqId,
OperationRecord,
UpdateMetadata,
Vector,
VectorEmbeddingRecord,
VectorQueryResult,
)
import numpy as np
from numpy.typing import NDArray
# TODO: Unit tests for this file, handling optional states etc
def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> proto.Vector:
if encoding == ScalarEncoding.FLOAT32:
as_bytes = np.array(vector, dtype=np.float32).tobytes()
proto_encoding = proto.ScalarEncoding.FLOAT32
elif encoding == ScalarEncoding.INT32:
as_bytes = np.array(vector, dtype=np.int32).tobytes()
proto_encoding = proto.ScalarEncoding.INT32
else:
raise ValueError(
f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \
or {ScalarEncoding.INT32}"
)
return proto.Vector(dimension=vector.size, vector=as_bytes, encoding=proto_encoding)
def from_proto_vector(vector: proto.Vector) -> Tuple[Embedding, ScalarEncoding]:
encoding = vector.encoding
as_array: Union[NDArray[np.int32], NDArray[np.float32]]
if encoding == proto.ScalarEncoding.FLOAT32:
as_array = np.frombuffer(vector.vector, dtype=np.float32)
out_encoding = ScalarEncoding.FLOAT32
elif encoding == proto.ScalarEncoding.INT32:
as_array = np.frombuffer(vector.vector, dtype=np.int32)
out_encoding = ScalarEncoding.INT32
else:
raise ValueError(
f"Unknown encoding {encoding}, expected one of \
{proto.ScalarEncoding.FLOAT32} or {proto.ScalarEncoding.INT32}"
)
return (as_array, out_encoding)
def from_proto_operation(operation: proto.Operation) -> Operation:
if operation == proto.Operation.ADD:
return Operation.ADD
elif operation == proto.Operation.UPDATE:
return Operation.UPDATE
elif operation == proto.Operation.UPSERT:
return Operation.UPSERT
elif operation == proto.Operation.DELETE:
return Operation.DELETE
else:
# TODO: full error
raise RuntimeError(f"Unknown operation {operation}")
def from_proto_metadata(metadata: proto.UpdateMetadata) -> Optional[Metadata]:
return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False))
def from_proto_update_metadata(
metadata: proto.UpdateMetadata,
) -> Optional[UpdateMetadata]:
return cast(
Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True)
)
def _from_proto_metadata_handle_none(
metadata: proto.UpdateMetadata, is_update: bool
) -> Optional[Union[UpdateMetadata, Metadata]]:
if not metadata.metadata:
return None
out_metadata: Dict[str, Union[str, int, float, bool, None]] = {}
for key, value in metadata.metadata.items():
if value.HasField("bool_value"):
out_metadata[key] = value.bool_value
elif value.HasField("string_value"):
out_metadata[key] = value.string_value
elif value.HasField("int_value"):
out_metadata[key] = value.int_value
elif value.HasField("float_value"):
out_metadata[key] = value.float_value
elif is_update:
out_metadata[key] = None
else:
raise ValueError(f"Metadata key {key} value cannot be None")
return out_metadata
def to_proto_update_metadata(metadata: UpdateMetadata) -> proto.UpdateMetadata:
return proto.UpdateMetadata(
metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()}
)
def from_proto_submit(
operation_record: proto.OperationRecord, seq_id: SeqId
) -> LogRecord:
embedding, encoding = from_proto_vector(operation_record.vector)
record = LogRecord(
log_offset=seq_id,
record=OperationRecord(
id=operation_record.id,
embedding=embedding,
encoding=encoding,
metadata=from_proto_update_metadata(operation_record.metadata),
operation=from_proto_operation(operation_record.operation),
),
)
return record
def from_proto_segment(segment: proto.Segment) -> Segment:
return Segment(
id=UUID(hex=segment.id),
type=segment.type,
scope=from_proto_segment_scope(segment.scope),
collection=UUID(hex=segment.collection),
metadata=from_proto_metadata(segment.metadata)
if segment.HasField("metadata")
else None,
)
def to_proto_segment(segment: Segment) -> proto.Segment:
return proto.Segment(
id=segment["id"].hex,
type=segment["type"],
scope=to_proto_segment_scope(segment["scope"]),
collection=segment["collection"].hex,
metadata=None
if segment["metadata"] is None
else to_proto_update_metadata(segment["metadata"]),
)
def from_proto_segment_scope(segment_scope: proto.SegmentScope) -> SegmentScope:
if segment_scope == proto.SegmentScope.VECTOR:
return SegmentScope.VECTOR
elif segment_scope == proto.SegmentScope.METADATA:
return SegmentScope.METADATA
elif segment_scope == proto.SegmentScope.RECORD:
return SegmentScope.RECORD
else:
raise RuntimeError(f"Unknown segment scope {segment_scope}")
def to_proto_segment_scope(segment_scope: SegmentScope) -> proto.SegmentScope:
if segment_scope == SegmentScope.VECTOR:
return proto.SegmentScope.VECTOR
elif segment_scope == SegmentScope.METADATA:
return proto.SegmentScope.METADATA
elif segment_scope == SegmentScope.RECORD:
return proto.SegmentScope.RECORD
else:
raise RuntimeError(f"Unknown segment scope {segment_scope}")
def to_proto_metadata_update_value(
value: Union[str, int, float, bool, None]
) -> proto.UpdateMetadataValue:
# Be careful with the order here. Since bools are a subtype of int in python,
# isinstance(value, bool) and isinstance(value, int) both return true
# for a value of bool type.
if isinstance(value, bool):
return proto.UpdateMetadataValue(bool_value=value)
elif isinstance(value, str):
return proto.UpdateMetadataValue(string_value=value)
elif isinstance(value, int):
return proto.UpdateMetadataValue(int_value=value)
elif isinstance(value, float):
return proto.UpdateMetadataValue(float_value=value)
# None is used to delete the metadata key.
elif value is None:
return proto.UpdateMetadataValue()
else:
raise ValueError(
f"Unknown metadata value type {type(value)}, expected one of str, int, \
float, or None"
)
def from_proto_collection(collection: proto.Collection) -> Collection:
return Collection(
id=UUID(hex=collection.id),
name=collection.name,
configuration=CollectionConfigurationInternal.from_json_str(
collection.configuration_json_str
),
metadata=from_proto_metadata(collection.metadata)
if collection.HasField("metadata")
else None,
dimension=collection.dimension
if collection.HasField("dimension") and collection.dimension
else None,
database=collection.database,
tenant=collection.tenant,
version=collection.version,
log_position=collection.log_position,
)
def to_proto_collection(collection: Collection) -> proto.Collection:
return proto.Collection(
id=collection["id"].hex,
name=collection["name"],
configuration_json_str=collection.get_configuration().to_json_str(),
metadata=None
if collection["metadata"] is None
else to_proto_update_metadata(collection["metadata"]),
dimension=collection["dimension"],
tenant=collection["tenant"],
database=collection["database"],
version=collection["version"],
)
def to_proto_operation(operation: Operation) -> proto.Operation:
if operation == Operation.ADD:
return proto.Operation.ADD
elif operation == Operation.UPDATE:
return proto.Operation.UPDATE
elif operation == Operation.UPSERT:
return proto.Operation.UPSERT
elif operation == Operation.DELETE:
return proto.Operation.DELETE
else:
raise ValueError(
f"Unknown operation {operation}, expected one of {Operation.ADD}, \
{Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}"
)
def to_proto_submit(
submit_record: OperationRecord,
) -> proto.OperationRecord:
vector = None
if submit_record["embedding"] is not None and submit_record["encoding"] is not None:
vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"])
metadata = None
if submit_record["metadata"] is not None:
metadata = to_proto_update_metadata(submit_record["metadata"])
return proto.OperationRecord(
id=submit_record["id"],
vector=vector,
metadata=metadata,
operation=to_proto_operation(submit_record["operation"]),
)
def from_proto_vector_embedding_record(
embedding_record: proto.VectorEmbeddingRecord,
) -> VectorEmbeddingRecord:
return VectorEmbeddingRecord(
id=embedding_record.id,
embedding=from_proto_vector(embedding_record.vector)[0],
)
def to_proto_vector_embedding_record(
embedding_record: VectorEmbeddingRecord,
encoding: ScalarEncoding,
) -> proto.VectorEmbeddingRecord:
return proto.VectorEmbeddingRecord(
id=embedding_record["id"],
vector=to_proto_vector(embedding_record["embedding"], encoding),
)
def from_proto_vector_query_result(
vector_query_result: proto.VectorQueryResult,
) -> VectorQueryResult:
return VectorQueryResult(
id=vector_query_result.id,
distance=vector_query_result.distance,
embedding=from_proto_vector(vector_query_result.vector)[0],
)
def from_proto_request_version_context(
request_version_context: proto.RequestVersionContext,
) -> RequestVersionContext:
return RequestVersionContext(
collection_version=request_version_context.collection_version,
log_position=request_version_context.log_position,
)
def to_proto_request_version_context(
request_version_context: RequestVersionContext,
) -> proto.RequestVersionContext:
return proto.RequestVersionContext(
collection_version=request_version_context["collection_version"],
log_position=request_version_context["log_position"],
)