LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import dataclasses
from pathlib import Path
from typing import Any, Dict, Optional, Type, TypeVar
import yaml
from . import util
__all__ = ['ConfigBase', 'PathLike']
T = TypeVar('T', bound='ConfigBase')
PathLike = util.PathLike
def _is_missing(obj: Any) -> bool:
return isinstance(obj, type(dataclasses.MISSING))
class ConfigBase:
"""
Base class of config classes.
Subclass may override `_canonical_rules` and `_validation_rules`,
and `validate()` if the logic is complex.
"""
# Rules to convert field value to canonical format.
# The key is field name.
# The value is callable `value -> canonical_value`
# It is not type-hinted so dataclass won't treat it as field
_canonical_rules = {} # type: ignore
# Rules to validate field value.
# The key is field name.
# The value is callable `value -> valid` or `value -> (valid, error_message)`
# The rule will be called with canonical format and is only called when `value` is not None.
# `error_message` is used when `valid` is False.
# It will be prepended with class name and field name in exception message.
_validation_rules = {} # type: ignore
def __init__(self, *, _base_path: Optional[Path] = None, **kwargs):
"""
Initialize a config object and set some fields.
Name of keyword arguments can either be snake_case or camelCase.
They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`.
"""
if 'basepath' in kwargs:
_base_path = kwargs.pop('basepath')
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()}
if _base_path is None:
_base_path = Path()
for field in dataclasses.fields(self):
value = kwargs.pop(util.case_insensitive(field.name), field.default)
if value is not None and not _is_missing(value):
# relative paths loaded from config file are not relative to pwd
if 'Path' in str(field.type):
value = Path(value).expanduser()
if not value.is_absolute():
value = _base_path / value
setattr(self, field.name, value)
if kwargs:
cls = type(self).__name__
fields = ', '.join(kwargs.keys())
raise ValueError(f'{cls}: Unrecognized fields {fields}')
@classmethod
def load(cls: Type[T], path: PathLike) -> T:
"""
Load config from YAML (or JSON) file.
Keys in YAML file can either be camelCase or snake_case.
"""
data = yaml.safe_load(open(path))
if not isinstance(data, dict):
raise ValueError(f'Content of config file {path} is not a dict/object')
return cls(**data, _base_path=Path(path).parent)
def json(self) -> Dict[str, Any]:
"""
Convert config to JSON object.
The keys of returned object will be camelCase.
"""
self.validate()
return dataclasses.asdict(
self.canonical(),
dict_factory=lambda items: dict((util.camel_case(k), v) for k, v in items if v is not None)
)
def canonical(self: T) -> T:
"""
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format.
Noticeably, relative path may be converted to absolute path.
"""
ret = copy.deepcopy(self)
for field in dataclasses.fields(ret):
key, value = field.name, getattr(ret, field.name)
rule = ret._canonical_rules.get(key)
if rule is not None:
setattr(ret, key, rule(value))
elif isinstance(value, ConfigBase):
setattr(ret, key, value.canonical())
# value will be copied twice, should not be a performance issue anyway
elif isinstance(value, Path):
setattr(ret, key, str(value))
return ret
def validate(self) -> None:
"""
Validate the config object and raise Exception if it's ill-formed.
"""
class_name = type(self).__name__
config = self.canonical()
for field in dataclasses.fields(config):
key, value = field.name, getattr(config, field.name)
# check existence
if _is_missing(value):
raise ValueError(f'{class_name}: {key} is not set')
# check type (TODO)
type_name = str(field.type).replace('typing.', '')
optional = any([
type_name.startswith('Optional['),
type_name.startswith('Union[') and 'None' in type_name,
type_name == 'Any'
])
if value is None:
if optional:
continue
else:
raise ValueError(f'{class_name}: {key} cannot be None')
# check value
rule = config._validation_rules.get(key)
if rule is not None:
try:
result = rule(value)
except Exception:
raise ValueError(f'{class_name}: {key} has bad value {repr(value)}')
if isinstance(result, bool):
if not result:
raise ValueError(f'{class_name}: {key} ({repr(value)}) is out of range')
else:
if not result[0]:
raise ValueError(f'{class_name}: {key} {result[1]}')
# check nested config
if isinstance(value, ConfigBase):
value.validate()