import glob |
import os |
import pickle |
import torch |
def _remove_files(files): |
for f in files: |
return os.remove(f) |
def assert_dir_exits(path): |
if not os.path.exists(path): |
os.makedirs(path) |
def save_model(model, epoch, out_path): |
assert_dir_exits(out_path) |
model_file = out_path + str(epoch) + '.pth' |
chk_files = glob.glob(out_path + '*.pth') |
_remove_files(chk_files) |
torch.save(model.state_dict(), model_file) |
print('model saved for epoch: {}'.format(epoch)) |
return model_file |
def save_objects(obj, epoch, out_path): |
assert_dir_exits(out_path) |
dat_files = glob.glob(out_path + '*.dat') |
_remove_files(dat_files) |
with open(out_path + str(epoch) + '.dat', 'wb') as output: |
pickle.dump(obj, output) |
print('objects saved for epoch: {}'.format(epoch)) |
def restore_model(model, out_path): |
chk_file = glob.glob(out_path + '*.pth') |
if chk_file: |
chk_file = str(chk_file[0]) |
print('found modeL {}, restoring'.format(chk_file)) |
model.load_state_dict(torch.load(chk_file)) |
else: |
print('Model not found, using untrained model') |
return model |
def restore_objects(out_path, default): |
data_file = glob.glob(out_path + '*.dat') |
if data_file: |
data_file = str(data_file[0]) |
print('found data {}, restoring'.format(data_file)) |
with open(data_file, 'rb') as input_: |
obj = pickle.load(input_) |
return obj |
else: |
return default |