File size: 12,519 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93528c6
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Base configurations to standardize experiments."""

import copy
import dataclasses
import functools
import inspect
import typing
from typing import Any, List, Mapping, Optional, Type, Union

from absl import logging
import tensorflow as tf, tf_keras
import yaml

from official.modeling.hyperparams import params_dict


_BOUND = set()


def bind(config_cls):
  """Bind a class to config cls."""
  if not inspect.isclass(config_cls):
    raise ValueError('The bind decorator is supposed to apply on the class '
                     f'attribute. Received {config_cls}, not a class.')

  def decorator(builder):
    if config_cls in _BOUND:
      raise ValueError('Inside a program, we should not bind the config with a'
                       ' class twice.')
    if inspect.isclass(builder):
      config_cls._BUILDER = builder  # pylint: disable=protected-access
    elif inspect.isfunction(builder):

      def _wrapper(self, *args, **kwargs):  # pylint: disable=unused-argument
        return builder(*args, **kwargs)

      config_cls._BUILDER = _wrapper  # pylint: disable=protected-access
    else:
      raise ValueError(f'The `BUILDER` type is not supported: {builder}')
    _BOUND.add(config_cls)
    return builder

  return decorator


def _is_optional(field):
  return typing.get_origin(field) is Union and type(None) in typing.get_args(
      field)


@dataclasses.dataclass
class Config(params_dict.ParamsDict):
  """The base configuration class that supports YAML/JSON based overrides.

  Because of YAML/JSON serialization limitations, some semantics of dataclass
  are not supported:
  * It recursively enforces a allowlist of basic types and container types, so
    it avoids surprises with copy and reuse caused by unanticipated types.
  * Warning: it converts Dict to `Config` even within sequences,
    e.g. for config = Config({'key': [([{'a': 42}],)]),
         type(config.key[0][0][0]) is Config rather than dict.
    If you define/annotate some field as Dict, the field will convert to a
    `Config` instance and lose the dictionary type.
  """
  # The class or method to bind with the params class.
  _BUILDER = None
  # It's safe to add bytes and other immutable types here.
  IMMUTABLE_TYPES = (str, int, float, bool, type(None))
  # It's safe to add set, frozenset and other collections here.
  SEQUENCE_TYPES = (list, tuple)

  default_params: dataclasses.InitVar[Optional[Mapping[str, Any]]] = None
  restrictions: dataclasses.InitVar[Optional[List[str]]] = None

  def __post_init__(self, default_params, restrictions):
    super().__init__(
        default_params=default_params,
        restrictions=restrictions)

  @property
  def BUILDER(self):
    return self._BUILDER

  @classmethod
  def _get_annotations(cls):
    """Returns valid annotations.

    Note: this is similar to dataclasses.__annotations__ except it also includes
      annotations from its parent classes.
    """
    all_annotations = typing.get_type_hints(cls)
    # Removes Config class annotation from the value, e.g., default_params,
    # restrictions, etc.
    for k in Config.__annotations__:
      del all_annotations[k]
    return all_annotations

  @classmethod
  def _isvalidsequence(cls, v):
    """Check if the input values are valid sequences.

    Args:
      v: Input sequence.

    Returns:
      True if the sequence is valid. Valid sequence includes the sequence
      type in cls.SEQUENCE_TYPES and element type is in cls.IMMUTABLE_TYPES or
      is dict or ParamsDict.
    """
    if not isinstance(v, cls.SEQUENCE_TYPES):
      return False
    return (all(isinstance(e, cls.IMMUTABLE_TYPES) for e in v) or
            all(isinstance(e, dict) for e in v) or
            all(isinstance(e, params_dict.ParamsDict) for e in v))

  @classmethod
  def _import_config(cls, v, subconfig_type):
    """Returns v with dicts converted to Configs, recursively."""
    if not issubclass(subconfig_type, params_dict.ParamsDict):
      raise TypeError(
          'Subconfig_type should be subclass of ParamsDict, found {!r}'.format(
              subconfig_type))
    if isinstance(v, cls.IMMUTABLE_TYPES):
      return v
    elif isinstance(v, cls.SEQUENCE_TYPES):
      # Only support one layer of sequence.
      if not cls._isvalidsequence(v):
        raise TypeError(
            'Invalid sequence: only supports single level {!r} of {!r} or '
            'dict or ParamsDict found: {!r}'.format(cls.SEQUENCE_TYPES,
                                                    cls.IMMUTABLE_TYPES, v))
      import_fn = functools.partial(
          cls._import_config, subconfig_type=subconfig_type)
      return type(v)(map(import_fn, v))
    elif isinstance(v, params_dict.ParamsDict):
      # Deepcopy here is a temporary solution for preserving type in nested
      # Config object.
      return copy.deepcopy(v)
    elif isinstance(v, dict):
      return subconfig_type(v)
    else:
      raise TypeError('Unknown type: {!r}'.format(type(v)))

  @classmethod
  def _export_config(cls, v):
    """Returns v with Configs converted to dicts, recursively."""
    if isinstance(v, cls.IMMUTABLE_TYPES):
      return v
    elif isinstance(v, cls.SEQUENCE_TYPES):
      return type(v)(map(cls._export_config, v))
    elif isinstance(v, params_dict.ParamsDict):
      return v.as_dict()
    elif isinstance(v, dict):
      raise TypeError('dict value not supported in converting.')
    else:
      raise TypeError('Unknown type: {!r}'.format(type(v)))

  @classmethod
  def _get_subconfig_type(
      cls, k, subconfig_type=None
  ) -> Type[params_dict.ParamsDict]:
    """Get element type by the field name.

    Args:
      k: the key/name of the field.
      subconfig_type: default subconfig_type. If None, it is set to
        Config.

    Returns:
      Config as default. If a type annotation is found for `k`,
      1) returns the type of the annotation if it is subtype of ParamsDict;
      2) returns the element type if the annotation of `k` is List[SubType]
         or Tuple[SubType].
    """
    if not subconfig_type:
      subconfig_type = Config

    annotations = cls._get_annotations()
    if k in annotations:
      # Directly Config subtype.
      type_annotation = annotations[k]
      i = 0
      # Loop for striping the Optional annotation.
      traverse_in = True
      while traverse_in:
        i += 1
        if (isinstance(type_annotation, type) and
            issubclass(type_annotation, Config)):
          subconfig_type = type_annotation
          break
        else:
          # Check if the field is a sequence of subtypes.
          field_type = typing.get_origin(type_annotation)
          if (isinstance(field_type, type) and
              issubclass(field_type, cls.SEQUENCE_TYPES)):
            element_type = typing.get_args(type_annotation)[0]
            subconfig_type = (
                element_type if issubclass(element_type, params_dict.ParamsDict)
                else subconfig_type)
            break
          elif _is_optional(type_annotation):
            # Strip the `Optional` annotation and process the subtype.
            type_annotation = typing.get_args(type_annotation)[0]
            continue
        traverse_in = False
    return subconfig_type

  def _set(self, k, v):
    """Overrides same method in ParamsDict.

    Also called by ParamsDict methods.

    Args:
      k: key to set.
      v: value.

    Raises:
      RuntimeError
    """
    subconfig_type = self._get_subconfig_type(k)

    def is_null(k):
      if k not in self.__dict__ or not self.__dict__[k]:
        return True
      return False

    if isinstance(v, dict):
      if is_null(k):
        # If the key not exist or the value is None, a new Config-family object
        # sould be created for the key.
        self.__dict__[k] = subconfig_type(v)
      else:
        self.__dict__[k].override(v)
    elif not is_null(k) and isinstance(v, self.SEQUENCE_TYPES) and all(
        [not isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
      if len(self.__dict__[k]) == len(v):
        for i in range(len(v)):
          self.__dict__[k][i].override(v[i])
      elif not all([isinstance(e, self.IMMUTABLE_TYPES) for e in v]):
        logging.warning(
            "The list/tuple don't match the value dictionaries provided. Thus, "
            'the list/tuple is determined by the type annotation and '
            'values provided. This is error-prone.')
        self.__dict__[k] = self._import_config(v, subconfig_type)
      else:
        self.__dict__[k] = self._import_config(v, subconfig_type)
    else:
      self.__dict__[k] = self._import_config(v, subconfig_type)

  def __setattr__(self, k, v):
    if k == 'BUILDER' or k == '_BUILDER':
      raise AttributeError('`BUILDER` is a property and `_BUILDER` is the '
                           'reserved class attribute. We should only assign '
                           '`_BUILDER` at the class level.')

    if k not in self.RESERVED_ATTR:
      if getattr(self, '_locked', False):
        raise ValueError('The Config has been locked. ' 'No change is allowed.')
    self._set(k, v)

  def _override(self, override_dict, is_strict=True):
    """Overrides same method in ParamsDict.

    Also called by ParamsDict methods.

    Args:
      override_dict: dictionary to write to .
      is_strict: If True, not allows to add new keys.

    Raises:
      KeyError: overriding reserved keys or keys not exist (is_strict=True).
    """
    for k, v in sorted(override_dict.items()):
      if k in self.RESERVED_ATTR:
        raise KeyError('The key {!r} is internally reserved. '
                       'Can not be overridden.'.format(k))
      if k not in self.__dict__:
        if is_strict:
          raise KeyError('The key {!r} does not exist in {!r}. '
                         'To extend the existing keys, use '
                         '`override` with `is_strict` = False.'.format(
                             k, type(self)))
        else:
          self._set(k, v)
      else:
        if isinstance(v, dict) and self.__dict__[k]:
          self.__dict__[k]._override(v, is_strict)  # pylint: disable=protected-access
        elif isinstance(v, params_dict.ParamsDict) and self.__dict__[k]:
          self.__dict__[k]._override(v.as_dict(), is_strict)  # pylint: disable=protected-access
        else:
          self._set(k, v)

  def as_dict(self):
    """Returns a dict representation of params_dict.ParamsDict.

    For the nested params_dict.ParamsDict, a nested dict will be returned.
    """
    return {
        k: self._export_config(v)
        for k, v in self.__dict__.items()
        if k not in self.RESERVED_ATTR
    }

  def replace(self, **kwargs):
    """Overrides/returns a unlocked copy with the current config unchanged."""
    # pylint: disable=protected-access
    params = copy.deepcopy(self)
    params._locked = False
    params._override(kwargs, is_strict=True)
    # pylint: enable=protected-access
    return params

  @classmethod
  def from_yaml(cls, file_path: str):
    # Note: This only works if the Config has all default values.
    with tf.io.gfile.GFile(file_path, 'r') as f:
      loaded = yaml.load(f, Loader=yaml.FullLoader)
      config = cls()
      config.override(loaded)
      return config

  @classmethod
  def from_json(cls, file_path: str):
    """Wrapper for `from_yaml`."""
    return cls.from_yaml(file_path)

  @classmethod
  def from_args(cls, *args, **kwargs):
    """Builds a config from the given list of arguments."""
    # Note we intend to keep `__annotations__` instead of `_get_annotations`.
    # Assuming a parent class of (a, b) with the sub-class of (c, d), the
    # sub-class will take (c, d) for args, rather than starting from (a, b).
    attributes = list(cls.__annotations__.keys())
    default_params = {a: p for a, p in zip(attributes, args)}
    default_params.update(kwargs)
    return cls(default_params=default_params)