|
from abc import abstractmethod |
|
import json |
|
from overrides import override |
|
from typing import ( |
|
Any, |
|
ClassVar, |
|
Dict, |
|
List, |
|
Optional, |
|
Protocol, |
|
Union, |
|
TypeVar, |
|
cast, |
|
) |
|
from typing_extensions import Self |
|
from multiprocessing import cpu_count |
|
|
|
from chromadb.serde import JSONSerializable |
|
|
|
|
|
|
|
|
|
class StaticParameterError(Exception): |
|
"""Represents an error that occurs when a static parameter is set.""" |
|
|
|
pass |
|
|
|
|
|
class InvalidConfigurationError(ValueError): |
|
"""Represents an error that occurs when a configuration is invalid.""" |
|
|
|
pass |
|
|
|
|
|
ParameterValue = Union[str, int, float, bool, "ConfigurationInternal"] |
|
|
|
|
|
class ParameterValidator(Protocol): |
|
"""Represents an abstract parameter validator.""" |
|
|
|
@abstractmethod |
|
def __call__(self, value: ParameterValue) -> bool: |
|
"""Returns whether the given value is valid.""" |
|
raise NotImplementedError() |
|
|
|
|
|
class ConfigurationDefinition: |
|
"""Represents the definition of a configuration.""" |
|
|
|
name: str |
|
validator: ParameterValidator |
|
is_static: bool |
|
default_value: ParameterValue |
|
|
|
def __init__( |
|
self, |
|
name: str, |
|
validator: ParameterValidator, |
|
is_static: bool, |
|
default_value: ParameterValue, |
|
): |
|
self.name = name |
|
self.validator = validator |
|
self.is_static = is_static |
|
self.default_value = default_value |
|
|
|
|
|
class ConfigurationParameter: |
|
"""Represents a parameter of a configuration.""" |
|
|
|
name: str |
|
value: ParameterValue |
|
|
|
def __init__(self, name: str, value: ParameterValue): |
|
self.name = name |
|
self.value = value |
|
|
|
def __repr__(self) -> str: |
|
return f"ConfigurationParameter({self.name}, {self.value})" |
|
|
|
def __eq__(self, __value: object) -> bool: |
|
if not isinstance(__value, ConfigurationParameter): |
|
return NotImplemented |
|
return self.name == __value.name and self.value == __value.value |
|
|
|
|
|
T = TypeVar("T", bound="ConfigurationInternal") |
|
|
|
|
|
class ConfigurationInternal(JSONSerializable["ConfigurationInternal"]): |
|
"""Represents an abstract configuration, used internally by Chroma.""" |
|
|
|
|
|
|
|
parameter_map: Dict[str, ConfigurationParameter] |
|
definitions: ClassVar[Dict[str, ConfigurationDefinition]] |
|
|
|
def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None): |
|
"""Initializes a new instance of the Configuration class. Respecting defaults and |
|
validators.""" |
|
self.parameter_map = {} |
|
if parameters is not None: |
|
for parameter in parameters: |
|
if parameter.name not in self.definitions: |
|
raise ValueError(f"Invalid parameter name: {parameter.name}") |
|
|
|
definition = self.definitions[parameter.name] |
|
|
|
if isinstance(parameter.value, dict): |
|
child_type = globals().get(parameter.value.get("_type", None)) |
|
if child_type is None: |
|
raise ValueError( |
|
f"Invalid configuration type: {parameter.value}" |
|
) |
|
parameter.value = child_type.from_json(parameter.value) |
|
if not isinstance(parameter.value, type(definition.default_value)): |
|
raise ValueError(f"Invalid parameter value: {parameter.value}") |
|
|
|
parameter_validator = definition.validator |
|
if not parameter_validator(parameter.value): |
|
raise ValueError(f"Invalid parameter value: {parameter.value}") |
|
self.parameter_map[parameter.name] = parameter |
|
|
|
for name, definition in self.definitions.items(): |
|
if name not in self.parameter_map: |
|
self.parameter_map[name] = ConfigurationParameter( |
|
name=name, value=definition.default_value |
|
) |
|
|
|
self.configuration_validator() |
|
|
|
def __repr__(self) -> str: |
|
return f"Configuration({self.parameter_map.values()})" |
|
|
|
def __eq__(self, __value: object) -> bool: |
|
if not isinstance(__value, ConfigurationInternal): |
|
return NotImplemented |
|
return self.parameter_map == __value.parameter_map |
|
|
|
@abstractmethod |
|
def configuration_validator(self) -> None: |
|
"""Perform custom validation when parameters are dependent on each other. |
|
|
|
Raises an InvalidConfigurationError if the configuration is invalid. |
|
""" |
|
pass |
|
|
|
def get_parameters(self) -> List[ConfigurationParameter]: |
|
"""Returns the parameters of the configuration.""" |
|
return list(self.parameter_map.values()) |
|
|
|
def get_parameter(self, name: str) -> ConfigurationParameter: |
|
"""Returns the parameter with the given name, or except if it doesn't exist.""" |
|
if name not in self.parameter_map: |
|
raise ValueError( |
|
f"Invalid parameter name: {name} for configuration {self.__class__.__name__}" |
|
) |
|
param_value = cast(ConfigurationParameter, self.parameter_map.get(name)) |
|
return param_value |
|
|
|
def set_parameter(self, name: str, value: Union[str, int, float, bool]) -> None: |
|
"""Sets the parameter with the given name to the given value.""" |
|
if name not in self.definitions: |
|
raise ValueError(f"Invalid parameter name: {name}") |
|
definition = self.definitions[name] |
|
parameter = self.parameter_map[name] |
|
if definition.is_static: |
|
raise StaticParameterError(f"Cannot set static parameter: {name}") |
|
if not definition.validator(value): |
|
raise ValueError(f"Invalid value for parameter {name}: {value}") |
|
parameter.value = value |
|
|
|
@override |
|
def to_json_str(self) -> str: |
|
"""Returns the JSON representation of the configuration.""" |
|
return json.dumps(self.to_json()) |
|
|
|
@classmethod |
|
@override |
|
def from_json_str(cls, json_str: str) -> Self: |
|
"""Returns a configuration from the given JSON string.""" |
|
try: |
|
config_json = json.loads(json_str) |
|
except json.JSONDecodeError: |
|
raise ValueError( |
|
f"Unable to decode configuration from JSON string: {json_str}" |
|
) |
|
return cls.from_json(config_json) |
|
|
|
@override |
|
def to_json(self) -> Dict[str, Any]: |
|
"""Returns the JSON compatible dictionary representation of the configuration.""" |
|
json_dict = { |
|
name: parameter.value.to_json() |
|
if isinstance(parameter.value, ConfigurationInternal) |
|
else parameter.value |
|
for name, parameter in self.parameter_map.items() |
|
} |
|
|
|
json_dict["_type"] = self.__class__.__name__ |
|
return json_dict |
|
|
|
@classmethod |
|
@override |
|
def from_json(cls, json_map: Dict[str, Any]) -> Self: |
|
"""Returns a configuration from the given JSON string.""" |
|
if cls.__name__ != json_map.get("_type", None): |
|
raise ValueError( |
|
f"Trying to instantiate configuration of type {cls.__name__} from JSON with type {json_map['_type']}" |
|
) |
|
parameters = [] |
|
for name, value in json_map.items(): |
|
|
|
if name == "_type": |
|
continue |
|
parameters.append(ConfigurationParameter(name=name, value=value)) |
|
return cls(parameters=parameters) |
|
|
|
|
|
class HNSWConfigurationInternal(ConfigurationInternal): |
|
"""Internal representation of the HNSW configuration. |
|
Used for validation, defaults, serialization and deserialization.""" |
|
|
|
definitions = { |
|
"space": ConfigurationDefinition( |
|
name="space", |
|
validator=lambda value: isinstance(value, str) |
|
and value in ["l2", "ip", "cosine"], |
|
is_static=True, |
|
default_value="l2", |
|
), |
|
"ef_construction": ConfigurationDefinition( |
|
name="ef_construction", |
|
validator=lambda value: isinstance(value, int) and value >= 1, |
|
is_static=True, |
|
default_value=100, |
|
), |
|
"ef_search": ConfigurationDefinition( |
|
name="ef_search", |
|
validator=lambda value: isinstance(value, int) and value >= 1, |
|
is_static=False, |
|
default_value=10, |
|
), |
|
"num_threads": ConfigurationDefinition( |
|
name="num_threads", |
|
validator=lambda value: isinstance(value, int) and value >= 1, |
|
is_static=False, |
|
default_value=cpu_count(), |
|
), |
|
"M": ConfigurationDefinition( |
|
name="M", |
|
validator=lambda value: isinstance(value, int) and value >= 1, |
|
is_static=True, |
|
default_value=16, |
|
), |
|
"resize_factor": ConfigurationDefinition( |
|
name="resize_factor", |
|
validator=lambda value: isinstance(value, float) and value >= 1, |
|
is_static=True, |
|
default_value=1.2, |
|
), |
|
"batch_size": ConfigurationDefinition( |
|
name="batch_size", |
|
validator=lambda value: isinstance(value, int) and value >= 1, |
|
is_static=True, |
|
default_value=100, |
|
), |
|
"sync_threshold": ConfigurationDefinition( |
|
name="sync_threshold", |
|
validator=lambda value: isinstance(value, int) and value >= 1, |
|
is_static=True, |
|
default_value=1000, |
|
), |
|
} |
|
|
|
@override |
|
def configuration_validator(self) -> None: |
|
batch_size = self.parameter_map.get("batch_size") |
|
sync_threshold = self.parameter_map.get("sync_threshold") |
|
|
|
if ( |
|
batch_size |
|
and sync_threshold |
|
and cast(int, batch_size.value) > cast(int, sync_threshold.value) |
|
): |
|
raise InvalidConfigurationError( |
|
"batch_size must be less than or equal to sync_threshold" |
|
) |
|
|
|
@classmethod |
|
def from_legacy_params(cls, params: Dict[str, Any]) -> Self: |
|
"""Returns an HNSWConfiguration from a metadata dict containing legacy HNSW parameters. Used for migration.""" |
|
|
|
|
|
|
|
|
|
old_to_new = { |
|
"hnsw:space": "space", |
|
"hnsw:construction_ef": "ef_construction", |
|
"hnsw:search_ef": "ef_search", |
|
"hnsw:M": "M", |
|
"hnsw:num_threads": "num_threads", |
|
"hnsw:resize_factor": "resize_factor", |
|
"hnsw:batch_size": "batch_size", |
|
"hnsw:sync_threshold": "sync_threshold", |
|
} |
|
|
|
parameters = [] |
|
for name, value in params.items(): |
|
if name not in old_to_new: |
|
raise ValueError(f"Invalid legacy HNSW parameter name: {name}") |
|
parameters.append( |
|
ConfigurationParameter(name=old_to_new[name], value=value) |
|
) |
|
return cls(parameters) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class HNSWConfigurationInterface(HNSWConfigurationInternal): |
|
"""HNSW index configuration parameters. |
|
See https://docs.trychroma.com/guides#changing-the-distance-function for more information. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
space: str = "l2", |
|
ef_construction: int = 100, |
|
ef_search: int = 10, |
|
num_threads: int = cpu_count(), |
|
M: int = 16, |
|
resize_factor: float = 1.2, |
|
batch_size: int = 100, |
|
sync_threshold: int = 1000, |
|
): |
|
parameters = [ |
|
ConfigurationParameter(name="space", value=space), |
|
ConfigurationParameter(name="ef_construction", value=ef_construction), |
|
ConfigurationParameter(name="ef_search", value=ef_search), |
|
ConfigurationParameter(name="num_threads", value=num_threads), |
|
ConfigurationParameter(name="M", value=M), |
|
ConfigurationParameter(name="resize_factor", value=resize_factor), |
|
ConfigurationParameter(name="batch_size", value=batch_size), |
|
ConfigurationParameter(name="sync_threshold", value=sync_threshold), |
|
] |
|
|
|
super().__init__(parameters=parameters) |
|
|
|
|
|
|
|
HNSWConfiguration = HNSWConfigurationInterface |
|
|
|
|
|
class CollectionConfigurationInternal(ConfigurationInternal): |
|
"""Internal representation of the collection configuration. |
|
Used for validation, defaults, and serialization / deserialization.""" |
|
|
|
definitions = { |
|
"hnsw_configuration": ConfigurationDefinition( |
|
name="hnsw_configuration", |
|
validator=lambda value: isinstance(value, HNSWConfigurationInternal), |
|
is_static=True, |
|
default_value=HNSWConfigurationInternal(), |
|
), |
|
} |
|
|
|
@override |
|
def configuration_validator(self) -> None: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
class CollectionConfigurationInterface(CollectionConfigurationInternal): |
|
"""Configuration parameters for creating a collection.""" |
|
|
|
def __init__(self, hnsw_configuration: Optional[HNSWConfigurationInternal]): |
|
"""Initializes a new instance of the CollectionConfiguration class. |
|
Args: |
|
hnsw_configuration: The HNSW configuration to use for the collection. |
|
""" |
|
if hnsw_configuration is None: |
|
hnsw_configuration = HNSWConfigurationInternal() |
|
parameters = [ |
|
ConfigurationParameter(name="hnsw_configuration", value=hnsw_configuration) |
|
] |
|
super().__init__(parameters=parameters) |
|
|
|
|
|
|
|
CollectionConfiguration = CollectionConfigurationInterface |
|
|
|
|
|
class EmbeddingsQueueConfigurationInternal(ConfigurationInternal): |
|
definitions = { |
|
"automatically_purge": ConfigurationDefinition( |
|
name="automatically_purge", |
|
validator=lambda value: isinstance(value, bool), |
|
is_static=False, |
|
default_value=True, |
|
), |
|
} |
|
|
|
@override |
|
def configuration_validator(self) -> None: |
|
pass |
|
|