deanna-emery's picture
updates
93528c6
raw
history blame
12.5 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base configurations to standardize experiments."""
import copy
import dataclasses
import functools
import inspect
import typing
from typing import Any, List, Mapping, Optional, Type, Union
from absl import logging
import tensorflow as tf, tf_keras
import yaml
from official.modeling.hyperparams import params_dict
_BOUND = set()
def bind(config_cls):
"""Bind a class to config cls."""
if not inspect.isclass(config_cls):
raise ValueError('The bind decorator is supposed to apply on the class '
f'attribute. Received {config_cls}, not a class.')
def decorator(builder):
if config_cls in _BOUND:
raise ValueError('Inside a program, we should not bind the config with a'
' class twice.')
if inspect.isclass(builder):
config_cls._BUILDER = builder # pylint: disable=protected-access
elif inspect.isfunction(builder):
def _wrapper(self, *args, **kwargs): # pylint: disable=unused-argument
return builder(*args, **kwargs)
config_cls._BUILDER = _wrapper # pylint: disable=protected-access
else:
raise ValueError(f'The `BUILDER` type is not supported: {builder}')
_BOUND.add(config_cls)
return builder
return decorator
def _is_optional(field):
return typing.get_origin(field) is Union and type(None) in typing.get_args(
field)
@dataclasses.dataclass
class Config(params_dict.ParamsDict):
"""The base configuration class that supports YAML/JSON based overrides.
Because of YAML/JSON serialization limitations, some semantics of dataclass
are not supported:
* It recursively enforces a allowlist of basic types and container types, so
it avoids surprises with copy and reuse caused by unanticipated types.
* Warning: it converts Dict to `Config` even within sequences,
e.g. for config = Config({'key': [([{'a': 42}],)]),
type(config.key[0][0][0]) is Config rather than dict.
If you define/annotate some field as Dict, the field will convert to a
`Config` instance and lose the dictionary type.
"""
# The class or method to bind with the params class.
_BUILDER = None
# It's safe to add bytes and other immutable types here.
IMMUTABLE_TYPES = (str, int, float, bool, type(None))
# It's safe to add set, frozenset and other collections here.
SEQUENCE_TYPES = (list, tuple)
default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
restrictions: dataclasses.InitVar[Optional[List[str]]] = None
def __post_init__(self, default_params, restrictions):
super().__init__(
default_params=default_params,
restrictions=restrictions)
@property
def BUILDER(self):
return self._BUILDER
@classmethod
def _get_annotations(cls):
"""Returns valid annotations.
Note: this is similar to dataclasses.__annotations__ except it also includes
annotations from its parent classes.
"""
all_annotations = typing.get_type_hints(cls)
# Removes Config class annotation from the value, e.g., default_params,
# restrictions, etc.
for k in Config.__annotations__:
del all_annotations[k]
return all_annotations
@classmethod
def _isvalidsequence(cls, v):
"""Check if the input values are valid sequences.
Args:
v: Input sequence.
Returns:
True if the sequence is valid. Valid sequence includes the sequence
type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
is dict or ParamsDict.
"""
if not isinstance(v, cls.SEQUENCE_TYPES):
return False
return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
all(isinstance(e, dict) for e in v) or
all(isinstance(e, params_dict.ParamsDict) for e in v))
@classmethod
def _import_config(cls, v, subconfig_type):
"""Returns v with dicts converted to Configs, recursively."""
if not issubclass(subconfig_type, params_dict.ParamsDict):
raise TypeError(
'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
subconfig_type))
if isinstance(v, cls.IMMUTABLE_TYPES):
return v
elif isinstance(v, cls.SEQUENCE_TYPES):
# Only support one layer of sequence.
if not cls._isvalidsequence(v):
raise TypeError(
'Invalid sequence: only supports single level {!r} of {!r} or '
'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
cls.IMMUTABLE_TYPES, v))
import_fn = functools.partial(
cls._import_config, subconfig_type=subconfig_type)
return type(v)(map(import_fn, v))
elif isinstance(v, params_dict.ParamsDict):
# Deepcopy here is a temporary solution for preserving type in nested
# Config object.
return copy.deepcopy(v)
elif isinstance(v, dict):
return subconfig_type(v)
else:
raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod
def _export_config(cls, v):
"""Returns v with Configs converted to dicts, recursively."""
if isinstance(v, cls.IMMUTABLE_TYPES):
return v
elif isinstance(v, cls.SEQUENCE_TYPES):
return type(v)(map(cls._export_config, v))
elif isinstance(v, params_dict.ParamsDict):
return v.as_dict()
elif isinstance(v, dict):
raise TypeError('dict value not supported in converting.')
else:
raise TypeError('Unknown type: {!r}'.format(type(v)))
@classmethod
def _get_subconfig_type(
cls, k, subconfig_type=None
) -> Type[params_dict.ParamsDict]:
"""Get element type by the field name.
Args:
k: the key/name of the field.
subconfig_type: default subconfig_type. If None, it is set to
Config.
Returns:
Config as default. If a type annotation is found for `k`,
1) returns the type of the annotation if it is subtype of ParamsDict;
2) returns the element type if the annotation of `k` is List[SubType]
or Tuple[SubType].
"""
if not subconfig_type:
subconfig_type = Config
annotations = cls._get_annotations()
if k in annotations:
# Directly Config subtype.
type_annotation = annotations[k]
i = 0
# Loop for striping the Optional annotation.
traverse_in = True
while traverse_in:
i += 1
if (isinstance(type_annotation, type) and
issubclass(type_annotation, Config)):
subconfig_type = type_annotation
break
else:
# Check if the field is a sequence of subtypes.
field_type = typing.get_origin(type_annotation)
if (isinstance(field_type, type) and
issubclass(field_type, cls.SEQUENCE_TYPES)):
element_type = typing.get_args(type_annotation)[0]
subconfig_type = (
element_type if issubclass(element_type, params_dict.ParamsDict)
else subconfig_type)
break
elif _is_optional(type_annotation):
# Strip the `Optional` annotation and process the subtype.
type_annotation = typing.get_args(type_annotation)[0]
continue
traverse_in = False
return subconfig_type
def _set(self, k, v):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
k: key to set.
v: value.
Raises:
RuntimeError
"""
subconfig_type = self._get_subconfig_type(k)
def is_null(k):
if k not in self.__dict__ or not self.__dict__[k]:
return True
return False
if isinstance(v, dict):
if is_null(k):
# If the key not exist or the value is None, a new Config-family object
# sould be created for the key.
self.__dict__[k] = subconfig_type(v)
else:
self.__dict__[k].override(v)
elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
[not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
if len(self.__dict__[k]) == len(v):
for i in range(len(v)):
self.__dict__[k][i].override(v[i])
elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
logging.warning(
"The list/tuple don't match the value dictionaries provided. Thus, "
'the list/tuple is determined by the type annotation and '
'values provided. This is error-prone.')
self.__dict__[k] = self._import_config(v, subconfig_type)
else:
self.__dict__[k] = self._import_config(v, subconfig_type)
else:
self.__dict__[k] = self._import_config(v, subconfig_type)
def __setattr__(self, k, v):
if k == 'BUILDER' or k == '_BUILDER':
raise AttributeError('`BUILDER` is a property and `_BUILDER` is the '
'reserved class attribute. We should only assign '
'`_BUILDER` at the class level.')
if k not in self.RESERVED_ATTR:
if getattr(self, '_locked', False):
raise ValueError('The Config has been locked. ' 'No change is allowed.')
self._set(k, v)
def _override(self, override_dict, is_strict=True):
"""Overrides same method in ParamsDict.
Also called by ParamsDict methods.
Args:
override_dict: dictionary to write to .
is_strict: If True, not allows to add new keys.
Raises:
KeyError: overriding reserved keys or keys not exist (is_strict=True).
"""
for k, v in sorted(override_dict.items()):
if k in self.RESERVED_ATTR:
raise KeyError('The key {!r} is internally reserved. '
'Can not be overridden.'.format(k))
if k not in self.__dict__:
if is_strict:
raise KeyError('The key {!r} does not exist in {!r}. '
'To extend the existing keys, use '
'`override` with `is_strict` = False.'.format(
k, type(self)))
else:
self._set(k, v)
else:
if isinstance(v, dict) and self.__dict__[k]:
self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access
elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access
else:
self._set(k, v)
def as_dict(self):
"""Returns a dict representation of params_dict.ParamsDict.
For the nested params_dict.ParamsDict, a nested dict will be returned.
"""
return {
k: self._export_config(v)
for k, v in self.__dict__.items()
if k not in self.RESERVED_ATTR
}
def replace(self, **kwargs):
"""Overrides/returns a unlocked copy with the current config unchanged."""
# pylint: disable=protected-access
params = copy.deepcopy(self)
params._locked = False
params._override(kwargs, is_strict=True)
# pylint: enable=protected-access
return params
@classmethod
def from_yaml(cls, file_path: str):
# Note: This only works if the Config has all default values.
with tf.io.gfile.GFile(file_path, 'r') as f:
loaded = yaml.load(f, Loader=yaml.FullLoader)
config = cls()
config.override(loaded)
return config
@classmethod
def from_json(cls, file_path: str):
"""Wrapper for `from_yaml`."""
return cls.from_yaml(file_path)
@classmethod
def from_args(cls, *args, **kwargs):
"""Builds a config from the given list of arguments."""
# Note we intend to keep `__annotations__` instead of `_get_annotations`.
# Assuming a parent class of (a, b) with the sub-class of (c, d), the
# sub-class will take (c, d) for args, rather than starting from (a, b).
attributes = list(cls.__annotations__.keys())
default_params = {a: p for a, p in zip(attributes, args)}
default_params.update(kwargs)
return cls(default_params=default_params)