# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A parameter dictionary class which supports the nest structure.""" import collections import copy import re import six import tensorflow as tf, tf_keras import yaml # regex pattern that matches on key-value pairs in a comma-separated # key-value pair string. It splits each k-v pair on the = sign, and # matches on values that are within single quotes, double quotes, single # values (e.g. floats, ints, etc.), and a lists within brackets. _PARAM_RE = re.compile( r""" (?P[a-zA-Z][\w\.]*)(?P\[?[0-9]*\]?) # variable name: "var" or "x" followed by optional index: "[0]" or "[23]" \s*=\s* ((?P\'(.*?)\' # single quote | \"(.*?)\" # double quote | [^,\[]* # single value | \[[^\]]*\])) # list of values ($|,\s*)""", re.VERBOSE) _CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)') # Yaml LOADER with an implicit resolver to parse float decimal and exponential # format. The regular experission parse the following cases: # 1- Decimal number with an optional exponential term. # 2- Integer number with an exponential term. # 3- Decimal number with an optional exponential term. # 4- Decimal number. _LOADER = yaml.FullLoader _LOADER.add_implicit_resolver( 'tag:yaml.org,2002:float', re.compile(r''' ^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? | [-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) | \\.[0-9_]+(?:[eE][-+][0-9]+)? | [-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*)$''', re.X), list('-+0123456789.')) class ParamsDict(object): """A hyperparameter container class.""" RESERVED_ATTR = ['_locked', '_restrictions'] def __init__(self, default_params=None, restrictions=None): """Instantiate a ParamsDict. Instantiate a ParamsDict given a set of default parameters and a list of restrictions. Upon initialization, it validates itself by checking all the defined restrictions, and raise error if it finds inconsistency. Args: default_params: a Python dict or another ParamsDict object including the default parameters to initialize. restrictions: a list of strings, which define a list of restrictions to ensure the consistency of different parameters internally. Each restriction string is defined as a binary relation with a set of operators, including {'==', '!=', '<', '<=', '>', '>='}. """ self._locked = False self._restrictions = [] if restrictions: self._restrictions = restrictions if default_params is None: default_params = {} self.override(default_params, is_strict=False) def _set(self, k, v): if isinstance(v, dict): self.__dict__[k] = ParamsDict(v) else: self.__dict__[k] = copy.deepcopy(v) def __setattr__(self, k, v): """Sets the value of the existing key. Note that this does not allow directly defining a new key. Use the `override` method with `is_strict=False` instead. Args: k: the key string. v: the value to be used to set the key `k`. Raises: KeyError: if k is not defined in the ParamsDict. """ if k not in ParamsDict.RESERVED_ATTR: if k not in self.__dict__.keys(): raise KeyError('The key `%{}` does not exist. ' 'To extend the existing keys, use ' '`override` with `is_strict` = True.'.format(k)) if self._locked: raise ValueError('The ParamsDict has been locked. ' 'No change is allowed.') self._set(k, v) def __getattr__(self, k): """Gets the value of the existing key. Args: k: the key string. Returns: the value of the key. Raises: AttributeError: if k is not defined in the ParamsDict. """ if k not in self.__dict__.keys(): raise AttributeError('The key `{}` does not exist. '.format(k)) return self.__dict__[k] def __contains__(self, key): """Implements the membership test operator.""" return key in self.__dict__ def get(self, key, value=None): """Accesses through built-in dictionary get method.""" return self.__dict__.get(key, value) def __delattr__(self, k): """Deletes the key and removes its values. Args: k: the key string. Raises: AttributeError: if k is reserverd or not defined in the ParamsDict. ValueError: if the ParamsDict instance has been locked. """ if k in ParamsDict.RESERVED_ATTR: raise AttributeError( 'The key `{}` is reserved. No change is allowes. '.format(k)) if k not in self.__dict__.keys(): raise AttributeError('The key `{}` does not exist. '.format(k)) if self._locked: raise ValueError('The ParamsDict has been locked. No change is allowed.') del self.__dict__[k] def override(self, override_params, is_strict=True): """Override the ParamsDict with a set of given params. Args: override_params: a dict or a ParamsDict specifying the parameters to be overridden. is_strict: a boolean specifying whether override is strict or not. If True, keys in `override_params` must be present in the ParamsDict. If False, keys in `override_params` can be different from what is currently defined in the ParamsDict. In this case, the ParamsDict will be extended to include the new keys. """ if self._locked: raise ValueError('The ParamsDict has been locked. No change is allowed.') if isinstance(override_params, ParamsDict): override_params = override_params.as_dict() self._override(override_params, is_strict) # pylint: disable=protected-access def _override(self, override_dict, is_strict=True): """The implementation of `override`.""" for k, v in six.iteritems(override_dict): if k in ParamsDict.RESERVED_ATTR: raise KeyError('The key `%{}` is internally reserved. ' 'Can not be overridden.') if k not in self.__dict__.keys(): if is_strict: raise KeyError('The key `{}` does not exist. ' 'To extend the existing keys, use ' '`override` with `is_strict` = False.'.format(k)) else: self._set(k, v) else: if isinstance(v, dict): self.__dict__[k]._override(v, is_strict) # pylint: disable=protected-access elif isinstance(v, ParamsDict): self.__dict__[k]._override(v.as_dict(), is_strict) # pylint: disable=protected-access else: self.__dict__[k] = copy.deepcopy(v) def lock(self): """Makes the ParamsDict immutable.""" self._locked = True def as_dict(self): """Returns a dict representation of ParamsDict. For the nested ParamsDict, a nested dict will be returned. """ params_dict = {} for k, v in six.iteritems(self.__dict__): if k not in ParamsDict.RESERVED_ATTR: if isinstance(v, ParamsDict): params_dict[k] = v.as_dict() else: params_dict[k] = copy.deepcopy(v) return params_dict def validate(self): """Validate the parameters consistency based on the restrictions. This method validates the internal consistency using the pre-defined list of restrictions. A restriction is defined as a string which specifies a binary operation. The supported binary operations are {'==', '!=', '<', '<=', '>', '>='}. Note that the meaning of these operators are consistent with the underlying Python immplementation. Users should make sure the define restrictions on their type make sense. For example, for a ParamsDict like the following ``` a: a1: 1 a2: 2 b: bb: bb1: 10 bb2: 20 ccc: a1: 1 a3: 3 ``` one can define two restrictions like this ['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2'] What it enforces are: - a.a1 = 1 == b.ccc.a1 = 1 - a.a2 = 2 <= b.bb.bb2 = 20 Raises: KeyError: if any of the following happens (1) any of parameters in any of restrictions is not defined in ParamsDict, (2) any inconsistency violating the restriction is found. ValueError: if the restriction defined in the string is not supported. """ def _get_kv(dotted_string, params_dict): """Get keys and values indicated by dotted_string.""" if _CONST_VALUE_RE.match(dotted_string) is not None: const_str = dotted_string if const_str == 'None': constant = None else: constant = float(const_str) return None, constant else: tokenized_params = dotted_string.split('.') v = params_dict for t in tokenized_params: v = v[t] return tokenized_params[-1], v def _get_kvs(tokens, params_dict): if len(tokens) != 2: raise ValueError('Only support binary relation in restriction.') stripped_tokens = [t.strip() for t in tokens] left_k, left_v = _get_kv(stripped_tokens[0], params_dict) right_k, right_v = _get_kv(stripped_tokens[1], params_dict) return left_k, left_v, right_k, right_v params_dict = self.as_dict() for restriction in self._restrictions: if '==' in restriction: tokens = restriction.split('==') _, left_v, _, right_v = _get_kvs(tokens, params_dict) if left_v != right_v: raise KeyError( 'Found inconsistency between key `{}` and key `{}`.'.format( tokens[0], tokens[1])) elif '!=' in restriction: tokens = restriction.split('!=') _, left_v, _, right_v = _get_kvs(tokens, params_dict) if left_v == right_v: raise KeyError( 'Found inconsistency between key `{}` and key `{}`.'.format( tokens[0], tokens[1])) elif '<' in restriction: tokens = restriction.split('<') _, left_v, _, right_v = _get_kvs(tokens, params_dict) if left_v >= right_v: raise KeyError( 'Found inconsistency between key `{}` and key `{}`.'.format( tokens[0], tokens[1])) elif '<=' in restriction: tokens = restriction.split('<=') _, left_v, _, right_v = _get_kvs(tokens, params_dict) if left_v > right_v: raise KeyError( 'Found inconsistency between key `{}` and key `{}`.'.format( tokens[0], tokens[1])) elif '>' in restriction: tokens = restriction.split('>') _, left_v, _, right_v = _get_kvs(tokens, params_dict) if left_v <= right_v: raise KeyError( 'Found inconsistency between key `{}` and key `{}`.'.format( tokens[0], tokens[1])) elif '>=' in restriction: tokens = restriction.split('>=') _, left_v, _, right_v = _get_kvs(tokens, params_dict) if left_v < right_v: raise KeyError( 'Found inconsistency between key `{}` and key `{}`.'.format( tokens[0], tokens[1])) else: raise ValueError('Unsupported relation in restriction.') def read_yaml_to_params_dict(file_path: str): """Reads a YAML file to a ParamsDict.""" with tf.io.gfile.GFile(file_path, 'r') as f: params_dict = yaml.load(f, Loader=_LOADER) return ParamsDict(params_dict) def save_params_dict_to_yaml(params, file_path): """Saves the input ParamsDict to a YAML file.""" with tf.io.gfile.GFile(file_path, 'w') as f: def _my_list_rep(dumper, data): # u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence. return dumper.represent_sequence( u'tag:yaml.org,2002:seq', data, flow_style=True) yaml.add_representer(list, _my_list_rep) yaml.dump(params.as_dict(), f, default_flow_style=False) def nested_csv_str_to_json_str(csv_str): """Converts a nested (using '.') comma-separated k=v string to a JSON string. Converts a comma-separated string of key/value pairs that supports nesting of keys to a JSON string. Nesting is implemented using '.' between levels for a given key. Spacing between commas and = is supported (e.g. there is no difference between "a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported). Note that this will only support values supported by CSV, meaning values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not supported. Strings are supported as well, e.g. "a='hello'". An example conversion would be: "a=1, b=2, c.a=2, c.b=3, d.a.a=5" to "{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}" Args: csv_str: the comma separated string. Returns: the converted JSON string. Raises: ValueError: If csv_str is not in a comma separated string or if the string is formatted incorrectly. """ if not csv_str: return '' array_param_map = collections.defaultdict(str) max_index_map = collections.defaultdict(str) formatted_entries = [] nested_map = collections.defaultdict(list) pos = 0 while pos < len(csv_str): m = _PARAM_RE.match(csv_str, pos) if not m: raise ValueError('Malformed hyperparameter value while parsing ' 'CSV string: %s' % csv_str[pos:]) pos = m.end() # Parse the values. m_dict = m.groupdict() name = m_dict['name'] v = m_dict['val'] bracketed_index = m_dict['bracketed_index'] # If we reach the name of the array. if bracketed_index and '.' not in name: # Extract the array's index by removing '[' and ']' index = int(bracketed_index[1:-1]) if '.' in v: numeric_val = float(v) else: numeric_val = int(v) # Add the value to the array. if name not in array_param_map: max_index_map[name] = index array_param_map[name] = [None] * (index + 1) array_param_map[name][index] = numeric_val elif index < max_index_map[name]: array_param_map[name][index] = numeric_val else: array_param_map[name] += [None] * (index - max_index_map[name]) array_param_map[name][index] = numeric_val max_index_map[name] = index continue # If a GCS path (e.g. gs://...) is provided, wrap this in quotes # as yaml.load would otherwise throw an exception if re.match(r'(?=[^\"\'])(?=[gs://])', v): v = '\'{}\''.format(v) name_nested = name.split('.') if len(name_nested) > 1: grouping = name_nested[0] if bracketed_index: value = '.'.join(name_nested[1:]) + bracketed_index + '=' + v else: value = '.'.join(name_nested[1:]) + '=' + v nested_map[grouping].append(value) else: formatted_entries.append('%s : %s' % (name, v)) for grouping, value in nested_map.items(): value = ','.join(value) value = nested_csv_str_to_json_str(value) formatted_entries.append('%s : %s' % (grouping, value)) # Add array parameters and check that the array is fully initialized. for name in array_param_map: if any(v is None for v in array_param_map[name]): raise ValueError('Did not pass all values of array: %s' % name) formatted_entries.append('%s : %s' % (name, array_param_map[name])) return '{' + ', '.join(formatted_entries) + '}' def override_params_dict(params, dict_or_string_or_yaml_file, is_strict): """Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file. The logic of the function is outlined below: 1. Test that the input is a dict. If not, proceed to 2. 2. Tests that the input is a string. If not, raise unknown ValueError 2.1. Test if the string is in a CSV format. If so, parse. If not, proceed to 2.2. 2.2. Try loading the string as a YAML/JSON. If successful, parse to dict and use it to override. If not, proceed to 2.3. 2.3. Try using the string as a file path and load the YAML file. Args: params: a ParamsDict object to be overridden. dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or path to a YAML file specifying the parameters to be overridden. is_strict: a boolean specifying whether override is strict or not. Returns: params: the overridden ParamsDict object. Raises: ValueError: if failed to override the parameters. """ if not dict_or_string_or_yaml_file: return params if isinstance(dict_or_string_or_yaml_file, dict): params.override(dict_or_string_or_yaml_file, is_strict) elif isinstance(dict_or_string_or_yaml_file, six.string_types): try: dict_or_string_or_yaml_file = ( nested_csv_str_to_json_str(dict_or_string_or_yaml_file)) except ValueError: pass params_dict = yaml.load(dict_or_string_or_yaml_file, Loader=_LOADER) if isinstance(params_dict, dict): params.override(params_dict, is_strict) else: with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f: params.override(yaml.load(f, Loader=_LOADER), is_strict) else: raise ValueError('Unknown input type to parse.') return params