import os import shutil import tempfile import pytest from typing import Generator, List, Callable, Dict, Union from chromadb.db.impl.grpc.client import GrpcSysDB from chromadb.db.impl.grpc.server import GrpcMockSysDB from chromadb.errors import NotFoundError from chromadb.test.conftest import find_free_port from chromadb.types import Collection, Segment, SegmentScope from chromadb.db.impl.sqlite import SqliteDB from chromadb.config import ( DEFAULT_DATABASE, DEFAULT_TENANT, System, Settings, ) from chromadb.db.system import SysDB from chromadb.db.base import UniqueConstraintError from pytest import FixtureRequest import uuid from chromadb.api.configuration import CollectionConfigurationInternal TENANT = "default" NAMESPACE = "default" # These are the sample collections that are used in the tests below. Tests can override # the fields as needed. sample_collections: List[Collection] = [ Collection( id=uuid.UUID(int=1), name="test_collection_1", configuration=CollectionConfigurationInternal(), metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, dimension=128, database=DEFAULT_DATABASE, tenant=DEFAULT_TENANT, version=0, ), Collection( id=uuid.UUID(int=2), name="test_collection_2", configuration=CollectionConfigurationInternal(), metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, dimension=None, database=DEFAULT_DATABASE, tenant=DEFAULT_TENANT, version=0, ), Collection( id=uuid.UUID(int=3), name="test_collection_3", configuration=CollectionConfigurationInternal(), metadata={"test_str": "str3", "test_int": 3, "test_float": 3.3}, dimension=None, database=DEFAULT_DATABASE, tenant=DEFAULT_TENANT, version=0, ), ] def sqlite() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB""" db = SqliteDB( System( Settings( allow_reset=True, ) ) ) db.start() yield db db.stop() def sqlite_persistent() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB""" save_path = tempfile.mkdtemp() db = SqliteDB( System( Settings( allow_reset=True, is_persistent=True, persist_directory=save_path, ) ) ) db.start() yield db db.stop() if os.path.exists(save_path): shutil.rmtree(save_path) def grpc_with_mock_server() -> Generator[SysDB, None, None]: """Fixture generator for sqlite DB that creates a mock grpc sysdb server and a grpc client that connects to it.""" port = find_free_port() system = System( Settings( allow_reset=True, chroma_server_grpc_port=port, ) ) system.instance(GrpcMockSysDB) client = system.instance(GrpcSysDB) system.start() client.reset_and_wait_for_ready() yield client system.stop() def grpc_with_real_server() -> Generator[SysDB, None, None]: system = System( Settings( allow_reset=True, chroma_server_grpc_port=50051, ) ) client = system.instance(GrpcSysDB) system.start() client.reset_and_wait_for_ready() yield client def db_fixtures() -> List[Callable[[], Generator[SysDB, None, None]]]: if "CHROMA_CLUSTER_TEST_ONLY" in os.environ: return [grpc_with_real_server] else: return [sqlite, sqlite_persistent, grpc_with_mock_server] @pytest.fixture(scope="module", params=db_fixtures()) def sysdb(request: FixtureRequest) -> Generator[SysDB, None, None]: yield next(request.param()) # region Collection tests def test_create_get_delete_collections(sysdb: SysDB) -> None: sysdb.reset_state() for collection in sample_collections: sysdb.create_collection( id=collection.id, name=collection.name, configuration=collection.get_configuration(), metadata=collection["metadata"], dimension=collection["dimension"], ) collection["database"] = DEFAULT_DATABASE collection["tenant"] = DEFAULT_TENANT results = sysdb.get_collections() results = sorted(results, key=lambda c: c.name) assert sorted(results, key=lambda c: c.name) == sample_collections # Duplicate create fails with pytest.raises(UniqueConstraintError): sysdb.create_collection( name=sample_collections[0].name, id=sample_collections[0].id, configuration=sample_collections[0].get_configuration(), ) # Find by name for collection in sample_collections: result = sysdb.get_collections(name=collection["name"]) assert result == [collection] # Find by id for collection in sample_collections: result = sysdb.get_collections(id=collection["id"]) assert result == [collection] # Delete c1 = sample_collections[0] sysdb.delete_collection(c1.id) results = sysdb.get_collections() assert c1 not in results assert len(results) == len(sample_collections) - 1 assert sorted(results, key=lambda c: c.name) == sample_collections[1:] by_id_result = sysdb.get_collections(id=c1["id"]) assert by_id_result == [] # Duplicate delete throws an exception with pytest.raises(NotFoundError): sysdb.delete_collection(c1.id) def test_update_collections(sysdb: SysDB) -> None: coll = Collection( name=sample_collections[0].name, id=sample_collections[0].id, configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database=DEFAULT_DATABASE, tenant=DEFAULT_TENANT, version=0, ) sysdb.reset_state() sysdb.create_collection( id=coll.id, name=coll.name, configuration=coll.get_configuration(), metadata=coll["metadata"], dimension=coll["dimension"], ) # Update name coll["name"] = "new_name" sysdb.update_collection(coll.id, name=coll.name) result = sysdb.get_collections(name=coll.name) assert result == [coll] # Update dimension coll["dimension"] = 128 sysdb.update_collection(coll.id, dimension=coll.dimension) result = sysdb.get_collections(id=coll["id"]) assert result == [coll] # Reset the metadata coll["metadata"] = {"test_str2": "str2"} sysdb.update_collection(coll.id, metadata=coll["metadata"]) result = sysdb.get_collections(id=coll["id"]) assert result == [coll] # Delete all metadata keys coll["metadata"] = None sysdb.update_collection(coll.id, metadata=None) result = sysdb.get_collections(id=coll["id"]) assert result == [coll] def test_get_or_create_collection(sysdb: SysDB) -> None: sysdb.reset_state() # get_or_create = True returns existing collection collection = sample_collections[0] sysdb.create_collection( id=collection.id, name=collection.name, configuration=collection.get_configuration(), metadata=collection["metadata"], dimension=collection["dimension"], ) result, created = sysdb.create_collection( name=collection.name, id=uuid.uuid4(), configuration=CollectionConfigurationInternal(), get_or_create=True, metadata=collection["metadata"], ) assert result == collection # Only one collection with the same name exists get_result = sysdb.get_collections(name=collection["name"]) assert get_result == [collection] # get_or_create = True creates new collection result, created = sysdb.create_collection( name=sample_collections[1].name, id=sample_collections[1].id, configuration=sample_collections[1].get_configuration(), get_or_create=True, metadata=sample_collections[1]["metadata"], ) assert result == sample_collections[1] # get_or_create = False creates new collection result, created = sysdb.create_collection( name=sample_collections[2].name, id=sample_collections[2].id, configuration=sample_collections[2].get_configuration(), get_or_create=False, metadata=sample_collections[2]["metadata"], ) assert result == sample_collections[2] # get_or_create = False fails if collection already exists with pytest.raises(UniqueConstraintError): sysdb.create_collection( name=sample_collections[2].name, id=sample_collections[2].id, configuration=sample_collections[2].get_configuration(), get_or_create=False, metadata=collection["metadata"], ) # get_or_create = True does not overwrite metadata overlayed_metadata: Dict[str, Union[str, int, float]] = { "test_new_str": "new_str", "test_int": 1, } result, created = sysdb.create_collection( name=sample_collections[2].name, id=sample_collections[2].id, configuration=sample_collections[2].get_configuration(), get_or_create=True, metadata=overlayed_metadata, ) assert result["metadata"] != overlayed_metadata assert result["metadata"] == sample_collections[2]["metadata"] # get_or_create = True with None metadata does not overwrite metadata result, created = sysdb.create_collection( name=sample_collections[2].name, id=sample_collections[2].id, configuration=sample_collections[2].get_configuration(), get_or_create=True, metadata=None, ) assert result["metadata"] == sample_collections[2]["metadata"] def test_create_get_delete_database_and_collection(sysdb: SysDB) -> None: sysdb.reset_state() # Create a new database sysdb.create_database(id=uuid.uuid4(), name="new_database") # Create a new collection in the new database sysdb.create_collection( id=sample_collections[0].id, name=sample_collections[0].name, configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database="new_database", ) # Create a new collection with the same id but different name in the new database # and expect an error with pytest.raises(UniqueConstraintError): sysdb.create_collection( id=sample_collections[0].id, name="new_name", configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database="new_database", get_or_create=False, ) # Create a new collection in the default database sysdb.create_collection( id=sample_collections[1].id, name=sample_collections[1].name, configuration=sample_collections[1].get_configuration(), metadata=sample_collections[1]["metadata"], dimension=sample_collections[1]["dimension"], ) # Check that the new database and collections exist result = sysdb.get_collections( name=sample_collections[0]["name"], database="new_database" ) assert len(result) == 1 sample_collections[0]["database"] = "new_database" assert result[0] == sample_collections[0] # Check that the collection in the default database exists result = sysdb.get_collections(name=sample_collections[1]["name"]) assert len(result) == 1 assert result[0] == sample_collections[1] # Get for a database that doesn't exist with a name that exists in the new database and expect no results assert ( len( sysdb.get_collections( name=sample_collections[0]["name"], database="fake_db" ) ) == 0 ) # Delete the collection in the new database sysdb.delete_collection(id=sample_collections[0].id, database="new_database") # Check that the collection in the new database was deleted result = sysdb.get_collections(database="new_database") assert len(result) == 0 # Check that the collection in the default database still exists result = sysdb.get_collections(name=sample_collections[1].name) assert len(result) == 1 assert result[0] == sample_collections[1] # Delete the deleted collection in the default database and expect an error with pytest.raises(NotFoundError): sysdb.delete_collection(id=sample_collections[0].id) # Delete the existing collection in the new database and expect an error with pytest.raises(NotFoundError): sysdb.delete_collection(id=sample_collections[1].id, database="new_database") def test_create_update_with_database(sysdb: SysDB) -> None: sysdb.reset_state() # Create a new database sysdb.create_database(id=uuid.uuid4(), name="new_database") # Create a new collection in the new database sysdb.create_collection( id=sample_collections[0].id, name=sample_collections[0].name, configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database="new_database", ) # Create a new collection in the default database sysdb.create_collection( id=sample_collections[1].id, name=sample_collections[1].name, configuration=sample_collections[1].get_configuration(), metadata=sample_collections[1]["metadata"], dimension=sample_collections[1]["dimension"], ) # Update the collection in the default database sysdb.update_collection( id=sample_collections[1].id, name="new_name_1", ) # Check that the collection in the default database was updated result = sysdb.get_collections(id=sample_collections[1]["id"]) assert len(result) == 1 assert result[0]["name"] == "new_name_1" # Update the collection in the new database sysdb.update_collection( id=sample_collections[0].id, name="new_name_0", ) # Check that the collection in the new database was updated result = sysdb.get_collections( id=sample_collections[0]["id"], database="new_database" ) assert len(result) == 1 assert result[0]["name"] == "new_name_0" # Try to create the collection in the default database in the new database and expect an error with pytest.raises(UniqueConstraintError): sysdb.create_collection( id=sample_collections[1].id, name=sample_collections[1].name, configuration=sample_collections[1].get_configuration(), metadata=sample_collections[1]["metadata"], dimension=sample_collections[1]["dimension"], database="new_database", ) def test_get_multiple_with_database(sysdb: SysDB) -> None: sysdb.reset_state() # Create a new database sysdb.create_database(id=uuid.uuid4(), name="new_database") # Create sample collections in the new database for collection in sample_collections: sysdb.create_collection( id=collection.id, name=collection.name, configuration=collection.get_configuration(), metadata=collection["metadata"], dimension=collection["dimension"], database="new_database", ) collection["database"] = "new_database" # Get all collections in the new database result = sysdb.get_collections(database="new_database") assert len(result) == len(sample_collections) assert sorted(result, key=lambda c: c.name) == sample_collections # Get all collections in the default database result = sysdb.get_collections() assert len(result) == 0 def test_create_database_with_tenants(sysdb: SysDB) -> None: sysdb.reset_state() # Create a new tenant sysdb.create_tenant(name="tenant1") # Create tenant that already exits and expect an error with pytest.raises(UniqueConstraintError): sysdb.create_tenant(name="tenant1") with pytest.raises(UniqueConstraintError): sysdb.create_tenant(name=DEFAULT_TENANT) # Create a new database within this tenant and also in the default tenant sysdb.create_database(id=uuid.uuid4(), name="new_database", tenant="tenant1") sysdb.create_database(id=uuid.uuid4(), name="new_database") # Create a new collection in the new tenant sysdb.create_collection( id=sample_collections[0].id, name=sample_collections[0].name, configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database="new_database", tenant="tenant1", ) sample_collections[0]["tenant"] = "tenant1" sample_collections[0]["database"] = "new_database" # Create a new collection in the default tenant sysdb.create_collection( id=sample_collections[1].id, name=sample_collections[1].name, configuration=sample_collections[1].get_configuration(), metadata=sample_collections[1]["metadata"], dimension=sample_collections[1]["dimension"], database="new_database", ) sample_collections[1]["database"] = "new_database" # Check that both tenants have the correct collections result = sysdb.get_collections(database="new_database", tenant="tenant1") assert len(result) == 1 assert result[0] == sample_collections[0] result = sysdb.get_collections(database="new_database") assert len(result) == 1 assert result[0] == sample_collections[1] # Creating a collection id that already exists in a tenant that does not have it # should error with pytest.raises(UniqueConstraintError): sysdb.create_collection( id=sample_collections[0].id, name=sample_collections[0].name, configuration=sample_collections[0].get_configuration(), metadata=sample_collections[0]["metadata"], dimension=sample_collections[0]["dimension"], database="new_database", ) with pytest.raises(UniqueConstraintError): sysdb.create_collection( id=sample_collections[1].id, name=sample_collections[1].name, configuration=sample_collections[1].get_configuration(), metadata=sample_collections[1]["metadata"], dimension=sample_collections[1]["dimension"], database="new_database", tenant="tenant1", ) # A new tenant DOES NOT have a default database. This does not error, instead 0 # results are returned result = sysdb.get_collections(database=DEFAULT_DATABASE, tenant="tenant1") assert len(result) == 0 def test_get_database_with_tenants(sysdb: SysDB) -> None: sysdb.reset_state() # Create a new tenant sysdb.create_tenant(name="tenant1") # Get the tenant and check that it exists result = sysdb.get_tenant(name="tenant1") assert result["name"] == "tenant1" # Get a tenant that does not exist and expect an error with pytest.raises(NotFoundError): sysdb.get_tenant(name="tenant2") # Create a new database within this tenant sysdb.create_database(id=uuid.uuid4(), name="new_database", tenant="tenant1") # Get the database and check that it exists result = sysdb.get_database(name="new_database", tenant="tenant1") assert result["name"] == "new_database" assert result["tenant"] == "tenant1" # Get a database that does not exist in a tenant that does exist and expect an error with pytest.raises(NotFoundError): sysdb.get_database(name="new_database1", tenant="tenant1") # Get a database that does not exist in a tenant that does not exist and expect an # error with pytest.raises(NotFoundError): sysdb.get_database(name="new_database1", tenant="tenant2") # endregion # region Segment tests sample_segments = [ Segment( id=uuid.UUID("00000000-d7d7-413b-92e1-731098a6e492"), type="test_type_a", scope=SegmentScope.VECTOR, collection=sample_collections[0].id, metadata={"test_str": "str1", "test_int": 1, "test_float": 1.3}, ), Segment( id=uuid.UUID("11111111-d7d7-413b-92e1-731098a6e492"), type="test_type_b", scope=SegmentScope.VECTOR, collection=sample_collections[1].id, metadata={"test_str": "str2", "test_int": 2, "test_float": 2.3}, ), ] def test_create_get_delete_segments(sysdb: SysDB) -> None: sysdb.reset_state() for collection in sample_collections: sysdb.create_collection( id=collection.id, name=collection.name, configuration=collection.get_configuration(), metadata=collection["metadata"], dimension=collection["dimension"], ) for segment in sample_segments: sysdb.create_segment(segment) results: List[Segment] = [] for collection in sample_collections: results.extend(sysdb.get_segments(collection=collection.id)) results = sorted(results, key=lambda c: c["id"]) assert results == sample_segments # Duplicate create fails with pytest.raises(UniqueConstraintError): sysdb.create_segment(sample_segments[0]) # Find by id for segment in sample_segments: result = sysdb.get_segments(id=segment["id"], collection=segment["collection"]) assert result == [segment] # Find by type result = sysdb.get_segments(type="test_type_a", collection=sample_collections[0].id) assert result == sample_segments[:1] result = sysdb.get_segments(type="test_type_b", collection=sample_collections[1].id) assert sorted(result, key=lambda c: c["id"]) == sample_segments[1:] # Find by collection ID result = sysdb.get_segments(collection=sample_collections[0].id) assert result == sample_segments[:1] # Find by type and collection ID (positive case) result = sysdb.get_segments(type="test_type_a", collection=sample_collections[0].id) assert result == sample_segments[:1] # Find by type and collection ID (negative case) result = sysdb.get_segments(type="test_type_b", collection=sample_collections[0].id) assert result == [] # Delete s1 = sample_segments[0] sysdb.delete_segment(s1["collection"], s1["id"]) results = [] for collection in sample_collections: results.extend(sysdb.get_segments(collection=collection.id)) assert s1 not in results assert len(results) == len(sample_segments) - 1 assert sorted(results, key=lambda c: c["id"]) == sample_segments[1:] # Duplicate delete throws an exception with pytest.raises(NotFoundError): sysdb.delete_segment(s1["collection"], s1["id"]) def test_update_segment(sysdb: SysDB) -> None: metadata: Dict[str, Union[str, int, float]] = { "test_str": "str1", "test_int": 1, "test_float": 1.3, } segment = Segment( id=uuid.uuid4(), type="test_type_a", scope=SegmentScope.VECTOR, collection=sample_collections[0].id, metadata=metadata, ) sysdb.reset_state() for c in sample_collections: sysdb.create_collection( id=c.id, name=c.name, configuration=c.get_configuration(), metadata=c["metadata"], dimension=c["dimension"], ) sysdb.create_segment(segment) # TODO: revisit update segment - push collection id result = sysdb.get_segments(id=segment["id"], collection=segment["collection"]) result[0]["collection"] = segment["collection"] assert result == [segment] result = sysdb.get_segments(id=segment["id"], collection=segment["collection"]) result[0]["collection"] = segment["collection"] assert result == [segment] # Add a new metadata key metadata["test_str2"] = "str2" sysdb.update_segment( segment["collection"], segment["id"], metadata={"test_str2": "str2"} ) result = sysdb.get_segments(id=segment["id"], collection=segment["collection"]) result[0]["collection"] = segment["collection"] assert result == [segment] # Update a metadata key metadata["test_str"] = "str3" sysdb.update_segment( segment["collection"], segment["id"], metadata={"test_str": "str3"} ) result = sysdb.get_segments(id=segment["id"], collection=segment["collection"]) result[0]["collection"] = segment["collection"] assert result == [segment] # Delete a metadata key del metadata["test_str"] sysdb.update_segment( segment["collection"], segment["id"], metadata={"test_str": None} ) result = sysdb.get_segments(id=segment["id"], collection=segment["collection"]) result[0]["collection"] = segment["collection"] assert result == [segment] # Delete all metadata keys segment["metadata"] = None sysdb.update_segment(segment["collection"], segment["id"], metadata=None) result = sysdb.get_segments(id=segment["id"], collection=segment["collection"]) result[0]["collection"] = segment["collection"] assert result == [segment] # endregion