szukevin's picture
upload
7900c16
from tencentpretrain.targets import *
from tencentpretrain.utils.misc import *
class BilmTarget(LmTarget):
"""
Bi-directional Language Model Target
"""
def __init__(self, args, vocab_size):
args.hidden_size = args.hidden_size // 2
super(BilmTarget, self).__init__(args, vocab_size)
def forward(self, memory_bank, tgt, seg):
"""
Args:
memory_bank: [batch_size x seq_length x hidden_size]
tgt: [batch_size x seq_length]
Returns:
loss: Language modeling loss.
correct: Number of words that are predicted correctly.
denominator: Number of predicted words.
"""
assert type(tgt) == tuple
tgt_forward, tgt_backward = tgt[0], tgt[1]
# Forward.
loss_forward, correct_forward, _ = \
self.lm(memory_bank[:, :, :self.hidden_size], tgt_forward)
# Backward.
loss_backward, correct_backward, denominator_backward = \
self.lm(memory_bank[:, :, self.hidden_size:], tgt_backward)
return loss_forward, loss_backward, correct_forward, correct_backward, denominator_backward