File size: 3,680 Bytes
60e3a80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
from overrides import overrides
import pytest
from chromadb.api.configuration import (
ConfigurationInternal,
ConfigurationDefinition,
InvalidConfigurationError,
StaticParameterError,
ConfigurationParameter,
HNSWConfiguration,
)
class TestConfiguration(ConfigurationInternal):
definitions = {
"static_str_value": ConfigurationDefinition(
name="static_str_value",
validator=lambda value: isinstance(value, str),
is_static=True,
default_value="default",
),
"int_value": ConfigurationDefinition(
name="int_value",
validator=lambda value: isinstance(value, int),
is_static=False,
default_value=0,
),
}
@overrides
def configuration_validator(self) -> None:
pass
def test_default_values() -> None:
default_test_configuration = TestConfiguration()
assert default_test_configuration.get_parameter("static_str_value") is not None
assert (
default_test_configuration.get_parameter("static_str_value").value
== TestConfiguration.definitions["static_str_value"].default_value
)
assert default_test_configuration.get_parameter("static_str_value") is not None
assert (
default_test_configuration.get_parameter("int_value").value
== TestConfiguration.definitions["int_value"].default_value
)
def test_set_values() -> None:
test_configuration = TestConfiguration()
with pytest.raises(StaticParameterError):
test_configuration.set_parameter("static_str_value", "new_value")
test_configuration.set_parameter("int_value", 1)
assert test_configuration.get_parameter("int_value").value == 1
def test_get_invalid_parameter() -> None:
test_configuration = TestConfiguration()
with pytest.raises(ValueError):
test_configuration.get_parameter("invalid_name")
def test_validation() -> None:
valid_parameters = [
ConfigurationParameter(name="static_str_value", value="valid_value"),
ConfigurationParameter(name="int_value", value=1),
]
valid_test_configuration = TestConfiguration(parameters=valid_parameters)
assert (
valid_test_configuration.get_parameter("static_str_value").value
== "valid_value"
)
assert valid_test_configuration.get_parameter("int_value").value == 1
invalid_parameter_values = [
ConfigurationParameter(name="static_str_value", value=1.0)
]
with pytest.raises(ValueError):
TestConfiguration(parameters=invalid_parameter_values)
invalid_parameter_names = [
ConfigurationParameter(name="invalid_name", value="some_value")
]
with pytest.raises(ValueError):
TestConfiguration(parameters=invalid_parameter_names)
def test_configuration_validation() -> None:
class FooConfiguration(ConfigurationInternal):
definitions = {
"foo": ConfigurationDefinition(
name="foo",
validator=lambda value: isinstance(value, str),
is_static=False,
default_value="default",
),
}
@overrides
def configuration_validator(self) -> None:
if self.parameter_map.get("foo") != "bar":
raise InvalidConfigurationError("foo must be 'bar'")
with pytest.raises(ValueError, match="foo must be 'bar'"):
FooConfiguration(parameters=[ConfigurationParameter(name="foo", value="baz")])
def test_hnsw_validation() -> None:
with pytest.raises(ValueError, match="must be less than or equal"):
HNSWConfiguration(batch_size=500, sync_threshold=100)
|