antonlabate
ver 1.3
d758c99
"""Tools to save/restore model from checkpoints."""
import argparse
import shutil
import sys
import os
import re
import json
import time
import torch
CHECKPOINT_PATTERN = re.compile('^model_checkpoint-(\d+)$')
class ArgsDict(dict):
def __init__(self, **kwargs):
super(ArgsDict, self).__init__()
for key, value in kwargs.items():
self[key] = value
self.__dict__ = self
def load_checkpoint(item_dict, model_dir, map_location=None, step=None):
""" item_dict: {"model": model, "opt1": opt1, ...}"""
path = os.path.join(model_dir, 'model_checkpoint')
if step is not None:
path += '-{:08d}'.format(step)
if os.path.exists(path):
print("Loading model from %s" % path)
checkpoint = torch.load(path, map_location=map_location)
old_state_dict = item_dict["model"].state_dict()
for key in old_state_dict.keys():
if key not in checkpoint['model']:
checkpoint['model'][key] = old_state_dict[key]
for item_name in item_dict:
item_dict[item_name].load_state_dict(checkpoint[item_name])
return checkpoint.get('step', 0)
return 0
def load_and_map_checkpoint(model, model_dir, remap):
path = os.path.join(model_dir, 'model_checkpoint')
print("Loading parameters %s from %s" % (remap.keys(), model_dir))
checkpoint = torch.load(path)
new_state_dict = model.state_dict()
for name, value in remap.items():
# TODO: smarter mapping.
new_state_dict[name] = checkpoint['model'][value]
model.load_state_dict(new_state_dict)
def save_checkpoint(items, step, model_dir, ignore=[],
keep_every_n=10000000):
if not os.path.exists(model_dir):
os.makedirs(model_dir)
path_without_step = os.path.join(model_dir, 'model_checkpoint')
step_padded = format(step, '08d')
state_dict = items["model"].state_dict()
if ignore:
for key in state_dict.keys():
for item in ignore:
if key.startswith(item):
state_dict.pop(key)
path_with_step = '{}-{}'.format(path_without_step, step_padded)
saved_dic = {}
for key in items:
saved_dic[key] = items[key].state_dict()
torch.save({**saved_dic, "step": step}, path_with_step)
try:
os.unlink(path_without_step)
except FileNotFoundError:
pass
try:
os.symlink(os.path.basename(path_with_step), path_without_step)
except OSError:
shutil.copy2(path_with_step, path_without_step)
# Cull old checkpoints.
if keep_every_n is not None:
all_checkpoints = []
for name in os.listdir(model_dir):
m = CHECKPOINT_PATTERN.match(name)
if m is None or name == os.path.basename(path_with_step):
continue
checkpoint_step = int(m.group(1))
all_checkpoints.append((checkpoint_step, name))
all_checkpoints.sort()
last_step = float('-inf')
for checkpoint_step, name in all_checkpoints:
if checkpoint_step - last_step >= keep_every_n:
last_step = checkpoint_step
continue
os.unlink(os.path.join(model_dir, name))
class Saver(object):
"""Class to manage save and restore for the model and optimizer."""
def __init__(self, items, keep_every_n=None):
assert type(items) == dict
assert "model" in items
self._items = items
self._keep_every_n = keep_every_n
def restore(self, model_dir, map_location=None,
step=None, item_keys=["model", "optimizer"]):
"""Restores model and optimizer from given directory.
Specify what shoud be restored
Returns:
Last training step for the model restored.
"""
items2restore = { k: self._items[k] for k in item_keys}
last_step = load_checkpoint(
items2restore, model_dir, map_location, step)
return last_step
def save(self, model_dir, step):
"""Saves model and optimizer to given directory.
Args:
model_dir: Model directory to save.
step: Current training step.
"""
save_checkpoint(self._items, step, model_dir,
keep_every_n=self._keep_every_n)
def restore_part(self, other_model_dir, remap):
"""Restores part of the model from other directory.
Useful to initialize part of the model with another pretrained model.
Args:
other_model_dir: Model directory to load from.
remap: dict, remapping current parameters to the other model's.
"""
load_and_map_checkpoint(self._items["model"], other_model_dir, remap)