cold-email-generator
/
env
/Lib
/site-packages
/chromadb
/test
/configurations
/test_configurations.py
from overrides import overrides | |
import pytest | |
from chromadb.api.configuration import ( | |
ConfigurationInternal, | |
ConfigurationDefinition, | |
InvalidConfigurationError, | |
StaticParameterError, | |
ConfigurationParameter, | |
HNSWConfiguration, | |
) | |
class TestConfiguration(ConfigurationInternal): | |
definitions = { | |
"static_str_value": ConfigurationDefinition( | |
name="static_str_value", | |
validator=lambda value: isinstance(value, str), | |
is_static=True, | |
default_value="default", | |
), | |
"int_value": ConfigurationDefinition( | |
name="int_value", | |
validator=lambda value: isinstance(value, int), | |
is_static=False, | |
default_value=0, | |
), | |
} | |
def configuration_validator(self) -> None: | |
pass | |
def test_default_values() -> None: | |
default_test_configuration = TestConfiguration() | |
assert default_test_configuration.get_parameter("static_str_value") is not None | |
assert ( | |
default_test_configuration.get_parameter("static_str_value").value | |
== TestConfiguration.definitions["static_str_value"].default_value | |
) | |
assert default_test_configuration.get_parameter("static_str_value") is not None | |
assert ( | |
default_test_configuration.get_parameter("int_value").value | |
== TestConfiguration.definitions["int_value"].default_value | |
) | |
def test_set_values() -> None: | |
test_configuration = TestConfiguration() | |
with pytest.raises(StaticParameterError): | |
test_configuration.set_parameter("static_str_value", "new_value") | |
test_configuration.set_parameter("int_value", 1) | |
assert test_configuration.get_parameter("int_value").value == 1 | |
def test_get_invalid_parameter() -> None: | |
test_configuration = TestConfiguration() | |
with pytest.raises(ValueError): | |
test_configuration.get_parameter("invalid_name") | |
def test_validation() -> None: | |
valid_parameters = [ | |
ConfigurationParameter(name="static_str_value", value="valid_value"), | |
ConfigurationParameter(name="int_value", value=1), | |
] | |
valid_test_configuration = TestConfiguration(parameters=valid_parameters) | |
assert ( | |
valid_test_configuration.get_parameter("static_str_value").value | |
== "valid_value" | |
) | |
assert valid_test_configuration.get_parameter("int_value").value == 1 | |
invalid_parameter_values = [ | |
ConfigurationParameter(name="static_str_value", value=1.0) | |
] | |
with pytest.raises(ValueError): | |
TestConfiguration(parameters=invalid_parameter_values) | |
invalid_parameter_names = [ | |
ConfigurationParameter(name="invalid_name", value="some_value") | |
] | |
with pytest.raises(ValueError): | |
TestConfiguration(parameters=invalid_parameter_names) | |
def test_configuration_validation() -> None: | |
class FooConfiguration(ConfigurationInternal): | |
definitions = { | |
"foo": ConfigurationDefinition( | |
name="foo", | |
validator=lambda value: isinstance(value, str), | |
is_static=False, | |
default_value="default", | |
), | |
} | |
def configuration_validator(self) -> None: | |
if self.parameter_map.get("foo") != "bar": | |
raise InvalidConfigurationError("foo must be 'bar'") | |
with pytest.raises(ValueError, match="foo must be 'bar'"): | |
FooConfiguration(parameters=[ConfigurationParameter(name="foo", value="baz")]) | |
def test_hnsw_validation() -> None: | |
with pytest.raises(ValueError, match="must be less than or equal"): | |
HNSWConfiguration(batch_size=500, sync_threshold=100) | |