|
import torch |
|
import torch.nn as nn |
|
|
|
from onmt.utils.misc import aeq |
|
from onmt.utils.loss import CommonLossCompute |
|
|
|
|
|
def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs=None, |
|
batch_dim=1, batch_offset=None): |
|
""" |
|
Given scores from an expanded dictionary |
|
corresponeding to a batch, sums together copies, |
|
with a dictionary word when it is ambiguous. |
|
""" |
|
offset = len(tgt_vocab) |
|
for b in range(scores.size(batch_dim)): |
|
blank = [] |
|
fill = [] |
|
|
|
if src_vocabs is None: |
|
src_vocab = batch.src_ex_vocab[b] |
|
else: |
|
batch_id = batch_offset[b] if batch_offset is not None else b |
|
index = batch.indices.data[batch_id] |
|
src_vocab = src_vocabs[index] |
|
|
|
for i in range(1, len(src_vocab)): |
|
sw = src_vocab.itos[i] |
|
ti = tgt_vocab.stoi[sw] |
|
if ti != 0: |
|
blank.append(offset + i) |
|
fill.append(ti) |
|
if blank: |
|
blank = torch.Tensor(blank).type_as(batch.indices.data) |
|
fill = torch.Tensor(fill).type_as(batch.indices.data) |
|
score = scores[:, b] if batch_dim == 1 else scores[b] |
|
score.index_add_(1, fill, score.index_select(1, blank)) |
|
score.index_fill_(1, blank, 1e-10) |
|
return scores |
|
|
|
|
|
class CopyGenerator(nn.Module): |
|
"""An implementation of pointer-generator networks |
|
:cite:`DBLP:journals/corr/SeeLM17`. |
|
|
|
These networks consider copying words |
|
directly from the source sequence. |
|
|
|
The copy generator is an extended version of the standard |
|
generator that computes three values. |
|
|
|
* :math:`p_{softmax}` the standard softmax over `tgt_dict` |
|
* :math:`p(z)` the probability of copying a word from |
|
the source |
|
* :math:`p_{copy}` the probility of copying a particular word. |
|
taken from the attention distribution directly. |
|
|
|
The model returns a distribution over the extend dictionary, |
|
computed as |
|
|
|
:math:`p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)` |
|
|
|
|
|
.. mermaid:: |
|
|
|
graph BT |
|
A[input] |
|
S[src_map] |
|
B[softmax] |
|
BB[switch] |
|
C[attn] |
|
D[copy] |
|
O[output] |
|
A --> B |
|
A --> BB |
|
S --> D |
|
C --> D |
|
D --> O |
|
B --> O |
|
BB --> O |
|
|
|
|
|
Args: |
|
input_size (int): size of input representation |
|
output_size (int): size of output vocabulary |
|
pad_idx (int) |
|
""" |
|
|
|
def __init__(self, input_size, output_size, pad_idx): |
|
super(CopyGenerator, self).__init__() |
|
self.linear = nn.Linear(input_size, output_size) |
|
self.linear_copy = nn.Linear(input_size, 1) |
|
self.pad_idx = pad_idx |
|
|
|
def forward(self, hidden, attn, src_map): |
|
""" |
|
Compute a distribution over the target dictionary |
|
extended by the dynamic dictionary implied by copying |
|
source words. |
|
|
|
Args: |
|
hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)`` |
|
attn (FloatTensor): attn for each ``(batch x tlen, slen)`` |
|
src_map (FloatTensor): |
|
A sparse indicator matrix mapping each source word to |
|
its index in the "extended" vocab containing. |
|
``(src_len, batch, extra_words)`` |
|
""" |
|
|
|
|
|
batch_by_tlen, _ = hidden.size() |
|
batch_by_tlen_, slen = attn.size() |
|
slen_, batch, cvocab = src_map.size() |
|
aeq(batch_by_tlen, batch_by_tlen_) |
|
aeq(slen, slen_) |
|
|
|
|
|
logits = self.linear(hidden) |
|
logits[:, self.pad_idx] = -float('inf') |
|
prob = torch.softmax(logits, 1) |
|
|
|
|
|
p_copy = torch.sigmoid(self.linear_copy(hidden)) |
|
|
|
out_prob = torch.mul(prob, 1 - p_copy) |
|
mul_attn = torch.mul(attn, p_copy) |
|
copy_prob = torch.bmm( |
|
mul_attn.view(-1, batch, slen).transpose(0, 1), |
|
src_map.transpose(0, 1) |
|
).transpose(0, 1) |
|
copy_prob = copy_prob.contiguous().view(-1, cvocab) |
|
return torch.cat([out_prob, copy_prob], 1) |
|
|
|
|
|
class CopyGeneratorLoss(nn.Module): |
|
"""Copy generator criterion.""" |
|
def __init__(self, vocab_size, force_copy, unk_index=0, |
|
ignore_index=-100, eps=1e-20): |
|
super(CopyGeneratorLoss, self).__init__() |
|
self.force_copy = force_copy |
|
self.eps = eps |
|
self.vocab_size = vocab_size |
|
self.ignore_index = ignore_index |
|
self.unk_index = unk_index |
|
|
|
def forward(self, scores, align, target): |
|
""" |
|
Args: |
|
scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size |
|
whose sum along dim 1 is less than or equal to 1, i.e. cols |
|
softmaxed. |
|
align (LongTensor): ``(batch_size x tgt_len)`` |
|
target (LongTensor): ``(batch_size x tgt_len)`` |
|
""" |
|
|
|
vocab_probs = scores.gather(1, target.unsqueeze(1)).squeeze(1) |
|
|
|
|
|
copy_ix = align.unsqueeze(1) + self.vocab_size |
|
copy_tok_probs = scores.gather(1, copy_ix).squeeze(1) |
|
|
|
copy_tok_probs[align == self.unk_index] = 0 |
|
copy_tok_probs += self.eps |
|
|
|
|
|
non_copy = align == self.unk_index |
|
if not self.force_copy: |
|
non_copy = non_copy | (target != self.unk_index) |
|
|
|
probs = torch.where( |
|
non_copy, copy_tok_probs + vocab_probs, copy_tok_probs |
|
) |
|
|
|
loss = -probs.log() |
|
|
|
loss[target == self.ignore_index] = 0 |
|
return loss |
|
|
|
|
|
class CommonCopyGeneratorLossCompute(CommonLossCompute): |
|
"""Common Copy Generator Loss Computation.""" |
|
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, |
|
lambda_coverage=0.0, tgt_shift_index=1): |
|
super(CommonCopyGeneratorLossCompute, self).__init__( |
|
criterion, generator, lambda_coverage=lambda_coverage, |
|
tgt_shift_index=tgt_shift_index) |
|
self.tgt_vocab = tgt_vocab |
|
self.normalize_by_length = normalize_by_length |
|
|
|
def _compute_loss(self, batch, output, target, copy_attn, align, |
|
std_attn=None, coverage_attn=None): |
|
"""Compute the loss. |
|
|
|
The args must match :func:`self._make_shard_state()`. |
|
|
|
Args: |
|
batch: the current batch. |
|
output: the predict output from the model. |
|
target: the validate target to compare output with. |
|
copy_attn: the copy attention value. |
|
align: the align info. |
|
""" |
|
target = target.view(-1) |
|
align = align.view(-1) |
|
scores = self.generator( |
|
self._bottle(output), self._bottle(copy_attn), batch.src_map |
|
) |
|
loss = self.criterion(scores, align, target) |
|
|
|
if self.lambda_coverage != 0.0: |
|
coverage_loss = self._compute_coverage_loss(std_attn, |
|
coverage_attn) |
|
loss += coverage_loss |
|
|
|
|
|
|
|
scores_data = collapse_copy_scores( |
|
self._unbottle(scores.clone(), batch.batch_size), |
|
batch, self.tgt_vocab, None) |
|
scores_data = self._bottle(scores_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
target_data = target.clone() |
|
unk = self.criterion.unk_index |
|
correct_mask = (target_data == unk) & (align != unk) |
|
offset_align = align[correct_mask] + len(self.tgt_vocab) |
|
target_data[correct_mask] += offset_align |
|
|
|
|
|
stats = self._stats(loss.sum().clone(), scores_data, target_data) |
|
|
|
|
|
if self.normalize_by_length: |
|
|
|
tgt_lens = batch.tgt[:, :, 0].ne(self.padding_idx).sum(0).float() |
|
|
|
loss = loss.view(-1, batch.batch_size).sum(0) |
|
|
|
loss = torch.div(loss, tgt_lens).sum() |
|
else: |
|
loss = loss.sum() |
|
|
|
return loss, stats |
|
|
|
def _make_shard_state(self, batch, output, range_, attns): |
|
"""See base class for args description.""" |
|
shard_state = super(CommonCopyGeneratorLossCompute, |
|
self)._make_shard_state(batch, output, |
|
range_, attns) |
|
|
|
start_range = range_[0] + self.tgt_shift_index |
|
end_range = range_[1] |
|
shard_state.update({ |
|
"copy_attn": attns.get("copy"), |
|
"align": batch.alignment[start_range: end_range] |
|
}) |
|
return shard_state |
|
|
|
|
|
class CopyGeneratorLossCompute(CommonCopyGeneratorLossCompute): |
|
"""Copy Generator Loss Computation.""" |
|
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, |
|
lambda_coverage=0.0): |
|
super(CopyGeneratorLossCompute, self).__init__(criterion, generator, |
|
tgt_vocab, |
|
normalize_by_length, |
|
lambda_coverage=0.0, |
|
tgt_shift_index=1) |
|
|
|
|
|
class CopyGeneratorLMLossCompute(CommonCopyGeneratorLossCompute): |
|
"""Copy Generator LM Loss Computation.""" |
|
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length, |
|
lambda_coverage=0.0): |
|
super(CopyGeneratorLMLossCompute, self).__init__(criterion, generator, |
|
tgt_vocab, |
|
normalize_by_length, |
|
lambda_coverage=0.0, |
|
tgt_shift_index=0) |
|
|