File size: 4,775 Bytes
d758c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
"""Tools to save/restore model from checkpoints."""
import argparse
import shutil
import sys
import os
import re
import json
import time
import torch
CHECKPOINT_PATTERN = re.compile('^model_checkpoint-(\d+)$')
class ArgsDict(dict):
def __init__(self, **kwargs):
super(ArgsDict, self).__init__()
for key, value in kwargs.items():
self[key] = value
self.__dict__ = self
def load_checkpoint(item_dict, model_dir, map_location=None, step=None):
""" item_dict: {"model": model, "opt1": opt1, ...}"""
path = os.path.join(model_dir, 'model_checkpoint')
if step is not None:
path += '-{:08d}'.format(step)
if os.path.exists(path):
print("Loading model from %s" % path)
checkpoint = torch.load(path, map_location=map_location)
old_state_dict = item_dict["model"].state_dict()
for key in old_state_dict.keys():
if key not in checkpoint['model']:
checkpoint['model'][key] = old_state_dict[key]
for item_name in item_dict:
item_dict[item_name].load_state_dict(checkpoint[item_name])
return checkpoint.get('step', 0)
return 0
def load_and_map_checkpoint(model, model_dir, remap):
path = os.path.join(model_dir, 'model_checkpoint')
print("Loading parameters %s from %s" % (remap.keys(), model_dir))
checkpoint = torch.load(path)
new_state_dict = model.state_dict()
for name, value in remap.items():
# TODO: smarter mapping.
new_state_dict[name] = checkpoint['model'][value]
model.load_state_dict(new_state_dict)
def save_checkpoint(items, step, model_dir, ignore=[],
keep_every_n=10000000):
if not os.path.exists(model_dir):
os.makedirs(model_dir)
path_without_step = os.path.join(model_dir, 'model_checkpoint')
step_padded = format(step, '08d')
state_dict = items["model"].state_dict()
if ignore:
for key in state_dict.keys():
for item in ignore:
if key.startswith(item):
state_dict.pop(key)
path_with_step = '{}-{}'.format(path_without_step, step_padded)
saved_dic = {}
for key in items:
saved_dic[key] = items[key].state_dict()
torch.save({**saved_dic, "step": step}, path_with_step)
try:
os.unlink(path_without_step)
except FileNotFoundError:
pass
try:
os.symlink(os.path.basename(path_with_step), path_without_step)
except OSError:
shutil.copy2(path_with_step, path_without_step)
# Cull old checkpoints.
if keep_every_n is not None:
all_checkpoints = []
for name in os.listdir(model_dir):
m = CHECKPOINT_PATTERN.match(name)
if m is None or name == os.path.basename(path_with_step):
continue
checkpoint_step = int(m.group(1))
all_checkpoints.append((checkpoint_step, name))
all_checkpoints.sort()
last_step = float('-inf')
for checkpoint_step, name in all_checkpoints:
if checkpoint_step - last_step >= keep_every_n:
last_step = checkpoint_step
continue
os.unlink(os.path.join(model_dir, name))
class Saver(object):
"""Class to manage save and restore for the model and optimizer."""
def __init__(self, items, keep_every_n=None):
assert type(items) == dict
assert "model" in items
self._items = items
self._keep_every_n = keep_every_n
def restore(self, model_dir, map_location=None,
step=None, item_keys=["model", "optimizer"]):
"""Restores model and optimizer from given directory.
Specify what shoud be restored
Returns:
Last training step for the model restored.
"""
items2restore = { k: self._items[k] for k in item_keys}
last_step = load_checkpoint(
items2restore, model_dir, map_location, step)
return last_step
def save(self, model_dir, step):
"""Saves model and optimizer to given directory.
Args:
model_dir: Model directory to save.
step: Current training step.
"""
save_checkpoint(self._items, step, model_dir,
keep_every_n=self._keep_every_n)
def restore_part(self, other_model_dir, remap):
"""Restores part of the model from other directory.
Useful to initialize part of the model with another pretrained model.
Args:
other_model_dir: Model directory to load from.
remap: dict, remapping current parameters to the other model's.
"""
load_and_map_checkpoint(self._items["model"], other_model_dir, remap)
|