File size: 10,654 Bytes
158b61b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
import torch
import torch.nn as nn
from onmt.utils.misc import aeq
from onmt.utils.loss import CommonLossCompute
def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs=None,
batch_dim=1, batch_offset=None):
"""
Given scores from an expanded dictionary
corresponeding to a batch, sums together copies,
with a dictionary word when it is ambiguous.
"""
offset = len(tgt_vocab)
for b in range(scores.size(batch_dim)):
blank = []
fill = []
if src_vocabs is None:
src_vocab = batch.src_ex_vocab[b]
else:
batch_id = batch_offset[b] if batch_offset is not None else b
index = batch.indices.data[batch_id]
src_vocab = src_vocabs[index]
for i in range(1, len(src_vocab)):
sw = src_vocab.itos[i]
ti = tgt_vocab.stoi[sw]
if ti != 0:
blank.append(offset + i)
fill.append(ti)
if blank:
blank = torch.Tensor(blank).type_as(batch.indices.data)
fill = torch.Tensor(fill).type_as(batch.indices.data)
score = scores[:, b] if batch_dim == 1 else scores[b]
score.index_add_(1, fill, score.index_select(1, blank))
score.index_fill_(1, blank, 1e-10)
return scores
class CopyGenerator(nn.Module):
"""An implementation of pointer-generator networks
:cite:`DBLP:journals/corr/SeeLM17`.
These networks consider copying words
directly from the source sequence.
The copy generator is an extended version of the standard
generator that computes three values.
* :math:`p_{softmax}` the standard softmax over `tgt_dict`
* :math:`p(z)` the probability of copying a word from
the source
* :math:`p_{copy}` the probility of copying a particular word.
taken from the attention distribution directly.
The model returns a distribution over the extend dictionary,
computed as
:math:`p(w) = p(z=1) p_{copy}(w) + p(z=0) p_{softmax}(w)`
.. mermaid::
graph BT
A[input]
S[src_map]
B[softmax]
BB[switch]
C[attn]
D[copy]
O[output]
A --> B
A --> BB
S --> D
C --> D
D --> O
B --> O
BB --> O
Args:
input_size (int): size of input representation
output_size (int): size of output vocabulary
pad_idx (int)
"""
def __init__(self, input_size, output_size, pad_idx):
super(CopyGenerator, self).__init__()
self.linear = nn.Linear(input_size, output_size)
self.linear_copy = nn.Linear(input_size, 1)
self.pad_idx = pad_idx
def forward(self, hidden, attn, src_map):
"""
Compute a distribution over the target dictionary
extended by the dynamic dictionary implied by copying
source words.
Args:
hidden (FloatTensor): hidden outputs ``(batch x tlen, input_size)``
attn (FloatTensor): attn for each ``(batch x tlen, slen)``
src_map (FloatTensor):
A sparse indicator matrix mapping each source word to
its index in the "extended" vocab containing.
``(src_len, batch, extra_words)``
"""
# CHECKS
batch_by_tlen, _ = hidden.size()
batch_by_tlen_, slen = attn.size()
slen_, batch, cvocab = src_map.size()
aeq(batch_by_tlen, batch_by_tlen_)
aeq(slen, slen_)
# Original probabilities.
logits = self.linear(hidden)
logits[:, self.pad_idx] = -float('inf')
prob = torch.softmax(logits, 1)
# Probability of copying p(z=1) batch.
p_copy = torch.sigmoid(self.linear_copy(hidden))
# Probability of not copying: p_{word}(w) * (1 - p(z))
out_prob = torch.mul(prob, 1 - p_copy)
mul_attn = torch.mul(attn, p_copy)
copy_prob = torch.bmm(
mul_attn.view(-1, batch, slen).transpose(0, 1),
src_map.transpose(0, 1)
).transpose(0, 1)
copy_prob = copy_prob.contiguous().view(-1, cvocab)
return torch.cat([out_prob, copy_prob], 1)
class CopyGeneratorLoss(nn.Module):
"""Copy generator criterion."""
def __init__(self, vocab_size, force_copy, unk_index=0,
ignore_index=-100, eps=1e-20):
super(CopyGeneratorLoss, self).__init__()
self.force_copy = force_copy
self.eps = eps
self.vocab_size = vocab_size
self.ignore_index = ignore_index
self.unk_index = unk_index
def forward(self, scores, align, target):
"""
Args:
scores (FloatTensor): ``(batch_size*tgt_len)`` x dynamic vocab size
whose sum along dim 1 is less than or equal to 1, i.e. cols
softmaxed.
align (LongTensor): ``(batch_size x tgt_len)``
target (LongTensor): ``(batch_size x tgt_len)``
"""
# probabilities assigned by the model to the gold targets
vocab_probs = scores.gather(1, target.unsqueeze(1)).squeeze(1)
# probability of tokens copied from source
copy_ix = align.unsqueeze(1) + self.vocab_size
copy_tok_probs = scores.gather(1, copy_ix).squeeze(1)
# Set scores for unk to 0 and add eps
copy_tok_probs[align == self.unk_index] = 0
copy_tok_probs += self.eps # to avoid -inf logs
# find the indices in which you do not use the copy mechanism
non_copy = align == self.unk_index
if not self.force_copy:
non_copy = non_copy | (target != self.unk_index)
probs = torch.where(
non_copy, copy_tok_probs + vocab_probs, copy_tok_probs
)
loss = -probs.log() # just NLLLoss; can the module be incorporated?
# Drop padding.
loss[target == self.ignore_index] = 0
return loss
class CommonCopyGeneratorLossCompute(CommonLossCompute):
"""Common Copy Generator Loss Computation."""
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
lambda_coverage=0.0, tgt_shift_index=1):
super(CommonCopyGeneratorLossCompute, self).__init__(
criterion, generator, lambda_coverage=lambda_coverage,
tgt_shift_index=tgt_shift_index)
self.tgt_vocab = tgt_vocab
self.normalize_by_length = normalize_by_length
def _compute_loss(self, batch, output, target, copy_attn, align,
std_attn=None, coverage_attn=None):
"""Compute the loss.
The args must match :func:`self._make_shard_state()`.
Args:
batch: the current batch.
output: the predict output from the model.
target: the validate target to compare output with.
copy_attn: the copy attention value.
align: the align info.
"""
target = target.view(-1)
align = align.view(-1)
scores = self.generator(
self._bottle(output), self._bottle(copy_attn), batch.src_map
)
loss = self.criterion(scores, align, target)
if self.lambda_coverage != 0.0:
coverage_loss = self._compute_coverage_loss(std_attn,
coverage_attn)
loss += coverage_loss
# this block does not depend on the loss value computed above
# and is used only for stats
scores_data = collapse_copy_scores(
self._unbottle(scores.clone(), batch.batch_size),
batch, self.tgt_vocab, None)
scores_data = self._bottle(scores_data)
# this block does not depend on the loss value computed above
# and is used only for stats
# Correct target copy token instead of <unk>
# tgt[i] = align[i] + len(tgt_vocab)
# for i such that tgt[i] == 0 and align[i] != 0
target_data = target.clone()
unk = self.criterion.unk_index
correct_mask = (target_data == unk) & (align != unk)
offset_align = align[correct_mask] + len(self.tgt_vocab)
target_data[correct_mask] += offset_align
# Compute sum of perplexities for stats
stats = self._stats(loss.sum().clone(), scores_data, target_data)
# this part looks like it belongs in CopyGeneratorLoss
if self.normalize_by_length:
# Compute Loss as NLL divided by seq length
tgt_lens = batch.tgt[:, :, 0].ne(self.padding_idx).sum(0).float()
# Compute Total Loss per sequence in batch
loss = loss.view(-1, batch.batch_size).sum(0)
# Divide by length of each sequence and sum
loss = torch.div(loss, tgt_lens).sum()
else:
loss = loss.sum()
return loss, stats
def _make_shard_state(self, batch, output, range_, attns):
"""See base class for args description."""
shard_state = super(CommonCopyGeneratorLossCompute,
self)._make_shard_state(batch, output,
range_, attns)
start_range = range_[0] + self.tgt_shift_index
end_range = range_[1]
shard_state.update({
"copy_attn": attns.get("copy"),
"align": batch.alignment[start_range: end_range]
})
return shard_state
class CopyGeneratorLossCompute(CommonCopyGeneratorLossCompute):
"""Copy Generator Loss Computation."""
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
lambda_coverage=0.0):
super(CopyGeneratorLossCompute, self).__init__(criterion, generator,
tgt_vocab,
normalize_by_length,
lambda_coverage=0.0,
tgt_shift_index=1)
class CopyGeneratorLMLossCompute(CommonCopyGeneratorLossCompute):
"""Copy Generator LM Loss Computation."""
def __init__(self, criterion, generator, tgt_vocab, normalize_by_length,
lambda_coverage=0.0):
super(CopyGeneratorLMLossCompute, self).__init__(criterion, generator,
tgt_vocab,
normalize_by_length,
lambda_coverage=0.0,
tgt_shift_index=0)
|