himanshud2611's picture
Upload folder using huggingface_hub
60e3a80 verified
from concurrent import futures
from typing import Any, Dict, cast
from uuid import UUID
from overrides import overrides
from chromadb.api.configuration import CollectionConfigurationInternal
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System
from chromadb.proto.convert import (
from_proto_metadata,
from_proto_update_metadata,
from_proto_segment,
from_proto_segment_scope,
to_proto_collection,
to_proto_segment,
)
import chromadb.proto.chroma_pb2 as proto
from chromadb.proto.coordinator_pb2 import (
CreateCollectionRequest,
CreateCollectionResponse,
CreateDatabaseRequest,
CreateDatabaseResponse,
CreateSegmentRequest,
CreateSegmentResponse,
CreateTenantRequest,
CreateTenantResponse,
DeleteCollectionRequest,
DeleteCollectionResponse,
DeleteSegmentRequest,
DeleteSegmentResponse,
GetCollectionsRequest,
GetCollectionsResponse,
GetDatabaseRequest,
GetDatabaseResponse,
GetSegmentsRequest,
GetSegmentsResponse,
GetTenantRequest,
GetTenantResponse,
ResetStateResponse,
UpdateCollectionRequest,
UpdateCollectionResponse,
UpdateSegmentRequest,
UpdateSegmentResponse,
)
from chromadb.proto.coordinator_pb2_grpc import (
SysDBServicer,
add_SysDBServicer_to_server,
)
import grpc
from google.protobuf.empty_pb2 import Empty
from chromadb.types import Collection, Metadata, Segment
class GrpcMockSysDB(SysDBServicer, Component):
"""A mock sysdb implementation that can be used for testing the grpc client. It stores
state in simple python data structures instead of a database."""
_server: grpc.Server
_server_port: int
_segments: Dict[str, Segment] = {}
_tenants_to_databases_to_collections: Dict[
str, Dict[str, Dict[str, Collection]]
] = {}
_tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {}
def __init__(self, system: System):
self._server_port = system.settings.require("chroma_server_grpc_port")
return super().__init__(system)
@overrides
def start(self) -> None:
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
add_SysDBServicer_to_server(self, self._server) # type: ignore
self._server.add_insecure_port(f"[::]:{self._server_port}")
self._server.start()
return super().start()
@overrides
def stop(self) -> None:
self._server.stop(None)
return super().stop()
@overrides
def reset_state(self) -> None:
self._segments = {}
self._tenants_to_databases_to_collections = {}
# Create defaults
self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {}
self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {}
self._tenants_to_database_to_id[DEFAULT_TENANT] = {}
self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0)
return super().reset_state()
@overrides(check_signature=False)
def CreateDatabase(
self, request: CreateDatabaseRequest, context: grpc.ServicerContext
) -> CreateDatabaseResponse:
tenant = request.tenant
database = request.name
if tenant not in self._tenants_to_databases_to_collections:
return CreateDatabaseResponse(
status=proto.Status(code=404, reason=f"Tenant {tenant} not found")
)
if database in self._tenants_to_databases_to_collections[tenant]:
return CreateDatabaseResponse(
status=proto.Status(
code=409, reason=f"Database {database} already exists"
)
)
self._tenants_to_databases_to_collections[tenant][database] = {}
self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id)
return CreateDatabaseResponse(status=proto.Status(code=200))
@overrides(check_signature=False)
def GetDatabase(
self, request: GetDatabaseRequest, context: grpc.ServicerContext
) -> GetDatabaseResponse:
tenant = request.tenant
database = request.name
if tenant not in self._tenants_to_databases_to_collections:
return GetDatabaseResponse(
status=proto.Status(code=404, reason=f"Tenant {tenant} not found")
)
if database not in self._tenants_to_databases_to_collections[tenant]:
return GetDatabaseResponse(
status=proto.Status(code=404, reason=f"Database {database} not found")
)
id = self._tenants_to_database_to_id[tenant][database]
return GetDatabaseResponse(
status=proto.Status(code=200),
database=proto.Database(id=id.hex, name=database, tenant=tenant),
)
@overrides(check_signature=False)
def CreateTenant(
self, request: CreateTenantRequest, context: grpc.ServicerContext
) -> CreateTenantResponse:
tenant = request.name
if tenant in self._tenants_to_databases_to_collections:
return CreateTenantResponse(
status=proto.Status(code=409, reason=f"Tenant {tenant} already exists")
)
self._tenants_to_databases_to_collections[tenant] = {}
self._tenants_to_database_to_id[tenant] = {}
return CreateTenantResponse(status=proto.Status(code=200))
@overrides(check_signature=False)
def GetTenant(
self, request: GetTenantRequest, context: grpc.ServicerContext
) -> GetTenantResponse:
tenant = request.name
if tenant not in self._tenants_to_databases_to_collections:
return GetTenantResponse(
status=proto.Status(code=404, reason=f"Tenant {tenant} not found")
)
return GetTenantResponse(
status=proto.Status(code=200),
tenant=proto.Tenant(name=tenant),
)
# We are forced to use check_signature=False because the generated proto code
# does not have type annotations for the request and response objects.
# TODO: investigate generating types for the request and response objects
@overrides(check_signature=False)
def CreateSegment(
self, request: CreateSegmentRequest, context: grpc.ServicerContext
) -> CreateSegmentResponse:
segment = from_proto_segment(request.segment)
if segment["id"].hex in self._segments:
return CreateSegmentResponse(
status=proto.Status(
code=409, reason=f"Segment {segment['id']} already exists"
)
)
self._segments[segment["id"].hex] = segment
return CreateSegmentResponse(
status=proto.Status(code=200)
) # TODO: how are these codes used? Need to determine the standards for the code and reason.
@overrides(check_signature=False)
def DeleteSegment(
self, request: DeleteSegmentRequest, context: grpc.ServicerContext
) -> DeleteSegmentResponse:
id_to_delete = request.id
if id_to_delete in self._segments:
del self._segments[id_to_delete]
return DeleteSegmentResponse(status=proto.Status(code=200))
else:
return DeleteSegmentResponse(
status=proto.Status(
code=404, reason=f"Segment {id_to_delete} not found"
)
)
@overrides(check_signature=False)
def GetSegments(
self, request: GetSegmentsRequest, context: grpc.ServicerContext
) -> GetSegmentsResponse:
target_id = UUID(hex=request.id) if request.HasField("id") else None
target_type = request.type if request.HasField("type") else None
target_scope = (
from_proto_segment_scope(request.scope)
if request.HasField("scope")
else None
)
target_collection = UUID(hex=request.collection)
found_segments = []
for segment in self._segments.values():
if target_id and segment["id"] != target_id:
continue
if target_type and segment["type"] != target_type:
continue
if target_scope and segment["scope"] != target_scope:
continue
if target_collection and segment["collection"] != target_collection:
continue
found_segments.append(segment)
return GetSegmentsResponse(
segments=[to_proto_segment(segment) for segment in found_segments]
)
@overrides(check_signature=False)
def UpdateSegment(
self, request: UpdateSegmentRequest, context: grpc.ServicerContext
) -> UpdateSegmentResponse:
id_to_update = UUID(request.id)
if id_to_update.hex not in self._segments:
return UpdateSegmentResponse(
status=proto.Status(
code=404, reason=f"Segment {id_to_update} not found"
)
)
else:
segment = self._segments[id_to_update.hex]
if request.HasField("metadata"):
target = cast(Dict[str, Any], segment["metadata"])
if segment["metadata"] is None:
segment["metadata"] = {}
self._merge_metadata(target, request.metadata)
if request.HasField("reset_metadata") and request.reset_metadata:
segment["metadata"] = {}
return UpdateSegmentResponse(status=proto.Status(code=200))
@overrides(check_signature=False)
def CreateCollection(
self, request: CreateCollectionRequest, context: grpc.ServicerContext
) -> CreateCollectionResponse:
collection_name = request.name
tenant = request.tenant
database = request.database
if tenant not in self._tenants_to_databases_to_collections:
return CreateCollectionResponse(
status=proto.Status(code=404, reason=f"Tenant {tenant} not found")
)
if database not in self._tenants_to_databases_to_collections[tenant]:
return CreateCollectionResponse(
status=proto.Status(code=404, reason=f"Database {database} not found")
)
# Check if the collection already exists globally by id
for (
search_tenant,
databases,
) in self._tenants_to_databases_to_collections.items():
for search_database, search_collections in databases.items():
if request.id in search_collections:
if (
search_tenant != request.tenant
or search_database != request.database
):
return CreateCollectionResponse(
status=proto.Status(
code=409,
reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
)
)
elif not request.get_or_create:
# If the id exists for this tenant and database, and we are not doing a get_or_create, then
# we should return a 409
return CreateCollectionResponse(
status=proto.Status(
code=409,
reason=f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
)
)
# Check if the collection already exists in this database by name
collections = self._tenants_to_databases_to_collections[tenant][database]
matches = [c for c in collections.values() if c["name"] == collection_name]
assert len(matches) <= 1
if len(matches) > 0:
if request.get_or_create:
existing_collection = matches[0]
return CreateCollectionResponse(
status=proto.Status(code=200),
collection=to_proto_collection(existing_collection),
created=False,
)
return CreateCollectionResponse(
status=proto.Status(
code=409, reason=f"Collection {request.name} already exists"
)
)
configuration = CollectionConfigurationInternal.from_json_str(
request.configuration_json_str
)
id = UUID(hex=request.id)
new_collection = Collection(
id=id,
name=request.name,
configuration=configuration,
metadata=from_proto_metadata(request.metadata),
dimension=request.dimension,
database=database,
tenant=tenant,
version=0,
)
collections[request.id] = new_collection
return CreateCollectionResponse(
status=proto.Status(code=200),
collection=to_proto_collection(new_collection),
created=True,
)
@overrides(check_signature=False)
def DeleteCollection(
self, request: DeleteCollectionRequest, context: grpc.ServicerContext
) -> DeleteCollectionResponse:
collection_id = request.id
tenant = request.tenant
database = request.database
if tenant not in self._tenants_to_databases_to_collections:
return DeleteCollectionResponse(
status=proto.Status(code=404, reason=f"Tenant {tenant} not found")
)
if database not in self._tenants_to_databases_to_collections[tenant]:
return DeleteCollectionResponse(
status=proto.Status(code=404, reason=f"Database {database} not found")
)
collections = self._tenants_to_databases_to_collections[tenant][database]
if collection_id in collections:
del collections[collection_id]
return DeleteCollectionResponse(status=proto.Status(code=200))
else:
return DeleteCollectionResponse(
status=proto.Status(
code=404, reason=f"Collection {collection_id} not found"
)
)
@overrides(check_signature=False)
def GetCollections(
self, request: GetCollectionsRequest, context: grpc.ServicerContext
) -> GetCollectionsResponse:
target_id = UUID(hex=request.id) if request.HasField("id") else None
target_name = request.name if request.HasField("name") else None
allCollections = {}
for tenant, databases in self._tenants_to_databases_to_collections.items():
for database, collections in databases.items():
if request.tenant != "" and tenant != request.tenant:
continue
if request.database != "" and database != request.database:
continue
allCollections.update(collections)
print(
f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
)
found_collections = []
for collection in allCollections.values():
if target_id and collection["id"] != target_id:
continue
if target_name and collection["name"] != target_name:
continue
found_collections.append(collection)
return GetCollectionsResponse(
collections=[
to_proto_collection(collection) for collection in found_collections
]
)
@overrides(check_signature=False)
def UpdateCollection(
self, request: UpdateCollectionRequest, context: grpc.ServicerContext
) -> UpdateCollectionResponse:
id_to_update = UUID(request.id)
# Find the collection with this id
collections = {}
for tenant, databases in self._tenants_to_databases_to_collections.items():
for database, maybe_collections in databases.items():
if id_to_update.hex in maybe_collections:
collections = maybe_collections
if id_to_update.hex not in collections:
return UpdateCollectionResponse(
status=proto.Status(
code=404, reason=f"Collection {id_to_update} not found"
)
)
else:
collection = collections[id_to_update.hex]
if request.HasField("name"):
collection["name"] = request.name
if request.HasField("dimension"):
collection["dimension"] = request.dimension
if request.HasField("metadata"):
# TODO: IN SysDB SQlite we have technical debt where we
# replace the entire metadata dict with the new one. We should
# fix that by merging it. For now we just do the same thing here
update_metadata = from_proto_update_metadata(request.metadata)
cleaned_metadata = None
if update_metadata is not None:
cleaned_metadata = {}
for key, value in update_metadata.items():
if value is not None:
cleaned_metadata[key] = value
collection["metadata"] = cleaned_metadata
elif request.HasField("reset_metadata"):
if request.reset_metadata:
collection["metadata"] = {}
return UpdateCollectionResponse(status=proto.Status(code=200))
@overrides(check_signature=False)
def ResetState(
self, request: Empty, context: grpc.ServicerContext
) -> ResetStateResponse:
self.reset_state()
return ResetStateResponse(status=proto.Status(code=200))
def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None:
target_metadata = cast(Dict[str, Any], target)
source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source))
target_metadata.update(source_metadata)
# If a key has a None value, remove it from the metadata
for key, value in source_metadata.items():
if value is None and key in target:
del target_metadata[key]