import argparse import torch def average_checkpoints(checkpoint_paths): averaged_ckpt = torch.load(checkpoint_paths[-1], map_location=torch.device('cpu')) param_sum_dict = {} for key, value in averaged_ckpt['state_dict'].items(): param_sum_dict[key] = value.clone() num_checkpoints = len(checkpoint_paths) for ckpt_path in checkpoint_paths[:-1]: checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) for key, value in checkpoint['state_dict'].items(): param_sum_dict[key] += value for key in param_sum_dict.keys(): param_sum_dict[key] = param_sum_dict[key] / num_checkpoints averaged_ckpt['state_dict'] = param_sum_dict return averaged_ckpt def parse_arguments(): parser = argparse.ArgumentParser(description="Averages the weights of multiple transformer model checkpoints.") parser.add_argument('--checkpoint_paths', nargs='+', required=True, help='List of paths to the checkpoints to be averaged. Example: --checkpoint_paths path1 path2 path3') parser.add_argument('--output_path', type=str, required=True,) return parser.parse_args() if __name__ == "__main__": args = parse_arguments() averaged_state_dict = average_checkpoints(args.checkpoint_paths) torch.save(averaged_state_dict, args.output_path) print(f"Averaged checkpoint saved to {args.output_path}")