NMT-LaVi / utils /save.py
hieungo1410's picture
'add'
8cb4f3b
import torch
import os, re, io
import json
import dill as pickle
from shutil import copy2 as copy
MODEL_EXTENSION = ".pkl"
MODEL_FILE_FORMAT = "{:s}_{:d}{:s}" # model_prefix, epoch and extension
BEST_MODEL_FILE = ".model_score.txt"
MODEL_SERVE_FILE = ".serve.txt"
VOCAB_FILE_FORMAT = "{:s}{:s}{:s}"
def save_model_name(name, path, serve_config_path=MODEL_SERVE_FILE):
with io.open(os.path.join(path, serve_config_path), "w", encoding="utf-8") as serve_config_file:
serve_config_file.write(name)
def save_vocab_to_path(path, language_tuple, fields, name_prefix="vocab", check_saved_vocab=True):
src_field, trg_field = fields
src_ext, trg_ext = language_tuple
src_vocab_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, src_ext, MODEL_EXTENSION))
trg_vocab_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, trg_ext, MODEL_EXTENSION))
if(check_saved_vocab and os.path.isfile(src_vocab_path) and os.path.isfile(trg_vocab_path)):# do nothing if already exist
return
with io.open(src_vocab_path , "wb") as src_vocab_file:
pickle.dump(src_field.vocab, src_vocab_file)
with io.open(trg_vocab_path , "wb") as trg_vocab_file:
pickle.dump(trg_field.vocab, trg_vocab_file)
def load_vocab_from_path(path, language_tuple, fields, name_prefix="vocab"):
"""Load the vocabulary from path into respective fields. If files doesn't exist, return False; if loaded properly, return True"""
src_field, trg_field = fields
src_ext, trg_ext = language_tuple
src_vocab_file_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, src_ext, MODEL_EXTENSION))
trg_vocab_file_path = os.path.join(path, VOCAB_FILE_FORMAT.format(name_prefix, trg_ext, MODEL_EXTENSION))
if(not os.path.isfile(src_vocab_file_path) or not os.path.isfile(trg_vocab_file_path)):
# the vocab file wasn't dumped, return False
return False
with io.open(src_vocab_file_path, "rb") as src_vocab_file, io.open(trg_vocab_file_path, "rb") as trg_vocab_file:
src_vocab = pickle.load(src_vocab_file)
src_field.vocab = src_vocab
trg_vocab = pickle.load(trg_vocab_file)
trg_field.vocab = trg_vocab
return True
def save_model_to_path(model, path, name_prefix="model", checkpoint_idx=0, save_vocab=True):
save_path = os.path.join(path, MODEL_FILE_FORMAT.format(name_prefix, checkpoint_idx, MODEL_EXTENSION))
torch.save(model.state_dict(), save_path)
if(save_vocab):
save_vocab_to_path(path, model.loader._language_tuple, model.fields)
def load_model_from_path(model, path, name_prefix="model", checkpoint_idx=0):
# do not load vocab here, as the vocab structure will be decided in model.loader.build_vocab
save_path = os.path.join(path, MODEL_FILE_FORMAT.format(name_prefix, checkpoint_idx, MODEL_EXTENSION))
model.load_state_dict(torch.load(save_path))
def load_model(model, model_path):
model.load_state_dict(torch.load(model_path))
def check_model_in_path(path, name_prefix="model", return_all_checkpoint=False):
model_re = re.compile(r"{:s}_(\d+){:s}".format(name_prefix, MODEL_EXTENSION))
if(not os.path.isdir(path)):
return 0
matches = [re.match(model_re, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]
# print(matches)
indices = sorted([int(m.group(1)) for m in matches if m is not None])
if(return_all_checkpoint):
return indices
elif(len(indices) == 0):
return 0
else:
return indices[-1]
def save_and_clear_model(model, path, name_prefix="model", checkpoint_idx=0, maximum_saved_model=5):
"""Keep only last n models when saving. Explicitly save the model regardless of its checkpoint index, e.g if checkpoint_idx=0 & model 3 4 5 6 7 is in path, it will remove 3 and save 0 instead."""
indices = check_model_in_path(path, name_prefix=name_prefix, return_all_checkpoint=True)
if(maximum_saved_model <= len(indices)):
# remove models until n-1 models are left
for i in indices[:-(maximum_saved_model-1)]:
os.remove(os.path.join(path, MODEL_FILE_FORMAT.format(name_prefix, i, MODEL_EXTENSION)))
# perform save as normal
save_model_to_path(model, path, name_prefix=name_prefix, checkpoint_idx=checkpoint_idx)
def load_model_score(path, score_file=BEST_MODEL_FILE):
"""Load the model score as a list from a json dump, organized from best to worst."""
score_file_path = os.path.join(path, score_file)
if(not os.path.isfile(score_file_path)):
return []
with io.open(score_file_path, "r") as jf:
return json.load(jf)
def write_model_score(path, score_obj, score_file=BEST_MODEL_FILE):
with io.open(os.path.join(path, score_file), "w") as jf:
json.dump(score_obj, jf)
def save_model_best_to_path(model, path, score_obj, model_metric, best_model_prefix="best_model", maximum_saved_model=5, score_file=BEST_MODEL_FILE, save_after_update=True):
worst_score = score_obj[-1] if len(score_obj) > 0 else -1.0
if(model_metric > worst_score):
# perform update, overriding a slot or create new if needed
insert_loc = next((idx for idx, score in enumerate(score_obj) if model_metric > score), 0)
# every model below it, up to {maximum_saved_model}, will be moved down an index
for i in range(insert_loc, min(len(score_obj), maximum_saved_model)-1): # -1, due to the models are copied up to +1
old_loc = save_path = os.path.join(path, MODEL_FILE_FORMAT.format(best_model_prefix, i, MODEL_EXTENSION))
new_loc = save_path = os.path.join(path, MODEL_FILE_FORMAT.format(best_model_prefix, i+1, MODEL_EXTENSION))
copy(old_loc, new_loc)
# save the model to the selected loc
save_model_to_path(model, path, name_prefix=best_model_prefix, checkpoint_idx=insert_loc)
# update the score obj
score_obj.insert(insert_loc, model_metric)
score_obj = score_obj[:maximum_saved_model]
# also update in disk, if enabled
if(save_after_update):
write_model_score(path, score_obj, score_file=score_file)
# after routine had been done, return the obj
return score_obj