Spaces:
Sleeping
Sleeping
File size: 5,412 Bytes
9bf9e42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Copy from fvcore
import logging
import os
from typing import Any
import yaml
from yacs.config import CfgNode as _CfgNode
import io as PathManager
BASE_KEY = "_BASE_"
class CfgNode(_CfgNode):
"""
Our own extended version of :class:`yacs.config.CfgNode`.
It contains the following extra features:
1. The :meth:`merge_from_file` method supports the "_BASE_" key,
which allows the new CfgNode to inherit all the attributes from the
base configuration file.
2. Keys that start with "COMPUTED_" are treated as insertion-only
"computed" attributes. They can be inserted regardless of whether
the CfgNode is frozen or not.
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
expressions in config. See examples in
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
Note that this may lead to arbitrary code execution: you must not
load a config file from untrusted sources before manually inspecting
the content of the file.
"""
@staticmethod
def load_yaml_with_base(filename, allow_unsafe = False):
"""
Just like `yaml.load(open(filename))`, but inherit attributes from its
`_BASE_`.
Args:
filename (str): the file name of the current config. Will be used to
find the base config file.
allow_unsafe (bool): whether to allow loading the config file with
`yaml.unsafe_load`.
Returns:
(dict): the loaded yaml
"""
with PathManager.open(filename, "r") as f:
try:
cfg = yaml.safe_load(f)
except yaml.constructor.ConstructorError:
if not allow_unsafe:
raise
logger = logging.getLogger(__name__)
logger.warning(
"Loading config {} with yaml.unsafe_load. Your machine may "
"be at risk if the file contains malicious content.".format(
filename
)
)
f.close()
with open(filename, "r") as f:
cfg = yaml.unsafe_load(f)
def merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
if BASE_KEY in cfg:
base_cfg_file = cfg[BASE_KEY]
if base_cfg_file.startswith("~"):
base_cfg_file = os.path.expanduser(base_cfg_file)
if not any(
map(base_cfg_file.startswith, ["/", "https://", "http://"])
):
# the path to base cfg is relative to the config file itself.
base_cfg_file = os.path.join(
os.path.dirname(filename), base_cfg_file
)
base_cfg = CfgNode.load_yaml_with_base(
base_cfg_file, allow_unsafe=allow_unsafe
)
del cfg[BASE_KEY]
merge_a_into_b(cfg, base_cfg)
return base_cfg
return cfg
def merge_from_file(self, cfg_filename, allow_unsafe = False):
"""
Merge configs from a given yaml file.
Args:
cfg_filename: the file name of the yaml config.
allow_unsafe: whether to allow loading the config file with
`yaml.unsafe_load`.
"""
loaded_cfg = CfgNode.load_yaml_with_base(
cfg_filename, allow_unsafe=allow_unsafe
)
loaded_cfg = type(self)(loaded_cfg)
self.merge_from_other_cfg(loaded_cfg)
# Forward the following calls to base, but with a check on the BASE_KEY.
def merge_from_other_cfg(self, cfg_other):
"""
Args:
cfg_other (CfgNode): configs to merge from.
"""
assert (
BASE_KEY not in cfg_other
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
return super().merge_from_other_cfg(cfg_other)
def merge_from_list(self, cfg_list):
"""
Args:
cfg_list (list): list of configs to merge from.
"""
keys = set(cfg_list[0::2])
assert (
BASE_KEY not in keys
), "The reserved key '{}' can only be used in files!".format(BASE_KEY)
return super().merge_from_list(cfg_list)
def __setattr__(self, name, val):
if name.startswith("COMPUTED_"):
if name in self:
old_val = self[name]
if old_val == val:
return
raise KeyError(
"Computed attributed '{}' already exists "
"with a different value! old={}, new={}.".format(
name, old_val, val
)
)
self[name] = val
else:
super().__setattr__(name, val)
if __name__ == '__main__':
cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml')
print(cfg) |