File size: 4,820 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import torch
from collections import deque
from onmt.utils.logging import logger
from copy import deepcopy
def build_model_saver(model_opt, opt, model, fields, optim):
# _check_save_model_path
save_model_path = os.path.abspath(opt.save_model)
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)
model_saver = ModelSaver(opt.save_model,
model,
model_opt,
fields,
optim,
opt.keep_checkpoint)
return model_saver
def load_checkpoint(ckpt_path):
"""Load checkpoint from `ckpt_path` if any else return `None`."""
checkpoint = None
if ckpt_path:
logger.info('Loading checkpoint from %s' % ckpt_path)
checkpoint = torch.load(ckpt_path,
map_location=lambda storage, loc: storage)
return checkpoint
class ModelSaverBase(object):
"""Base class for model saving operations
Inherited classes must implement private methods:
* `_save`
* `_rm_checkpoint
"""
def __init__(self, base_path, model, model_opt, fields, optim,
keep_checkpoint=-1):
self.base_path = base_path
self.model = model
self.model_opt = model_opt
self.fields = fields
self.optim = optim
self.last_saved_step = None
self.keep_checkpoint = keep_checkpoint
if keep_checkpoint > 0:
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
def save(self, step, moving_average=None):
"""Main entry point for model saver
It wraps the `_save` method with checks and apply `keep_checkpoint`
related logic
"""
if self.keep_checkpoint == 0 or step == self.last_saved_step:
return
save_model = self.model
if moving_average:
model_params_data = []
for avg, param in zip(moving_average, save_model.parameters()):
model_params_data.append(param.data)
param.data = avg.data
chkpt, chkpt_name = self._save(step, save_model)
self.last_saved_step = step
if moving_average:
for param_data, param in zip(model_params_data,
save_model.parameters()):
param.data = param_data
if self.keep_checkpoint > 0:
if len(self.checkpoint_queue) == self.checkpoint_queue.maxlen:
todel = self.checkpoint_queue.popleft()
self._rm_checkpoint(todel)
self.checkpoint_queue.append(chkpt_name)
def _save(self, step, model):
"""Save a resumable checkpoint.
Args:
step (int): step number
model (nn.Module): torch model to save
Returns:
(object, str):
* checkpoint: the saved object
* checkpoint_name: name (or path) of the saved checkpoint
"""
raise NotImplementedError()
def _rm_checkpoint(self, name):
"""Remove a checkpoint
Args:
name(str): name that indentifies the checkpoint
(it may be a filepath)
"""
raise NotImplementedError()
class ModelSaver(ModelSaverBase):
"""Simple model saver to filesystem"""
def _save(self, step, model):
model_state_dict = model.state_dict()
model_state_dict = {k: v for k, v in model_state_dict.items()
if 'generator' not in k}
generator_state_dict = model.generator.state_dict()
# NOTE: We need to trim the vocab to remove any unk tokens that
# were not originally here.
vocab = deepcopy(self.fields)
for side in ["src", "tgt"]:
keys_to_pop = []
if hasattr(vocab[side], "fields"):
unk_token = vocab[side].fields[0][1].vocab.itos[0]
for key, value in vocab[side].fields[0][1].vocab.stoi.items():
if value == 0 and key != unk_token:
keys_to_pop.append(key)
for key in keys_to_pop:
vocab[side].fields[0][1].vocab.stoi.pop(key, None)
checkpoint = {
'model': model_state_dict,
'generator': generator_state_dict,
'vocab': vocab,
'opt': self.model_opt,
'optim': self.optim.state_dict(),
}
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
checkpoint_path = '%s_step_%d.pt' % (self.base_path, step)
torch.save(checkpoint, checkpoint_path)
return checkpoint, checkpoint_path
def _rm_checkpoint(self, name):
if os.path.exists(name):
os.remove(name)
|