Spaces:
Build error
Build error
import array | |
from typing import Dict, Any, Callable | |
import pytest | |
from chromadb.config import System, Settings | |
from chromadb.logservice.logservice import LogService | |
from chromadb.test.conftest import skip_if_not_cluster | |
from chromadb.test.test_api import records # type: ignore | |
from chromadb.api.models.Collection import Collection | |
import time | |
batch_records = { | |
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], | |
"ids": ["https://example.com/1", "https://example.com/2"], | |
} | |
metadata_records = { | |
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], | |
"ids": ["id1", "id2"], | |
"metadatas": [ | |
{"int_value": 1, "string_value": "one", "float_value": 1.001}, | |
{"int_value": 2}, | |
], | |
} | |
contains_records = { | |
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]], | |
"documents": ["this is doc1 and it's great!", "doc2 is also great!"], | |
"ids": ["id1", "id2"], | |
"metadatas": [ | |
{"int_value": 1, "string_value": "one", "float_value": 1.001}, | |
{"int_value": 2, "float_value": 2.002, "string_value": "two"}, | |
], | |
} | |
# Sleep to allow memberlist to initialize after reset() | |
MEMBERLIST_DELAY_SLEEP_TIME = 5 | |
def verify_records( | |
logservice: LogService, | |
collection: Collection, | |
test_records_map: Dict[str, Dict[str, Any]], | |
test_func: Callable, # type: ignore | |
operation: int, | |
) -> None: | |
start_offset = 1 | |
for batch_records in test_records_map.values(): | |
test_func(**batch_records) | |
pushed_records = logservice.pull_logs(collection.id, start_offset, 100) | |
assert len(pushed_records) == len(batch_records["ids"]) | |
for i, record in enumerate(pushed_records): | |
assert record.record.id == batch_records["ids"][i] | |
assert record.record.operation == operation | |
embedding = array.array("f", batch_records["embeddings"][i]).tobytes() | |
assert record.record.vector.vector == embedding | |
metadata_count = 0 | |
if "metadatas" in batch_records: | |
metadata_count += len(batch_records["metadatas"][i]) | |
for key, value in batch_records["metadatas"][i].items(): | |
if isinstance(value, int): | |
assert record.record.metadata.metadata[key].int_value == value | |
elif isinstance(value, float): | |
assert record.record.metadata.metadata[key].float_value == value | |
elif isinstance(value, str): | |
assert ( | |
record.record.metadata.metadata[key].string_value == value | |
) | |
else: | |
assert False | |
if "documents" in batch_records: | |
metadata_count += 1 | |
assert ( | |
record.record.metadata.metadata["chroma:document"].string_value | |
== batch_records["documents"][i] | |
) | |
assert len(record.record.metadata.metadata) == metadata_count | |
start_offset += len(pushed_records) | |
def test_add(client): # type: ignore | |
system = System(Settings(allow_reset=True)) | |
logservice = system.instance(LogService) | |
system.start() | |
client.reset() | |
time.sleep(MEMBERLIST_DELAY_SLEEP_TIME) | |
test_records_map = { | |
"batch_records": batch_records, | |
"metadata_records": metadata_records, | |
"contains_records": contains_records, | |
} | |
collection = client.create_collection("testadd") | |
verify_records(logservice, collection, test_records_map, collection.add, 0) | |
def test_update(client): # type: ignore | |
system = System(Settings(allow_reset=True)) | |
logservice = system.instance(LogService) | |
system.start() | |
client.reset() | |
time.sleep(MEMBERLIST_DELAY_SLEEP_TIME) | |
test_records_map = { | |
"updated_records": { | |
"ids": [records["ids"][0]], | |
"embeddings": [[0.1, 0.2, 0.3]], | |
"metadatas": [{"foo": "bar"}], | |
}, | |
} | |
collection = client.create_collection("testupdate") | |
verify_records(logservice, collection, test_records_map, collection.update, 1) | |
def test_delete(client): # type: ignore | |
system = System(Settings(allow_reset=True)) | |
logservice = system.instance(LogService) | |
system.start() | |
client.reset() | |
time.sleep(MEMBERLIST_DELAY_SLEEP_TIME) | |
collection = client.create_collection("testdelete") | |
# push 2 records | |
collection.add(**contains_records) | |
pushed_records = logservice.pull_logs(collection.id, 1, 100) | |
assert len(pushed_records) == 2 | |
# TODO: These tests should be enabled when the distributed system has metadata segments | |
def test_delete_filter(client): # type: ignore | |
system = System(Settings(allow_reset=True)) | |
logservice = system.instance(LogService) | |
system.start() | |
client.reset() | |
time.sleep(MEMBERLIST_DELAY_SLEEP_TIME) | |
collection = client.create_collection("testdelete_filter") | |
# delete by where | |
collection.delete(where_document={"$contains": "doc1"}) | |
collection.delete(where_document={"$contains": "bad"}) | |
collection.delete(where_document={"$contains": "great"}) | |
pushed_records = logservice.pull_logs(collection.id, 3, 100) | |
assert len(pushed_records) == 0 | |
# delete by ids | |
collection.delete(ids=["id1", "id2"]) | |
pushed_records = logservice.pull_logs(collection.id, 3, 100) | |
assert len(pushed_records) == 2 | |
for record in pushed_records: | |
assert record.record.operation == 3 | |
assert record.record.id in ["id1", "id2"] | |
def test_upsert(client): # type: ignore | |
system = System(Settings(allow_reset=True)) | |
logservice = system.instance(LogService) | |
system.start() | |
client.reset() | |
time.sleep(MEMBERLIST_DELAY_SLEEP_TIME) | |
test_records_map = { | |
"batch_records": batch_records, | |
"metadata_records": metadata_records, | |
"contains_records": contains_records, | |
} | |
collection = client.create_collection("testupsert") | |
verify_records(logservice, collection, test_records_map, collection.upsert, 2) | |