File size: 6,047 Bytes
2d8da09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
# Changes to script
Change the script to import the NeMo model class you would like to load a checkpoint for,
then update the model constructor to use this model class. This can be found by the line:

<<< Change model class here ! >>>

By default, this script imports and creates the `EncDecCTCModelBPE` class but it can be
changed to any NeMo Model.

# Run the script

## Saving a .nemo model file (loaded with ModelPT.restore_from(...))
HYDRA_FULL_ERROR=1 python average_model_checkpoints.py \
    --config-path="<path to config directory>" \
    --config-name="<config name>" \
    name=<name of the averaged checkpoint> \
    +checkpoint_dir=<OPTIONAL: directory of checkpoint> \
    +checkpoint_paths=\"[/path/to/ptl_1.ckpt,/path/to/ptl_2.ckpt,/path/to/ptl_3.ckpt,...]\"


## Saving an averaged pytorch checkpoint (loaded with torch.load(...))
HYDRA_FULL_ERROR=1 python average_model_checkpoints.py \
    --config-path="<path to config directory>" \
    --config-name="<config name>" \
    name=<name of the averaged checkpoint> \
     +checkpoint_dir=<OPTIONAL: directory of checkpoint> \
    +checkpoint_paths=\"[/path/to/ptl_1.ckpt,/path/to/ptl_2.ckpt,/path/to/ptl_3.ckpt,...]\" \
    +save_ckpt_only=true

"""

import os

import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf, open_dict

# Change this import to the model you would like to average
from nemo.collections.asr.models import EncDecCTCModelBPE
from nemo.core.config import hydra_runner
from nemo.utils import logging


def process_config(cfg: OmegaConf):
    if 'name' not in cfg or cfg.name is None:
        raise ValueError("`cfg.name` must be provided to save a model checkpoint")

    if 'checkpoint_paths' not in cfg or cfg.checkpoint_paths is None:
        raise ValueError(
            "`cfg.checkpoint_paths` must be provided as a list of one or more str paths to "
            "pytorch lightning checkpoints"
        )

    save_ckpt_only = False

    with open_dict(cfg):
        name_prefix = cfg.name
        checkpoint_paths = cfg.pop('checkpoint_paths')

        if 'checkpoint_dir' in cfg:
            checkpoint_dir = cfg.pop('checkpoint_dir')
        else:
            checkpoint_dir = None

        if 'save_ckpt_only' in cfg:
            save_ckpt_only = cfg.pop('save_ckpt_only')

    if type(checkpoint_paths) not in (list, tuple):
        checkpoint_paths = str(checkpoint_paths).replace("[", "").replace("]", "")
        checkpoint_paths = checkpoint_paths.split(",")
        checkpoint_paths = [ckpt_path.strip() for ckpt_path in checkpoint_paths]

    if checkpoint_dir is not None:
        checkpoint_paths = [os.path.join(checkpoint_dir, path) for path in checkpoint_paths]

    return name_prefix, checkpoint_paths, save_ckpt_only


@hydra_runner(config_path=None, config_name=None)
def main(cfg):
    name_prefix, checkpoint_paths, save_ckpt_only = process_config(cfg)

    if not save_ckpt_only:
        trainer = pl.Trainer(**cfg.trainer)

        # <<< Change model class here ! >>>
        # Model architecture which will contain the averaged checkpoints
        # Change the model constructor to the one you would like (if needed)
        model = EncDecCTCModelBPE(cfg=cfg.model, trainer=trainer)

    """ < Checkpoint Averaging Logic > """
    # load state dicts
    n = len(checkpoint_paths)
    avg_state = None

    logging.info(f"Averaging {n} checkpoints ...")

    for ix, path in enumerate(checkpoint_paths):
        checkpoint = torch.load(path, map_location='cpu')

        if 'state_dict' in checkpoint:
            checkpoint = checkpoint['state_dict']

        if ix == 0:
            # Initial state
            avg_state = checkpoint

            logging.info(f"Initialized average state dict with checkpoint : {path}")
        else:
            # Accumulated state
            for k in avg_state:
                avg_state[k] = avg_state[k] + checkpoint[k]

            logging.info(f"Updated average state dict with state from checkpoint : {path}")

    for k in avg_state:
        if str(avg_state[k].dtype).startswith("torch.int"):
            # For int type, not averaged, but only accumulated.
            # e.g. BatchNorm.num_batches_tracked
            pass
        else:
            avg_state[k] = avg_state[k] / n

    # Save model
    if save_ckpt_only:
        ckpt_name = name_prefix + '-averaged.ckpt'
        torch.save(avg_state, ckpt_name)

        logging.info(f"Averaged pytorch checkpoint saved as : {ckpt_name}")
    else:
        # Set model state
        logging.info("Loading averaged state dict in provided model")
        model.load_state_dict(avg_state, strict=True)

        ckpt_name = name_prefix + '-averaged.nemo'
        model.save_to(ckpt_name)

        logging.info(f"Averaged model saved as : {ckpt_name}")


if __name__ == '__main__':
    main()