VISOR-GPT / train /tencentpretrain /decoders /transformer_decoder.py
szukevin's picture
upload
7900c16
import torch
import torch.nn as nn
from tencentpretrain.layers import *
from tencentpretrain.layers.transformer import TransformerDecoderLayer
from tencentpretrain.layers.layer_norm import LayerNorm, T5LayerNorm
from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding
class TransformerDecoder(nn.Module):
"""
BERT encoder exploits 12 or 24 transformer layers to extract features.
"""
def __init__(self, args):
super(TransformerDecoder, self).__init__()
self.layers_num = args.decoder_layers_num
self.layernorm_positioning = args.layernorm_positioning
self.relative_position_embedding = args.relative_position_embedding
self.transformer_decoder = nn.ModuleList(
[TransformerDecoderLayer(args) for _ in range(self.layers_num)]
)
if "deepspeed_checkpoint_activations" in args:
self.deepspeed_checkpoint_activations = args.deepspeed_checkpoint_activations
self.deepspeed_checkpoint_layers_num = args.deepspeed_checkpoint_layers_num
else:
self.deepspeed_checkpoint_activations = False
has_bias = bool(1 - args.remove_transformer_bias)
if self.layernorm_positioning == "pre":
if args.layernorm == "t5":
self.layer_norm = T5LayerNorm(args.hidden_size)
else:
self.layer_norm = LayerNorm(args.hidden_size)
if self.relative_position_embedding:
self.self_pos_emb = RelativePositionEmbedding(bidirectional=False, heads_num=args.heads_num,
num_buckets=args.relative_attention_buckets_num)
def forward(self, memory_bank, emb, additional_info):
"""
Args:
memory_bank: [batch_size x seq_length x emb_size]
emb: [batch_size x seq_length x emb_size]
Returns:
hidden: [batch_size x seq_length x hidden_size]
"""
_, src_seq_length, _ = memory_bank.size()
batch_size, tgt_seq_length, _ = emb.size()
mask_encoder = (additional_info[0] > 0). \
unsqueeze(1). \
repeat(1, tgt_seq_length, 1). \
unsqueeze(1)
mask_encoder = mask_encoder.float()
mask_encoder = (1.0 - mask_encoder) * -10000.0
mask_decoder = torch.ones(tgt_seq_length, tgt_seq_length, device=emb.device)
mask_decoder = torch.tril(mask_decoder)
mask_decoder = (1.0 - mask_decoder) * -10000
mask_decoder = mask_decoder.repeat(batch_size, 1, 1, 1)
hidden = emb
if self.relative_position_embedding:
self_position_bias = self.self_pos_emb(hidden, hidden)
else:
self_position_bias = None
if self.deepspeed_checkpoint_activations:
from deepspeed import checkpointing
def custom(start, end):
def custom_forward(*inputs):
x_, memory_bank_, self_position_bias_ = inputs
for index in range(start, end):
x_ = self.transformer_decoder[index](x_, memory_bank_, mask_decoder, mask_encoder, self_position_bias_, None)
return x_
return custom_forward
l = 0
while l < self.layers_num:
hidden = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), hidden, memory_bank, self_position_bias)
l += self.deepspeed_checkpoint_layers_num
else:
for i in range(self.layers_num):
hidden = self.transformer_decoder[i](hidden, memory_bank, mask_decoder, mask_encoder, self_position_bias, None)
if self.layernorm_positioning == "pre":
return self.layer_norm(hidden)
else:
return hidden