File size: 4,775 Bytes
d758c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Tools to save/restore model from checkpoints."""

import argparse
import shutil
import sys
import os
import re
import json
import time

import torch

CHECKPOINT_PATTERN = re.compile('^model_checkpoint-(\d+)$')


class ArgsDict(dict):

    def __init__(self, **kwargs):
        super(ArgsDict, self).__init__()
        for key, value in kwargs.items():
            self[key] = value
        self.__dict__ = self


def load_checkpoint(item_dict, model_dir, map_location=None, step=None):
    """ item_dict: {"model": model, "opt1": opt1, ...}"""
    path = os.path.join(model_dir, 'model_checkpoint')
    if step is not None:
        path += '-{:08d}'.format(step)
    if os.path.exists(path):
        print("Loading model from %s" % path)
        checkpoint = torch.load(path, map_location=map_location)

        old_state_dict = item_dict["model"].state_dict()
        for key in old_state_dict.keys():
            if key not in checkpoint['model']:
                checkpoint['model'][key] = old_state_dict[key]
            
        for item_name in item_dict:
            item_dict[item_name].load_state_dict(checkpoint[item_name])
        return checkpoint.get('step', 0)
    return 0


def load_and_map_checkpoint(model, model_dir, remap):
    path = os.path.join(model_dir, 'model_checkpoint')
    print("Loading parameters %s from %s" % (remap.keys(), model_dir))
    checkpoint = torch.load(path)
    new_state_dict = model.state_dict()
    for name, value in remap.items():
        # TODO: smarter mapping.
        new_state_dict[name] = checkpoint['model'][value]
    model.load_state_dict(new_state_dict)


def save_checkpoint(items, step, model_dir, ignore=[],
                    keep_every_n=10000000):
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    path_without_step = os.path.join(model_dir, 'model_checkpoint')
    step_padded = format(step, '08d')
    state_dict = items["model"].state_dict()
    if ignore:
        for key in state_dict.keys():
            for item in ignore:
                if key.startswith(item):
                    state_dict.pop(key)
    path_with_step = '{}-{}'.format(path_without_step, step_padded)

    saved_dic = {}
    for key in items:
        saved_dic[key] = items[key].state_dict()
    torch.save({**saved_dic, "step": step}, path_with_step)

    try:
        os.unlink(path_without_step)
    except FileNotFoundError:
        pass
    try:
        os.symlink(os.path.basename(path_with_step), path_without_step)
    except OSError:
        shutil.copy2(path_with_step, path_without_step)

    # Cull old checkpoints.
    if keep_every_n is not None:
        all_checkpoints = []
        for name in os.listdir(model_dir):
            m = CHECKPOINT_PATTERN.match(name)
            if m is None or name == os.path.basename(path_with_step):
                continue
            checkpoint_step = int(m.group(1))
            all_checkpoints.append((checkpoint_step, name))
        all_checkpoints.sort()

        last_step = float('-inf')
        for checkpoint_step, name in all_checkpoints:
            if checkpoint_step - last_step >= keep_every_n:
                last_step = checkpoint_step
                continue
            os.unlink(os.path.join(model_dir, name))


class Saver(object):
    """Class to manage save and restore for the model and optimizer."""

    def __init__(self, items, keep_every_n=None):
        assert type(items) == dict
        assert "model" in items
        self._items = items
        self._keep_every_n = keep_every_n

    def restore(self, model_dir, map_location=None, 
            step=None, item_keys=["model", "optimizer"]):
        """Restores model and optimizer from given directory.
            Specify what shoud be restored

        Returns:
           Last training step for the model restored.
        """
        items2restore = { k: self._items[k] for k in item_keys}
        last_step = load_checkpoint(
            items2restore, model_dir, map_location, step)
        return last_step

    def save(self, model_dir, step):
        """Saves model and optimizer to given directory.

        Args:
           model_dir: Model directory to save.
           step: Current training step.
        """
        save_checkpoint(self._items, step, model_dir,
                        keep_every_n=self._keep_every_n)

    def restore_part(self, other_model_dir, remap):
        """Restores part of the model from other directory.

        Useful to initialize part of the model with another pretrained model.

        Args:
            other_model_dir: Model directory to load from.
            remap: dict, remapping current parameters to the other model's.
        """
        load_and_map_checkpoint(self._items["model"], other_model_dir, remap)