himanshud2611's picture
Upload folder using huggingface_hub
60e3a80 verified
import os
import shutil
import tempfile
import pytest
from typing import (
Generator,
List,
Callable,
Iterator,
Dict,
Optional,
Union,
Sequence,
)
from chromadb.api.types import validate_metadata
from chromadb.config import System, Settings
from chromadb.db.base import ParameterValue, get_sql
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.test.conftest import ProducerFn
from chromadb.types import (
OperationRecord,
MetadataEmbeddingRecord,
Operation,
RequestVersionContext,
ScalarEncoding,
Segment,
SegmentScope,
SeqId,
)
from pypika import Table
from chromadb.ingest import Producer
from chromadb.segment import MetadataReader
import uuid
import time
from chromadb.segment.impl.metadata.sqlite import SqliteMetadataSegment
from pytest import FixtureRequest
from itertools import count
def sqlite() -> Generator[System, None, None]:
"""Fixture generator for sqlite DB"""
settings = Settings(allow_reset=True, is_persistent=False)
system = System(settings)
system.start()
yield system
system.stop()
def sqlite_persistent() -> Generator[System, None, None]:
"""Fixture generator for sqlite DB"""
save_path = tempfile.mkdtemp()
settings = Settings(
allow_reset=True, is_persistent=True, persist_directory=save_path
)
system = System(settings)
system.start()
yield system
system.stop()
if os.path.exists(save_path):
shutil.rmtree(save_path)
def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]:
return [sqlite, sqlite_persistent]
@pytest.fixture(scope="module", params=system_fixtures())
def system(request: FixtureRequest) -> Generator[System, None, None]:
yield next(request.param())
@pytest.fixture(scope="function")
def sample_embeddings() -> Iterator[OperationRecord]:
def create_record(i: int) -> OperationRecord:
vector = [i + i * 0.1, i + 1 + i * 0.1]
metadata: Optional[Dict[str, Union[str, int, float, bool]]]
if i == 0:
metadata = None
else:
metadata = {
"str_key": f"value_{i}",
"int_key": i,
"float_key": i + i * 0.1,
"bool_key": True,
}
if i % 3 == 0:
metadata["div_by_three"] = "true"
if i % 2 == 0:
metadata["bool_key"] = False
metadata["chroma:document"] = _build_document(i)
record = OperationRecord(
id=f"embedding_{i}",
embedding=vector, # type: ignore[typeddict-item]
encoding=ScalarEncoding.FLOAT32,
metadata=metadata,
operation=Operation.ADD,
)
return record
return (create_record(i) for i in count())
_digit_map = {
"0": "zero",
"1": "one",
"2": "two",
"3": "three",
"4": "four",
"5": "five",
"6": "six",
"7": "seven",
"8": "eight",
"9": "nine",
}
def _build_document(i: int) -> str:
digits = list(str(i))
return " ".join(_digit_map[d] for d in digits)
segment_definition = Segment(
id=uuid.uuid4(),
type="test_type",
scope=SegmentScope.METADATA,
collection=uuid.UUID(int=0),
metadata=None,
)
segment_definition2 = Segment(
id=uuid.uuid4(),
type="test_type",
scope=SegmentScope.METADATA,
collection=uuid.UUID(int=1),
metadata=None,
)
def sync(segment: MetadataReader, seq_id: SeqId) -> None:
# Try for up to 5 seconds, then throw a TimeoutError
start = time.time()
while time.time() - start < 5:
if segment.max_seqid() >= seq_id:
return
time.sleep(0.25)
raise TimeoutError(f"Timed out waiting for seq_id {seq_id}")
def test_insert_and_count(
system: System,
sample_embeddings: Iterator[OperationRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
collection_id = segment_definition["collection"]
# We know that the collection_id exists so we can cast
collection_id = collection_id
max_id = produce_fns(producer, collection_id, sample_embeddings, 3)[1][-1]
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
sync(segment, max_id)
assert (
segment.count(
request_version_context=RequestVersionContext(
collection_version=0, log_position=0
)
)
== 3
)
for i in range(3):
max_id = producer.submit_embedding(collection_id, next(sample_embeddings))
sync(segment, max_id)
assert (
segment.count(
request_version_context=RequestVersionContext(
collection_version=0, log_position=0
)
)
== 6
)
def assert_equiv_records(
expected: Sequence[OperationRecord], actual: Sequence[MetadataEmbeddingRecord]
) -> None:
assert len(expected) == len(actual)
sorted_expected = sorted(expected, key=lambda r: r["id"])
sorted_actual = sorted(actual, key=lambda r: r["id"])
for e, a in zip(sorted_expected, sorted_actual):
assert e["id"] == a["id"]
assert e["metadata"] == a["metadata"]
def test_get(
system: System,
sample_embeddings: Iterator[OperationRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
collection_id = segment_definition["collection"]
# We know that the collection_id exists so we can cast
collection_id = collection_id
embeddings, seq_ids = produce_fns(producer, collection_id, sample_embeddings, 10)
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
sync(segment, seq_ids[-1])
request_version_context = RequestVersionContext(
collection_version=0, log_position=0
)
# get with bool key
result = segment.get_metadata(
where={"bool_key": True}, request_version_context=request_version_context
)
assert len(result) == 5
result = segment.get_metadata(
where={"bool_key": False}, request_version_context=request_version_context
)
assert len(result) == 4
# Get all records
results = segment.get_metadata(request_version_context=request_version_context)
assert_equiv_records(embeddings, results)
# get by ID
result = segment.get_metadata(
ids=[e["id"] for e in embeddings[0:5]],
request_version_context=request_version_context,
)
assert_equiv_records(embeddings[0:5], result)
# Get with limit and offset
# Cannot rely on order(yet), but can rely on retrieving exactly the
# whole set eventually
ret: List[MetadataEmbeddingRecord] = []
ret.extend(
segment.get_metadata(limit=3, request_version_context=request_version_context)
)
assert len(ret) == 3
ret.extend(
segment.get_metadata(
limit=3, offset=3, request_version_context=request_version_context
)
)
assert len(ret) == 6
ret.extend(
segment.get_metadata(
limit=3, offset=6, request_version_context=request_version_context
)
)
assert len(ret) == 9
ret.extend(
segment.get_metadata(
limit=3, offset=9, request_version_context=request_version_context
)
)
assert len(ret) == 10
assert_equiv_records(embeddings, ret)
# Get with simple where
result = segment.get_metadata(
where={"div_by_three": "true"}, request_version_context=request_version_context
)
assert len(result) == 3
# Get with gt/gte/lt/lte on int keys
result = segment.get_metadata(
where={"int_key": {"$gt": 5}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 4
result = segment.get_metadata(
where={"int_key": {"$gte": 5}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 5
result = segment.get_metadata(
where={"int_key": {"$lt": 5}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 4
result = segment.get_metadata(
where={"int_key": {"$lte": 5}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 5
# Get with gt/lt on float keys with float values
result = segment.get_metadata(
where={"float_key": {"$gt": 5.01}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 5
result = segment.get_metadata(
where={"float_key": {"$lt": 4.99}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 4
# Get with gt/lt on float keys with int values
result = segment.get_metadata(
where={"float_key": {"$gt": 5}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 5
result = segment.get_metadata(
where={"float_key": {"$lt": 5}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 4
# Get with gt/lt on int keys with float values
result = segment.get_metadata(
where={"int_key": {"$gt": 5.01}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 4
result = segment.get_metadata(
where={"int_key": {"$lt": 4.99}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 4
# Get with $ne
# Returns metadata that has an int_key but not equal to 5, or without an int_key
result = segment.get_metadata(
where={"int_key": {"$ne": 5}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 9
# get with multiple heterogenous conditions
result = segment.get_metadata(
where={"div_by_three": "true", "int_key": {"$gt": 5}}, # type:ignore[dict-item]
request_version_context=request_version_context,
)
assert len(result) == 2
# get with OR conditions
result = segment.get_metadata(
where={"$or": [{"int_key": 1}, {"int_key": 2}]},
request_version_context=request_version_context,
)
assert len(result) == 2
# get with AND conditions
result = segment.get_metadata(
where={
"$and": [
{"int_key": 3},
{"float_key": {"$gt": 5}}, # type:ignore[dict-item]
]
},
request_version_context=request_version_context,
)
assert len(result) == 0
result = segment.get_metadata(
where={
"$and": [
{"int_key": 3},
{"float_key": {"$lt": 5}}, # type:ignore[dict-item]
]
},
request_version_context=request_version_context,
)
assert len(result) == 1
def test_fulltext(
system: System,
sample_embeddings: Iterator[OperationRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
collection_id = segment_definition["collection"]
# We know that the collection_id exists so we can cast
collection_id = collection_id
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
max_id = produce_fns(producer, collection_id, sample_embeddings, 100)[1][-1]
sync(segment, max_id)
request_version_context = RequestVersionContext(
collection_version=0, log_position=0
)
result = segment.get_metadata(
where={"chroma:document": "four two"},
request_version_context=request_version_context,
)
result2 = segment.get_metadata(
ids=["embedding_42"], request_version_context=request_version_context
)
assert result == result2
# Test single result
result = segment.get_metadata(
where_document={"$contains": "four two"},
request_version_context=request_version_context,
)
assert len(result) == 1
# Test not_contains
# Returns records without documents or with documents not containing the searched text.
result = segment.get_metadata(
where_document={"$not_contains": "four two"},
request_version_context=request_version_context,
)
assert (
len(result)
== len([i for i in range(1, 100) if "four two" not in _build_document(i)]) + 1
) # The first record does not have a document, which should be included in the result
# Test many results
result = segment.get_metadata(
where_document={"$contains": "zero"},
request_version_context=request_version_context,
)
assert len(result) == 9
# Test not_contains
result = segment.get_metadata(
where_document={"$not_contains": "zero"},
request_version_context=request_version_context,
)
assert (
len(result)
== len([i for i in range(1, 100) if "zero" not in _build_document(i)]) + 1
) # The first record does not have a document, which should be included in the result
# test $and
result = segment.get_metadata(
where_document={"$and": [{"$contains": "four"}, {"$contains": "two"}]},
request_version_context=request_version_context,
)
assert len(result) == 2
assert set([r["id"] for r in result]) == {"embedding_42", "embedding_24"}
result = segment.get_metadata(
where_document={"$and": [{"$not_contains": "four"}, {"$not_contains": "two"}]},
request_version_context=request_version_context,
)
assert (
len(result)
== len(
[
i
for i in range(1, 100)
if "four" not in _build_document(i) and "two" not in _build_document(i)
]
)
+ 1
) # The first record does not have a document, which should be included in the result
# test $or
result = segment.get_metadata(
where_document={"$or": [{"$contains": "zero"}, {"$contains": "one"}]},
request_version_context=request_version_context,
)
ones = [i for i in range(1, 100) if "one" in _build_document(i)]
zeros = [i for i in range(1, 100) if "zero" in _build_document(i)]
expected = set([f"embedding_{i}" for i in set(ones + zeros)])
assert set([r["id"] for r in result]) == expected
result = segment.get_metadata(
where_document={"$or": [{"$not_contains": "zero"}, {"$not_contains": "one"}]},
request_version_context=request_version_context,
)
assert (
len(result)
== len(
[
i
for i in range(1, 100)
if "zero" not in _build_document(i) or "one" not in _build_document(i)
]
)
+ 1
) # The first record does not have a document, which should be included in the result
# test combo with where clause (negative case)
result = segment.get_metadata(
where={"int_key": {"$eq": 42}}, # type:ignore[dict-item]
where_document={"$contains": "zero"},
request_version_context=request_version_context,
)
assert len(result) == 0
# test combo with where clause (positive case)
result = segment.get_metadata(
where={"int_key": {"$eq": 42}}, # type:ignore[dict-item]
where_document={"$contains": "four"},
request_version_context=request_version_context,
)
assert len(result) == 1
# test partial words
result = segment.get_metadata(
where_document={"$contains": "zer"},
request_version_context=request_version_context,
)
assert len(result) == 9
def test_delete(
system: System,
sample_embeddings: Iterator[OperationRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
collection_id = segment_definition["collection"]
# We know that the collection_id exists so we can cast
collection_id = collection_id
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
embeddings, seq_ids = produce_fns(producer, collection_id, sample_embeddings, 10)
max_id = seq_ids[-1]
sync(segment, max_id)
version_context = RequestVersionContext(collection_version=0, log_position=0)
assert segment.count(request_version_context=version_context) == 10
results = segment.get_metadata(
ids=["embedding_0"], request_version_context=version_context
)
assert_equiv_records(embeddings[:1], results)
# Delete by ID
delete_embedding = OperationRecord(
id="embedding_0",
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
)
max_id = produce_fns(
producer, collection_id, (delete_embedding for _ in range(1)), 1
)[1][-1]
sync(segment, max_id)
version_context = RequestVersionContext(collection_version=0, log_position=0)
assert segment.count(request_version_context=version_context) == 9
assert (
segment.get_metadata(
ids=["embedding_0"], request_version_context=version_context
)
== []
)
# Delete is idempotent
max_id = produce_fns(
producer, collection_id, (delete_embedding for _ in range(1)), 1
)[1][-1]
sync(segment, max_id)
assert segment.count(request_version_context=version_context) == 9
assert (
segment.get_metadata(
ids=["embedding_0"], request_version_context=version_context
)
== []
)
# re-add
max_id = producer.submit_embedding(collection_id, embeddings[0])
sync(segment, max_id)
assert segment.count(request_version_context=version_context) == 10
results = segment.get_metadata(
ids=["embedding_0"], request_version_context=version_context
)
def test_update(system: System, sample_embeddings: Iterator[OperationRecord]) -> None:
producer = system.instance(Producer)
system.reset_state()
collection_id = segment_definition["collection"]
# We know that the collection_id exists so we can cast
collection_id = collection_id
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
_test_update(sample_embeddings, producer, segment, collection_id, Operation.UPDATE)
# Update nonexisting ID
update_record = OperationRecord(
id="no_such_id",
metadata={"foo": "bar"},
embedding=None,
encoding=None,
operation=Operation.UPDATE,
)
max_id = producer.submit_embedding(collection_id, update_record)
sync(segment, max_id)
request_version_context = RequestVersionContext(
collection_version=0, log_position=0
)
results = segment.get_metadata(
ids=["no_such_id"], request_version_context=request_version_context
)
assert len(results) == 0
assert segment.count(request_version_context=request_version_context) == 3
def test_upsert(
system: System,
sample_embeddings: Iterator[OperationRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
collection_id = segment_definition["collection"]
# We know that the collection_id exists so we can cast
collection_id = collection_id
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
_test_update(sample_embeddings, producer, segment, collection_id, Operation.UPSERT)
# upsert previously nonexisting ID
update_record = OperationRecord(
id="no_such_id",
metadata={"foo": "bar"},
embedding=None,
encoding=None,
operation=Operation.UPSERT,
)
max_id = produce_fns(
producer=producer,
collection_id=collection_id,
embeddings=(update_record for _ in range(1)),
n=1,
)[1][-1]
sync(segment, max_id)
request_version_context = RequestVersionContext(
collection_version=0, log_position=0
)
results = segment.get_metadata(
ids=["no_such_id"], request_version_context=request_version_context
)
assert results[0]["metadata"] == {"foo": "bar"}
def _test_update(
sample_embeddings: Iterator[OperationRecord],
producer: Producer,
segment: MetadataReader,
collection_id: uuid.UUID,
op: Operation,
) -> None:
"""test code common between update and upsert paths"""
embeddings = [next(sample_embeddings) for i in range(3)]
max_id = 0
for e in embeddings:
max_id = producer.submit_embedding(collection_id, e)
sync(segment, max_id)
request_version_context = RequestVersionContext(
collection_version=0, log_position=0
)
results = segment.get_metadata(
ids=["embedding_0"], request_version_context=request_version_context
)
assert_equiv_records(embeddings[:1], results)
# Update embedding with no metadata
update_record = OperationRecord(
id="embedding_0",
metadata={"chroma:document": "foo bar"},
embedding=None,
encoding=None,
operation=op,
)
max_id = producer.submit_embedding(collection_id, update_record)
sync(segment, max_id)
results = segment.get_metadata(
ids=["embedding_0"], request_version_context=request_version_context
)
assert results[0]["metadata"] == {"chroma:document": "foo bar"}
results = segment.get_metadata(
where_document={"$contains": "foo"},
request_version_context=request_version_context,
)
assert results[0]["metadata"] == {"chroma:document": "foo bar"}
# Update and overrwrite key
update_record = OperationRecord(
id="embedding_0",
metadata={"chroma:document": "biz buz"},
embedding=None,
encoding=None,
operation=op,
)
max_id = producer.submit_embedding(collection_id, update_record)
sync(segment, max_id)
results = segment.get_metadata(
ids=["embedding_0"], request_version_context=request_version_context
)
assert results[0]["metadata"] == {"chroma:document": "biz buz"}
results = segment.get_metadata(
where_document={"$contains": "biz"},
request_version_context=request_version_context,
)
assert results[0]["metadata"] == {"chroma:document": "biz buz"}
results = segment.get_metadata(
where_document={"$contains": "foo"},
request_version_context=request_version_context,
)
assert len(results) == 0
# Update and add key
update_record = OperationRecord(
id="embedding_0",
metadata={"baz": 42},
embedding=None,
encoding=None,
operation=op,
)
max_id = producer.submit_embedding(collection_id, update_record)
sync(segment, max_id)
results = segment.get_metadata(
ids=["embedding_0"], request_version_context=request_version_context
)
assert results[0]["metadata"] == {"chroma:document": "biz buz", "baz": 42}
# Update and delete key
update_record = OperationRecord(
id="embedding_0",
metadata={"chroma:document": None},
embedding=None,
encoding=None,
operation=op,
)
max_id = producer.submit_embedding(collection_id, update_record)
sync(segment, max_id)
results = segment.get_metadata(
ids=["embedding_0"], request_version_context=request_version_context
)
assert results[0]["metadata"] == {"baz": 42}
results = segment.get_metadata(
where_document={"$contains": "biz"},
request_version_context=request_version_context,
)
assert len(results) == 0
def test_limit(
system: System,
sample_embeddings: Iterator[OperationRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
collection_id = segment_definition["collection"]
max_id = produce_fns(producer, collection_id, sample_embeddings, 3)[1][-1]
collection_id_2 = segment_definition2["collection"]
max_id2 = produce_fns(producer, collection_id_2, sample_embeddings, 3)[1][-1]
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
segment2 = SqliteMetadataSegment(system, segment_definition2)
segment2.start()
sync(segment, max_id)
sync(segment2, max_id2)
request_version_context = RequestVersionContext(
collection_version=0, log_position=0
)
assert segment.count(request_version_context=request_version_context) == 3
for i in range(3):
max_id = producer.submit_embedding(collection_id, next(sample_embeddings))
sync(segment, max_id)
assert segment.count(request_version_context=request_version_context) == 6
res = segment.get_metadata(limit=3, request_version_context=request_version_context)
assert len(res) == 3
# if limit is negative, throw error
with pytest.raises(ValueError):
segment.get_metadata(limit=-1, request_version_context=request_version_context)
# if offset is more than number of results, return empty list
res = segment.get_metadata(
limit=3, offset=10, request_version_context=request_version_context
)
assert len(res) == 0
def test_delete_segment(
system: System,
sample_embeddings: Iterator[OperationRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
collection_id = segment_definition["collection"]
# We know that the collection_id exists so we can cast
collection_id = collection_id
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
embeddings, seq_ids = produce_fns(producer, collection_id, sample_embeddings, 10)
max_id = seq_ids[-1]
sync(segment, max_id)
request_version_context = RequestVersionContext(
collection_version=0, log_position=0
)
assert segment.count(request_version_context=request_version_context) == 10
results = segment.get_metadata(
ids=["embedding_0"], request_version_context=request_version_context
)
assert_equiv_records(embeddings[:1], results)
_id = segment._id
segment.delete()
_db = system.instance(SqliteDB)
t = Table("embeddings")
q = (
_db.querybuilder()
.from_(t)
.select(t.id)
.where(t.segment_id == ParameterValue(_db.uuid_to_db(_id)))
)
sql, params = get_sql(q)
with _db.tx() as cur:
res = cur.execute(sql, params)
# assert that the segment is gone
assert len(res.fetchall()) == 0
fts_t = Table("embedding_fulltext_search")
q_fts = (
_db.querybuilder()
.from_(fts_t)
.select()
.where(
fts_t.rowid.isin(
_db.querybuilder()
.from_(t)
.select(t.id)
.where(t.segment_id == ParameterValue(_db.uuid_to_db(_id)))
)
)
)
sql, params = get_sql(q_fts)
with _db.tx() as cur:
res = cur.execute(sql, params)
# assert that all FTS rows are gone
assert len(res.fetchall()) == 0
def test_delete_single_fts_record(
system: System,
sample_embeddings: Iterator[OperationRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
collection_id = segment_definition["collection"]
# We know that the collection_id exists so we can cast
collection_id = collection_id
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
embeddings, seq_ids = produce_fns(producer, collection_id, sample_embeddings, 10)
max_id = seq_ids[-1]
sync(segment, max_id)
request_version_context = RequestVersionContext(
collection_version=0, log_position=0
)
assert segment.count(request_version_context=request_version_context) == 10
results = segment.get_metadata(
ids=["embedding_0"], request_version_context=request_version_context
)
assert_equiv_records(embeddings[:1], results)
_id = segment._id
_db = system.instance(SqliteDB)
# Delete by ID
delete_embedding = OperationRecord(
id="embedding_0",
embedding=None,
encoding=None,
metadata=None,
operation=Operation.DELETE,
)
max_id = produce_fns(
producer, collection_id, (delete_embedding for _ in range(1)), 1
)[1][-1]
t = Table("embeddings")
sync(segment, max_id)
fts_t = Table("embedding_fulltext_search")
q_fts = (
_db.querybuilder()
.from_(fts_t)
.select()
.where(
fts_t.rowid.isin(
_db.querybuilder()
.from_(t)
.select(t.id)
.where(t.segment_id == ParameterValue(_db.uuid_to_db(_id)))
.where(t.embedding_id == ParameterValue(delete_embedding["id"]))
)
)
)
sql, params = get_sql(q_fts)
with _db.tx() as cur:
res = cur.execute(sql, params)
# assert that the ids that are deleted from the segment are also gone from the fts table
assert len(res.fetchall()) == 0
def test_include_metadata(
system: System,
sample_embeddings: Iterator[OperationRecord],
produce_fns: ProducerFn,
) -> None:
producer = system.instance(Producer)
system.reset_state()
collection_id = segment_definition["collection"]
# We know that the collection_id exists so we can cast
collection_id = collection_id
segment = SqliteMetadataSegment(system, segment_definition)
segment.start()
embeddings, seq_ids = produce_fns(producer, collection_id, sample_embeddings, 10)
max_id = seq_ids[-1]
sync(segment, max_id)
request_version_context = RequestVersionContext(
collection_version=0, log_position=0
)
assert segment.count(request_version_context=request_version_context) == 10
results = segment.get_metadata(
ids=["embedding_0"], request_version_context=request_version_context
)
assert_equiv_records(embeddings[:1], results)
# Test include_metadata=False
results = segment.get_metadata(
ids=["embedding_0"],
include_metadata=False,
request_version_context=request_version_context,
)
assert len(results) == 1
assert results[0]["metadata"] is None
# Test include_metadata=True
results = segment.get_metadata(
ids=["embedding_0"],
include_metadata=True,
request_version_context=request_version_context,
)
assert len(results) == 1
assert results[0]["metadata"] == embeddings[0]["metadata"]
def test_metadata_validation_forbidden_key() -> None:
with pytest.raises(ValueError, match="chroma:document"):
validate_metadata(
{"chroma:document": "this is not the document you are looking for"}
)