|
|
|
|
|
|
|
|
|
import os |
|
|
|
import yaml |
|
|
|
__all__ = [ |
|
"parse_with_yaml", |
|
"parse_unknown_args", |
|
"partial_update_config", |
|
"resolve_and_load_config", |
|
"load_config", |
|
"dump_config", |
|
] |
|
|
|
|
|
def parse_with_yaml(config_str: str) -> str or dict: |
|
try: |
|
|
|
if "{" in config_str and "}" in config_str and ":" in config_str: |
|
out_str = config_str.replace(":", ": ") |
|
else: |
|
out_str = config_str |
|
return yaml.safe_load(out_str) |
|
except ValueError: |
|
|
|
return config_str |
|
|
|
|
|
def parse_unknown_args(unknown: list) -> dict: |
|
"""Parse unknown args.""" |
|
index = 0 |
|
parsed_dict = {} |
|
while index < len(unknown): |
|
key, val = unknown[index], unknown[index + 1] |
|
index += 2 |
|
if not key.startswith("--"): |
|
continue |
|
key = key[2:] |
|
|
|
|
|
|
|
if "." in key: |
|
|
|
keys = key.split(".") |
|
dict_to_update = parsed_dict |
|
for key in keys[:-1]: |
|
if not ( |
|
key in dict_to_update and isinstance(dict_to_update[key], dict) |
|
): |
|
dict_to_update[key] = {} |
|
dict_to_update = dict_to_update[key] |
|
dict_to_update[keys[-1]] = parse_with_yaml( |
|
val |
|
) |
|
else: |
|
parsed_dict[key] = parse_with_yaml(val) |
|
return parsed_dict |
|
|
|
|
|
def partial_update_config(config: dict, partial_config: dict) -> dict: |
|
for key in partial_config: |
|
if ( |
|
key in config |
|
and isinstance(partial_config[key], dict) |
|
and isinstance(config[key], dict) |
|
): |
|
partial_update_config(config[key], partial_config[key]) |
|
else: |
|
config[key] = partial_config[key] |
|
return config |
|
|
|
|
|
def resolve_and_load_config(path: str, config_name="config.yaml") -> dict: |
|
path = os.path.realpath(os.path.expanduser(path)) |
|
if os.path.isdir(path): |
|
config_path = os.path.join(path, config_name) |
|
else: |
|
config_path = path |
|
if os.path.isfile(config_path): |
|
pass |
|
else: |
|
raise Exception(f"Cannot find a valid config at {path}") |
|
config = load_config(config_path) |
|
return config |
|
|
|
|
|
class SafeLoaderWithTuple(yaml.SafeLoader): |
|
"""A yaml safe loader with python tuple loading capabilities.""" |
|
|
|
def construct_python_tuple(self, node): |
|
return tuple(self.construct_sequence(node)) |
|
|
|
|
|
SafeLoaderWithTuple.add_constructor( |
|
"tag:yaml.org,2002:python/tuple", SafeLoaderWithTuple.construct_python_tuple |
|
) |
|
|
|
|
|
def load_config(filename: str) -> dict: |
|
"""Load a yaml file.""" |
|
filename = os.path.realpath(os.path.expanduser(filename)) |
|
return yaml.load(open(filename), Loader=SafeLoaderWithTuple) |
|
|
|
|
|
def dump_config(config: dict, filename: str) -> None: |
|
"""Dump a config file""" |
|
filename = os.path.realpath(os.path.expanduser(filename)) |
|
yaml.dump(config, open(filename, "w"), sort_keys=False) |
|
|