import torch import torch.nn as nn from tencentpretrain.utils.rope import precompute_freqs_cis from tencentpretrain.layers.transformer import TransformerLayer from tencentpretrain.layers.layer_norm import * from tencentpretrain.layers.relative_position_embedding import RelativePositionEmbedding class TransformerEncoder(nn.Module): """ BERT encoder exploits 12 or 24 transformer layers to extract features. """ def __init__(self, args): super(TransformerEncoder, self).__init__() self.mask = args.mask self.layers_num = args.layers_num self.parameter_sharing = args.parameter_sharing self.factorized_embedding_parameterization = args.factorized_embedding_parameterization self.layernorm_positioning = args.layernorm_positioning self.relative_position_embedding = args.relative_position_embedding self.rotary_position_embedding = args.rotary_position_embedding self.has_residual_attention = args.has_residual_attention if "deepspeed_checkpoint_activations" in args: self.deepspeed_checkpoint_activations = args.deepspeed_checkpoint_activations self.deepspeed_checkpoint_layers_num = args.deepspeed_checkpoint_layers_num else: self.deepspeed_checkpoint_activations = False has_bias = bool(1 - args.remove_transformer_bias) if self.factorized_embedding_parameterization: self.linear = nn.Linear(args.emb_size, args.hidden_size) if self.parameter_sharing: self.transformer = TransformerLayer(args) else: self.transformer = nn.ModuleList( [TransformerLayer(args) for _ in range(self.layers_num)] ) if self.layernorm_positioning == "pre": if args.layernorm == "t5": self.layer_norm = T5LayerNorm(args.hidden_size) elif args.layernorm == "rms": self.layer_norm = RMSNorm(args.hidden_size) else: self.layer_norm = LayerNorm(args.hidden_size) if self.relative_position_embedding: self.relative_pos_emb = RelativePositionEmbedding(bidirectional=True, heads_num=args.heads_num, num_buckets=args.relative_attention_buckets_num) elif self.rotary_position_embedding: self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2) def forward(self, emb, seg): """ Args: emb: [batch_size x seq_length x emb_size] seg: [batch_size x seq_length] Returns: hidden: [batch_size x seq_length x hidden_size] """ if self.factorized_embedding_parameterization: emb = self.linear(emb) batch_size, seq_length, _ = emb.size() # Generate mask according to segment indicators. # mask: [batch_size x 1 x seq_length x seq_length] if self.mask == "fully_visible": mask = (seg > 0). \ unsqueeze(1). \ repeat(1, seq_length, 1). \ unsqueeze(1) mask = mask.float() mask = (1.0 - mask) * -10000.0 elif self.mask == "causal": mask = torch.ones(seq_length, seq_length, device=emb.device) mask = torch.tril(mask) mask = (1.0 - mask) * -10000 mask = mask.repeat(batch_size, 1, 1, 1) else: mask_a = (seg == 1). \ unsqueeze(1). \ repeat(1, seq_length, 1). \ unsqueeze(1).float() mask_b = (seg > 0). \ unsqueeze(1). \ repeat(1, seq_length, 1). \ unsqueeze(1).float() mask_tril = torch.ones(seq_length, seq_length, device=emb.device) mask_tril = torch.tril(mask_tril) mask_tril = mask_tril.repeat(batch_size, 1, 1, 1) mask = (mask_a + mask_b + mask_tril >= 2).float() mask = (1.0 - mask) * -10000.0 hidden = emb if self.relative_position_embedding: position_bias = self.relative_pos_emb(hidden, hidden) else: position_bias = None if self.rotary_position_embedding: freqs_cis = self.freqs_cis[:seq_length].to(hidden.device) else: freqs_cis = None prev_attn = None if self.deepspeed_checkpoint_activations: from deepspeed import checkpointing def custom(start, end): def custom_forward(*inputs): x_, y_, position_bias_, freqs_cis_ = inputs for index in range(start, end): if self.parameter_sharing: x_, y_ = self.transformer(x_, mask, position_bias=position_bias_, has_residual_attention=self.has_residual_attention, prev_attn=y_, freqs_cis=freqs_cis_) else: x_, y_ = self.transformer[index](x_, mask, position_bias=position_bias_, has_residual_attention=self.has_residual_attention, prev_attn=y_, freqs_cis=freqs_cis_) return x_, y_ return custom_forward l = 0 while l < self.layers_num: hidden, prev_attn = checkpointing.checkpoint(custom(l, l + self.deepspeed_checkpoint_layers_num), hidden, prev_attn, position_bias, freqs_cis) l += self.deepspeed_checkpoint_layers_num else: for i in range(self.layers_num): if self.parameter_sharing: hidden, prev_attn = self.transformer(hidden, mask, position_bias=position_bias, has_residual_attention=self.has_residual_attention, prev_attn=prev_attn, freqs_cis=freqs_cis) else: hidden, prev_attn = self.transformer[i](hidden, mask, position_bias=position_bias, has_residual_attention=self.has_residual_attention, prev_attn=prev_attn, freqs_cis=freqs_cis) if self.layernorm_positioning == "pre": return self.layer_norm(hidden) else: return hidden