sakharamg's picture
Uploading all files
158b61b
import os
import torch
from collections import deque
from onmt.utils.logging import logger
from copy import deepcopy
def build_model_saver(model_opt, opt, model, fields, optim):
# _check_save_model_path
save_model_path = os.path.abspath(opt.save_model)
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)
model_saver = ModelSaver(opt.save_model,
model,
model_opt,
fields,
optim,
opt.keep_checkpoint)
return model_saver
def load_checkpoint(ckpt_path):
"""Load checkpoint from `ckpt_path` if any else return `None`."""
checkpoint = None
if ckpt_path:
logger.info('Loading checkpoint from %s' % ckpt_path)
checkpoint = torch.load(ckpt_path,
map_location=lambda storage, loc: storage)
return checkpoint
class ModelSaverBase(object):
"""Base class for model saving operations
Inherited classes must implement private methods:
* `_save`
* `_rm_checkpoint
"""
def __init__(self, base_path, model, model_opt, fields, optim,
keep_checkpoint=-1):
self.base_path = base_path
self.model = model
self.model_opt = model_opt
self.fields = fields
self.optim = optim
self.last_saved_step = None
self.keep_checkpoint = keep_checkpoint
if keep_checkpoint > 0:
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
def save(self, step, moving_average=None):
"""Main entry point for model saver
It wraps the `_save` method with checks and apply `keep_checkpoint`
related logic
"""
if self.keep_checkpoint == 0 or step == self.last_saved_step:
return
save_model = self.model
if moving_average:
model_params_data = []
for avg, param in zip(moving_average, save_model.parameters()):
model_params_data.append(param.data)
param.data = avg.data
chkpt, chkpt_name = self._save(step, save_model)
self.last_saved_step = step
if moving_average:
for param_data, param in zip(model_params_data,
save_model.parameters()):
param.data = param_data
if self.keep_checkpoint > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
self._rm_checkpoint(todel)
self.checkpoint_queue.append(chkpt_name)
def _save(self, step, model):
"""Save a resumable checkpoint.
Args:
step (int): step number
model (nn.Module): torch model to save
Returns:
(object, str):
* checkpoint: the saved object
* checkpoint_name: name (or path) of the saved checkpoint
"""
raise NotImplementedError()
def _rm_checkpoint(self, name):
"""Remove a checkpoint
Args:
name(str): name that indentifies the checkpoint
(it may be a filepath)
"""
raise NotImplementedError()
class ModelSaver(ModelSaverBase):
"""Simple model saver to filesystem"""
def _save(self, step, model):
model_state_dict = model.state_dict()
model_state_dict = {k: v for k, v in model_state_dict.items()
if 'generator' not in k}
generator_state_dict = model.generator.state_dict()
# NOTE: We need to trim the vocab to remove any unk tokens that
# were not originally here.
vocab = deepcopy(self.fields)
for side in ["src", "tgt"]:
keys_to_pop = []
if hasattr(vocab[side], "fields"):
unk_token = vocab[side].fields[0][1].vocab.itos[0]
for key, value in vocab[side].fields[0][1].vocab.stoi.items():
if value == 0 and key != unk_token:
keys_to_pop.append(key)
for key in keys_to_pop:
vocab[side].fields[0][1].vocab.stoi.pop(key, None)
checkpoint = {
'model': model_state_dict,
'generator': generator_state_dict,
'vocab': vocab,
'opt': self.model_opt,
'optim': self.optim.state_dict(),
}
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)
torch.save(checkpoint, checkpoint_path)
return checkpoint, checkpoint_path
def _rm_checkpoint(self, name):
if os.path.exists(name):
os.remove(name)