|
import glob |
|
import os |
|
import pickle |
|
|
|
import torch |
|
|
|
|
|
def _remove_files(files): |
|
for f in files: |
|
return os.remove(f) |
|
|
|
|
|
def assert_dir_exits(path): |
|
if not os.path.exists(path): |
|
os.makedirs(path) |
|
|
|
|
|
def save_model(model, epoch, out_path): |
|
assert_dir_exits(out_path) |
|
model_file = out_path + str(epoch) + '.pth' |
|
chk_files = glob.glob(out_path + '*.pth') |
|
_remove_files(chk_files) |
|
torch.save(model.state_dict(), model_file) |
|
print('model saved for epoch: {}'.format(epoch)) |
|
return model_file |
|
|
|
|
|
def save_objects(obj, epoch, out_path): |
|
assert_dir_exits(out_path) |
|
dat_files = glob.glob(out_path + '*.dat') |
|
_remove_files(dat_files) |
|
|
|
with open(out_path + str(epoch) + '.dat', 'wb') as output: |
|
pickle.dump(obj, output) |
|
|
|
print('objects saved for epoch: {}'.format(epoch)) |
|
|
|
|
|
def restore_model(model, out_path): |
|
chk_file = glob.glob(out_path + '*.pth') |
|
|
|
if chk_file: |
|
chk_file = str(chk_file[0]) |
|
print('found modeL {}, restoring'.format(chk_file)) |
|
model.load_state_dict(torch.load(chk_file)) |
|
else: |
|
print('Model not found, using untrained model') |
|
return model |
|
|
|
|
|
def restore_objects(out_path, default): |
|
data_file = glob.glob(out_path + '*.dat') |
|
if data_file: |
|
data_file = str(data_file[0]) |
|
print('found data {}, restoring'.format(data_file)) |
|
with open(data_file, 'rb') as input_: |
|
obj = pickle.load(input_) |
|
|
|
return obj |
|
else: |
|
return default |
|
|