import sys | |
import torch | |
import yaml | |
def load_yaml_config(path): | |
with open(path) as f: | |
config = yaml.full_load(f) | |
return config | |
def save_config_to_yaml(config, path): | |
assert path.endswith(".yaml") | |
with open(path, "w") as f: | |
f.write(yaml.dump(config)) | |
f.close() | |
def write_args(args, path): | |
args_dict = dict( | |
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_") | |
) | |
with open(path, "a") as args_file: | |
args_file.write("==> torch version: {}\n".format(torch.__version__)) | |
args_file.write( | |
"==> cudnn version: {}\n".format(torch.backends.cudnn.version()) | |
) | |
args_file.write("==> Cmd:\n") | |
args_file.write(str(sys.argv)) | |
args_file.write("\n==> args:\n") | |
for k, v in sorted(args_dict.items()): | |
args_file.write(" %s: %s\n" % (str(k), str(v))) | |
args_file.close() | |