SuperFeatures / how /utils /io_helpers.py
YannisK's picture
temp state
32408ed
"""Helper functions related to io"""
import os.path
import sys
import shutil
import urllib.request
from pathlib import Path
import yaml
import torch
def progress(iterable, *, size=None, print_freq=1, handle=sys.stdout):
"""Generator wrapping an iterable to print progress"""
for i, element in enumerate(iterable):
yield element
if i == 0 or (i+1) % print_freq == 0 or (i+1) == size:
if size:
handle.write(f'\r>>>> {i+1}/{size} done...')
else:
handle.write(f'\r>>>> {i+1} done...')
handle.write("\n")
# Params
def load_params(path):
"""Return loaded parameters from a yaml file"""
with open(path, "r") as handle:
content = yaml.safe_load(handle)
return load_nested_templates(content, os.path.dirname(path))
def save_params(path, params):
"""Save given parameters to a yaml file"""
with open(path, "w") as handle:
yaml.safe_dump(params, handle, default_flow_style=False)
def load_nested_templates(params, root_path):
"""Find keys '__template__' in nested dictionary and replace corresponding value with loaded
yaml file"""
if not isinstance(params, dict):
return params
if "__template__" in params:
template_path = os.path.expanduser(params.pop("__template__"))
path = os.path.join(root_path, template_path)
root_path = os.path.dirname(path)
# Treat template as defaults
params = dict_deep_overlay(load_params(path), params)
for key, value in params.items():
params[key] = load_nested_templates(value, root_path)
return params
def dict_deep_overlay(defaults, params):
"""If defaults and params are both dictionaries, perform deep overlay (use params value for
keys defined in params), otherwise use defaults value"""
if isinstance(defaults, dict) and isinstance(params, dict):
for key in params:
defaults[key] = dict_deep_overlay(defaults.get(key, None), params[key])
return defaults
return params
def dict_deep_set(dct, key, value):
"""Set key to value for a nested dictionary where the key is a sequence (e.g. list)"""
if len(key) == 1:
dct[key[0]] = value
return
if not isinstance(dct[key[0]], dict) or key[0] not in dct:
dct[key[0]] = {}
dict_deep_set(dct[key[0]], key[1:], value)
# Download
def download_files(names, root_path, base_url, logfunc=None):
"""Download file names from given url to given directory path. If logfunc given, use it to log
status."""
root_path = Path(root_path)
for name in names:
path = root_path / name
if path.exists():
continue
if logfunc:
logfunc(f"Downloading file '{name}'")
path.parent.mkdir(parents=True, exist_ok=True)
urllib.request.urlretrieve(base_url + name, path)
# Checkpoints
def save_checkpoint(state, is_best, keep_epoch, directory):
"""Save state dictionary to the directory providing whether the corresponding epoch is the best
and whether to keep it anyway"""
filename = os.path.join(directory, 'model_epoch%d.pth' % state['epoch'])
filename_best = os.path.join(directory, 'model_best.pth')
if is_best and keep_epoch:
torch.save(state, filename)
shutil.copyfile(filename, filename_best)
elif is_best or keep_epoch:
torch.save(state, filename_best if is_best else filename)