Spaces:
Sleeping
Sleeping
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
# Copy from fvcore | |
import logging | |
import os | |
from typing import Any | |
import yaml | |
from yacs.config import CfgNode as _CfgNode | |
import io as PathManager | |
BASE_KEY = "_BASE_" | |
class CfgNode(_CfgNode): | |
""" | |
Our own extended version of :class:`yacs.config.CfgNode`. | |
It contains the following extra features: | |
1. The :meth:`merge_from_file` method supports the "_BASE_" key, | |
which allows the new CfgNode to inherit all the attributes from the | |
base configuration file. | |
2. Keys that start with "COMPUTED_" are treated as insertion-only | |
"computed" attributes. They can be inserted regardless of whether | |
the CfgNode is frozen or not. | |
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate | |
expressions in config. See examples in | |
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types | |
Note that this may lead to arbitrary code execution: you must not | |
load a config file from untrusted sources before manually inspecting | |
the content of the file. | |
""" | |
def load_yaml_with_base(filename, allow_unsafe = False): | |
""" | |
Just like `yaml.load(open(filename))`, but inherit attributes from its | |
`_BASE_`. | |
Args: | |
filename (str): the file name of the current config. Will be used to | |
find the base config file. | |
allow_unsafe (bool): whether to allow loading the config file with | |
`yaml.unsafe_load`. | |
Returns: | |
(dict): the loaded yaml | |
""" | |
with PathManager.open(filename, "r") as f: | |
try: | |
cfg = yaml.safe_load(f) | |
except yaml.constructor.ConstructorError: | |
if not allow_unsafe: | |
raise | |
logger = logging.getLogger(__name__) | |
logger.warning( | |
"Loading config {} with yaml.unsafe_load. Your machine may " | |
"be at risk if the file contains malicious content.".format( | |
filename | |
) | |
) | |
f.close() | |
with open(filename, "r") as f: | |
cfg = yaml.unsafe_load(f) | |
def merge_a_into_b(a, b): | |
# merge dict a into dict b. values in a will overwrite b. | |
for k, v in a.items(): | |
if isinstance(v, dict) and k in b: | |
assert isinstance( | |
b[k], dict | |
), "Cannot inherit key '{}' from base!".format(k) | |
merge_a_into_b(v, b[k]) | |
else: | |
b[k] = v | |
if BASE_KEY in cfg: | |
base_cfg_file = cfg[BASE_KEY] | |
if base_cfg_file.startswith("~"): | |
base_cfg_file = os.path.expanduser(base_cfg_file) | |
if not any( | |
map(base_cfg_file.startswith, ["/", "https://", "http://"]) | |
): | |
# the path to base cfg is relative to the config file itself. | |
base_cfg_file = os.path.join( | |
os.path.dirname(filename), base_cfg_file | |
) | |
base_cfg = CfgNode.load_yaml_with_base( | |
base_cfg_file, allow_unsafe=allow_unsafe | |
) | |
del cfg[BASE_KEY] | |
merge_a_into_b(cfg, base_cfg) | |
return base_cfg | |
return cfg | |
def merge_from_file(self, cfg_filename, allow_unsafe = False): | |
""" | |
Merge configs from a given yaml file. | |
Args: | |
cfg_filename: the file name of the yaml config. | |
allow_unsafe: whether to allow loading the config file with | |
`yaml.unsafe_load`. | |
""" | |
loaded_cfg = CfgNode.load_yaml_with_base( | |
cfg_filename, allow_unsafe=allow_unsafe | |
) | |
loaded_cfg = type(self)(loaded_cfg) | |
self.merge_from_other_cfg(loaded_cfg) | |
# Forward the following calls to base, but with a check on the BASE_KEY. | |
def merge_from_other_cfg(self, cfg_other): | |
""" | |
Args: | |
cfg_other (CfgNode): configs to merge from. | |
""" | |
assert ( | |
BASE_KEY not in cfg_other | |
), "The reserved key '{}' can only be used in files!".format(BASE_KEY) | |
return super().merge_from_other_cfg(cfg_other) | |
def merge_from_list(self, cfg_list): | |
""" | |
Args: | |
cfg_list (list): list of configs to merge from. | |
""" | |
keys = set(cfg_list[0::2]) | |
assert ( | |
BASE_KEY not in keys | |
), "The reserved key '{}' can only be used in files!".format(BASE_KEY) | |
return super().merge_from_list(cfg_list) | |
def __setattr__(self, name, val): | |
if name.startswith("COMPUTED_"): | |
if name in self: | |
old_val = self[name] | |
if old_val == val: | |
return | |
raise KeyError( | |
"Computed attributed '{}' already exists " | |
"with a different value! old={}, new={}.".format( | |
name, old_val, val | |
) | |
) | |
self[name] = val | |
else: | |
super().__setattr__(name, val) | |
if __name__ == '__main__': | |
cfg = CfgNode.load_yaml_with_base('configs/updown_long.yml') | |
print(cfg) |