Spaces:
Build error
Build error
File size: 6,182 Bytes
60e3a80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import array
from typing import Dict, Any, Callable
import pytest
from chromadb.config import System, Settings
from chromadb.logservice.logservice import LogService
from chromadb.test.conftest import skip_if_not_cluster
from chromadb.test.test_api import records # type: ignore
from chromadb.api.models.Collection import Collection
import time
batch_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["https://example.com/1", "https://example.com/2"],
}
metadata_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"ids": ["id1", "id2"],
"metadatas": [
{"int_value": 1, "string_value": "one", "float_value": 1.001},
{"int_value": 2},
],
}
contains_records = {
"embeddings": [[1.1, 2.3, 3.2], [1.2, 2.24, 3.2]],
"documents": ["this is doc1 and it's great!", "doc2 is also great!"],
"ids": ["id1", "id2"],
"metadatas": [
{"int_value": 1, "string_value": "one", "float_value": 1.001},
{"int_value": 2, "float_value": 2.002, "string_value": "two"},
],
}
# Sleep to allow memberlist to initialize after reset()
MEMBERLIST_DELAY_SLEEP_TIME = 5
def verify_records(
logservice: LogService,
collection: Collection,
test_records_map: Dict[str, Dict[str, Any]],
test_func: Callable, # type: ignore
operation: int,
) -> None:
start_offset = 1
for batch_records in test_records_map.values():
test_func(**batch_records)
pushed_records = logservice.pull_logs(collection.id, start_offset, 100)
assert len(pushed_records) == len(batch_records["ids"])
for i, record in enumerate(pushed_records):
assert record.record.id == batch_records["ids"][i]
assert record.record.operation == operation
embedding = array.array("f", batch_records["embeddings"][i]).tobytes()
assert record.record.vector.vector == embedding
metadata_count = 0
if "metadatas" in batch_records:
metadata_count += len(batch_records["metadatas"][i])
for key, value in batch_records["metadatas"][i].items():
if isinstance(value, int):
assert record.record.metadata.metadata[key].int_value == value
elif isinstance(value, float):
assert record.record.metadata.metadata[key].float_value == value
elif isinstance(value, str):
assert (
record.record.metadata.metadata[key].string_value == value
)
else:
assert False
if "documents" in batch_records:
metadata_count += 1
assert (
record.record.metadata.metadata["chroma:document"].string_value
== batch_records["documents"][i]
)
assert len(record.record.metadata.metadata) == metadata_count
start_offset += len(pushed_records)
@skip_if_not_cluster()
def test_add(client): # type: ignore
system = System(Settings(allow_reset=True))
logservice = system.instance(LogService)
system.start()
client.reset()
time.sleep(MEMBERLIST_DELAY_SLEEP_TIME)
test_records_map = {
"batch_records": batch_records,
"metadata_records": metadata_records,
"contains_records": contains_records,
}
collection = client.create_collection("testadd")
verify_records(logservice, collection, test_records_map, collection.add, 0)
@skip_if_not_cluster()
def test_update(client): # type: ignore
system = System(Settings(allow_reset=True))
logservice = system.instance(LogService)
system.start()
client.reset()
time.sleep(MEMBERLIST_DELAY_SLEEP_TIME)
test_records_map = {
"updated_records": {
"ids": [records["ids"][0]],
"embeddings": [[0.1, 0.2, 0.3]],
"metadatas": [{"foo": "bar"}],
},
}
collection = client.create_collection("testupdate")
verify_records(logservice, collection, test_records_map, collection.update, 1)
@skip_if_not_cluster()
def test_delete(client): # type: ignore
system = System(Settings(allow_reset=True))
logservice = system.instance(LogService)
system.start()
client.reset()
time.sleep(MEMBERLIST_DELAY_SLEEP_TIME)
collection = client.create_collection("testdelete")
# push 2 records
collection.add(**contains_records)
pushed_records = logservice.pull_logs(collection.id, 1, 100)
assert len(pushed_records) == 2
# TODO: These tests should be enabled when the distributed system has metadata segments
@pytest.mark.xfail
@skip_if_not_cluster()
def test_delete_filter(client): # type: ignore
system = System(Settings(allow_reset=True))
logservice = system.instance(LogService)
system.start()
client.reset()
time.sleep(MEMBERLIST_DELAY_SLEEP_TIME)
collection = client.create_collection("testdelete_filter")
# delete by where
collection.delete(where_document={"$contains": "doc1"})
collection.delete(where_document={"$contains": "bad"})
collection.delete(where_document={"$contains": "great"})
pushed_records = logservice.pull_logs(collection.id, 3, 100)
assert len(pushed_records) == 0
# delete by ids
collection.delete(ids=["id1", "id2"])
pushed_records = logservice.pull_logs(collection.id, 3, 100)
assert len(pushed_records) == 2
for record in pushed_records:
assert record.record.operation == 3
assert record.record.id in ["id1", "id2"]
@skip_if_not_cluster()
def test_upsert(client): # type: ignore
system = System(Settings(allow_reset=True))
logservice = system.instance(LogService)
system.start()
client.reset()
time.sleep(MEMBERLIST_DELAY_SLEEP_TIME)
test_records_map = {
"batch_records": batch_records,
"metadata_records": metadata_records,
"contains_records": contains_records,
}
collection = client.create_collection("testupsert")
verify_records(logservice, collection, test_records_map, collection.upsert, 2)
|