|
import os |
|
import pytest |
|
from unittest.mock import patch, MagicMock |
|
|
|
import chromadb |
|
from chromadb.db.impl.sqlite import SqliteDB |
|
from chromadb.config import System, Settings |
|
|
|
|
|
@pytest.mark.parametrize("migrations_hash_algorithm", [None, "md5", "sha256"]) |
|
@patch("chromadb.api.fastapi.FastAPI") |
|
@patch.dict(os.environ, {}, clear=True) |
|
def test_settings_valid_hash_algorithm( |
|
api_mock: MagicMock, migrations_hash_algorithm: str |
|
) -> None: |
|
""" |
|
Ensure that when no hash algorithm or a valid one is provided, the client is set up |
|
with that value |
|
""" |
|
if migrations_hash_algorithm: |
|
settings = chromadb.config.Settings( |
|
chroma_api_impl="chromadb.api.fastapi.FastAPI", |
|
is_persistent=True, |
|
persist_directory="./foo", |
|
migrations_hash_algorithm=migrations_hash_algorithm, |
|
) |
|
else: |
|
settings = chromadb.config.Settings( |
|
chroma_api_impl="chromadb.api.fastapi.FastAPI", |
|
is_persistent=True, |
|
persist_directory="./foo", |
|
) |
|
|
|
client = chromadb.Client(settings) |
|
|
|
|
|
assert api_mock.called |
|
|
|
|
|
|
|
|
|
|
|
args, kwargs = api_mock.call_args |
|
passed_settings = args[0] if args else None |
|
|
|
|
|
expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5" |
|
assert passed_settings |
|
assert ( |
|
getattr(passed_settings.settings, "migrations_hash_algorithm", None) |
|
== expected_migrations_hash_algorithm |
|
) |
|
client.clear_system_cache() |
|
|
|
|
|
@patch("chromadb.api.fastapi.FastAPI") |
|
@patch.dict(os.environ, {}, clear=True) |
|
def test_settings_invalid_hash_algorithm(mock: MagicMock) -> None: |
|
""" |
|
Ensure that providing an invalid hash results in a raised exception and the client |
|
is not called |
|
""" |
|
with pytest.raises(Exception): |
|
settings = chromadb.config.Settings( |
|
chroma_api_impl="chromadb.api.fastapi.FastAPI", |
|
migrations_hash_algorithm="invalid_hash_alg", |
|
persist_directory="./foo", |
|
) |
|
|
|
chromadb.Client(settings) |
|
|
|
assert not mock.called |
|
|
|
|
|
@pytest.mark.parametrize("migrations_hash_algorithm", ["md5", "sha256"]) |
|
@patch("chromadb.db.migrations.verify_migration_sequence") |
|
@patch("chromadb.db.migrations.hashlib") |
|
@patch.dict(os.environ, {}, clear=True) |
|
def test_hashlib_alg( |
|
hashlib_mock: MagicMock, |
|
verify_migration_sequence_mock: MagicMock, |
|
migrations_hash_algorithm: str, |
|
) -> None: |
|
""" |
|
Test that only the appropriate hashlib functions are called |
|
""" |
|
db = SqliteDB( |
|
System( |
|
Settings( |
|
migrations="apply", |
|
allow_reset=True, |
|
migrations_hash_algorithm=migrations_hash_algorithm, |
|
) |
|
) |
|
) |
|
|
|
|
|
db.apply_migration = MagicMock() |
|
db.config = MagicMock() |
|
|
|
|
|
|
|
|
|
verify_migration_sequence_mock.return_value = ["unapplied_migration"] |
|
|
|
db.start() |
|
|
|
assert db.apply_migration.called |
|
|
|
|
|
expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5" |
|
|
|
|
|
if expected_migrations_hash_algorithm == "md5": |
|
assert hashlib_mock.md5.called |
|
assert not hashlib_mock.sha256.called |
|
elif expected_migrations_hash_algorithm == "sha256": |
|
assert not hashlib_mock.md5.called |
|
assert hashlib_mock.sha256.called |
|
else: |
|
|
|
assert False |
|
|