|
import sys |
|
|
|
import torch |
|
import yaml |
|
|
|
|
|
def load_yaml_config(path): |
|
with open(path) as f: |
|
config = yaml.full_load(f) |
|
return config |
|
|
|
|
|
def save_config_to_yaml(config, path): |
|
assert path.endswith('.yaml') |
|
with open(path, 'w') as f: |
|
f.write(yaml.dump(config)) |
|
f.close() |
|
|
|
|
|
def write_args(args, path): |
|
args_dict = dict((name, getattr(args, name)) for name in dir(args) |
|
if not name.startswith('_')) |
|
with open(path, 'a') as args_file: |
|
args_file.write('==> torch version: {}\n'.format(torch.__version__)) |
|
args_file.write( |
|
'==> cudnn version: {}\n'.format(torch.backends.cudnn.version())) |
|
args_file.write('==> Cmd:\n') |
|
args_file.write(str(sys.argv)) |
|
args_file.write('\n==> args:\n') |
|
for k, v in sorted(args_dict.items()): |
|
args_file.write(' %s: %s\n' % (str(k), str(v))) |
|
args_file.close() |
|
|