File size: 3,978 Bytes
a006afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import unittest
import os
from unittest.mock import patch, Mock
import pytest
import chromadb
import chromadb.config
from chromadb.db.system import SysDB
from chromadb.ingest import Consumer, Producer


class GetDBTest(unittest.TestCase):
    @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True)
    def test_default_db(self, mock: Mock) -> None:
        system = chromadb.config.System(
            chromadb.config.Settings(persist_directory="./foo")
        )
        system.instance(SysDB)
        assert mock.called

    @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True)
    def test_sqlite_sysdb(self, mock: Mock) -> None:
        system = chromadb.config.System(
            chromadb.config.Settings(
                chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
                persist_directory="./foo",
            )
        )
        system.instance(SysDB)
        assert mock.called

    @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True)
    def test_sqlite_queue(self, mock: Mock) -> None:
        system = chromadb.config.System(
            chromadb.config.Settings(
                chroma_sysdb_impl="chromadb.db.impl.sqlite.SqliteDB",
                chroma_producer_impl="chromadb.db.impl.sqlite.SqliteDB",
                chroma_consumer_impl="chromadb.db.impl.sqlite.SqliteDB",
                persist_directory="./foo",
            )
        )
        system.instance(Producer)
        system.instance(Consumer)
        assert mock.called


class GetAPITest(unittest.TestCase):
    @patch("chromadb.api.segment.SegmentAPI", autospec=True)
    @patch.dict(os.environ, {}, clear=True)
    def test_local(self, mock_api: Mock) -> None:
        chromadb.Client(chromadb.config.Settings(persist_directory="./foo"))
        assert mock_api.called

    @patch("chromadb.db.impl.sqlite.SqliteDB", autospec=True)
    @patch.dict(os.environ, {}, clear=True)
    def test_local_db(self, mock_db: Mock) -> None:
        chromadb.Client(chromadb.config.Settings(persist_directory="./foo"))
        assert mock_db.called

    @patch("chromadb.api.fastapi.FastAPI", autospec=True)
    @patch.dict(os.environ, {}, clear=True)
    def test_fastapi(self, mock: Mock) -> None:
        chromadb.Client(
            chromadb.config.Settings(
                chroma_api_impl="chromadb.api.fastapi.FastAPI",
                persist_directory="./foo",
                chroma_server_host="foo",
                chroma_server_http_port="80",
            )
        )
        assert mock.called

    @patch("chromadb.api.fastapi.FastAPI", autospec=True)
    @patch.dict(os.environ, {}, clear=True)
    def test_settings_pass_to_fastapi(self, mock: Mock) -> None:
        settings = chromadb.config.Settings(
            chroma_api_impl="chromadb.api.fastapi.FastAPI",
            chroma_server_host="foo",
            chroma_server_http_port="80",
            chroma_server_headers={"foo": "bar"},
        )
        chromadb.Client(settings)

        # Check that the mock was called
        assert mock.called

        # Retrieve the arguments with which the mock was called
        # `call_args` returns a tuple, where the first element is a tuple of positional arguments
        # and the second element is a dictionary of keyword arguments. We assume here that
        # the settings object is passed as a positional argument.
        args, kwargs = mock.call_args
        passed_settings = args[0] if args else None

        # Check if the settings passed to the mock match the settings we used
        # raise Exception(passed_settings.settings)
        assert passed_settings.settings == settings


def test_legacy_values() -> None:
    with pytest.raises(ValueError):
        chromadb.Client(
            chromadb.config.Settings(
                chroma_api_impl="chromadb.api.local.LocalAPI",
                persist_directory="./foo",
                chroma_server_host="foo",
                chroma_server_http_port="80",
            )
        )