|
from concurrent import futures |
|
from typing import Any, Dict, cast |
|
from uuid import UUID |
|
from overrides import overrides |
|
from chromadb.api.configuration import CollectionConfigurationInternal |
|
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System |
|
from chromadb.proto.convert import ( |
|
from_proto_metadata, |
|
from_proto_update_metadata, |
|
from_proto_segment, |
|
from_proto_segment_scope, |
|
to_proto_collection, |
|
to_proto_segment, |
|
) |
|
import chromadb.proto.chroma_pb2 as proto |
|
from chromadb.proto.coordinator_pb2 import ( |
|
CreateCollectionRequest, |
|
CreateCollectionResponse, |
|
CreateDatabaseRequest, |
|
CreateDatabaseResponse, |
|
CreateSegmentRequest, |
|
CreateSegmentResponse, |
|
CreateTenantRequest, |
|
CreateTenantResponse, |
|
DeleteCollectionRequest, |
|
DeleteCollectionResponse, |
|
DeleteSegmentRequest, |
|
DeleteSegmentResponse, |
|
GetCollectionsRequest, |
|
GetCollectionsResponse, |
|
GetDatabaseRequest, |
|
GetDatabaseResponse, |
|
GetSegmentsRequest, |
|
GetSegmentsResponse, |
|
GetTenantRequest, |
|
GetTenantResponse, |
|
ResetStateResponse, |
|
UpdateCollectionRequest, |
|
UpdateCollectionResponse, |
|
UpdateSegmentRequest, |
|
UpdateSegmentResponse, |
|
) |
|
from chromadb.proto.coordinator_pb2_grpc import ( |
|
SysDBServicer, |
|
add_SysDBServicer_to_server, |
|
) |
|
import grpc |
|
from google.protobuf.empty_pb2 import Empty |
|
from chromadb.types import Collection, Metadata, Segment |
|
|
|
|
|
class GrpcMockSysDB(SysDBServicer, Component): |
|
"""A mock sysdb implementation that can be used for testing the grpc client. It stores |
|
state in simple python data structures instead of a database.""" |
|
|
|
_server: grpc.Server |
|
_server_port: int |
|
_segments: Dict[str, Segment] = {} |
|
_tenants_to_databases_to_collections: Dict[ |
|
str, Dict[str, Dict[str, Collection]] |
|
] = {} |
|
_tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {} |
|
|
|
def __init__(self, system: System): |
|
self._server_port = system.settings.require("chroma_server_grpc_port") |
|
return super().__init__(system) |
|
|
|
@overrides |
|
def start(self) -> None: |
|
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) |
|
add_SysDBServicer_to_server(self, self._server) |
|
self._server.add_insecure_port(f"[::]:{self._server_port}") |
|
self._server.start() |
|
return super().start() |
|
|
|
@overrides |
|
def stop(self) -> None: |
|
self._server.stop(None) |
|
return super().stop() |
|
|
|
@overrides |
|
def reset_state(self) -> None: |
|
self._segments = {} |
|
self._tenants_to_databases_to_collections = {} |
|
|
|
self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {} |
|
self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {} |
|
self._tenants_to_database_to_id[DEFAULT_TENANT] = {} |
|
self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0) |
|
return super().reset_state() |
|
|
|
@overrides(check_signature=False) |
|
def CreateDatabase( |
|
self, request: CreateDatabaseRequest, context: grpc.ServicerContext |
|
) -> CreateDatabaseResponse: |
|
tenant = request.tenant |
|
database = request.name |
|
if tenant not in self._tenants_to_databases_to_collections: |
|
return CreateDatabaseResponse( |
|
status=proto.Status(code=404, reason=f"Tenant {tenant} not found") |
|
) |
|
if database in self._tenants_to_databases_to_collections[tenant]: |
|
return CreateDatabaseResponse( |
|
status=proto.Status( |
|
code=409, reason=f"Database {database} already exists" |
|
) |
|
) |
|
self._tenants_to_databases_to_collections[tenant][database] = {} |
|
self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id) |
|
return CreateDatabaseResponse(status=proto.Status(code=200)) |
|
|
|
@overrides(check_signature=False) |
|
def GetDatabase( |
|
self, request: GetDatabaseRequest, context: grpc.ServicerContext |
|
) -> GetDatabaseResponse: |
|
tenant = request.tenant |
|
database = request.name |
|
if tenant not in self._tenants_to_databases_to_collections: |
|
return GetDatabaseResponse( |
|
status=proto.Status(code=404, reason=f"Tenant {tenant} not found") |
|
) |
|
if database not in self._tenants_to_databases_to_collections[tenant]: |
|
return GetDatabaseResponse( |
|
status=proto.Status(code=404, reason=f"Database {database} not found") |
|
) |
|
id = self._tenants_to_database_to_id[tenant][database] |
|
return GetDatabaseResponse( |
|
status=proto.Status(code=200), |
|
database=proto.Database(id=id.hex, name=database, tenant=tenant), |
|
) |
|
|
|
@overrides(check_signature=False) |
|
def CreateTenant( |
|
self, request: CreateTenantRequest, context: grpc.ServicerContext |
|
) -> CreateTenantResponse: |
|
tenant = request.name |
|
if tenant in self._tenants_to_databases_to_collections: |
|
return CreateTenantResponse( |
|
status=proto.Status(code=409, reason=f"Tenant {tenant} already exists") |
|
) |
|
self._tenants_to_databases_to_collections[tenant] = {} |
|
self._tenants_to_database_to_id[tenant] = {} |
|
return CreateTenantResponse(status=proto.Status(code=200)) |
|
|
|
@overrides(check_signature=False) |
|
def GetTenant( |
|
self, request: GetTenantRequest, context: grpc.ServicerContext |
|
) -> GetTenantResponse: |
|
tenant = request.name |
|
if tenant not in self._tenants_to_databases_to_collections: |
|
return GetTenantResponse( |
|
status=proto.Status(code=404, reason=f"Tenant {tenant} not found") |
|
) |
|
return GetTenantResponse( |
|
status=proto.Status(code=200), |
|
tenant=proto.Tenant(name=tenant), |
|
) |
|
|
|
|
|
|
|
|
|
@overrides(check_signature=False) |
|
def CreateSegment( |
|
self, request: CreateSegmentRequest, context: grpc.ServicerContext |
|
) -> CreateSegmentResponse: |
|
segment = from_proto_segment(request.segment) |
|
if segment["id"].hex in self._segments: |
|
return CreateSegmentResponse( |
|
status=proto.Status( |
|
code=409, reason=f"Segment {segment['id']} already exists" |
|
) |
|
) |
|
self._segments[segment["id"].hex] = segment |
|
return CreateSegmentResponse( |
|
status=proto.Status(code=200) |
|
) |
|
|
|
@overrides(check_signature=False) |
|
def DeleteSegment( |
|
self, request: DeleteSegmentRequest, context: grpc.ServicerContext |
|
) -> DeleteSegmentResponse: |
|
id_to_delete = request.id |
|
if id_to_delete in self._segments: |
|
del self._segments[id_to_delete] |
|
return DeleteSegmentResponse(status=proto.Status(code=200)) |
|
else: |
|
return DeleteSegmentResponse( |
|
status=proto.Status( |
|
code=404, reason=f"Segment {id_to_delete} not found" |
|
) |
|
) |
|
|
|
@overrides(check_signature=False) |
|
def GetSegments( |
|
self, request: GetSegmentsRequest, context: grpc.ServicerContext |
|
) -> GetSegmentsResponse: |
|
target_id = UUID(hex=request.id) if request.HasField("id") else None |
|
target_type = request.type if request.HasField("type") else None |
|
target_scope = ( |
|
from_proto_segment_scope(request.scope) |
|
if request.HasField("scope") |
|
else None |
|
) |
|
target_collection = UUID(hex=request.collection) |
|
|
|
found_segments = [] |
|
for segment in self._segments.values(): |
|
if target_id and segment["id"] != target_id: |
|
continue |
|
if target_type and segment["type"] != target_type: |
|
continue |
|
if target_scope and segment["scope"] != target_scope: |
|
continue |
|
if target_collection and segment["collection"] != target_collection: |
|
continue |
|
found_segments.append(segment) |
|
return GetSegmentsResponse( |
|
segments=[to_proto_segment(segment) for segment in found_segments] |
|
) |
|
|
|
@overrides(check_signature=False) |
|
def UpdateSegment( |
|
self, request: UpdateSegmentRequest, context: grpc.ServicerContext |
|
) -> UpdateSegmentResponse: |
|
id_to_update = UUID(request.id) |
|
if id_to_update.hex not in self._segments: |
|
return UpdateSegmentResponse( |
|
status=proto.Status( |
|
code=404, reason=f"Segment {id_to_update} not found" |
|
) |
|
) |
|
else: |
|
segment = self._segments[id_to_update.hex] |
|
if request.HasField("metadata"): |
|
target = cast(Dict[str, Any], segment["metadata"]) |
|
if segment["metadata"] is None: |
|
segment["metadata"] = {} |
|
self._merge_metadata(target, request.metadata) |
|
if request.HasField("reset_metadata") and request.reset_metadata: |
|
segment["metadata"] = {} |
|
return UpdateSegmentResponse(status=proto.Status(code=200)) |
|
|
|
@overrides(check_signature=False) |
|
def CreateCollection( |
|
self, request: CreateCollectionRequest, context: grpc.ServicerContext |
|
) -> CreateCollectionResponse: |
|
collection_name = request.name |
|
tenant = request.tenant |
|
database = request.database |
|
if tenant not in self._tenants_to_databases_to_collections: |
|
return CreateCollectionResponse( |
|
status=proto.Status(code=404, reason=f"Tenant {tenant} not found") |
|
) |
|
if database not in self._tenants_to_databases_to_collections[tenant]: |
|
return CreateCollectionResponse( |
|
status=proto.Status(code=404, reason=f"Database {database} not found") |
|
) |
|
|
|
|
|
for ( |
|
search_tenant, |
|
databases, |
|
) in self._tenants_to_databases_to_collections.items(): |
|
for search_database, search_collections in databases.items(): |
|
if request.id in search_collections: |
|
if ( |
|
search_tenant != request.tenant |
|
or search_database != request.database |
|
): |
|
return CreateCollectionResponse( |
|
status=proto.Status( |
|
code=409, |
|
reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", |
|
) |
|
) |
|
elif not request.get_or_create: |
|
|
|
|
|
return CreateCollectionResponse( |
|
status=proto.Status( |
|
code=409, |
|
reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", |
|
) |
|
) |
|
|
|
|
|
collections = self._tenants_to_databases_to_collections[tenant][database] |
|
matches = [c for c in collections.values() if c["name"] == collection_name] |
|
assert len(matches) <= 1 |
|
if len(matches) > 0: |
|
if request.get_or_create: |
|
existing_collection = matches[0] |
|
return CreateCollectionResponse( |
|
status=proto.Status(code=200), |
|
collection=to_proto_collection(existing_collection), |
|
created=False, |
|
) |
|
return CreateCollectionResponse( |
|
status=proto.Status( |
|
code=409, reason=f"Collection {request.name} already exists" |
|
) |
|
) |
|
|
|
configuration = CollectionConfigurationInternal.from_json_str( |
|
request.configuration_json_str |
|
) |
|
|
|
id = UUID(hex=request.id) |
|
new_collection = Collection( |
|
id=id, |
|
name=request.name, |
|
configuration=configuration, |
|
metadata=from_proto_metadata(request.metadata), |
|
dimension=request.dimension, |
|
database=database, |
|
tenant=tenant, |
|
version=0, |
|
) |
|
collections[request.id] = new_collection |
|
return CreateCollectionResponse( |
|
status=proto.Status(code=200), |
|
collection=to_proto_collection(new_collection), |
|
created=True, |
|
) |
|
|
|
@overrides(check_signature=False) |
|
def DeleteCollection( |
|
self, request: DeleteCollectionRequest, context: grpc.ServicerContext |
|
) -> DeleteCollectionResponse: |
|
collection_id = request.id |
|
tenant = request.tenant |
|
database = request.database |
|
if tenant not in self._tenants_to_databases_to_collections: |
|
return DeleteCollectionResponse( |
|
status=proto.Status(code=404, reason=f"Tenant {tenant} not found") |
|
) |
|
if database not in self._tenants_to_databases_to_collections[tenant]: |
|
return DeleteCollectionResponse( |
|
status=proto.Status(code=404, reason=f"Database {database} not found") |
|
) |
|
collections = self._tenants_to_databases_to_collections[tenant][database] |
|
if collection_id in collections: |
|
del collections[collection_id] |
|
return DeleteCollectionResponse(status=proto.Status(code=200)) |
|
else: |
|
return DeleteCollectionResponse( |
|
status=proto.Status( |
|
code=404, reason=f"Collection {collection_id} not found" |
|
) |
|
) |
|
|
|
@overrides(check_signature=False) |
|
def GetCollections( |
|
self, request: GetCollectionsRequest, context: grpc.ServicerContext |
|
) -> GetCollectionsResponse: |
|
target_id = UUID(hex=request.id) if request.HasField("id") else None |
|
target_name = request.name if request.HasField("name") else None |
|
|
|
allCollections = {} |
|
for tenant, databases in self._tenants_to_databases_to_collections.items(): |
|
for database, collections in databases.items(): |
|
if request.tenant != "" and tenant != request.tenant: |
|
continue |
|
if request.database != "" and database != request.database: |
|
continue |
|
allCollections.update(collections) |
|
print( |
|
f"Tenant: {tenant}, Database: {database}, Collections: {collections}" |
|
) |
|
found_collections = [] |
|
for collection in allCollections.values(): |
|
if target_id and collection["id"] != target_id: |
|
continue |
|
if target_name and collection["name"] != target_name: |
|
continue |
|
found_collections.append(collection) |
|
return GetCollectionsResponse( |
|
collections=[ |
|
to_proto_collection(collection) for collection in found_collections |
|
] |
|
) |
|
|
|
@overrides(check_signature=False) |
|
def UpdateCollection( |
|
self, request: UpdateCollectionRequest, context: grpc.ServicerContext |
|
) -> UpdateCollectionResponse: |
|
id_to_update = UUID(request.id) |
|
|
|
collections = {} |
|
for tenant, databases in self._tenants_to_databases_to_collections.items(): |
|
for database, maybe_collections in databases.items(): |
|
if id_to_update.hex in maybe_collections: |
|
collections = maybe_collections |
|
|
|
if id_to_update.hex not in collections: |
|
return UpdateCollectionResponse( |
|
status=proto.Status( |
|
code=404, reason=f"Collection {id_to_update} not found" |
|
) |
|
) |
|
else: |
|
collection = collections[id_to_update.hex] |
|
if request.HasField("name"): |
|
collection["name"] = request.name |
|
if request.HasField("dimension"): |
|
collection["dimension"] = request.dimension |
|
if request.HasField("metadata"): |
|
|
|
|
|
|
|
|
|
update_metadata = from_proto_update_metadata(request.metadata) |
|
cleaned_metadata = None |
|
if update_metadata is not None: |
|
cleaned_metadata = {} |
|
for key, value in update_metadata.items(): |
|
if value is not None: |
|
cleaned_metadata[key] = value |
|
|
|
collection["metadata"] = cleaned_metadata |
|
elif request.HasField("reset_metadata"): |
|
if request.reset_metadata: |
|
collection["metadata"] = {} |
|
|
|
return UpdateCollectionResponse(status=proto.Status(code=200)) |
|
|
|
@overrides(check_signature=False) |
|
def ResetState( |
|
self, request: Empty, context: grpc.ServicerContext |
|
) -> ResetStateResponse: |
|
self.reset_state() |
|
return ResetStateResponse(status=proto.Status(code=200)) |
|
|
|
def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None: |
|
target_metadata = cast(Dict[str, Any], target) |
|
source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source)) |
|
target_metadata.update(source_metadata) |
|
|
|
for key, value in source_metadata.items(): |
|
if value is None and key in target: |
|
del target_metadata[key] |
|
|