"""A parameter dictionary class which supports the nest structure.""" |
from __future__ import absolute_import |
from __future__ import division |
from __future__ import print_function |
import collections |
import copy |
import re |
import six |
import tensorflow as tf |
import yaml |
_PARAM_RE = re.compile(r""" |
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x" |
\s*=\s* |
((?P<val>\'(.*?)\' # single quote |
| |
\"(.*?)\" # double quote |
| |
[^,\[]* # single value |
| |
\[[^\]]*\])) # list of values |
($|,\s*)""", re.VERBOSE) |
_CONST_VALUE_RE = re.compile(r'(\d.*|-\d.*|None)') |
class ParamsDict(object): |
"""A hyperparameter container class.""" |
RESERVED_ATTR = ['_locked', '_restrictions'] |
def __init__(self, default_params=None, restrictions=None): |
"""Instantiate a ParamsDict. |
Instantiate a ParamsDict given a set of default parameters and a list of |
restrictions. Upon initialization, it validates itself by checking all the |
defined restrictions, and raise error if it finds inconsistency. |
Args: |
default_params: a Python dict or another ParamsDict object including the |
default parameters to initialize. |
restrictions: a list of strings, which define a list of restrictions to |
ensure the consistency of different parameters internally. Each |
restriction string is defined as a binary relation with a set of |
operators, including {'==', '!=', '<', '<=', '>', '>='}. |
""" |
self._locked = False |
self._restrictions = [] |
if restrictions: |
self._restrictions = restrictions |
if default_params is None: |
default_params = {} |
self.override(default_params, is_strict=False) |
self.validate() |
def _set(self, k, v): |
if isinstance(v, dict): |
self.__dict__[k] = ParamsDict(v) |
else: |
self.__dict__[k] = copy.deepcopy(v) |
def __setattr__(self, k, v): |
"""Sets the value of the existing key. |
Note that this does not allow directly defining a new key. Use the |
`override` method with `is_strict=False` instead. |
Args: |
k: the key string. |
v: the value to be used to set the key `k`. |
Raises: |
KeyError: if k is not defined in the ParamsDict. |
""" |
if k not in ParamsDict.RESERVED_ATTR: |
if k not in self.__dict__.keys(): |
raise KeyError('The key `%{}` does not exist. ' |
'To extend the existing keys, use ' |
'`override` with `is_strict` = True.'.format(k)) |
if self._locked: |
raise ValueError('The ParamsDict has been locked. ' |
'No change is allowed.') |
self._set(k, v) |
def __getattr__(self, k): |
"""Gets the value of the existing key. |
Args: |
k: the key string. |
Returns: |
the value of the key. |
Raises: |
AttributeError: if k is not defined in the ParamsDict. |
""" |
if k not in self.__dict__.keys(): |
raise AttributeError('The key `{}` does not exist. '.format(k)) |
return self.__dict__[k] |
def __contains__(self, key): |
"""Implements the membership test operator.""" |
return key in self.__dict__ |
def get(self, key, value=None): |
"""Accesses through built-in dictionary get method.""" |
return self.__dict__.get(key, value) |
def __delattr__(self, k): |
"""Deletes the key and removes its values. |
Args: |
k: the key string. |
Raises: |
AttributeError: if k is reserverd or not defined in the ParamsDict. |
ValueError: if the ParamsDict instance has been locked. |
""" |
if k in ParamsDict.RESERVED_ATTR: |
raise AttributeError('The key `{}` is reserved. No change is allowes. ' |
.format(k)) |
if k not in self.__dict__.keys(): |
raise AttributeError('The key `{}` does not exist. '.format(k)) |
if self._locked: |
raise ValueError('The ParamsDict has been locked. No change is allowed.') |
del self.__dict__[k] |
def override(self, override_params, is_strict=True): |
"""Override the ParamsDict with a set of given params. |
Args: |
override_params: a dict or a ParamsDict specifying the parameters to |
be overridden. |
is_strict: a boolean specifying whether override is strict or not. If |
True, keys in `override_params` must be present in the ParamsDict. |
If False, keys in `override_params` can be different from what is |
currently defined in the ParamsDict. In this case, the ParamsDict will |
be extended to include the new keys. |
""" |
if self._locked: |
raise ValueError('The ParamsDict has been locked. No change is allowed.') |
if isinstance(override_params, ParamsDict): |
override_params = override_params.as_dict() |
self._override(override_params, is_strict) |
def _override(self, override_dict, is_strict=True): |
"""The implementation of `override`.""" |
for k, v in six.iteritems(override_dict): |
if k in ParamsDict.RESERVED_ATTR: |
raise KeyError('The key `%{}` is internally reserved. ' |
'Can not be overridden.') |
if k not in self.__dict__.keys(): |
if is_strict: |
raise KeyError('The key `{}` does not exist. ' |
'To extend the existing keys, use ' |
'`override` with `is_strict` = False.'.format(k)) |
else: |
self._set(k, v) |
else: |
if isinstance(v, dict): |
self.__dict__[k]._override(v, is_strict) |
elif isinstance(v, ParamsDict): |
self.__dict__[k]._override(v.as_dict(), is_strict) |
else: |
self.__dict__[k] = copy.deepcopy(v) |
def lock(self): |
"""Makes the ParamsDict immutable.""" |
self._locked = True |
def as_dict(self): |
"""Returns a dict representation of ParamsDict. |
For the nested ParamsDict, a nested dict will be returned. |
""" |
params_dict = {} |
for k, v in six.iteritems(self.__dict__): |
if k not in ParamsDict.RESERVED_ATTR: |
if isinstance(v, ParamsDict): |
params_dict[k] = v.as_dict() |
else: |
params_dict[k] = copy.deepcopy(v) |
return params_dict |
def validate(self): |
"""Validate the parameters consistency based on the restrictions. |
This method validates the internal consistency using the pre-defined list of |
restrictions. A restriction is defined as a string which specfiies a binary |
operation. The supported binary operations are {'==', '!=', '<', '<=', '>', |
'>='}. Note that the meaning of these operators are consistent with the |
underlying Python immplementation. Users should make sure the define |
restrictions on their type make sense. |
For example, for a ParamsDict like the following |
``` |
a: |
a1: 1 |
a2: 2 |
b: |
bb: |
bb1: 10 |
bb2: 20 |
ccc: |
a1: 1 |
a3: 3 |
``` |
one can define two restrictions like this |
['a.a1 == b.ccc.a1', 'a.a2 <= b.bb.bb2'] |
What it enforces are: |
- a.a1 = 1 == b.ccc.a1 = 2 |
- a.a2 = 2 <= b.bb.bb2 = 20 |
Raises: |
KeyError: if any of the following happens |
(1) any of parameters in any of restrictions is not defined in |
ParamsDict, |
(2) any inconsistency violating the restriction is found. |
ValueError: if the restriction defined in the string is not supported. |
""" |
def _get_kv(dotted_string, params_dict): |
"""Get keys and values indicated by dotted_string.""" |
if _CONST_VALUE_RE.match(dotted_string) is not None: |
const_str = dotted_string |
if const_str == 'None': |
constant = None |
else: |
constant = float(const_str) |
return None, constant |
else: |
tokenized_params = dotted_string.split('.') |
v = params_dict |
for t in tokenized_params: |
v = v[t] |
return tokenized_params[-1], v |
def _get_kvs(tokens, params_dict): |
if len(tokens) != 2: |
raise ValueError('Only support binary relation in restriction.') |
stripped_tokens = [t.strip() for t in tokens] |
left_k, left_v = _get_kv(stripped_tokens[0], params_dict) |
right_k, right_v = _get_kv(stripped_tokens[1], params_dict) |
return left_k, left_v, right_k, right_v |
params_dict = self.as_dict() |
for restriction in self._restrictions: |
if '==' in restriction: |
tokens = restriction.split('==') |
_, left_v, _, right_v = _get_kvs(tokens, params_dict) |
if left_v != right_v: |
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' |
.format(tokens[0], tokens[1])) |
elif '!=' in restriction: |
tokens = restriction.split('!=') |
_, left_v, _, right_v = _get_kvs(tokens, params_dict) |
if left_v == right_v: |
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' |
.format(tokens[0], tokens[1])) |
elif '<' in restriction: |
tokens = restriction.split('<') |
_, left_v, _, right_v = _get_kvs(tokens, params_dict) |
if left_v >= right_v: |
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' |
.format(tokens[0], tokens[1])) |
elif '<=' in restriction: |
tokens = restriction.split('<=') |
_, left_v, _, right_v = _get_kvs(tokens, params_dict) |
if left_v > right_v: |
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' |
.format(tokens[0], tokens[1])) |
elif '>' in restriction: |
tokens = restriction.split('>') |
_, left_v, _, right_v = _get_kvs(tokens, params_dict) |
if left_v <= right_v: |
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' |
.format(tokens[0], tokens[1])) |
elif '>=' in restriction: |
tokens = restriction.split('>=') |
_, left_v, _, right_v = _get_kvs(tokens, params_dict) |
if left_v < right_v: |
raise KeyError('Found inconsistncy between key `{}` and key `{}`.' |
.format(tokens[0], tokens[1])) |
else: |
raise ValueError('Unsupported relation in restriction.') |
def read_yaml_to_params_dict(file_path): |
"""Reads a YAML file to a ParamsDict.""" |
with tf.io.gfile.GFile(file_path, 'r') as f: |
params_dict = yaml.load(f) |
return ParamsDict(params_dict) |
def save_params_dict_to_yaml(params, file_path): |
"""Saves the input ParamsDict to a YAML file.""" |
with tf.io.gfile.GFile(file_path, 'w') as f: |
def _my_list_rep(dumper, data): |
return dumper.represent_sequence( |
u'tag:yaml.org,2002:seq', data, flow_style=True) |
yaml.add_representer(list, _my_list_rep) |
yaml.dump(params.as_dict(), f, default_flow_style=False) |
def nested_csv_str_to_json_str(csv_str): |
"""Converts a nested (using '.') comma-separated k=v string to a JSON string. |
Converts a comma-separated string of key/value pairs that supports |
nesting of keys to a JSON string. Nesting is implemented using |
'.' between levels for a given key. |
Spacing between commas and = is supported (e.g. there is no difference between |
"a=1,b=2", "a = 1, b = 2", or "a=1, b=2") but there should be no spaces before |
keys or after values (e.g. " a=1,b=2" and "a=1,b=2 " are not supported). |
Note that this will only support values supported by CSV, meaning |
values such as nested lists (e.g. "a=[[1,2,3],[4,5,6]]") are not |
supported. Strings are supported as well, e.g. "a='hello'". |
An example conversion would be: |
"a=1, b=2, c.a=2, c.b=3, d.a.a=5" |
to |
"{ a: 1, b : 2, c: {a : 2, b : 3}, d: {a: {a : 5}}}" |
Args: |
csv_str: the comma separated string. |
Returns: |
the converted JSON string. |
Raises: |
ValueError: If csv_str is not in a comma separated string or |
if the string is formatted incorrectly. |
""" |
if not csv_str: |
return '' |
formatted_entries = [] |
nested_map = collections.defaultdict(list) |
pos = 0 |
while pos < len(csv_str): |
m = _PARAM_RE.match(csv_str, pos) |
if not m: |
raise ValueError('Malformed hyperparameter value while parsing ' |
'CSV string: %s' % csv_str[pos:]) |
pos = m.end() |
m_dict = m.groupdict() |
name = m_dict['name'] |
v = m_dict['val'] |
if re.match(r'(?=[^\"\'])(?=[gs://])', v): |
v = '\'{}\''.format(v) |
name_nested = name.split('.') |
if len(name_nested) > 1: |
grouping = name_nested[0] |
value = '.'.join(name_nested[1:]) + '=' + v |
nested_map[grouping].append(value) |
else: |
formatted_entries.append('%s : %s' % (name, v)) |
for grouping, value in nested_map.items(): |
value = ','.join(value) |
value = nested_csv_str_to_json_str(value) |
formatted_entries.append('%s : %s' % (grouping, value)) |
return '{' + ', '.join(formatted_entries) + '}' |
def override_params_dict(params, dict_or_string_or_yaml_file, is_strict): |
"""Override a given ParamsDict using a dict, JSON/YAML/CSV string or YAML file. |
The logic of the function is outlined below: |
1. Test that the input is a dict. If not, proceed to 2. |
2. Tests that the input is a string. If not, raise unknown ValueError |
2.1. Test if the string is in a CSV format. If so, parse. |
If not, proceed to 2.2. |
2.2. Try loading the string as a YAML/JSON. If successful, parse to |
dict and use it to override. If not, proceed to 2.3. |
2.3. Try using the string as a file path and load the YAML file. |
Args: |
params: a ParamsDict object to be overridden. |
dict_or_string_or_yaml_file: a Python dict, JSON/YAML/CSV string or |
path to a YAML file specifying the parameters to be overridden. |
is_strict: a boolean specifying whether override is strict or not. |
Returns: |
params: the overridden ParamsDict object. |
Raises: |
ValueError: if failed to override the parameters. |
""" |
if not dict_or_string_or_yaml_file: |
return params |
if isinstance(dict_or_string_or_yaml_file, dict): |
params.override(dict_or_string_or_yaml_file, is_strict) |
elif isinstance(dict_or_string_or_yaml_file, six.string_types): |
try: |
dict_or_string_or_yaml_file = ( |
nested_csv_str_to_json_str(dict_or_string_or_yaml_file)) |
except ValueError: |
pass |
params_dict = yaml.load(dict_or_string_or_yaml_file) |
if isinstance(params_dict, dict): |
params.override(params_dict, is_strict) |
else: |
with tf.io.gfile.GFile(dict_or_string_or_yaml_file) as f: |
params.override(yaml.load(f), is_strict) |
else: |
raise ValueError('Unknown input type to parse.') |
return params |