import array from typing import Dict, Any, Callable import pytest from chromadb.config import System, Settings from chromadb.logservice.logservice import LogService from chromadb.test.conftest import skip_if_not_cluster from chromadb.test.test_api import records # type: ignore from chromadb.api.models.Collection import Collection import time batch_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "ids": ["https://example.com/1", "https://example.com/2"], } metadata_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "ids": ["id1", "id2"], "metadatas": [ {"int_value": 1, "string_value": "one", "float_value": 1.001}, {"int_value": 2}, ], } contains_records = { "embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], "documents": ["this is doc1 and it's great!", "doc2 is also great!"], "ids": ["id1", "id2"], "metadatas": [ {"int_value": 1, "string_value": "one", "float_value": 1.001}, {"int_value": 2, "float_value": 2.002, "string_value": "two"}, ], } # Sleep to allow memberlist to initialize after reset() MEMBERLIST_DELAY_SLEEP_TIME = 5 def verify_records( logservice: LogService, collection: Collection, test_records_map: Dict[str, Dict[str, Any]], test_func: Callable, # type: ignore operation: int, ) -> None: start_offset = 1 for batch_records in test_records_map.values(): test_func(**batch_records) pushed_records = logservice.pull_logs(collection.id, start_offset, 100) assert len(pushed_records) == len(batch_records["ids"]) for i, record in enumerate(pushed_records): assert record.record.id == batch_records["ids"][i] assert record.record.operation == operation embedding = array.array("f", batch_records["embeddings"][i]).tobytes() assert record.record.vector.vector == embedding metadata_count = 0 if "metadatas" in batch_records: metadata_count += len(batch_records["metadatas"][i]) for key, value in batch_records["metadatas"][i].items(): if isinstance(value, int): assert record.record.metadata.metadata[key].int_value == value elif isinstance(value, float): assert record.record.metadata.metadata[key].float_value == value elif isinstance(value, str): assert ( record.record.metadata.metadata[key].string_value == value ) else: assert False if "documents" in batch_records: metadata_count += 1 assert ( record.record.metadata.metadata["chroma:document"].string_value == batch_records["documents"][i] ) assert len(record.record.metadata.metadata) == metadata_count start_offset += len(pushed_records) @skip_if_not_cluster() def test_add(client): # type: ignore system = System(Settings(allow_reset=True)) logservice = system.instance(LogService) system.start() client.reset() time.sleep(MEMBERLIST_DELAY_SLEEP_TIME) test_records_map = { "batch_records": batch_records, "metadata_records": metadata_records, "contains_records": contains_records, } collection = client.create_collection("testadd") verify_records(logservice, collection, test_records_map, collection.add, 0) @skip_if_not_cluster() def test_update(client): # type: ignore system = System(Settings(allow_reset=True)) logservice = system.instance(LogService) system.start() client.reset() time.sleep(MEMBERLIST_DELAY_SLEEP_TIME) test_records_map = { "updated_records": { "ids": [records["ids"][0]], "embeddings": [[0.1, 0.2, 0.3]], "metadatas": [{"foo": "bar"}], }, } collection = client.create_collection("testupdate") verify_records(logservice, collection, test_records_map, collection.update, 1) @skip_if_not_cluster() def test_delete(client): # type: ignore system = System(Settings(allow_reset=True)) logservice = system.instance(LogService) system.start() client.reset() time.sleep(MEMBERLIST_DELAY_SLEEP_TIME) collection = client.create_collection("testdelete") # push 2 records collection.add(**contains_records) pushed_records = logservice.pull_logs(collection.id, 1, 100) assert len(pushed_records) == 2 # TODO: These tests should be enabled when the distributed system has metadata segments @pytest.mark.xfail @skip_if_not_cluster() def test_delete_filter(client): # type: ignore system = System(Settings(allow_reset=True)) logservice = system.instance(LogService) system.start() client.reset() time.sleep(MEMBERLIST_DELAY_SLEEP_TIME) collection = client.create_collection("testdelete_filter") # delete by where collection.delete(where_document={"$contains": "doc1"}) collection.delete(where_document={"$contains": "bad"}) collection.delete(where_document={"$contains": "great"}) pushed_records = logservice.pull_logs(collection.id, 3, 100) assert len(pushed_records) == 0 # delete by ids collection.delete(ids=["id1", "id2"]) pushed_records = logservice.pull_logs(collection.id, 3, 100) assert len(pushed_records) == 2 for record in pushed_records: assert record.record.operation == 3 assert record.record.id in ["id1", "id2"] @skip_if_not_cluster() def test_upsert(client): # type: ignore system = System(Settings(allow_reset=True)) logservice = system.instance(LogService) system.start() client.reset() time.sleep(MEMBERLIST_DELAY_SLEEP_TIME) test_records_map = { "batch_records": batch_records, "metadata_records": metadata_records, "contains_records": contains_records, } collection = client.create_collection("testupsert") verify_records(logservice, collection, test_records_map, collection.upsert, 2)