|
import json |
|
import os |
|
import sys |
|
import types |
|
|
|
import torch |
|
|
|
def add_arguments(parser): |
|
group = parser.add_argument_group(title='Megatron loader') |
|
|
|
group.add_argument('--true-vocab-size', type=int, default=None, |
|
help='original size of vocab, if specified will trim padding from embedding table.') |
|
group.add_argument('--vocab-file', type=str, default=None, |
|
help='Path to the vocab file. If specified will use this to get vocab size and ' |
|
'trim padding from the embedding table.') |
|
group.add_argument('--megatron-path', type=str, default=None, |
|
help='Base directory of deepspeed repository') |
|
|
|
def _load_checkpoint(queue, args): |
|
|
|
|
|
sys.path.append(os.path.abspath( |
|
os.path.join(os.path.dirname(__file__), |
|
os.path.pardir))) |
|
if args.megatron_path is not None: |
|
sys.path.insert(0, args.megatron_path) |
|
|
|
try: |
|
from megatron.arguments import parse_args, validate_args |
|
from megatron.global_vars import set_args, set_global_variables |
|
from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint |
|
from megatron.model import ModelType, module |
|
from megatron import mpu, fused_kernels |
|
except ModuleNotFoundError: |
|
print("Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting.") |
|
queue.put("exit") |
|
exit(1) |
|
|
|
|
|
sys.argv = ['script.py', |
|
'--no-masked-softmax-fusion', |
|
'--no-bias-gelu-fusion', |
|
'--no-bias-dropout-fusion', |
|
'--use-cpu-initialization', |
|
'--micro-batch-size', '1', |
|
'--no-load-optim', |
|
'--no-load-rng', |
|
'--no-save-optim', |
|
'--no-save-rng', |
|
'--no-initialization', |
|
'--load', args.load_dir |
|
] |
|
|
|
margs = parse_args() |
|
margs = load_args_from_checkpoint(margs) |
|
|
|
|
|
|
|
margs.world_size = margs.tensor_model_parallel_size * margs.pipeline_model_parallel_size |
|
|
|
margs = validate_args(margs) |
|
|
|
def check_for_arg(arg_name): |
|
if getattr(margs, arg_name, None) is None: |
|
print(f"Checkpoint does not specify the argument {arg_name}. Exiting.") |
|
print(f"Arguments: {margs}") |
|
queue.put("exit") |
|
exit(1) |
|
|
|
check_for_arg('tensor_model_parallel_size') |
|
check_for_arg('pipeline_model_parallel_size') |
|
check_for_arg('num_layers') |
|
check_for_arg('hidden_size') |
|
check_for_arg('seq_length') |
|
check_for_arg('num_attention_heads') |
|
check_for_arg('max_position_embeddings') |
|
check_for_arg('tokenizer_type') |
|
check_for_arg('iteration') |
|
check_for_arg('bert_binary_head') |
|
check_for_arg('params_dtype') |
|
|
|
|
|
if args.model_type == 'GPT': |
|
from pretrain_gpt import model_provider |
|
margs.model_type = ModelType.encoder_or_decoder |
|
elif args.model_type == 'BERT': |
|
from pretrain_bert import model_provider |
|
margs.model_type = ModelType.encoder_or_decoder |
|
else: |
|
raise Exception(f'unrecognized model type: {args.model_type}') |
|
|
|
|
|
module.MegatronModule.embedding_warning_printed = True |
|
|
|
consumed_train_samples = None |
|
consumed_valid_samples = None |
|
def get_models(count, dtype, pre_process, post_process): |
|
nonlocal consumed_train_samples |
|
nonlocal consumed_valid_samples |
|
models = [] |
|
for rank in range(count): |
|
mpu.initialize.set_tensor_model_parallel_rank(rank) |
|
model_ = [model_provider(pre_process, post_process).to(dtype)] |
|
margs.consumed_train_samples = 0 |
|
margs.consumed_valid_samples = 0 |
|
load_checkpoint(model_, None, None) |
|
assert(len(model_) == 1) |
|
model_ = model_[0] |
|
if consumed_train_samples is not None: |
|
assert(margs.consumed_train_samples == consumed_train_samples) |
|
else: |
|
consumed_train_samples = margs.consumed_train_samples |
|
if consumed_valid_samples is not None: |
|
assert(margs.consumed_valid_samples == consumed_valid_samples) |
|
else: |
|
consumed_valid_samples = margs.consumed_valid_samples |
|
models.append(model_) |
|
return models |
|
|
|
if margs.num_layers_per_virtual_pipeline_stage is not None: |
|
print("Model with an interleaved pipeline schedule are not yet supported.") |
|
queue.put("exit") |
|
exit(1) |
|
|
|
set_global_variables(margs) |
|
mpu.initialize.set_tensor_model_parallel_world_size(margs.tensor_model_parallel_size) |
|
mpu.initialize.set_pipeline_model_parallel_world_size(margs.pipeline_model_parallel_size) |
|
fused_kernels.load(margs) |
|
|
|
|
|
if args.true_vocab_size is not None: |
|
true_vocab_size = args.true_vocab_size |
|
elif args.vocab_file is not None: |
|
vocab = json.load(open(args.vocab_file)) |
|
true_vocab_size = len(vocab) |
|
if args.true_vocab_size is not None and true_vocab_size != args.true_vocab_size: |
|
print("Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting.") |
|
queue.put("exit") |
|
exit(1) |
|
else: |
|
true_vocab_size = None |
|
|
|
|
|
tp_size = margs.tensor_model_parallel_size |
|
pp_size = margs.pipeline_model_parallel_size |
|
|
|
|
|
md = types.SimpleNamespace() |
|
md.model_type = args.model_type |
|
md.num_layers = margs.num_layers |
|
md.hidden_size = margs.hidden_size |
|
md.seq_length = margs.seq_length |
|
md.num_attention_heads = margs.num_attention_heads |
|
md.max_position_embeddings = margs.max_position_embeddings |
|
md.tokenizer_type = margs.tokenizer_type |
|
md.iteration = margs.iteration |
|
md.params_dtype = margs.params_dtype |
|
md.bert_binary_head = margs.bert_binary_head |
|
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size |
|
md.previous_pipeline_parallel_size = margs.pipeline_model_parallel_size |
|
md.true_vocab_size = true_vocab_size |
|
md.make_vocab_size_divisible_by = margs.make_vocab_size_divisible_by |
|
|
|
|
|
mpu.initialize.set_pipeline_model_parallel_rank(0) |
|
post_process = pp_size == 1 |
|
models = get_models(tp_size, md.params_dtype, True, post_process) |
|
|
|
md.consumed_train_samples = consumed_train_samples |
|
md.consumed_valid_samples = consumed_valid_samples |
|
queue.put(md) |
|
|
|
def queue_put(name, msg): |
|
print(f"sending {name}") |
|
msg["name"] = name |
|
queue.put(msg) |
|
|
|
|
|
message = { |
|
"position embeddings": models[0].language_model.embedding.position_embeddings.weight.data, |
|
"word embeddings": torch.cat( |
|
[models[tp_rank].language_model.embedding.word_embeddings.weight.data for tp_rank in range(tp_size)], |
|
dim = 0) |
|
} |
|
|
|
queue_put("embeddings", message) |
|
|
|
total_layer_num = 0 |
|
for pp_rank in range(pp_size): |
|
if pp_rank > 0: |
|
mpu.initialize.set_pipeline_model_parallel_rank(pp_rank) |
|
post_process = pp_rank == pp_size - 1 |
|
models = get_models(tp_size, md.params_dtype, False, post_process) |
|
for layer_num in range(len(models[0].language_model.encoder.layers)): |
|
message = {} |
|
|
|
|
|
layer = models[0].language_model.encoder.layers[layer_num] |
|
message["input layernorm weight"] = layer.input_layernorm.weight.data |
|
message["input layernorm bias"] = layer.input_layernorm.bias.data |
|
message["dense bias"] = layer.self_attention.dense.bias.data |
|
message["post layernorm weight"] = layer.post_attention_layernorm.weight.data |
|
message["post layernorm bias"] = layer.post_attention_layernorm.bias.data |
|
message["mlp l1 bias"] = layer.mlp.dense_4h_to_h.bias.data |
|
|
|
|
|
qkv_weight = [] |
|
qkv_bias = [] |
|
dense_weight = [] |
|
mlp_l0_weight = [] |
|
mlp_l0_bias = [] |
|
mlp_l1_weight = [] |
|
for tp_rank, model in enumerate(models): |
|
layer = model.language_model.encoder.layers[layer_num] |
|
qkv_weight.append(layer.self_attention.query_key_value.weight.data) |
|
qkv_bias.append(layer.self_attention.query_key_value.bias.data) |
|
dense_weight.append(layer.self_attention.dense.weight.data) |
|
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data) |
|
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data) |
|
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data) |
|
|
|
|
|
message["qkv weight"] = torch.cat(qkv_weight, dim=0) |
|
message["qkv bias"] = torch.cat(qkv_bias, dim=0) |
|
message["dense weight"] = torch.cat(dense_weight, dim=1) |
|
message["mlp l0 weight"] = torch.cat(mlp_l0_weight, dim=0) |
|
message["mlp l0 bias"] = torch.cat(mlp_l0_bias, dim=0) |
|
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) |
|
|
|
queue_put(f"transformer layer {total_layer_num}", message) |
|
|
|
total_layer_num = total_layer_num + 1 |
|
|
|
|
|
message = { |
|
"weight": models[0].language_model.encoder.final_layernorm.weight.data, |
|
"bias": models[0].language_model.encoder.final_layernorm.bias.data |
|
} |
|
queue_put("final layernorm", message) |
|
|
|
|
|
if md.model_type == 'BERT': |
|
print("Sending LM Pooler") |
|
message = { |
|
"weight": models[0].language_model.pooler.dense.weight.data, |
|
"bias": models[0].language_model.pooler.dense.bias.data |
|
} |
|
queue_put("pooler", message) |
|
|
|
message = { |
|
"dense weight": models[0].lm_head.dense.weight.data, |
|
"dense bias": models[0].lm_head.dense.bias.data, |
|
"layernorm weight": models[0].lm_head.layernorm.weight.data, |
|
"layernorm bias": models[0].lm_head.layernorm.bias.data |
|
} |
|
queue_put("lm head", message) |
|
|
|
if md.bert_binary_head: |
|
print("Sending BERT Binary head") |
|
queue.put("binary head") |
|
message = { |
|
"weight": models[0].binary_head.weight.data, |
|
"bias": models[0].binary_head.bias.data |
|
} |
|
queue_put("binary head", message) |
|
queue.put("done") |
|
|
|
def load_checkpoint(queue, args): |
|
try: |
|
_load_checkpoint(queue, args) |
|
except: |
|
queue.put("exit") |
|
raise |
|
|