Spaces:
Runtime error
Runtime error
from tencentpretrain.targets import * | |
from tencentpretrain.utils.misc import * | |
class BilmTarget(LmTarget): | |
""" | |
Bi-directional Language Model Target | |
""" | |
def __init__(self, args, vocab_size): | |
args.hidden_size = args.hidden_size // 2 | |
super(BilmTarget, self).__init__(args, vocab_size) | |
def forward(self, memory_bank, tgt, seg): | |
""" | |
Args: | |
memory_bank: [batch_size x seq_length x hidden_size] | |
tgt: [batch_size x seq_length] | |
Returns: | |
loss: Language modeling loss. | |
correct: Number of words that are predicted correctly. | |
denominator: Number of predicted words. | |
""" | |
assert type(tgt) == tuple | |
tgt_forward, tgt_backward = tgt[0], tgt[1] | |
# Forward. | |
loss_forward, correct_forward, _ = \ | |
self.lm(memory_bank[:, :, :self.hidden_size], tgt_forward) | |
# Backward. | |
loss_backward, correct_backward, denominator_backward = \ | |
self.lm(memory_bank[:, :, self.hidden_size:], tgt_backward) | |
return loss_forward, loss_backward, correct_forward, correct_backward, denominator_backward | |