|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Transformer based language model.""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from megatron import get_args |
|
from megatron import mpu |
|
from .module import MegatronModule |
|
from megatron.model.enums import LayerType, AttnMaskType |
|
from megatron.model.transformer import ParallelTransformer |
|
from megatron.model.utils import get_linear_layer |
|
from megatron.model.utils import init_method_normal, scaled_init_method_normal |
|
|
|
|
|
def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, |
|
bias=None): |
|
"""LM logits using word embedding weights.""" |
|
args = get_args() |
|
|
|
if args.async_tensor_model_parallel_allreduce or\ |
|
args.sequence_parallel: |
|
input_parallel = input_ |
|
model_parallel = mpu.get_tensor_model_parallel_world_size() > 1 |
|
async_grad_allreduce = args.async_tensor_model_parallel_allreduce and \ |
|
model_parallel and not args.sequence_parallel |
|
else: |
|
input_parallel = mpu.copy_to_tensor_model_parallel_region(input_) |
|
async_grad_allreduce = False |
|
|
|
|
|
logits_parallel = mpu.LinearWithGradAccumulationAndAsyncCommunication.apply( |
|
input_parallel, word_embeddings_weight, bias, |
|
args.gradient_accumulation_fusion, |
|
async_grad_allreduce, args.sequence_parallel) |
|
|
|
|
|
if parallel_output: |
|
return logits_parallel |
|
|
|
return mpu.gather_from_tensor_model_parallel_region(logits_parallel) |
|
|
|
|
|
def get_language_model(num_tokentypes, add_pooler, |
|
encoder_attn_mask_type, init_method=None, |
|
scaled_init_method=None, add_encoder=True, |
|
add_decoder=False, |
|
decoder_attn_mask_type=AttnMaskType.causal, |
|
pre_process=True, post_process=True): |
|
"""Build language model and return along with the key to save.""" |
|
args = get_args() |
|
|
|
if init_method is None: |
|
init_method = init_method_normal(args.init_method_std) |
|
|
|
if scaled_init_method is None: |
|
scaled_init_method = scaled_init_method_normal(args.init_method_std, |
|
args.num_layers) |
|
|
|
|
|
language_model = TransformerLanguageModel( |
|
init_method, |
|
scaled_init_method, |
|
encoder_attn_mask_type, |
|
num_tokentypes=num_tokentypes, |
|
add_encoder=add_encoder, |
|
add_decoder=add_decoder, |
|
decoder_attn_mask_type=decoder_attn_mask_type, |
|
add_pooler=add_pooler, |
|
pre_process=pre_process, |
|
post_process=post_process |
|
) |
|
|
|
language_model_key = 'language_model' |
|
|
|
return language_model, language_model_key |
|
|
|
|
|
class Pooler(MegatronModule): |
|
"""Pooler layer. |
|
|
|
Pool hidden states of a specific token (for example start of the |
|
sequence) and add a linear transformation followed by a tanh. |
|
|
|
Arguments: |
|
hidden_size: hidden size |
|
init_method: weight initialization method for the linear layer. |
|
bias is set to zero. |
|
""" |
|
|
|
def __init__(self, hidden_size, init_method): |
|
super(Pooler, self).__init__() |
|
args = get_args() |
|
self.dense = get_linear_layer(hidden_size, hidden_size, init_method) |
|
self.sequence_parallel = args.sequence_parallel |
|
|
|
|
|
def forward(self, hidden_states, sequence_index=0): |
|
|
|
|
|
|
|
|
|
|
|
if self.sequence_parallel: |
|
hidden_states = mpu.gather_from_sequence_parallel_region( |
|
hidden_states, |
|
tensor_parallel_output_grad=False) |
|
|
|
pooled = hidden_states[sequence_index, :, :] |
|
pooled = self.dense(pooled) |
|
pooled = torch.tanh(pooled) |
|
return pooled |
|
|
|
|
|
class Embedding(MegatronModule): |
|
"""Language model embeddings. |
|
|
|
Arguments: |
|
hidden_size: hidden size |
|
vocab_size: vocabulary size |
|
max_sequence_length: maximum size of sequence. This |
|
is used for positional embedding |
|
embedding_dropout_prob: dropout probability for embeddings |
|
init_method: weight initialization method |
|
num_tokentypes: size of the token-type embeddings. 0 value |
|
will ignore this embedding |
|
""" |
|
|
|
def __init__(self, |
|
hidden_size, |
|
vocab_size, |
|
max_sequence_length, |
|
embedding_dropout_prob, |
|
init_method, |
|
num_tokentypes=0): |
|
super(Embedding, self).__init__() |
|
|
|
self.hidden_size = hidden_size |
|
self.init_method = init_method |
|
self.num_tokentypes = num_tokentypes |
|
|
|
args = get_args() |
|
|
|
|
|
self.word_embeddings = mpu.VocabParallelEmbedding( |
|
vocab_size, self.hidden_size, |
|
init_method=self.init_method) |
|
self._word_embeddings_key = 'word_embeddings' |
|
|
|
|
|
self.position_embeddings = torch.nn.Embedding( |
|
max_sequence_length, self.hidden_size) |
|
self._position_embeddings_key = 'position_embeddings' |
|
|
|
if args.perform_initialization: |
|
self.init_method(self.position_embeddings.weight) |
|
|
|
|
|
|
|
|
|
|
|
self._tokentype_embeddings_key = 'tokentype_embeddings' |
|
if self.num_tokentypes > 0: |
|
self.tokentype_embeddings = torch.nn.Embedding(self.num_tokentypes, |
|
self.hidden_size) |
|
|
|
if args.perform_initialization: |
|
self.init_method(self.tokentype_embeddings.weight) |
|
else: |
|
self.tokentype_embeddings = None |
|
|
|
self.fp32_residual_connection = args.fp32_residual_connection |
|
self.sequence_parallel = args.sequence_parallel |
|
|
|
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob) |
|
|
|
def zero_parameters(self): |
|
"""Zero out all parameters in embedding.""" |
|
self.word_embeddings.weight.data.fill_(0) |
|
self.word_embeddings.weight.shared = True |
|
self.position_embeddings.weight.data.fill_(0) |
|
self.position_embeddings.weight.shared = True |
|
if self.num_tokentypes > 0: |
|
self.tokentype_embeddings.weight.data.fill_(0) |
|
self.tokentype_embeddings.weight.shared = True |
|
|
|
def add_tokentype_embeddings(self, num_tokentypes): |
|
"""Add token-type embedding. This function is provided so we can add |
|
token-type embeddings in case the pretrained model does not have it. |
|
This allows us to load the model normally and then add this embedding. |
|
""" |
|
if self.tokentype_embeddings is not None: |
|
raise Exception('tokentype embeddings is already initialized') |
|
if torch.distributed.get_rank() == 0: |
|
print('adding embedding for {} tokentypes'.format(num_tokentypes), |
|
flush=True) |
|
self.num_tokentypes = num_tokentypes |
|
self.tokentype_embeddings = torch.nn.Embedding(num_tokentypes, |
|
self.hidden_size) |
|
|
|
args = get_args() |
|
self.init_method(self.tokentype_embeddings.weight) |
|
|
|
def forward(self, input_ids, position_ids, tokentype_ids=None): |
|
|
|
words_embeddings = self.word_embeddings(input_ids) |
|
position_embeddings = self.position_embeddings(position_ids) |
|
embeddings = words_embeddings + position_embeddings |
|
if tokentype_ids is not None: |
|
assert self.tokentype_embeddings is not None |
|
embeddings = embeddings + self.tokentype_embeddings(tokentype_ids) |
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
embeddings = embeddings.transpose(0, 1).contiguous() |
|
|
|
|
|
if self.fp32_residual_connection: |
|
embeddings = embeddings.float() |
|
|
|
|
|
if self.sequence_parallel: |
|
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings) |
|
with mpu.get_cuda_rng_tracker().fork(): |
|
embeddings = self.embedding_dropout(embeddings) |
|
else: |
|
embeddings = self.embedding_dropout(embeddings) |
|
|
|
return embeddings |
|
|
|
def state_dict_for_save_checkpoint(self, destination=None, prefix='', |
|
keep_vars=False): |
|
"""For easy load.""" |
|
|
|
state_dict_ = {} |
|
state_dict_[self._word_embeddings_key] \ |
|
= self.word_embeddings.state_dict(destination, prefix, keep_vars) |
|
state_dict_[self._position_embeddings_key] \ |
|
= self.position_embeddings.state_dict( |
|
destination, prefix, keep_vars) |
|
if self.num_tokentypes > 0: |
|
state_dict_[self._tokentype_embeddings_key] \ |
|
= self.tokentype_embeddings.state_dict( |
|
destination, prefix, keep_vars) |
|
|
|
return state_dict_ |
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
"""Customized load.""" |
|
|
|
|
|
if self._word_embeddings_key in state_dict: |
|
state_dict_ = state_dict[self._word_embeddings_key] |
|
else: |
|
|
|
state_dict_ = {} |
|
for key in state_dict.keys(): |
|
if 'word_embeddings' in key: |
|
state_dict_[key.split('word_embeddings.')[1]] \ |
|
= state_dict[key] |
|
self.word_embeddings.load_state_dict(state_dict_, strict=strict) |
|
|
|
|
|
if self._position_embeddings_key in state_dict: |
|
state_dict_ = state_dict[self._position_embeddings_key] |
|
else: |
|
|
|
state_dict_ = {} |
|
for key in state_dict.keys(): |
|
if 'position_embeddings' in key: |
|
state_dict_[key.split('position_embeddings.')[1]] \ |
|
= state_dict[key] |
|
self.position_embeddings.load_state_dict(state_dict_, strict=strict) |
|
|
|
|
|
if self.num_tokentypes > 0: |
|
state_dict_ = {} |
|
if self._tokentype_embeddings_key in state_dict: |
|
state_dict_ = state_dict[self._tokentype_embeddings_key] |
|
else: |
|
|
|
for key in state_dict.keys(): |
|
if 'tokentype_embeddings' in key: |
|
state_dict_[key.split('tokentype_embeddings.')[1]] \ |
|
= state_dict[key] |
|
if len(state_dict_.keys()) > 0: |
|
self.tokentype_embeddings.load_state_dict(state_dict_, |
|
strict=strict) |
|
else: |
|
print('***WARNING*** expected tokentype embeddings in the ' |
|
'checkpoint but could not find it', flush=True) |
|
|
|
|
|
class TransformerLanguageModel(MegatronModule): |
|
"""Transformer language model. |
|
|
|
Arguments: |
|
transformer_hparams: transformer hyperparameters |
|
vocab_size: vocabulary size |
|
max_sequence_length: maximum size of sequence. This |
|
is used for positional embedding |
|
embedding_dropout_prob: dropout probability for embeddings |
|
num_tokentypes: size of the token-type embeddings. 0 value |
|
will ignore this embedding |
|
""" |
|
|
|
def __init__(self, |
|
init_method, |
|
output_layer_init_method, |
|
encoder_attn_mask_type, |
|
num_tokentypes=0, |
|
add_encoder=True, |
|
add_decoder=False, |
|
decoder_attn_mask_type=AttnMaskType.causal, |
|
add_pooler=False, |
|
pre_process=True, |
|
post_process=True): |
|
super(TransformerLanguageModel, self).__init__() |
|
args = get_args() |
|
|
|
self.pre_process = pre_process |
|
self.post_process = post_process |
|
self.hidden_size = args.hidden_size |
|
self.num_tokentypes = num_tokentypes |
|
self.init_method = init_method |
|
self.add_encoder = add_encoder |
|
self.encoder_attn_mask_type = encoder_attn_mask_type |
|
self.add_decoder = add_decoder |
|
self.decoder_attn_mask_type = decoder_attn_mask_type |
|
self.add_pooler = add_pooler |
|
self.encoder_hidden_state = None |
|
|
|
|
|
if self.pre_process: |
|
self.embedding = Embedding(self.hidden_size, |
|
args.padded_vocab_size, |
|
args.max_position_embeddings, |
|
args.hidden_dropout, |
|
self.init_method, |
|
self.num_tokentypes) |
|
self._embedding_key = 'embedding' |
|
|
|
|
|
|
|
|
|
if self.add_encoder: |
|
self.encoder = ParallelTransformer( |
|
self.init_method, |
|
output_layer_init_method, |
|
self_attn_mask_type=self.encoder_attn_mask_type, |
|
trans_layer_type="encoder", |
|
pre_process=self.pre_process, |
|
post_process=self.post_process |
|
) |
|
self._encoder_key = 'encoder' |
|
else: |
|
self.encoder = None |
|
|
|
|
|
|
|
if self.add_decoder: |
|
self.decoder = ParallelTransformer( |
|
self.init_method, |
|
output_layer_init_method, |
|
layer_type=LayerType.decoder, |
|
self_attn_mask_type=self.decoder_attn_mask_type, |
|
trans_layer_type="decoder", |
|
pre_process=self.pre_process, |
|
post_process=self.post_process) |
|
self._decoder_key = 'decoder' |
|
else: |
|
self.decoder = None |
|
|
|
if self.post_process: |
|
|
|
if self.add_pooler: |
|
self.pooler = Pooler(self.hidden_size, self.init_method) |
|
self._pooler_key = 'pooler' |
|
|
|
def set_input_tensor(self, input_tensor): |
|
""" See megatron.model.transformer.set_input_tensor()""" |
|
|
|
|
|
|
|
if not isinstance(input_tensor, list): |
|
input_tensor = [input_tensor] |
|
|
|
if self.add_encoder and self.add_decoder: |
|
assert len(input_tensor) == 1, \ |
|
'input_tensor should only be length 1 for stage with both encoder and decoder' |
|
self.encoder.set_input_tensor(input_tensor[0]) |
|
elif self.add_encoder: |
|
assert len(input_tensor) == 1, \ |
|
'input_tensor should only be length 1 for stage with only encoder' |
|
self.encoder.set_input_tensor(input_tensor[0]) |
|
elif self.add_decoder: |
|
if len(input_tensor) == 2: |
|
self.decoder.set_input_tensor(input_tensor[0]) |
|
self.encoder_hidden_state = input_tensor[1] |
|
elif len(input_tensor) == 1: |
|
self.decoder.set_input_tensor(None) |
|
self.encoder_hidden_state = input_tensor[0] |
|
else: |
|
raise Exception('input_tensor must have either length 1 or 2') |
|
else: |
|
raise Exception('Stage must have at least either encoder or decoder') |
|
|
|
def forward(self, enc_input_ids, enc_position_ids, enc_attn_mask, |
|
dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, |
|
enc_dec_attn_mask=None, tokentype_ids=None, |
|
inference_params=None, |
|
pooling_sequence_index=0, |
|
enc_hidden_states=None, output_enc_hidden=False): |
|
|
|
|
|
if self.pre_process: |
|
encoder_input = self.embedding(enc_input_ids, enc_position_ids, |
|
tokentype_ids=tokentype_ids) |
|
else: |
|
encoder_input = None |
|
|
|
|
|
if enc_hidden_states is None: |
|
if self.encoder is not None: |
|
encoder_output = self.encoder( |
|
encoder_input, |
|
enc_attn_mask, |
|
inference_params=inference_params) |
|
else: |
|
encoder_output = self.encoder_hidden_state |
|
else: |
|
encoder_output = enc_hidden_states.to(encoder_input.dtype) |
|
|
|
if self.post_process: |
|
if self.add_pooler: |
|
pooled_output = self.pooler(encoder_output, |
|
pooling_sequence_index) |
|
|
|
|
|
|
|
|
|
if not self.add_decoder or output_enc_hidden: |
|
if self.add_pooler and self.post_process: |
|
return encoder_output, pooled_output |
|
else: |
|
return encoder_output |
|
|
|
|
|
if self.pre_process: |
|
decoder_input = self.embedding(dec_input_ids, |
|
dec_position_ids) |
|
else: |
|
decoder_input = None |
|
|
|
|
|
decoder_output = self.decoder( |
|
decoder_input, |
|
dec_attn_mask, |
|
encoder_output=encoder_output, |
|
enc_dec_attn_mask=enc_dec_attn_mask, |
|
inference_params=inference_params) |
|
|
|
if self.add_pooler and self.post_process: |
|
return decoder_output, encoder_output, pooled_output |
|
else: |
|
return decoder_output, encoder_output |
|
|
|
def state_dict_for_save_checkpoint(self, destination=None, prefix='', |
|
keep_vars=False): |
|
"""For easy load.""" |
|
|
|
state_dict_ = {} |
|
if self.pre_process: |
|
state_dict_[self._embedding_key] \ |
|
= self.embedding.state_dict_for_save_checkpoint( |
|
destination, prefix, keep_vars) |
|
if self.add_encoder: |
|
state_dict_[self._encoder_key] \ |
|
= self.encoder.state_dict_for_save_checkpoint( |
|
destination, prefix, keep_vars) |
|
if self.post_process: |
|
if self.add_pooler: |
|
state_dict_[self._pooler_key] \ |
|
= self.pooler.state_dict_for_save_checkpoint( |
|
destination, prefix, keep_vars) |
|
if self.add_decoder: |
|
state_dict_[self._decoder_key] \ |
|
= self.decoder.state_dict_for_save_checkpoint( |
|
destination, prefix, keep_vars) |
|
|
|
return state_dict_ |
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
"""Customized load.""" |
|
|
|
|
|
if self.pre_process: |
|
if self._embedding_key in state_dict: |
|
state_dict_ = state_dict[self._embedding_key] |
|
else: |
|
|
|
state_dict_ = {} |
|
for key in state_dict.keys(): |
|
if '_embeddings' in key: |
|
state_dict_[key] = state_dict[key] |
|
self.embedding.load_state_dict(state_dict_, strict=strict) |
|
|
|
|
|
if self.add_encoder: |
|
if self._encoder_key in state_dict: |
|
state_dict_ = state_dict[self._encoder_key] |
|
|
|
elif 'transformer' in state_dict: |
|
state_dict_ = state_dict['transformer'] |
|
else: |
|
|
|
state_dict_ = {} |
|
for key in state_dict.keys(): |
|
if 'transformer.' in key: |
|
state_dict_[key.split('transformer.')[1]] = state_dict[key] |
|
|
|
|
|
state_dict_self_attention = {} |
|
for key in state_dict_.keys(): |
|
if '.attention.' in key: |
|
state_dict_self_attention[key.replace(".attention.", |
|
".self_attention.")] = state_dict_[key] |
|
else: |
|
state_dict_self_attention[key] = state_dict_[key] |
|
state_dict_ = state_dict_self_attention |
|
|
|
self.encoder.load_state_dict(state_dict_, strict=strict) |
|
|
|
|
|
if self.post_process: |
|
if self.add_pooler: |
|
assert 'pooler' in state_dict, \ |
|
'could not find data for pooler in the checkpoint' |
|
self.pooler.load_state_dict(state_dict[self._pooler_key], |
|
strict=strict) |
|
|
|
if self.add_decoder: |
|
assert 'decoder' in state_dict, \ |
|
'could not find data for pooler in the checkpoint' |
|
self.decoder.load_state_dict(state_dict[self._decoder_key], |
|
strict=strict) |
|
|