himanshud2611's picture
Upload folder using huggingface_hub
60e3a80 verified
from abc import abstractmethod
import json
from overrides import override
from typing import (
Any,
ClassVar,
Dict,
List,
Optional,
Protocol,
Union,
TypeVar,
cast,
)
from typing_extensions import Self
from multiprocessing import cpu_count
from chromadb.serde import JSONSerializable
# TODO: move out of API
class StaticParameterError(Exception):
"""Represents an error that occurs when a static parameter is set."""
pass
class InvalidConfigurationError(ValueError):
"""Represents an error that occurs when a configuration is invalid."""
pass
ParameterValue = Union[str, int, float, bool, "ConfigurationInternal"]
class ParameterValidator(Protocol):
"""Represents an abstract parameter validator."""
@abstractmethod
def __call__(self, value: ParameterValue) -> bool:
"""Returns whether the given value is valid."""
raise NotImplementedError()
class ConfigurationDefinition:
"""Represents the definition of a configuration."""
name: str
validator: ParameterValidator
is_static: bool
default_value: ParameterValue
def __init__(
self,
name: str,
validator: ParameterValidator,
is_static: bool,
default_value: ParameterValue,
):
self.name = name
self.validator = validator
self.is_static = is_static
self.default_value = default_value
class ConfigurationParameter:
"""Represents a parameter of a configuration."""
name: str
value: ParameterValue
def __init__(self, name: str, value: ParameterValue):
self.name = name
self.value = value
def __repr__(self) -> str:
return f"ConfigurationParameter({self.name}, {self.value})"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, ConfigurationParameter):
return NotImplemented
return self.name == __value.name and self.value == __value.value
T = TypeVar("T", bound="ConfigurationInternal")
class ConfigurationInternal(JSONSerializable["ConfigurationInternal"]):
"""Represents an abstract configuration, used internally by Chroma."""
# The internal data structure used to store the parameters
# All expected parameters must be present with defaults or None values at initialization
parameter_map: Dict[str, ConfigurationParameter]
definitions: ClassVar[Dict[str, ConfigurationDefinition]]
def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None):
"""Initializes a new instance of the Configuration class. Respecting defaults and
validators."""
self.parameter_map = {}
if parameters is not None:
for parameter in parameters:
if parameter.name not in self.definitions:
raise ValueError(f"Invalid parameter name: {parameter.name}")
definition = self.definitions[parameter.name]
# Handle the case where we have a recursive configuration definition
if isinstance(parameter.value, dict):
child_type = globals().get(parameter.value.get("_type", None))
if child_type is None:
raise ValueError(
f"Invalid configuration type: {parameter.value}"
)
parameter.value = child_type.from_json(parameter.value)
if not isinstance(parameter.value, type(definition.default_value)):
raise ValueError(f"Invalid parameter value: {parameter.value}")
parameter_validator = definition.validator
if not parameter_validator(parameter.value):
raise ValueError(f"Invalid parameter value: {parameter.value}")
self.parameter_map[parameter.name] = parameter
# Apply the defaults for any missing parameters
for name, definition in self.definitions.items():
if name not in self.parameter_map:
self.parameter_map[name] = ConfigurationParameter(
name=name, value=definition.default_value
)
self.configuration_validator()
def __repr__(self) -> str:
return f"Configuration({self.parameter_map.values()})"
def __eq__(self, __value: object) -> bool:
if not isinstance(__value, ConfigurationInternal):
return NotImplemented
return self.parameter_map == __value.parameter_map
@abstractmethod
def configuration_validator(self) -> None:
"""Perform custom validation when parameters are dependent on each other.
Raises an InvalidConfigurationError if the configuration is invalid.
"""
pass
def get_parameters(self) -> List[ConfigurationParameter]:
"""Returns the parameters of the configuration."""
return list(self.parameter_map.values())
def get_parameter(self, name: str) -> ConfigurationParameter:
"""Returns the parameter with the given name, or except if it doesn't exist."""
if name not in self.parameter_map:
raise ValueError(
f"Invalid parameter name: {name} for configuration {self.__class__.__name__}"
)
param_value = cast(ConfigurationParameter, self.parameter_map.get(name))
return param_value
def set_parameter(self, name: str, value: Union[str, int, float, bool]) -> None:
"""Sets the parameter with the given name to the given value."""
if name not in self.definitions:
raise ValueError(f"Invalid parameter name: {name}")
definition = self.definitions[name]
parameter = self.parameter_map[name]
if definition.is_static:
raise StaticParameterError(f"Cannot set static parameter: {name}")
if not definition.validator(value):
raise ValueError(f"Invalid value for parameter {name}: {value}")
parameter.value = value
@override
def to_json_str(self) -> str:
"""Returns the JSON representation of the configuration."""
return json.dumps(self.to_json())
@classmethod
@override
def from_json_str(cls, json_str: str) -> Self:
"""Returns a configuration from the given JSON string."""
try:
config_json = json.loads(json_str)
except json.JSONDecodeError:
raise ValueError(
f"Unable to decode configuration from JSON string: {json_str}"
)
return cls.from_json(config_json)
@override
def to_json(self) -> Dict[str, Any]:
"""Returns the JSON compatible dictionary representation of the configuration."""
json_dict = {
name: parameter.value.to_json()
if isinstance(parameter.value, ConfigurationInternal)
else parameter.value
for name, parameter in self.parameter_map.items()
}
# What kind of configuration is this?
json_dict["_type"] = self.__class__.__name__
return json_dict
@classmethod
@override
def from_json(cls, json_map: Dict[str, Any]) -> Self:
"""Returns a configuration from the given JSON string."""
if cls.__name__ != json_map.get("_type", None):
raise ValueError(
f"Trying to instantiate configuration of type {cls.__name__} from JSON with type {json_map['_type']}"
)
parameters = []
for name, value in json_map.items():
# Type value is only for storage
if name == "_type":
continue
parameters.append(ConfigurationParameter(name=name, value=value))
return cls(parameters=parameters)
class HNSWConfigurationInternal(ConfigurationInternal):
"""Internal representation of the HNSW configuration.
Used for validation, defaults, serialization and deserialization."""
definitions = {
"space": ConfigurationDefinition(
name="space",
validator=lambda value: isinstance(value, str)
and value in ["l2", "ip", "cosine"],
is_static=True,
default_value="l2",
),
"ef_construction": ConfigurationDefinition(
name="ef_construction",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=100,
),
"ef_search": ConfigurationDefinition(
name="ef_search",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=False,
default_value=10,
),
"num_threads": ConfigurationDefinition(
name="num_threads",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=False,
default_value=cpu_count(), # By default use all cores available
),
"M": ConfigurationDefinition(
name="M",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=16,
),
"resize_factor": ConfigurationDefinition(
name="resize_factor",
validator=lambda value: isinstance(value, float) and value >= 1,
is_static=True,
default_value=1.2,
),
"batch_size": ConfigurationDefinition(
name="batch_size",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=100,
),
"sync_threshold": ConfigurationDefinition(
name="sync_threshold",
validator=lambda value: isinstance(value, int) and value >= 1,
is_static=True,
default_value=1000,
),
}
@override
def configuration_validator(self) -> None:
batch_size = self.parameter_map.get("batch_size")
sync_threshold = self.parameter_map.get("sync_threshold")
if (
batch_size
and sync_threshold
and cast(int, batch_size.value) > cast(int, sync_threshold.value)
):
raise InvalidConfigurationError(
"batch_size must be less than or equal to sync_threshold"
)
@classmethod
def from_legacy_params(cls, params: Dict[str, Any]) -> Self:
"""Returns an HNSWConfiguration from a metadata dict containing legacy HNSW parameters. Used for migration."""
# We maintain this map to avoid a circular import with HnswParams, and
# because then names won't change since we intend to deprecate HNSWParams
# in favor of this type of configuration.
old_to_new = {
"hnsw:space": "space",
"hnsw:construction_ef": "ef_construction",
"hnsw:search_ef": "ef_search",
"hnsw:M": "M",
"hnsw:num_threads": "num_threads",
"hnsw:resize_factor": "resize_factor",
"hnsw:batch_size": "batch_size",
"hnsw:sync_threshold": "sync_threshold",
}
parameters = []
for name, value in params.items():
if name not in old_to_new:
raise ValueError(f"Invalid legacy HNSW parameter name: {name}")
parameters.append(
ConfigurationParameter(name=old_to_new[name], value=value)
)
return cls(parameters)
# This is the user-facing interface for HNSW index configuration parameters.
# Internally, we pass around HNSWConfigurationInternal objects, which perform
# validation, serialization and deserialization. Users don't need to know
# about that and instead get a clean constructor with default arguments.
class HNSWConfigurationInterface(HNSWConfigurationInternal):
"""HNSW index configuration parameters.
See https://docs.trychroma.com/guides#changing-the-distance-function for more information.
"""
def __init__(
self,
space: str = "l2",
ef_construction: int = 100,
ef_search: int = 10,
num_threads: int = cpu_count(),
M: int = 16,
resize_factor: float = 1.2,
batch_size: int = 100,
sync_threshold: int = 1000,
):
parameters = [
ConfigurationParameter(name="space", value=space),
ConfigurationParameter(name="ef_construction", value=ef_construction),
ConfigurationParameter(name="ef_search", value=ef_search),
ConfigurationParameter(name="num_threads", value=num_threads),
ConfigurationParameter(name="M", value=M),
ConfigurationParameter(name="resize_factor", value=resize_factor),
ConfigurationParameter(name="batch_size", value=batch_size),
ConfigurationParameter(name="sync_threshold", value=sync_threshold),
]
super().__init__(parameters=parameters)
# Alias for user convenience - the user doesn't need to know this is an 'Interface'
HNSWConfiguration = HNSWConfigurationInterface
class CollectionConfigurationInternal(ConfigurationInternal):
"""Internal representation of the collection configuration.
Used for validation, defaults, and serialization / deserialization."""
definitions = {
"hnsw_configuration": ConfigurationDefinition(
name="hnsw_configuration",
validator=lambda value: isinstance(value, HNSWConfigurationInternal),
is_static=True,
default_value=HNSWConfigurationInternal(),
),
}
@override
def configuration_validator(self) -> None:
pass
# This is the user-facing interface for HNSW index configuration parameters.
# Internally, we pass around HNSWConfigurationInternal objects, which perform
# validation, serialization and deserialization. Users don't need to know
# about that and instead get a clean constructor with default arguments.
class CollectionConfigurationInterface(CollectionConfigurationInternal):
"""Configuration parameters for creating a collection."""
def __init__(self, hnsw_configuration: Optional[HNSWConfigurationInternal]):
"""Initializes a new instance of the CollectionConfiguration class.
Args:
hnsw_configuration: The HNSW configuration to use for the collection.
"""
if hnsw_configuration is None:
hnsw_configuration = HNSWConfigurationInternal()
parameters = [
ConfigurationParameter(name="hnsw_configuration", value=hnsw_configuration)
]
super().__init__(parameters=parameters)
# Alias for user convenience - the user doesn't need to know this is an 'Interface'.
CollectionConfiguration = CollectionConfigurationInterface
class EmbeddingsQueueConfigurationInternal(ConfigurationInternal):
definitions = {
"automatically_purge": ConfigurationDefinition(
name="automatically_purge",
validator=lambda value: isinstance(value, bool),
is_static=False,
default_value=True,
),
}
@override
def configuration_validator(self) -> None:
pass