Spaces:
Build error
Build error
from abc import abstractmethod | |
import json | |
from overrides import override | |
from typing import ( | |
Any, | |
ClassVar, | |
Dict, | |
List, | |
Optional, | |
Protocol, | |
Union, | |
TypeVar, | |
cast, | |
) | |
from typing_extensions import Self | |
from multiprocessing import cpu_count | |
from chromadb.serde import JSONSerializable | |
# TODO: move out of API | |
class StaticParameterError(Exception): | |
"""Represents an error that occurs when a static parameter is set.""" | |
pass | |
class InvalidConfigurationError(ValueError): | |
"""Represents an error that occurs when a configuration is invalid.""" | |
pass | |
ParameterValue = Union[str, int, float, bool, "ConfigurationInternal"] | |
class ParameterValidator(Protocol): | |
"""Represents an abstract parameter validator.""" | |
def __call__(self, value: ParameterValue) -> bool: | |
"""Returns whether the given value is valid.""" | |
raise NotImplementedError() | |
class ConfigurationDefinition: | |
"""Represents the definition of a configuration.""" | |
name: str | |
validator: ParameterValidator | |
is_static: bool | |
default_value: ParameterValue | |
def __init__( | |
self, | |
name: str, | |
validator: ParameterValidator, | |
is_static: bool, | |
default_value: ParameterValue, | |
): | |
self.name = name | |
self.validator = validator | |
self.is_static = is_static | |
self.default_value = default_value | |
class ConfigurationParameter: | |
"""Represents a parameter of a configuration.""" | |
name: str | |
value: ParameterValue | |
def __init__(self, name: str, value: ParameterValue): | |
self.name = name | |
self.value = value | |
def __repr__(self) -> str: | |
return f"ConfigurationParameter({self.name}, {self.value})" | |
def __eq__(self, __value: object) -> bool: | |
if not isinstance(__value, ConfigurationParameter): | |
return NotImplemented | |
return self.name == __value.name and self.value == __value.value | |
T = TypeVar("T", bound="ConfigurationInternal") | |
class ConfigurationInternal(JSONSerializable["ConfigurationInternal"]): | |
"""Represents an abstract configuration, used internally by Chroma.""" | |
# The internal data structure used to store the parameters | |
# All expected parameters must be present with defaults or None values at initialization | |
parameter_map: Dict[str, ConfigurationParameter] | |
definitions: ClassVar[Dict[str, ConfigurationDefinition]] | |
def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None): | |
"""Initializes a new instance of the Configuration class. Respecting defaults and | |
validators.""" | |
self.parameter_map = {} | |
if parameters is not None: | |
for parameter in parameters: | |
if parameter.name not in self.definitions: | |
raise ValueError(f"Invalid parameter name: {parameter.name}") | |
definition = self.definitions[parameter.name] | |
# Handle the case where we have a recursive configuration definition | |
if isinstance(parameter.value, dict): | |
child_type = globals().get(parameter.value.get("_type", None)) | |
if child_type is None: | |
raise ValueError( | |
f"Invalid configuration type: {parameter.value}" | |
) | |
parameter.value = child_type.from_json(parameter.value) | |
if not isinstance(parameter.value, type(definition.default_value)): | |
raise ValueError(f"Invalid parameter value: {parameter.value}") | |
parameter_validator = definition.validator | |
if not parameter_validator(parameter.value): | |
raise ValueError(f"Invalid parameter value: {parameter.value}") | |
self.parameter_map[parameter.name] = parameter | |
# Apply the defaults for any missing parameters | |
for name, definition in self.definitions.items(): | |
if name not in self.parameter_map: | |
self.parameter_map[name] = ConfigurationParameter( | |
name=name, value=definition.default_value | |
) | |
self.configuration_validator() | |
def __repr__(self) -> str: | |
return f"Configuration({self.parameter_map.values()})" | |
def __eq__(self, __value: object) -> bool: | |
if not isinstance(__value, ConfigurationInternal): | |
return NotImplemented | |
return self.parameter_map == __value.parameter_map | |
def configuration_validator(self) -> None: | |
"""Perform custom validation when parameters are dependent on each other. | |
Raises an InvalidConfigurationError if the configuration is invalid. | |
""" | |
pass | |
def get_parameters(self) -> List[ConfigurationParameter]: | |
"""Returns the parameters of the configuration.""" | |
return list(self.parameter_map.values()) | |
def get_parameter(self, name: str) -> ConfigurationParameter: | |
"""Returns the parameter with the given name, or except if it doesn't exist.""" | |
if name not in self.parameter_map: | |
raise ValueError( | |
f"Invalid parameter name: {name} for configuration {self.__class__.__name__}" | |
) | |
param_value = cast(ConfigurationParameter, self.parameter_map.get(name)) | |
return param_value | |
def set_parameter(self, name: str, value: Union[str, int, float, bool]) -> None: | |
"""Sets the parameter with the given name to the given value.""" | |
if name not in self.definitions: | |
raise ValueError(f"Invalid parameter name: {name}") | |
definition = self.definitions[name] | |
parameter = self.parameter_map[name] | |
if definition.is_static: | |
raise StaticParameterError(f"Cannot set static parameter: {name}") | |
if not definition.validator(value): | |
raise ValueError(f"Invalid value for parameter {name}: {value}") | |
parameter.value = value | |
def to_json_str(self) -> str: | |
"""Returns the JSON representation of the configuration.""" | |
return json.dumps(self.to_json()) | |
def from_json_str(cls, json_str: str) -> Self: | |
"""Returns a configuration from the given JSON string.""" | |
try: | |
config_json = json.loads(json_str) | |
except json.JSONDecodeError: | |
raise ValueError( | |
f"Unable to decode configuration from JSON string: {json_str}" | |
) | |
return cls.from_json(config_json) | |
def to_json(self) -> Dict[str, Any]: | |
"""Returns the JSON compatible dictionary representation of the configuration.""" | |
json_dict = { | |
name: parameter.value.to_json() | |
if isinstance(parameter.value, ConfigurationInternal) | |
else parameter.value | |
for name, parameter in self.parameter_map.items() | |
} | |
# What kind of configuration is this? | |
json_dict["_type"] = self.__class__.__name__ | |
return json_dict | |
def from_json(cls, json_map: Dict[str, Any]) -> Self: | |
"""Returns a configuration from the given JSON string.""" | |
if cls.__name__ != json_map.get("_type", None): | |
raise ValueError( | |
f"Trying to instantiate configuration of type {cls.__name__} from JSON with type {json_map['_type']}" | |
) | |
parameters = [] | |
for name, value in json_map.items(): | |
# Type value is only for storage | |
if name == "_type": | |
continue | |
parameters.append(ConfigurationParameter(name=name, value=value)) | |
return cls(parameters=parameters) | |
class HNSWConfigurationInternal(ConfigurationInternal): | |
"""Internal representation of the HNSW configuration. | |
Used for validation, defaults, serialization and deserialization.""" | |
definitions = { | |
"space": ConfigurationDefinition( | |
name="space", | |
validator=lambda value: isinstance(value, str) | |
and value in ["l2", "ip", "cosine"], | |
is_static=True, | |
default_value="l2", | |
), | |
"ef_construction": ConfigurationDefinition( | |
name="ef_construction", | |
validator=lambda value: isinstance(value, int) and value >= 1, | |
is_static=True, | |
default_value=100, | |
), | |
"ef_search": ConfigurationDefinition( | |
name="ef_search", | |
validator=lambda value: isinstance(value, int) and value >= 1, | |
is_static=False, | |
default_value=10, | |
), | |
"num_threads": ConfigurationDefinition( | |
name="num_threads", | |
validator=lambda value: isinstance(value, int) and value >= 1, | |
is_static=False, | |
default_value=cpu_count(), # By default use all cores available | |
), | |
"M": ConfigurationDefinition( | |
name="M", | |
validator=lambda value: isinstance(value, int) and value >= 1, | |
is_static=True, | |
default_value=16, | |
), | |
"resize_factor": ConfigurationDefinition( | |
name="resize_factor", | |
validator=lambda value: isinstance(value, float) and value >= 1, | |
is_static=True, | |
default_value=1.2, | |
), | |
"batch_size": ConfigurationDefinition( | |
name="batch_size", | |
validator=lambda value: isinstance(value, int) and value >= 1, | |
is_static=True, | |
default_value=100, | |
), | |
"sync_threshold": ConfigurationDefinition( | |
name="sync_threshold", | |
validator=lambda value: isinstance(value, int) and value >= 1, | |
is_static=True, | |
default_value=1000, | |
), | |
} | |
def configuration_validator(self) -> None: | |
batch_size = self.parameter_map.get("batch_size") | |
sync_threshold = self.parameter_map.get("sync_threshold") | |
if ( | |
batch_size | |
and sync_threshold | |
and cast(int, batch_size.value) > cast(int, sync_threshold.value) | |
): | |
raise InvalidConfigurationError( | |
"batch_size must be less than or equal to sync_threshold" | |
) | |
def from_legacy_params(cls, params: Dict[str, Any]) -> Self: | |
"""Returns an HNSWConfiguration from a metadata dict containing legacy HNSW parameters. Used for migration.""" | |
# We maintain this map to avoid a circular import with HnswParams, and | |
# because then names won't change since we intend to deprecate HNSWParams | |
# in favor of this type of configuration. | |
old_to_new = { | |
"hnsw:space": "space", | |
"hnsw:construction_ef": "ef_construction", | |
"hnsw:search_ef": "ef_search", | |
"hnsw:M": "M", | |
"hnsw:num_threads": "num_threads", | |
"hnsw:resize_factor": "resize_factor", | |
"hnsw:batch_size": "batch_size", | |
"hnsw:sync_threshold": "sync_threshold", | |
} | |
parameters = [] | |
for name, value in params.items(): | |
if name not in old_to_new: | |
raise ValueError(f"Invalid legacy HNSW parameter name: {name}") | |
parameters.append( | |
ConfigurationParameter(name=old_to_new[name], value=value) | |
) | |
return cls(parameters) | |
# This is the user-facing interface for HNSW index configuration parameters. | |
# Internally, we pass around HNSWConfigurationInternal objects, which perform | |
# validation, serialization and deserialization. Users don't need to know | |
# about that and instead get a clean constructor with default arguments. | |
class HNSWConfigurationInterface(HNSWConfigurationInternal): | |
"""HNSW index configuration parameters. | |
See https://docs.trychroma.com/guides#changing-the-distance-function for more information. | |
""" | |
def __init__( | |
self, | |
space: str = "l2", | |
ef_construction: int = 100, | |
ef_search: int = 10, | |
num_threads: int = cpu_count(), | |
M: int = 16, | |
resize_factor: float = 1.2, | |
batch_size: int = 100, | |
sync_threshold: int = 1000, | |
): | |
parameters = [ | |
ConfigurationParameter(name="space", value=space), | |
ConfigurationParameter(name="ef_construction", value=ef_construction), | |
ConfigurationParameter(name="ef_search", value=ef_search), | |
ConfigurationParameter(name="num_threads", value=num_threads), | |
ConfigurationParameter(name="M", value=M), | |
ConfigurationParameter(name="resize_factor", value=resize_factor), | |
ConfigurationParameter(name="batch_size", value=batch_size), | |
ConfigurationParameter(name="sync_threshold", value=sync_threshold), | |
] | |
super().__init__(parameters=parameters) | |
# Alias for user convenience - the user doesn't need to know this is an 'Interface' | |
HNSWConfiguration = HNSWConfigurationInterface | |
class CollectionConfigurationInternal(ConfigurationInternal): | |
"""Internal representation of the collection configuration. | |
Used for validation, defaults, and serialization / deserialization.""" | |
definitions = { | |
"hnsw_configuration": ConfigurationDefinition( | |
name="hnsw_configuration", | |
validator=lambda value: isinstance(value, HNSWConfigurationInternal), | |
is_static=True, | |
default_value=HNSWConfigurationInternal(), | |
), | |
} | |
def configuration_validator(self) -> None: | |
pass | |
# This is the user-facing interface for HNSW index configuration parameters. | |
# Internally, we pass around HNSWConfigurationInternal objects, which perform | |
# validation, serialization and deserialization. Users don't need to know | |
# about that and instead get a clean constructor with default arguments. | |
class CollectionConfigurationInterface(CollectionConfigurationInternal): | |
"""Configuration parameters for creating a collection.""" | |
def __init__(self, hnsw_configuration: Optional[HNSWConfigurationInternal]): | |
"""Initializes a new instance of the CollectionConfiguration class. | |
Args: | |
hnsw_configuration: The HNSW configuration to use for the collection. | |
""" | |
if hnsw_configuration is None: | |
hnsw_configuration = HNSWConfigurationInternal() | |
parameters = [ | |
ConfigurationParameter(name="hnsw_configuration", value=hnsw_configuration) | |
] | |
super().__init__(parameters=parameters) | |
# Alias for user convenience - the user doesn't need to know this is an 'Interface'. | |
CollectionConfiguration = CollectionConfigurationInterface | |
class EmbeddingsQueueConfigurationInternal(ConfigurationInternal): | |
definitions = { | |
"automatically_purge": ConfigurationDefinition( | |
name="automatically_purge", | |
validator=lambda value: isinstance(value, bool), | |
is_static=False, | |
default_value=True, | |
), | |
} | |
def configuration_validator(self) -> None: | |
pass | |