|
import argparse |
|
from collections.abc import Mapping |
|
import concurrent.futures |
|
import os |
|
import sys |
|
|
|
import torch |
|
|
|
def add_arguments(parser): |
|
group = parser.add_argument_group(title='Megatron saver') |
|
|
|
group.add_argument('--megatron-path', type=str, default=None, |
|
help='Base directory of Megatron repository') |
|
|
|
group.add_argument('--target-tensor-parallel-size', type=int, |
|
help='Target tensor model parallel size, defaults to the tensor parallel size ' |
|
'in the input checkpoint if provided by the loader, otherwise to 1') |
|
group.add_argument('--target-pipeline-parallel-size', type=int, |
|
help='Target tensor model parallel size, default to the pipeline parall size ' |
|
'in the input checkpoint if provided by the loader, otherwise to 1') |
|
|
|
def save_checkpoint(queue, args): |
|
|
|
|
|
sys.path.append(os.path.abspath( |
|
os.path.join(os.path.dirname(__file__), |
|
os.path.pardir))) |
|
if args.megatron_path is not None: |
|
sys.path.insert(0, args.megatron_path) |
|
|
|
try: |
|
from megatron.arguments import (parse_args, validate_args) |
|
from megatron.checkpointing import save_checkpoint |
|
from megatron.global_vars import set_global_variables, get_args |
|
from megatron.model import ModelType |
|
from megatron.tokenizer.tokenizer import _vocab_size_with_padding |
|
from megatron import mpu, fused_kernels |
|
except ModuleNotFoundError: |
|
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") |
|
exit(1) |
|
|
|
def queue_get(name=None): |
|
val = queue.get() |
|
if val == "exit": |
|
print("Loader exited, exiting saver") |
|
exit(1) |
|
if name is not None and args.checking and val["name"] != name: |
|
val_name = val["name"] |
|
print(f'Unexpected message. Expecting "{name}" but got "{val_name}". Exiting saver.') |
|
exit(1) |
|
if name is not None: |
|
print(f"received {name}") |
|
return val |
|
|
|
def check_message(msg): |
|
if not args.checking: |
|
return |
|
msg_name = msg.pop("name") |
|
if len(msg.keys()) > 0: |
|
print(f"Unexpected values in {msg_name}:") |
|
for key in msg.keys(): |
|
print(f" {key}") |
|
print(f"Exiting. If you want to ignore this, use the argument --no-checking.") |
|
exit(1) |
|
|
|
|
|
md = queue_get() |
|
|
|
if args.target_tensor_parallel_size is None: |
|
if hasattr(md, 'previous_tensor_parallel_size'): |
|
args.target_tensor_parallel_size = md.previous_tensor_parallel_size |
|
else: |
|
print("loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. " |
|
"Default to 1.") |
|
args.target_tensor_parallel_size = 1 |
|
|
|
if args.target_pipeline_parallel_size is None: |
|
if hasattr(md, 'previous_pipeline_parallel_size'): |
|
args.target_pipeline_parallel_size = md.previous_pipeline_parallel_size |
|
else: |
|
print("loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. " |
|
"Default to 1.") |
|
args.target_pipeline_parallel_size = 1 |
|
|
|
|
|
|
|
|
|
if args.target_tensor_parallel_size is not None and args.target_pipeline_parallel_size is not None: |
|
os.environ["WORLD_SIZE"] = f'{args.target_tensor_parallel_size * args.target_pipeline_parallel_size}' |
|
|
|
|
|
sys.argv = ['script.py', |
|
'--num-layers', str(md.num_layers), |
|
'--hidden-size', str(md.hidden_size), |
|
'--seq-length', str(md.seq_length), |
|
'--num-attention-heads', str(md.num_attention_heads), |
|
'--max-position-embeddings', str(md.max_position_embeddings), |
|
'--tokenizer-type', str(md.tokenizer_type), |
|
'--tensor-model-parallel-size', str(args.target_tensor_parallel_size), |
|
'--pipeline-model-parallel-size', str(args.target_pipeline_parallel_size), |
|
'--no-masked-softmax-fusion', |
|
'--no-bias-gelu-fusion', |
|
'--no-bias-dropout-fusion', |
|
'--use-cpu-initialization', |
|
'--micro-batch-size', '1', |
|
'--no-load-optim', |
|
'--no-load-rng', |
|
'--no-save-optim', |
|
'--no-save-rng', |
|
'--no-initialization', |
|
'--save-interval', '1', |
|
'--save', args.save_dir |
|
] |
|
|
|
if md.make_vocab_size_divisible_by is not None: |
|
sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)]) |
|
if md.params_dtype == torch.float16: |
|
sys.argv.append('--fp16') |
|
elif md.params_dtype == torch.bfloat16: |
|
sys.argv.append('--bf16') |
|
|
|
if md.model_type == 'BERT' and not md.bert_binary_head: |
|
sys.argv.append('--bert-no-binary-head') |
|
|
|
margs = parse_args() |
|
validate_args(margs) |
|
set_global_variables(margs) |
|
|
|
|
|
margs = get_args() |
|
|
|
if hasattr(md, 'consumed_train_samples'): |
|
margs.consumed_train_samples = md.consumed_train_samples |
|
margs.consumed_valid_samples = md.consumed_valid_samples |
|
print(f"Setting consumed_train_samples to {margs.consumed_train_samples}" |
|
f" and consumed_valid_samples to {margs.consumed_valid_samples}") |
|
else: |
|
print("consumed_train_samples not provided.") |
|
|
|
|
|
if md.model_type == 'GPT': |
|
from pretrain_gpt import model_provider |
|
margs.model_type = ModelType.encoder_or_decoder |
|
elif md.model_type == 'BERT': |
|
from pretrain_bert import model_provider |
|
margs.model_type = ModelType.encoder_or_decoder |
|
else: |
|
raise Exception(f'unrecognized model type: {args.model_type}') |
|
|
|
def get_models(count, dtype, pre_process, post_process): |
|
models = [model_provider(pre_process, post_process).to(dtype) for _ in range(count)] |
|
return models |
|
|
|
|
|
mpu.initialize.set_tensor_model_parallel_world_size(args.target_tensor_parallel_size) |
|
mpu.initialize.set_pipeline_model_parallel_world_size(args.target_pipeline_parallel_size) |
|
mpu.initialize.set_tensor_model_parallel_rank(0) |
|
mpu.initialize.set_pipeline_model_parallel_rank(0) |
|
fused_kernels.load(margs) |
|
|
|
|
|
|
|
embeddings_msg = queue_get("embeddings") |
|
|
|
pos_embed = embeddings_msg.pop("position embeddings") |
|
orig_word_embed = embeddings_msg.pop("word embeddings") |
|
check_message(embeddings_msg) |
|
|
|
|
|
if md.true_vocab_size is not None: |
|
|
|
orig_vocab_size = orig_word_embed.shape[0] |
|
margs.padded_vocab_size = _vocab_size_with_padding(md.true_vocab_size, margs) |
|
|
|
|
|
if orig_vocab_size > margs.padded_vocab_size: |
|
full_word_embed = orig_word_embed[0:margs.padded_vocab_size,:] |
|
|
|
|
|
elif orig_vocab_size < margs.padded_vocab_size: |
|
padding_size = margs.padded_vocab_size - orig_vocab_size |
|
|
|
full_word_embed = torch.cat(( |
|
orig_word_embed, |
|
orig_word_embed[-1].unsqueeze(0).expand(padding_size, -1))) |
|
|
|
|
|
else: |
|
full_word_embed = orig_word_embed |
|
else: |
|
print("Original vocab size not specified, leaving embedding table as-is. " |
|
"If you've changed the tensor parallel size this could cause problems.") |
|
margs.padded_vocab_size = orig_word_embed.shape[0] |
|
full_word_embed = orig_word_embed |
|
|
|
|
|
out_word_embed = torch.chunk(full_word_embed, args.target_tensor_parallel_size, dim=0) |
|
|
|
|
|
mpu.initialize.set_pipeline_model_parallel_rank(0) |
|
post_process = args.target_pipeline_parallel_size == 1 |
|
models = get_models(args.target_tensor_parallel_size, md.params_dtype, True, post_process) |
|
for tp_rank, model in enumerate(models): |
|
print(f"word embeddings shape {model.language_model.embedding.word_embeddings.weight.shape}") |
|
model.language_model.embedding.word_embeddings.weight.data.copy_(out_word_embed[tp_rank]) |
|
model.language_model.embedding.position_embeddings.weight.data.copy_(pos_embed) |
|
|
|
|
|
|
|
total_layer_num = 0 |
|
for pp_rank in range(args.target_pipeline_parallel_size): |
|
|
|
if pp_rank > 0: |
|
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank) |
|
post_process = pp_rank == args.target_pipeline_parallel_size - 1 |
|
models = get_models(args.target_tensor_parallel_size, md.params_dtype, False, post_process) |
|
|
|
for layer in range(len(models[0].language_model.encoder.layers)): |
|
msg = queue_get(f"transformer layer {total_layer_num}") |
|
|
|
|
|
input_layernorm_weight = msg.pop("input layernorm weight") |
|
input_layernorm_bias = msg.pop("input layernorm bias") |
|
dense_bias = msg.pop("dense bias") |
|
post_layernorm_weight = msg.pop("post layernorm weight") |
|
post_layernorm_bias = msg.pop("post layernorm bias") |
|
mlp_l1_bias = msg.pop("mlp l1 bias") |
|
|
|
|
|
qkv_weight = torch.chunk(msg.pop("qkv weight"), args.target_tensor_parallel_size, dim=0) |
|
qkv_bias = torch.chunk(msg.pop("qkv bias"), args.target_tensor_parallel_size, dim=0) |
|
dense_weight = torch.chunk(msg.pop("dense weight"), args.target_tensor_parallel_size, dim=1) |
|
mlp_l0_weight = torch.chunk(msg.pop("mlp l0 weight"), args.target_tensor_parallel_size, dim=0) |
|
mlp_l0_bias = torch.chunk(msg.pop("mlp l0 bias"), args.target_tensor_parallel_size, dim=0) |
|
mlp_l1_weight = torch.chunk(msg.pop("mlp l1 weight"), args.target_tensor_parallel_size, dim=1) |
|
|
|
|
|
for tp_rank in range(args.target_tensor_parallel_size): |
|
l = models[tp_rank].language_model.encoder.layers[layer] |
|
l.input_layernorm.weight.data.copy_(input_layernorm_weight) |
|
l.input_layernorm.bias.data.copy_(input_layernorm_bias) |
|
l.self_attention.query_key_value.weight.data.copy_(qkv_weight[tp_rank]) |
|
l.self_attention.query_key_value.bias.data.copy_(qkv_bias[tp_rank]) |
|
l.self_attention.dense.weight.data.copy_(dense_weight[tp_rank]) |
|
l.self_attention.dense.bias.data.copy_(dense_bias) |
|
l.post_attention_layernorm.weight.data.copy_(post_layernorm_weight) |
|
l.post_attention_layernorm.bias.data.copy_(post_layernorm_bias) |
|
l.mlp.dense_h_to_4h.weight.data.copy_(mlp_l0_weight[tp_rank]) |
|
l.mlp.dense_h_to_4h.bias.data.copy_(mlp_l0_bias[tp_rank]) |
|
l.mlp.dense_4h_to_h.weight.data.copy_(mlp_l1_weight[tp_rank]) |
|
l.mlp.dense_4h_to_h.bias.data.copy_(mlp_l1_bias) |
|
total_layer_num = total_layer_num + 1 |
|
check_message(msg) |
|
|
|
|
|
if post_process: |
|
msg = queue_get("final layernorm") |
|
final_layernorm_weight = msg.pop("weight") |
|
final_layernorm_bias = msg.pop("bias") |
|
for tp_rank in range(args.target_tensor_parallel_size): |
|
models[tp_rank].language_model.encoder.final_layernorm.weight.data.copy_(final_layernorm_weight) |
|
models[tp_rank].language_model.encoder.final_layernorm.bias.data.copy_(final_layernorm_bias) |
|
if pp_rank != 0: |
|
|
|
models[tp_rank].word_embeddings.weight.data.copy_(out_word_embed[tp_rank]) |
|
del final_layernorm_weight |
|
del final_layernorm_bias |
|
check_message(msg) |
|
|
|
msg = queue_get() |
|
if msg != "done" and msg["name"] == "pooler": |
|
if not hasattr(models[0].language_model, 'pooler'): |
|
print("ERROR: got a pooler, but model does not have one") |
|
exit(1) |
|
print("received pooler") |
|
pooler_weight = msg.pop("weight") |
|
pooler_bias = msg.pop("bias") |
|
for tp_rank in range(args.target_tensor_parallel_size): |
|
models[tp_rank].language_model.pooler.dense.weight.data.copy_(pooler_weight) |
|
models[tp_rank].language_model.pooler.dense.bias.data.copy_(pooler_bias) |
|
del pooler_weight |
|
del pooler_bias |
|
check_message(msg) |
|
msg = queue_get() |
|
|
|
if msg != "done" and msg["name"] == "lm head": |
|
if not hasattr(models[0], 'lm_head'): |
|
print("ERROR: got an lm head, but model does not have one") |
|
exit(1) |
|
print("received lm head") |
|
lm_head_dense_weight = msg.pop("dense weight") |
|
lm_head_dense_bias = msg.pop("dense bias") |
|
lm_head_layernorm_weight = msg.pop("layernorm weight") |
|
lm_head_layernorm_bias = msg.pop("layernorm bias") |
|
for tp_rank in range(args.target_tensor_parallel_size): |
|
models[tp_rank].lm_head.dense.weight.data.copy_(lm_head_dense_weight) |
|
models[tp_rank].lm_head.dense.bias.data.copy_(lm_head_dense_bias) |
|
models[tp_rank].lm_head.layernorm.weight.data.copy_(lm_head_layernorm_weight) |
|
models[tp_rank].lm_head.layernorm.bias.data.copy_(lm_head_layernorm_bias) |
|
check_message(msg) |
|
msg = queue_get() |
|
|
|
if msg != "done" and msg["name"] == "binary head": |
|
if not hasattr(models[0], 'binary_head'): |
|
print("ERROR: got a binary head, but model does not have one") |
|
exit(1) |
|
print("received binary head") |
|
binary_head_weight = msg.pop("weight") |
|
binary_head_bias = msg.pop("bias") |
|
for tp_rank in range(args.target_tensor_parallel_size): |
|
models[tp_rank].binary_head.weight.data.copy_(binary_head_weight) |
|
models[tp_rank].binary_head.bias.data.copy_(binary_head_bias) |
|
check_message(msg) |
|
msg = queue_get() |
|
|
|
if msg != "done": |
|
print("ERROR: got some more data but was expecting to be done") |
|
|
|
for tp_rank in range(args.target_tensor_parallel_size): |
|
mpu.initialize.set_tensor_model_parallel_rank(tp_rank) |
|
save_checkpoint(md.iteration, [models[tp_rank]], None, None) |
|
print("Done!") |
|
|