File size: 6,124 Bytes
2d8da09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Example: python scripts/checkpoint_averaging/distributed_checkpoint_averaging.py \
--name_prefix=<checkpoint name> \
--checkpoint_dir=<folder with mp_rank_X subfolders containing checkpoints>
--steps <optinally a list of checkpoint steps to average, if not provided, it will average all the checkpoints>
will generate a new directory in each of the distributed checkpoint subfolders named <checkpoint name>-averaged
"""
import argparse
import logging
import os
import shutil
import zarr
logging.basicConfig(level=logging.INFO)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--name_prefix', help='Name of the final checkpoint. Will append -averaged automatically.',
)
parser.add_argument(
'--checkpoint_dir', help='Folder containing all the distributed checkpoints.',
)
# list of checkpoint steps to average
parser.add_argument(
'--steps',
nargs='+',
type=int,
help='List of checkpoint steps to average. If not specified, will average all.',
)
args = parser.parse_args()
if args.steps is not None:
logging.info(f"Will average only steps {args.steps}")
# repeating for all ranks
checkpoint_paths = []
for ckpt_dir in os.listdir(args.checkpoint_dir):
logging.info("Processing %s", ckpt_dir)
if ckpt_dir.endswith('0-last'):
continue
if args.steps is None:
checkpoint_paths.append(ckpt_dir)
else:
for step in args.steps:
key = f"-step={step}-"
if key in ckpt_dir:
checkpoint_paths.append(ckpt_dir)
n = len(checkpoint_paths)
# initialize dict, will be used to store the weights that need to be averaged
avg_weights = {}
logging.info(f"Averaging {n} checkpoints ... {'at steps:' + str(args.steps) if args.steps is not None else ''}")
# item that needs to be copied to the new checkpoint folder
copy_items = []
for ix, path in enumerate(checkpoint_paths):
full_path = os.path.join(args.checkpoint_dir, path)
for item in os.listdir(full_path):
# if item is not a directory, skip it
if not os.path.isdir(os.path.join(full_path, item)):
if ix == 0:
copy_items.append(os.path.join(full_path, item))
continue
# transformer engine states, leave them out
if item.endswith('._extra_state'):
if ix == 0:
copy_items.append(os.path.join(full_path, item))
continue
# optimizer states, no point of averaing them
if item.startswith('optimizer.'):
if ix == 0:
copy_items.append(os.path.join(full_path, item))
continue
if item not in avg_weights:
logging.info(f"Initialized average weights dict with: {item}")
avg_weights[item] = zarr.open(os.path.join(full_path, item), mode='r')
else:
logging.info(f"Updated average weights dict with weight: {item}")
array_z = zarr.open(os.path.join(full_path, item), mode='r')
sum_array = avg_weights[item][:] + array_z[:]
avg_weights[item] = zarr.array(sum_array, chunks=array_z.chunks, dtype=array_z.dtype)
for k in avg_weights:
logging.info(f"Average weights dict key : {k}, dtype : {avg_weights[k].dtype}, shape : {avg_weights[k].shape}")
if str(avg_weights[k].dtype).startswith("int"):
raise ValueError("Int type not supported")
else:
array_z = avg_weights[k][:]
array_z = array_z / n
avg_weights[k] = zarr.array(array_z, chunks=avg_weights[k].chunks, dtype=avg_weights[k].dtype)
# Save model
if args.steps is None:
ckpt_name = os.path.join(args.checkpoint_dir, args.name_prefix + '-averaged')
else:
steps_combined = '_'.join([str(x) for x in args.steps])
ckpt_name = os.path.join(args.checkpoint_dir, args.name_prefix + '-' + steps_combined + '-averaged')
# save avg_weights
for k in avg_weights:
logging.info(f"Saving {k} to {ckpt_name}")
zarr.save(os.path.join(ckpt_name, k), avg_weights[k])
# copy other files
for item in copy_items:
is_file = os.path.isfile(item)
logging.info(f"Copying {'directory' if is_file else 'file'} {item} to {ckpt_name}")
if os.path.isfile(item):
# copy single file
shutil.copy(item, ckpt_name)
else:
# copy directory
shutil.copytree(item, os.path.join(ckpt_name, os.path.basename(item)), dirs_exist_ok=True)
logging.info(f"Averaged distributed checkpoint saved as : {ckpt_name}")
if __name__ == '__main__':
main()
|