Spaces:
Runtime error
Runtime error
File size: 1,415 Bytes
95f97c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
import argparse
import torch
def average_checkpoints(checkpoint_paths):
averaged_ckpt = torch.load(checkpoint_paths[-1], map_location=torch.device('cpu'))
param_sum_dict = {}
for key, value in averaged_ckpt['state_dict'].items():
param_sum_dict[key] = value.clone()
num_checkpoints = len(checkpoint_paths)
for ckpt_path in checkpoint_paths[:-1]:
checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
for key, value in checkpoint['state_dict'].items():
param_sum_dict[key] += value
for key in param_sum_dict.keys():
param_sum_dict[key] = param_sum_dict[key] / num_checkpoints
averaged_ckpt['state_dict'] = param_sum_dict
return averaged_ckpt
def parse_arguments():
parser = argparse.ArgumentParser(description="Averages the weights of multiple transformer model checkpoints.")
parser.add_argument('--checkpoint_paths', nargs='+', required=True,
help='List of paths to the checkpoints to be averaged. Example: --checkpoint_paths path1 path2 path3')
parser.add_argument('--output_path', type=str, required=True,)
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
averaged_state_dict = average_checkpoints(args.checkpoint_paths)
torch.save(averaged_state_dict, args.output_path)
print(f"Averaged checkpoint saved to {args.output_path}")
|