File size: 1,535 Bytes
f831146 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import glob
import os
import pickle
import torch
def _remove_files(files):
for f in files:
return os.remove(f)
def assert_dir_exits(path):
if not os.path.exists(path):
os.makedirs(path)
def save_model(model, epoch, out_path):
assert_dir_exits(out_path)
model_file = out_path + str(epoch) + '.pth'
chk_files = glob.glob(out_path + '*.pth')
_remove_files(chk_files)
torch.save(model.state_dict(), model_file)
print('model saved for epoch: {}'.format(epoch))
return model_file
def save_objects(obj, epoch, out_path):
assert_dir_exits(out_path)
dat_files = glob.glob(out_path + '*.dat')
_remove_files(dat_files)
# object should be tuple
with open(out_path + str(epoch) + '.dat', 'wb') as output:
pickle.dump(obj, output)
print('objects saved for epoch: {}'.format(epoch))
def restore_model(model, out_path):
chk_file = glob.glob(out_path + '*.pth')
if chk_file:
chk_file = str(chk_file[0])
print('found modeL {}, restoring'.format(chk_file))
model.load_state_dict(torch.load(chk_file))
else:
print('Model not found, using untrained model')
return model
def restore_objects(out_path, default):
data_file = glob.glob(out_path + '*.dat')
if data_file:
data_file = str(data_file[0])
print('found data {}, restoring'.format(data_file))
with open(data_file, 'rb') as input_:
obj = pickle.load(input_)
return obj
else:
return default
|