Spaces:
Build error
Build error
from concurrent import futures | |
from typing import Any, Dict, cast | |
from uuid import UUID | |
from overrides import overrides | |
from chromadb.api.configuration import CollectionConfigurationInternal | |
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System | |
from chromadb.proto.convert import ( | |
from_proto_metadata, | |
from_proto_update_metadata, | |
from_proto_segment, | |
from_proto_segment_scope, | |
to_proto_collection, | |
to_proto_segment, | |
) | |
import chromadb.proto.chroma_pb2 as proto | |
from chromadb.proto.coordinator_pb2 import ( | |
CreateCollectionRequest, | |
CreateCollectionResponse, | |
CreateDatabaseRequest, | |
CreateDatabaseResponse, | |
CreateSegmentRequest, | |
CreateSegmentResponse, | |
CreateTenantRequest, | |
CreateTenantResponse, | |
DeleteCollectionRequest, | |
DeleteCollectionResponse, | |
DeleteSegmentRequest, | |
DeleteSegmentResponse, | |
GetCollectionsRequest, | |
GetCollectionsResponse, | |
GetDatabaseRequest, | |
GetDatabaseResponse, | |
GetSegmentsRequest, | |
GetSegmentsResponse, | |
GetTenantRequest, | |
GetTenantResponse, | |
ResetStateResponse, | |
UpdateCollectionRequest, | |
UpdateCollectionResponse, | |
UpdateSegmentRequest, | |
UpdateSegmentResponse, | |
) | |
from chromadb.proto.coordinator_pb2_grpc import ( | |
SysDBServicer, | |
add_SysDBServicer_to_server, | |
) | |
import grpc | |
from google.protobuf.empty_pb2 import Empty | |
from chromadb.types import Collection, Metadata, Segment | |
class GrpcMockSysDB(SysDBServicer, Component): | |
"""A mock sysdb implementation that can be used for testing the grpc client. It stores | |
state in simple python data structures instead of a database.""" | |
_server: grpc.Server | |
_server_port: int | |
_segments: Dict[str, Segment] = {} | |
_tenants_to_databases_to_collections: Dict[ | |
str, Dict[str, Dict[str, Collection]] | |
] = {} | |
_tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {} | |
def __init__(self, system: System): | |
self._server_port = system.settings.require("chroma_server_grpc_port") | |
return super().__init__(system) | |
def start(self) -> None: | |
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) | |
add_SysDBServicer_to_server(self, self._server) # type: ignore | |
self._server.add_insecure_port(f"[::]:{self._server_port}") | |
self._server.start() | |
return super().start() | |
def stop(self) -> None: | |
self._server.stop(None) | |
return super().stop() | |
def reset_state(self) -> None: | |
self._segments = {} | |
self._tenants_to_databases_to_collections = {} | |
# Create defaults | |
self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {} | |
self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {} | |
self._tenants_to_database_to_id[DEFAULT_TENANT] = {} | |
self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0) | |
return super().reset_state() | |
def CreateDatabase( | |
self, request: CreateDatabaseRequest, context: grpc.ServicerContext | |
) -> CreateDatabaseResponse: | |
tenant = request.tenant | |
database = request.name | |
if tenant not in self._tenants_to_databases_to_collections: | |
return CreateDatabaseResponse( | |
status=proto.Status(code=404, reason=f"Tenant {tenant} not found") | |
) | |
if database in self._tenants_to_databases_to_collections[tenant]: | |
return CreateDatabaseResponse( | |
status=proto.Status( | |
code=409, reason=f"Database {database} already exists" | |
) | |
) | |
self._tenants_to_databases_to_collections[tenant][database] = {} | |
self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id) | |
return CreateDatabaseResponse(status=proto.Status(code=200)) | |
def GetDatabase( | |
self, request: GetDatabaseRequest, context: grpc.ServicerContext | |
) -> GetDatabaseResponse: | |
tenant = request.tenant | |
database = request.name | |
if tenant not in self._tenants_to_databases_to_collections: | |
return GetDatabaseResponse( | |
status=proto.Status(code=404, reason=f"Tenant {tenant} not found") | |
) | |
if database not in self._tenants_to_databases_to_collections[tenant]: | |
return GetDatabaseResponse( | |
status=proto.Status(code=404, reason=f"Database {database} not found") | |
) | |
id = self._tenants_to_database_to_id[tenant][database] | |
return GetDatabaseResponse( | |
status=proto.Status(code=200), | |
database=proto.Database(id=id.hex, name=database, tenant=tenant), | |
) | |
def CreateTenant( | |
self, request: CreateTenantRequest, context: grpc.ServicerContext | |
) -> CreateTenantResponse: | |
tenant = request.name | |
if tenant in self._tenants_to_databases_to_collections: | |
return CreateTenantResponse( | |
status=proto.Status(code=409, reason=f"Tenant {tenant} already exists") | |
) | |
self._tenants_to_databases_to_collections[tenant] = {} | |
self._tenants_to_database_to_id[tenant] = {} | |
return CreateTenantResponse(status=proto.Status(code=200)) | |
def GetTenant( | |
self, request: GetTenantRequest, context: grpc.ServicerContext | |
) -> GetTenantResponse: | |
tenant = request.name | |
if tenant not in self._tenants_to_databases_to_collections: | |
return GetTenantResponse( | |
status=proto.Status(code=404, reason=f"Tenant {tenant} not found") | |
) | |
return GetTenantResponse( | |
status=proto.Status(code=200), | |
tenant=proto.Tenant(name=tenant), | |
) | |
# We are forced to use check_signature=False because the generated proto code | |
# does not have type annotations for the request and response objects. | |
# TODO: investigate generating types for the request and response objects | |
def CreateSegment( | |
self, request: CreateSegmentRequest, context: grpc.ServicerContext | |
) -> CreateSegmentResponse: | |
segment = from_proto_segment(request.segment) | |
if segment["id"].hex in self._segments: | |
return CreateSegmentResponse( | |
status=proto.Status( | |
code=409, reason=f"Segment {segment['id']} already exists" | |
) | |
) | |
self._segments[segment["id"].hex] = segment | |
return CreateSegmentResponse( | |
status=proto.Status(code=200) | |
) # TODO: how are these codes used? Need to determine the standards for the code and reason. | |
def DeleteSegment( | |
self, request: DeleteSegmentRequest, context: grpc.ServicerContext | |
) -> DeleteSegmentResponse: | |
id_to_delete = request.id | |
if id_to_delete in self._segments: | |
del self._segments[id_to_delete] | |
return DeleteSegmentResponse(status=proto.Status(code=200)) | |
else: | |
return DeleteSegmentResponse( | |
status=proto.Status( | |
code=404, reason=f"Segment {id_to_delete} not found" | |
) | |
) | |
def GetSegments( | |
self, request: GetSegmentsRequest, context: grpc.ServicerContext | |
) -> GetSegmentsResponse: | |
target_id = UUID(hex=request.id) if request.HasField("id") else None | |
target_type = request.type if request.HasField("type") else None | |
target_scope = ( | |
from_proto_segment_scope(request.scope) | |
if request.HasField("scope") | |
else None | |
) | |
target_collection = UUID(hex=request.collection) | |
found_segments = [] | |
for segment in self._segments.values(): | |
if target_id and segment["id"] != target_id: | |
continue | |
if target_type and segment["type"] != target_type: | |
continue | |
if target_scope and segment["scope"] != target_scope: | |
continue | |
if target_collection and segment["collection"] != target_collection: | |
continue | |
found_segments.append(segment) | |
return GetSegmentsResponse( | |
segments=[to_proto_segment(segment) for segment in found_segments] | |
) | |
def UpdateSegment( | |
self, request: UpdateSegmentRequest, context: grpc.ServicerContext | |
) -> UpdateSegmentResponse: | |
id_to_update = UUID(request.id) | |
if id_to_update.hex not in self._segments: | |
return UpdateSegmentResponse( | |
status=proto.Status( | |
code=404, reason=f"Segment {id_to_update} not found" | |
) | |
) | |
else: | |
segment = self._segments[id_to_update.hex] | |
if request.HasField("metadata"): | |
target = cast(Dict[str, Any], segment["metadata"]) | |
if segment["metadata"] is None: | |
segment["metadata"] = {} | |
self._merge_metadata(target, request.metadata) | |
if request.HasField("reset_metadata") and request.reset_metadata: | |
segment["metadata"] = {} | |
return UpdateSegmentResponse(status=proto.Status(code=200)) | |
def CreateCollection( | |
self, request: CreateCollectionRequest, context: grpc.ServicerContext | |
) -> CreateCollectionResponse: | |
collection_name = request.name | |
tenant = request.tenant | |
database = request.database | |
if tenant not in self._tenants_to_databases_to_collections: | |
return CreateCollectionResponse( | |
status=proto.Status(code=404, reason=f"Tenant {tenant} not found") | |
) | |
if database not in self._tenants_to_databases_to_collections[tenant]: | |
return CreateCollectionResponse( | |
status=proto.Status(code=404, reason=f"Database {database} not found") | |
) | |
# Check if the collection already exists globally by id | |
for ( | |
search_tenant, | |
databases, | |
) in self._tenants_to_databases_to_collections.items(): | |
for search_database, search_collections in databases.items(): | |
if request.id in search_collections: | |
if ( | |
search_tenant != request.tenant | |
or search_database != request.database | |
): | |
return CreateCollectionResponse( | |
status=proto.Status( | |
code=409, | |
reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", | |
) | |
) | |
elif not request.get_or_create: | |
# If the id exists for this tenant and database, and we are not doing a get_or_create, then | |
# we should return a 409 | |
return CreateCollectionResponse( | |
status=proto.Status( | |
code=409, | |
reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}", | |
) | |
) | |
# Check if the collection already exists in this database by name | |
collections = self._tenants_to_databases_to_collections[tenant][database] | |
matches = [c for c in collections.values() if c["name"] == collection_name] | |
assert len(matches) <= 1 | |
if len(matches) > 0: | |
if request.get_or_create: | |
existing_collection = matches[0] | |
return CreateCollectionResponse( | |
status=proto.Status(code=200), | |
collection=to_proto_collection(existing_collection), | |
created=False, | |
) | |
return CreateCollectionResponse( | |
status=proto.Status( | |
code=409, reason=f"Collection {request.name} already exists" | |
) | |
) | |
configuration = CollectionConfigurationInternal.from_json_str( | |
request.configuration_json_str | |
) | |
id = UUID(hex=request.id) | |
new_collection = Collection( | |
id=id, | |
name=request.name, | |
configuration=configuration, | |
metadata=from_proto_metadata(request.metadata), | |
dimension=request.dimension, | |
database=database, | |
tenant=tenant, | |
version=0, | |
) | |
collections[request.id] = new_collection | |
return CreateCollectionResponse( | |
status=proto.Status(code=200), | |
collection=to_proto_collection(new_collection), | |
created=True, | |
) | |
def DeleteCollection( | |
self, request: DeleteCollectionRequest, context: grpc.ServicerContext | |
) -> DeleteCollectionResponse: | |
collection_id = request.id | |
tenant = request.tenant | |
database = request.database | |
if tenant not in self._tenants_to_databases_to_collections: | |
return DeleteCollectionResponse( | |
status=proto.Status(code=404, reason=f"Tenant {tenant} not found") | |
) | |
if database not in self._tenants_to_databases_to_collections[tenant]: | |
return DeleteCollectionResponse( | |
status=proto.Status(code=404, reason=f"Database {database} not found") | |
) | |
collections = self._tenants_to_databases_to_collections[tenant][database] | |
if collection_id in collections: | |
del collections[collection_id] | |
return DeleteCollectionResponse(status=proto.Status(code=200)) | |
else: | |
return DeleteCollectionResponse( | |
status=proto.Status( | |
code=404, reason=f"Collection {collection_id} not found" | |
) | |
) | |
def GetCollections( | |
self, request: GetCollectionsRequest, context: grpc.ServicerContext | |
) -> GetCollectionsResponse: | |
target_id = UUID(hex=request.id) if request.HasField("id") else None | |
target_name = request.name if request.HasField("name") else None | |
allCollections = {} | |
for tenant, databases in self._tenants_to_databases_to_collections.items(): | |
for database, collections in databases.items(): | |
if request.tenant != "" and tenant != request.tenant: | |
continue | |
if request.database != "" and database != request.database: | |
continue | |
allCollections.update(collections) | |
print( | |
f"Tenant: {tenant}, Database: {database}, Collections: {collections}" | |
) | |
found_collections = [] | |
for collection in allCollections.values(): | |
if target_id and collection["id"] != target_id: | |
continue | |
if target_name and collection["name"] != target_name: | |
continue | |
found_collections.append(collection) | |
return GetCollectionsResponse( | |
collections=[ | |
to_proto_collection(collection) for collection in found_collections | |
] | |
) | |
def UpdateCollection( | |
self, request: UpdateCollectionRequest, context: grpc.ServicerContext | |
) -> UpdateCollectionResponse: | |
id_to_update = UUID(request.id) | |
# Find the collection with this id | |
collections = {} | |
for tenant, databases in self._tenants_to_databases_to_collections.items(): | |
for database, maybe_collections in databases.items(): | |
if id_to_update.hex in maybe_collections: | |
collections = maybe_collections | |
if id_to_update.hex not in collections: | |
return UpdateCollectionResponse( | |
status=proto.Status( | |
code=404, reason=f"Collection {id_to_update} not found" | |
) | |
) | |
else: | |
collection = collections[id_to_update.hex] | |
if request.HasField("name"): | |
collection["name"] = request.name | |
if request.HasField("dimension"): | |
collection["dimension"] = request.dimension | |
if request.HasField("metadata"): | |
# TODO: IN SysDB SQlite we have technical debt where we | |
# replace the entire metadata dict with the new one. We should | |
# fix that by merging it. For now we just do the same thing here | |
update_metadata = from_proto_update_metadata(request.metadata) | |
cleaned_metadata = None | |
if update_metadata is not None: | |
cleaned_metadata = {} | |
for key, value in update_metadata.items(): | |
if value is not None: | |
cleaned_metadata[key] = value | |
collection["metadata"] = cleaned_metadata | |
elif request.HasField("reset_metadata"): | |
if request.reset_metadata: | |
collection["metadata"] = {} | |
return UpdateCollectionResponse(status=proto.Status(code=200)) | |
def ResetState( | |
self, request: Empty, context: grpc.ServicerContext | |
) -> ResetStateResponse: | |
self.reset_state() | |
return ResetStateResponse(status=proto.Status(code=200)) | |
def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None: | |
target_metadata = cast(Dict[str, Any], target) | |
source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source)) | |
target_metadata.update(source_metadata) | |
# If a key has a None value, remove it from the metadata | |
for key, value in source_metadata.items(): | |
if value is None and key in target: | |
del target_metadata[key] | |