sakharamg's picture
Uploading all files
158b61b
import torch
import torch.nn as nn
from onmt.utils.misc import aeq
from onmt.utils.loss import CommonLossCompute
def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs=None,
batch_dim=1, batch_offset=None):
"""
Given scores from an expanded dictionary
corresponeding to a batch, sums together copies,
with a dictionary word when it is ambiguous.
"""
offset = len(tgt_vocab)
for b in range(scores.size(batch_dim)):
blank = []
fill = []
if src_vocabs is None:
src_vocab = batch.src_ex_vocab[b]
else:
batch_id = batch_offset[b] if batch_offset is not None else b
index = batch.indices.data[batch_id]
src_vocab = src_vocabs[index]
for i in range(1, len(src_vocab)):
sw = src_vocab.itos[i]
ti = tgt_vocab.stoi[sw]
if ti != 0:
blank.append(offset + i)
fill.append(ti)
if blank:
blank = torch.Tensor(blank).type_as(batch.indices.data)
fill = torch.Tensor(fill).type_as(batch.indices.data)
score = scores[:, b] if batch_dim == 1 else scores[b]
score.index_add_(1, fill, score.index_select(1, blank))
score.index_fill_(1, blank, 1e-10)
return scores
class CopyGenerator(nn.Module):
"""An implementation of pointer-generator networks
:cite:`DBLP:journals/corr/SeeLM17`.
These networks consider copying words
directly from the source sequence.
The copy generator is an extended version of the standard
generator that computes three values.
* :math:`p_{softmax}` the standard softmax over `tgt_dict`
* :math:`p(z)` the probability of copying a word from
the source
* :math:`p_{copy}` the probility of copying a particular word.
taken from the attention distribution directly.
The model returns a distribution over the extend dictionary,
computed as
:math:`p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)`
.. mermaid::
graph BT
A[input]
S[src_map]
B[softmax]
BB[switch]
C[attn]
D[copy]
O[output]
A --> B
A --> BB
S --> D
C --> D
D --> O
B --> O
BB --> O
Args:
input_size (int): size of input representation
output_size (int): size of output vocabulary
pad_idx (int)
"""
def __init__(self, input_size, output_size, pad_idx):
super(CopyGenerator, self).__init__()
self.linear = nn.Linear(input_size, output_size)
self.linear_copy = nn.Linear(input_size, 1)
self.pad_idx = pad_idx
def forward(self, hidden, attn, src_map):
"""
Compute a distribution over the target dictionary
extended by the dynamic dictionary implied by copying
source words.
Args:
hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
attn (FloatTensor): attn for each ``(batch x tlen, slen)``
src_map (FloatTensor):
A sparse indicator matrix mapping each source word to
its index in the "extended" vocab containing.
``(src_len, batch, extra_words)``
"""
# CHECKS
batch_by_tlen, _ = hidden.size()
batch_by_tlen_, slen = attn.size()
slen_, batch, cvocab = src_map.size()
aeq(batch_by_tlen, batch_by_tlen_)
aeq(slen, slen_)
# Original probabilities.
logits = self.linear(hidden)
logits[:, self.pad_idx] = -float('inf')
prob = torch.softmax(logits, 1)
# Probability of copying p(z=1) batch.
p_copy = torch.sigmoid(self.linear_copy(hidden))
# Probability of not copying: p_{word}(w) * (1 - p(z))
out_prob = torch.mul(prob, 1 - p_copy)
mul_attn = torch.mul(attn, p_copy)
copy_prob = torch.bmm(
mul_attn.view(-1, batch, slen).transpose(0, 1),
src_map.transpose(0, 1)
).transpose(0, 1)
copy_prob = copy_prob.contiguous().view(-1, cvocab)
return torch.cat([out_prob, copy_prob], 1)
class CopyGeneratorLoss(nn.Module):
"""Copy generator criterion."""
def __init__(self, vocab_size, force_copy, unk_index=0,
ignore_index=-100, eps=1e-20):
super(CopyGeneratorLoss, self).__init__()
self.force_copy = force_copy
self.eps = eps
self.vocab_size = vocab_size
self.ignore_index = ignore_index
self.unk_index = unk_index
def forward(self, scores, align, target):
"""
Args:
scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size
whose sum along dim 1 is less than or equal to 1, i.e. cols
softmaxed.
align (LongTensor): ``(batch_size x tgt_len)``
target (LongTensor): ``(batch_size x tgt_len)``
"""
# probabilities assigned by the model to the gold targets
vocab_probs = scores.gather(1, target.unsqueeze(1)).squeeze(1)
# probability of tokens copied from source
copy_ix = align.unsqueeze(1) + self.vocab_size
copy_tok_probs = scores.gather(1, copy_ix).squeeze(1)
# Set scores for unk to 0 and add eps
copy_tok_probs[align == self.unk_index] = 0
copy_tok_probs += self.eps # to avoid -inf logs
# find the indices in which you do not use the copy mechanism
non_copy = align == self.unk_index
if not self.force_copy:
non_copy = non_copy | (target != self.unk_index)
probs = torch.where(
non_copy, copy_tok_probs + vocab_probs, copy_tok_probs
)
loss = -probs.log() # just NLLLoss; can the module be incorporated?
# Drop padding.
loss[target == self.ignore_index] = 0
return loss
class CommonCopyGeneratorLossCompute(CommonLossCompute):
"""Common Copy Generator Loss Computation."""
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
lambda_coverage=0.0, tgt_shift_index=1):
super(CommonCopyGeneratorLossCompute, self).__init__(
criterion, generator, lambda_coverage=lambda_coverage,
tgt_shift_index=tgt_shift_index)
self.tgt_vocab = tgt_vocab
self.normalize_by_length = normalize_by_length
def _compute_loss(self, batch, output, target, copy_attn, align,
std_attn=None, coverage_attn=None):
"""Compute the loss.
The args must match :func:`self._make_shard_state()`.
Args:
batch: the current batch.
output: the predict output from the model.
target: the validate target to compare output with.
copy_attn: the copy attention value.
align: the align info.
"""
target = target.view(-1)
align = align.view(-1)
scores = self.generator(
self._bottle(output), self._bottle(copy_attn), batch.src_map
)
loss = self.criterion(scores, align, target)
if self.lambda_coverage != 0.0:
coverage_loss = self._compute_coverage_loss(std_attn,
coverage_attn)
loss += coverage_loss
# this block does not depend on the loss value computed above
# and is used only for stats
scores_data = collapse_copy_scores(
self._unbottle(scores.clone(), batch.batch_size),
batch, self.tgt_vocab, None)
scores_data = self._bottle(scores_data)
# this block does not depend on the loss value computed above
# and is used only for stats
# Correct target copy token instead of <unk>
# tgt[i] = align[i] + len(tgt_vocab)
# for i such that tgt[i] == 0 and align[i] != 0
target_data = target.clone()
unk = self.criterion.unk_index
correct_mask = (target_data == unk) & (align != unk)
offset_align = align[correct_mask] + len(self.tgt_vocab)
target_data[correct_mask] += offset_align
# Compute sum of perplexities for stats
stats = self._stats(loss.sum().clone(), scores_data, target_data)
# this part looks like it belongs in CopyGeneratorLoss
if self.normalize_by_length:
# Compute Loss as NLL divided by seq length
tgt_lens = batch.tgt[:, :, 0].ne(self.padding_idx).sum(0).float()
# Compute Total Loss per sequence in batch
loss = loss.view(-1, batch.batch_size).sum(0)
# Divide by length of each sequence and sum
loss = torch.div(loss, tgt_lens).sum()
else:
loss = loss.sum()
return loss, stats
def _make_shard_state(self, batch, output, range_, attns):
"""See base class for args description."""
shard_state = super(CommonCopyGeneratorLossCompute,
self)._make_shard_state(batch, output,
range_, attns)
start_range = range_[0] + self.tgt_shift_index
end_range = range_[1]
shard_state.update({
"copy_attn": attns.get("copy"),
"align": batch.alignment[start_range: end_range]
})
return shard_state
class CopyGeneratorLossCompute(CommonCopyGeneratorLossCompute):
"""Copy Generator Loss Computation."""
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
lambda_coverage=0.0):
super(CopyGeneratorLossCompute, self).__init__(criterion, generator,
tgt_vocab,
normalize_by_length,
lambda_coverage=0.0,
tgt_shift_index=1)
class CopyGeneratorLMLossCompute(CommonCopyGeneratorLossCompute):
"""Copy Generator LM Loss Computation."""
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
lambda_coverage=0.0):
super(CopyGeneratorLMLossCompute, self).__init__(criterion, generator,
tgt_vocab,
normalize_by_length,
lambda_coverage=0.0,
tgt_shift_index=0)