File size: 4,245 Bytes
60e3a80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import pytest
from unittest.mock import patch, MagicMock

import chromadb
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.config import System, Settings


@pytest.mark.parametrize("migrations_hash_algorithm", [None, "md5", "sha256"])
@patch("chromadb.api.fastapi.FastAPI")
@patch.dict(os.environ, {}, clear=True)
def test_settings_valid_hash_algorithm(
    api_mock: MagicMock, migrations_hash_algorithm: str
) -> None:
    """
    Ensure that when no hash algorithm or a valid one is provided, the client is set up
    with that value
    """
    if migrations_hash_algorithm:
        settings = chromadb.config.Settings(
            chroma_api_impl="chromadb.api.fastapi.FastAPI",
            is_persistent=True,
            persist_directory="./foo",
            migrations_hash_algorithm=migrations_hash_algorithm,
        )
    else:
        settings = chromadb.config.Settings(
            chroma_api_impl="chromadb.api.fastapi.FastAPI",
            is_persistent=True,
            persist_directory="./foo",
        )

    client = chromadb.Client(settings)

    # Check that the mock was called
    assert api_mock.called

    # Retrieve the arguments with which the mock was called
    # `call_args` returns a tuple, where the first element is a tuple of positional arguments
    # and the second element is a dictionary of keyword arguments. We assume here that
    # the settings object is passed as a positional argument.
    args, kwargs = api_mock.call_args
    passed_settings = args[0] if args else None

    # Check if the default hash algorith was set
    expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5"
    assert passed_settings
    assert (
        getattr(passed_settings.settings, "migrations_hash_algorithm", None)
        == expected_migrations_hash_algorithm
    )
    client.clear_system_cache()


@patch("chromadb.api.fastapi.FastAPI")
@patch.dict(os.environ, {}, clear=True)
def test_settings_invalid_hash_algorithm(mock: MagicMock) -> None:
    """
    Ensure that providing an invalid hash results in a raised exception and the client
    is not called
    """
    with pytest.raises(Exception):
        settings = chromadb.config.Settings(
            chroma_api_impl="chromadb.api.fastapi.FastAPI",
            migrations_hash_algorithm="invalid_hash_alg",
            persist_directory="./foo",
        )

        chromadb.Client(settings)

    assert not mock.called


@pytest.mark.parametrize("migrations_hash_algorithm", ["md5", "sha256"])
@patch("chromadb.db.migrations.verify_migration_sequence")
@patch("chromadb.db.migrations.hashlib")
@patch.dict(os.environ, {}, clear=True)
def test_hashlib_alg(
    hashlib_mock: MagicMock,
    verify_migration_sequence_mock: MagicMock,
    migrations_hash_algorithm: str,
) -> None:
    """
    Test that only the appropriate hashlib functions are called
    """
    db = SqliteDB(
        System(
            Settings(
                migrations="apply",
                allow_reset=True,
                migrations_hash_algorithm=migrations_hash_algorithm,
            )
        )
    )

    # replace the real migration application call with a mock we can check
    db.apply_migration = MagicMock()  # type: ignore [method-assign]
    db.config = MagicMock()

    # we don't want `verify_migration_sequence` to actually run since a) we're not testing that functionality and
    # b) db may be cached between tests, and we're changing the algorithm, so it may fail.
    # Instead, return a fake unapplied migration (expect `apply_migration` to be called after)
    verify_migration_sequence_mock.return_value = ["unapplied_migration"]

    db.start()

    assert db.apply_migration.called

    # Check if the default hash algorith was set
    expected_migrations_hash_algorithm = migrations_hash_algorithm or "md5"

    # check that the right algorithm was used
    if expected_migrations_hash_algorithm == "md5":
        assert hashlib_mock.md5.called
        assert not hashlib_mock.sha256.called
    elif expected_migrations_hash_algorithm == "sha256":
        assert not hashlib_mock.md5.called
        assert hashlib_mock.sha256.called
    else:
        # we only support the algorithms above
        assert False