import torch import torch.nn as nn from tencentpretrain.layers import * from tencentpretrain.layers.transformer import TransformerDecoderLayer from tencentpretrain.layers.layer_norm import LayerNorm, T5LayerNorm from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding class TransformerDecoder(nn.Module): """ BERT encoder exploits 12 or 24 transformer layers to extract features. """ def __init__(self, args): super(TransformerDecoder, self).__init__() self.layers_num = args.decoder_layers_num self.layernorm_positioning = args.layernorm_positioning self.relative_position_embedding = args.relative_position_embedding self.transformer_decoder = nn.ModuleList( [TransformerDecoderLayer(args) for _ in range(self.layers_num)] ) if "deepspeed_checkpoint_activations" in args: self.deepspeed_checkpoint_activations = args.deepspeed_checkpoint_activations self.deepspeed_checkpoint_layers_num = args.deepspeed_checkpoint_layers_num else: self.deepspeed_checkpoint_activations = False has_bias = bool(1 - args.remove_transformer_bias) if self.layernorm_positioning == "pre": if args.layernorm == "t5": self.layer_norm = T5LayerNorm(args.hidden_size) else: self.layer_norm = LayerNorm(args.hidden_size) if self.relative_position_embedding: self.self_pos_emb = RelativePositionEmbedding(bidirectional=False, heads_num=args.heads_num, num_buckets=args.relative_attention_buckets_num) def forward(self, memory_bank, emb, additional_info): """ Args: memory_bank: [batch_size x seq_length x emb_size] emb: [batch_size x seq_length x emb_size] Returns: hidden: [batch_size x seq_length x hidden_size] """ _, src_seq_length, _ = memory_bank.size() batch_size, tgt_seq_length, _ = emb.size() mask_encoder = (additional_info[0] > 0). \ unsqueeze(1). \ repeat(1, tgt_seq_length, 1). \ unsqueeze(1) mask_encoder = mask_encoder.float() mask_encoder = (1.0 - mask_encoder) * -10000.0 mask_decoder = torch.ones(tgt_seq_length, tgt_seq_length, device=emb.device) mask_decoder = torch.tril(mask_decoder) mask_decoder = (1.0 - mask_decoder) * -10000 mask_decoder = mask_decoder.repeat(batch_size, 1, 1, 1) hidden = emb if self.relative_position_embedding: self_position_bias = self.self_pos_emb(hidden, hidden) else: self_position_bias = None if self.deepspeed_checkpoint_activations: from deepspeed import checkpointing def custom(start, end): def custom_forward(*inputs): x_, memory_bank_, self_position_bias_ = inputs for index in range(start, end): x_ = self.transformer_decoder[index](x_, memory_bank_, mask_decoder, mask_encoder, self_position_bias_, None) return x_ return custom_forward l = 0 while l < self.layers_num: hidden = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), hidden, memory_bank, self_position_bias) l += self.deepspeed_checkpoint_layers_num else: for i in range(self.layers_num): hidden = self.transformer_decoder[i](hidden, memory_bank, mask_decoder, mask_encoder, self_position_bias, None) if self.layernorm_positioning == "pre": return self.layer_norm(hidden) else: return hidden