himanshud2611's picture
Upload folder using huggingface_hub
60e3a80 verified
import os
import pytest
from unittest.mock import patch, MagicMock
import chromadb
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.config import System, Settings
@pytest.mark.parametrize("migrations_hash_algorithm", [None, "md5", "sha256"])
@patch("chromadb.api.fastapi.FastAPI")
@patch.dict(os.environ, {}, clear=True)
def test_settings_valid_hash_algorithm(
api_mock: MagicMock, migrations_hash_algorithm: str
) -> None:
"""
Ensure that when no hash algorithm or a valid one is provided, the client is set up
with that value
"""
if migrations_hash_algorithm:
settings = chromadb.config.Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
is_persistent=True,
persist_directory="./foo",
migrations_hash_algorithm=migrations_hash_algorithm,
)
else:
settings = chromadb.config.Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
is_persistent=True,
persist_directory="./foo",
)
client = chromadb.Client(settings)
# Check that the mock was called
assert api_mock.called
# Retrieve the arguments with which the mock was called
# `call_args` returns a tuple, where the first element is a tuple of positional arguments
# and the second element is a dictionary of keyword arguments. We assume here that
# the settings object is passed as a positional argument.
args, kwargs = api_mock.call_args
passed_settings = args[0] if args else None
# Check if the default hash algorith was set
expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5"
assert passed_settings
assert (
getattr(passed_settings.settings, "migrations_hash_algorithm", None)
== expected_migrations_hash_algorithm
)
client.clear_system_cache()
@patch("chromadb.api.fastapi.FastAPI")
@patch.dict(os.environ, {}, clear=True)
def test_settings_invalid_hash_algorithm(mock: MagicMock) -> None:
"""
Ensure that providing an invalid hash results in a raised exception and the client
is not called
"""
with pytest.raises(Exception):
settings = chromadb.config.Settings(
chroma_api_impl="chromadb.api.fastapi.FastAPI",
migrations_hash_algorithm="invalid_hash_alg",
persist_directory="./foo",
)
chromadb.Client(settings)
assert not mock.called
@pytest.mark.parametrize("migrations_hash_algorithm", ["md5", "sha256"])
@patch("chromadb.db.migrations.verify_migration_sequence")
@patch("chromadb.db.migrations.hashlib")
@patch.dict(os.environ, {}, clear=True)
def test_hashlib_alg(
hashlib_mock: MagicMock,
verify_migration_sequence_mock: MagicMock,
migrations_hash_algorithm: str,
) -> None:
"""
Test that only the appropriate hashlib functions are called
"""
db = SqliteDB(
System(
Settings(
migrations="apply",
allow_reset=True,
migrations_hash_algorithm=migrations_hash_algorithm,
)
)
)
# replace the real migration application call with a mock we can check
db.apply_migration = MagicMock() # type: ignore [method-assign]
db.config = MagicMock()
# we don't want `verify_migration_sequence` to actually run since a) we're not testing that functionality and
# b) db may be cached between tests, and we're changing the algorithm, so it may fail.
# Instead, return a fake unapplied migration (expect `apply_migration` to be called after)
verify_migration_sequence_mock.return_value = ["unapplied_migration"]
db.start()
assert db.apply_migration.called
# Check if the default hash algorith was set
expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5"
# check that the right algorithm was used
if expected_migrations_hash_algorithm == "md5":
assert hashlib_mock.md5.called
assert not hashlib_mock.sha256.called
elif expected_migrations_hash_algorithm == "sha256":
assert not hashlib_mock.md5.called
assert hashlib_mock.sha256.called
else:
# we only support the algorithms above
assert False