File size: 6,124 Bytes
2d8da09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Example: python scripts/checkpoint_averaging/distributed_checkpoint_averaging.py \
             --name_prefix=<checkpoint name> \
             --checkpoint_dir=<folder with mp_rank_X subfolders containing checkpoints>
             --steps <optinally a list of checkpoint steps to average, if not provided, it will average all the checkpoints>

will generate a new directory in each of the distributed checkpoint subfolders named <checkpoint name>-averaged
"""

import argparse
import logging
import os
import shutil

import zarr

logging.basicConfig(level=logging.INFO)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--name_prefix', help='Name of the final checkpoint. Will append -averaged automatically.',
    )
    parser.add_argument(
        '--checkpoint_dir', help='Folder containing all the distributed checkpoints.',
    )
    # list of checkpoint steps to average
    parser.add_argument(
        '--steps',
        nargs='+',
        type=int,
        help='List of checkpoint steps to average. If not specified, will average all.',
    )

    args = parser.parse_args()

    if args.steps is not None:
        logging.info(f"Will average only steps {args.steps}")

    # repeating for all ranks

    checkpoint_paths = []
    for ckpt_dir in os.listdir(args.checkpoint_dir):
        logging.info("Processing %s", ckpt_dir)
        if ckpt_dir.endswith('0-last'):
            continue
        if args.steps is None:
            checkpoint_paths.append(ckpt_dir)
        else:
            for step in args.steps:
                key = f"-step={step}-"
                if key in ckpt_dir:
                    checkpoint_paths.append(ckpt_dir)

    n = len(checkpoint_paths)
    # initialize dict, will be used to store the weights that need to be averaged
    avg_weights = {}

    logging.info(f"Averaging {n} checkpoints ... {'at steps:' + str(args.steps) if args.steps is not None else ''}")

    # item that needs to be copied to the new checkpoint folder
    copy_items = []
    for ix, path in enumerate(checkpoint_paths):
        full_path = os.path.join(args.checkpoint_dir, path)

        for item in os.listdir(full_path):

            # if item is not a directory, skip it
            if not os.path.isdir(os.path.join(full_path, item)):
                if ix == 0:
                    copy_items.append(os.path.join(full_path, item))
                continue

            # transformer engine states, leave them out
            if item.endswith('._extra_state'):
                if ix == 0:
                    copy_items.append(os.path.join(full_path, item))
                continue

            # optimizer states, no point of averaing them
            if item.startswith('optimizer.'):
                if ix == 0:
                    copy_items.append(os.path.join(full_path, item))
                continue

            if item not in avg_weights:
                logging.info(f"Initialized average weights dict with: {item}")
                avg_weights[item] = zarr.open(os.path.join(full_path, item), mode='r')
            else:
                logging.info(f"Updated average weights dict with weight: {item}")
                array_z = zarr.open(os.path.join(full_path, item), mode='r')
                sum_array = avg_weights[item][:] + array_z[:]
                avg_weights[item] = zarr.array(sum_array, chunks=array_z.chunks, dtype=array_z.dtype)

    for k in avg_weights:
        logging.info(f"Average weights dict key : {k}, dtype : {avg_weights[k].dtype}, shape : {avg_weights[k].shape}")
        if str(avg_weights[k].dtype).startswith("int"):
            raise ValueError("Int type not supported")
        else:
            array_z = avg_weights[k][:]
            array_z = array_z / n
            avg_weights[k] = zarr.array(array_z, chunks=avg_weights[k].chunks, dtype=avg_weights[k].dtype)

    # Save model
    if args.steps is None:
        ckpt_name = os.path.join(args.checkpoint_dir, args.name_prefix + '-averaged')
    else:
        steps_combined = '_'.join([str(x) for x in args.steps])
        ckpt_name = os.path.join(args.checkpoint_dir, args.name_prefix + '-' + steps_combined + '-averaged')

    # save avg_weights
    for k in avg_weights:
        logging.info(f"Saving {k} to {ckpt_name}")
        zarr.save(os.path.join(ckpt_name, k), avg_weights[k])

    # copy other files
    for item in copy_items:
        is_file = os.path.isfile(item)
        logging.info(f"Copying {'directory' if is_file else 'file'} {item} to {ckpt_name}")
        if os.path.isfile(item):
            # copy single file
            shutil.copy(item, ckpt_name)
        else:
            # copy directory
            shutil.copytree(item, os.path.join(ckpt_name, os.path.basename(item)), dirs_exist_ok=True)

    logging.info(f"Averaged distributed checkpoint saved as : {ckpt_name}")


if __name__ == '__main__':
    main()