import torch import torch.nn as nn from tencentpretrain.layers.layer_norm import LayerNorm from tencentpretrain.utils import * class MlmTarget(nn.Module): """ BERT exploits masked language modeling (MLM) and next sentence prediction (NSP) for pretraining. """ def __init__(self, args, vocab_size): super(MlmTarget, self).__init__() self.vocab_size = vocab_size self.hidden_size = args.hidden_size self.emb_size = args.emb_size self.factorized_embedding_parameterization = args.factorized_embedding_parameterization self.act = str2act[args.hidden_act] if self.factorized_embedding_parameterization: self.linear_1 = nn.Linear(args.hidden_size, args.emb_size) self.layer_norm = LayerNorm(args.emb_size) self.linear_2 = nn.Linear(args.emb_size, self.vocab_size) else: self.linear_1 = nn.Linear(args.hidden_size, args.hidden_size) self.layer_norm = LayerNorm(args.hidden_size) self.linear_2 = nn.Linear(args.hidden_size, self.vocab_size) self.softmax = nn.LogSoftmax(dim=-1) self.criterion = nn.NLLLoss() def mlm(self, memory_bank, tgt_mlm): # Masked language modeling (MLM) with full softmax prediction. output_mlm = self.act(self.linear_1(memory_bank)) output_mlm = self.layer_norm(output_mlm) if self.factorized_embedding_parameterization: output_mlm = output_mlm.contiguous().view(-1, self.emb_size) else: output_mlm = output_mlm.contiguous().view(-1, self.hidden_size) tgt_mlm = tgt_mlm.contiguous().view(-1) output_mlm = output_mlm[tgt_mlm > 0, :] tgt_mlm = tgt_mlm[tgt_mlm > 0] output_mlm = self.linear_2(output_mlm) output_mlm = self.softmax(output_mlm) denominator = torch.tensor(output_mlm.size(0) + 1e-6) if output_mlm.size(0) == 0: correct_mlm = torch.tensor(0.0) else: correct_mlm = torch.sum((output_mlm.argmax(dim=-1).eq(tgt_mlm)).float()) loss_mlm = self.criterion(output_mlm, tgt_mlm) return loss_mlm, correct_mlm, denominator def forward(self, memory_bank, tgt, seg): """ Args: memory_bank: [batch_size x seq_length x hidden_size] tgt: [batch_size x seq_length] Returns: loss: Masked language modeling loss. correct: Number of words that are predicted correctly. denominator: Number of masked words. """ # Masked language model (MLM). loss, correct, denominator = self.mlm(memory_bank, tgt) return loss, correct, denominator