|
|
|
|
|
|
|
import copy |
|
import dataclasses |
|
from pathlib import Path |
|
from typing import Any, Dict, Optional, Type, TypeVar |
|
|
|
import yaml |
|
|
|
from . import util |
|
|
|
__all__ = ['ConfigBase', 'PathLike'] |
|
|
|
T = TypeVar('T', bound='ConfigBase') |
|
|
|
PathLike = util.PathLike |
|
|
|
def _is_missing(obj: Any) -> bool: |
|
return isinstance(obj, type(dataclasses.MISSING)) |
|
|
|
class ConfigBase: |
|
""" |
|
Base class of config classes. |
|
Subclass may override `_canonical_rules` and `_validation_rules`, |
|
and `validate()` if the logic is complex. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
_canonical_rules = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_validation_rules = {} |
|
|
|
def __init__(self, *, _base_path: Optional[Path] = None, **kwargs): |
|
""" |
|
Initialize a config object and set some fields. |
|
Name of keyword arguments can either be snake_case or camelCase. |
|
They will be converted to snake_case automatically. |
|
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`. |
|
""" |
|
if 'basepath' in kwargs: |
|
_base_path = kwargs.pop('basepath') |
|
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()} |
|
if _base_path is None: |
|
_base_path = Path() |
|
for field in dataclasses.fields(self): |
|
value = kwargs.pop(util.case_insensitive(field.name), field.default) |
|
if value is not None and not _is_missing(value): |
|
|
|
if 'Path' in str(field.type): |
|
value = Path(value).expanduser() |
|
if not value.is_absolute(): |
|
value = _base_path / value |
|
setattr(self, field.name, value) |
|
if kwargs: |
|
cls = type(self).__name__ |
|
fields = ', '.join(kwargs.keys()) |
|
raise ValueError(f'{cls}: Unrecognized fields {fields}') |
|
|
|
@classmethod |
|
def load(cls: Type[T], path: PathLike) -> T: |
|
""" |
|
Load config from YAML (or JSON) file. |
|
Keys in YAML file can either be camelCase or snake_case. |
|
""" |
|
data = yaml.safe_load(open(path)) |
|
if not isinstance(data, dict): |
|
raise ValueError(f'Content of config file {path} is not a dict/object') |
|
return cls(**data, _base_path=Path(path).parent) |
|
|
|
def json(self) -> Dict[str, Any]: |
|
""" |
|
Convert config to JSON object. |
|
The keys of returned object will be camelCase. |
|
""" |
|
self.validate() |
|
return dataclasses.asdict( |
|
self.canonical(), |
|
dict_factory=lambda items: dict((util.camel_case(k), v) for k, v in items if v is not None) |
|
) |
|
|
|
def canonical(self: T) -> T: |
|
""" |
|
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format. |
|
Noticeably, relative path may be converted to absolute path. |
|
""" |
|
ret = copy.deepcopy(self) |
|
for field in dataclasses.fields(ret): |
|
key, value = field.name, getattr(ret, field.name) |
|
rule = ret._canonical_rules.get(key) |
|
if rule is not None: |
|
setattr(ret, key, rule(value)) |
|
elif isinstance(value, ConfigBase): |
|
setattr(ret, key, value.canonical()) |
|
|
|
elif isinstance(value, Path): |
|
setattr(ret, key, str(value)) |
|
return ret |
|
|
|
def validate(self) -> None: |
|
""" |
|
Validate the config object and raise Exception if it's ill-formed. |
|
""" |
|
class_name = type(self).__name__ |
|
config = self.canonical() |
|
|
|
for field in dataclasses.fields(config): |
|
key, value = field.name, getattr(config, field.name) |
|
|
|
|
|
if _is_missing(value): |
|
raise ValueError(f'{class_name}: {key} is not set') |
|
|
|
|
|
type_name = str(field.type).replace('typing.', '') |
|
optional = any([ |
|
type_name.startswith('Optional['), |
|
type_name.startswith('Union[') and 'None' in type_name, |
|
type_name == 'Any' |
|
]) |
|
if value is None: |
|
if optional: |
|
continue |
|
else: |
|
raise ValueError(f'{class_name}: {key} cannot be None') |
|
|
|
|
|
rule = config._validation_rules.get(key) |
|
if rule is not None: |
|
try: |
|
result = rule(value) |
|
except Exception: |
|
raise ValueError(f'{class_name}: {key} has bad value {repr(value)}') |
|
|
|
if isinstance(result, bool): |
|
if not result: |
|
raise ValueError(f'{class_name}: {key} ({repr(value)}) is out of range') |
|
else: |
|
if not result[0]: |
|
raise ValueError(f'{class_name}: {key} {result[1]}') |
|
|
|
|
|
if isinstance(value, ConfigBase): |
|
value.validate() |
|
|