""" |
This includes: LossComputeBase and the standard NMTLossCompute, and |
sharded loss compute stuff. |
""" |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
import onmt |
from onmt.modules.sparse_losses import SparsemaxLoss |
from onmt.modules.sparse_activations import LogSparsemax |
from onmt.constants import ModelTask |
def build_loss_compute(model, tgt_field, opt, train=True): |
""" |
Returns a LossCompute subclass which wraps around an nn.Module subclass |
(such as nn.NLLLoss) which defines the loss criterion. The LossCompute |
object allows this loss to be computed in shards and passes the relevant |
data to a Statistics object which handles training/validation logging. |
Currently, the NMTLossCompute class handles all loss computation except |
for when using a copy mechanism. |
""" |
device = torch.device("cuda" if onmt.utils.misc.use_gpu(opt) else "cpu") |
padding_idx = tgt_field.vocab.stoi[tgt_field.pad_token] |
unk_idx = tgt_field.vocab.stoi[tgt_field.unk_token] |
if opt.lambda_coverage != 0: |
assert opt.coverage_attn, "--coverage_attn needs to be set in " \ |
"order to use --lambda_coverage != 0" |
if opt.copy_attn: |
criterion = onmt.modules.CopyGeneratorLoss( |
len(tgt_field.vocab), opt.copy_attn_force, |
unk_index=unk_idx, ignore_index=padding_idx |
) |
elif opt.label_smoothing > 0 and train: |
criterion = LabelSmoothingLoss( |
opt.label_smoothing, len(tgt_field.vocab), ignore_index=padding_idx |
) |
elif isinstance(model.generator[-1], LogSparsemax): |
criterion = SparsemaxLoss(ignore_index=padding_idx, reduction='sum') |
else: |
criterion = nn.NLLLoss(ignore_index=padding_idx, reduction='sum') |
use_raw_logits = isinstance(criterion, SparsemaxLoss) |
loss_gen = model.generator[0] if use_raw_logits else model.generator |
if opt.copy_attn: |
if opt.model_task == ModelTask.SEQ2SEQ: |
compute = onmt.modules.CopyGeneratorLossCompute( |
criterion, loss_gen, tgt_field.vocab, |
opt.copy_loss_by_seqlength, |
lambda_coverage=opt.lambda_coverage |
) |
elif opt.model_task == ModelTask.LANGUAGE_MODEL: |
compute = onmt.modules.CopyGeneratorLMLossCompute( |
criterion, loss_gen, tgt_field.vocab, |
opt.copy_loss_by_seqlength, |
lambda_coverage=opt.lambda_coverage |
) |
else: |
raise ValueError( |
f"No copy generator loss defined for task {opt.model_task}" |
) |
else: |
if opt.model_task == ModelTask.SEQ2SEQ: |
compute = NMTLossCompute( |
criterion, |
loss_gen, |
lambda_coverage=opt.lambda_coverage, |
lambda_align=opt.lambda_align, |
) |
elif opt.model_task == ModelTask.LANGUAGE_MODEL: |
assert ( |
opt.lambda_align == 0.0 |
), "lamdba_align not supported in LM loss" |
compute = LMLossCompute( |
criterion, |
loss_gen, |
lambda_coverage=opt.lambda_coverage, |
lambda_align=opt.lambda_align, |
) |
else: |
raise ValueError( |
f"No compute loss defined for task {opt.model_task}" |
) |
compute.to(device) |
return compute |
class LossComputeBase(nn.Module): |
""" |
Class for managing efficient loss computation. Handles |
sharding next step predictions and accumulating multiple |
loss computations |
Users can implement their own loss computation strategy by making |
subclass of this one. Users need to implement the _compute_loss() |
and make_shard_state() methods. |
Args: |
generator (:obj:`nn.Module`) : |
module that maps the output of the decoder to a |
distribution over the target vocabulary. |
tgt_vocab (:obj:`Vocab`) : |
torchtext vocab object representing the target output |
normalzation (str): normalize by "sents" or "tokens" |
""" |
def __init__(self, criterion, generator): |
super(LossComputeBase, self).__init__() |
self.criterion = criterion |
self.generator = generator |
@property |
def padding_idx(self): |
return self.criterion.ignore_index |
def _make_shard_state(self, batch, output, range_, attns=None): |
""" |
Make shard state dictionary for shards() to return iterable |
shards for efficient loss computation. Subclass must define |
this method to match its own _compute_loss() interface. |
Args: |
batch: the current batch. |
output: the predict output from the model. |
range_: the range of examples for computing, the whole |
batch or a trunc of it? |
attns: the attns dictionary returned from the model. |
""" |
return NotImplementedError |
def _compute_loss(self, batch, output, target, **kwargs): |
""" |
Compute the loss. Subclass must define this method. |
Args: |
batch: the current batch. |
output: the predict output from the model. |
target: the validate target to compare output with. |
**kwargs(optional): additional info for computing loss. |
""" |
return NotImplementedError |
def __call__(self, |
batch, |
output, |
attns, |
normalization=1.0, |
shard_size=0, |
trunc_start=0, |
trunc_size=None): |
"""Compute the forward loss, possibly in shards in which case this |
method also runs the backward pass and returns ``None`` as the loss |
value. |
Also supports truncated BPTT for long sequences by taking a |
range in the decoder output sequence to back propagate in. |
Range is from `(trunc_start, trunc_start + trunc_size)`. |
Note sharding is an exact efficiency trick to relieve memory |
required for the generation buffers. Truncation is an |
approximate efficiency trick to relieve the memory required |
in the RNN buffers. |
Args: |
batch (batch) : batch of labeled examples |
output (:obj:`FloatTensor`) : |
output of decoder model `[tgt_len x batch x hidden]` |
attns (dict) : dictionary of attention distributions |
`[tgt_len x batch x src_len]` |
normalization: Optional normalization factor. |
shard_size (int) : maximum number of examples in a shard |
trunc_start (int) : starting position of truncation window |
trunc_size (int) : length of truncation window |
Returns: |
A tuple with the loss and a :obj:`onmt.utils.Statistics` instance. |
""" |
if trunc_size is None: |
trunc_size = batch.tgt.size(0) - trunc_start |
trunc_range = (trunc_start, trunc_start + trunc_size) |
shard_state = self._make_shard_state(batch, output, trunc_range, attns) |
if shard_size == 0: |
loss, stats = self._compute_loss(batch, **shard_state) |
return loss / float(normalization), stats |
batch_stats = onmt.utils.Statistics() |
for shard in shards(shard_state, shard_size): |
loss, stats = self._compute_loss(batch, **shard) |
loss.div(float(normalization)).backward() |
batch_stats.update(stats) |
return None, batch_stats |
def _stats(self, loss, scores, target): |
""" |
Args: |
loss (:obj:`FloatTensor`): the loss computed by the loss criterion. |
scores (:obj:`FloatTensor`): a score for each possible output |
target (:obj:`FloatTensor`): true targets |
Returns: |
:obj:`onmt.utils.Statistics` : statistics for this batch. |
""" |
pred = scores.max(1)[1] |
non_padding = target.ne(self.padding_idx) |
num_correct = pred.eq(target).masked_select(non_padding).sum().item() |
num_non_padding = non_padding.sum().item() |
return onmt.utils.Statistics(loss.item(), num_non_padding, num_correct) |
def _bottle(self, _v): |
return _v.view(-1, _v.size(2)) |
def _unbottle(self, _v, batch_size): |
return _v.view(-1, batch_size, _v.size(1)) |
class LabelSmoothingLoss(nn.Module): |
""" |
With label smoothing, |
KL-divergence between q_{smoothed ground truth prob.}(w) |
and p_{prob. computed by model}(w) is minimized. |
""" |
def __init__(self, label_smoothing, tgt_vocab_size, ignore_index=-100): |
assert 0.0 < label_smoothing <= 1.0 |
self.ignore_index = ignore_index |
super(LabelSmoothingLoss, self).__init__() |
smoothing_value = label_smoothing / (tgt_vocab_size - 2) |
one_hot = torch.full((tgt_vocab_size,), smoothing_value) |
one_hot[self.ignore_index] = 0 |
self.register_buffer('one_hot', one_hot.unsqueeze(0)) |
self.confidence = 1.0 - label_smoothing |
def forward(self, output, target): |
""" |
output (FloatTensor): batch_size x n_classes |
target (LongTensor): batch_size |
""" |
model_prob = self.one_hot.repeat(target.size(0), 1) |
model_prob.scatter_(1, target.unsqueeze(1), self.confidence) |
model_prob.masked_fill_((target == self.ignore_index).unsqueeze(1), 0) |
return F.kl_div(output, model_prob, reduction='sum') |
class CommonLossCompute(LossComputeBase): |
""" |
Loss Computation parent for NMTLossCompute and LMLossCompute |
Implement loss compatible with coverage and alignement shards |
""" |
def __init__(self, criterion, generator, normalization="sents", |
lambda_coverage=0.0, lambda_align=0.0, tgt_shift_index=1): |
super(CommonLossCompute, self).__init__(criterion, generator) |
self.lambda_coverage = lambda_coverage |
self.lambda_align = lambda_align |
self.tgt_shift_index = tgt_shift_index |
def _add_coverage_shard_state(self, shard_state, attns): |
coverage = attns.get("coverage", None) |
std = attns.get("std", None) |
assert attns is not None |
assert coverage is not None, ( |
"lambda_coverage != 0.0 requires coverage attention" |
" that could not be found in the model." |
" Transformer decoders do not implement coverage" |
) |
assert std is not None, ( |
"lambda_coverage != 0.0 requires attention mechanism" |
" that could not be found in the model." |
) |
shard_state.update({"std_attn": attns.get("std"), |
"coverage_attn": coverage}) |
def _compute_loss(self, batch, output, target, std_attn=None, |
coverage_attn=None, align_head=None, ref_align=None): |
bottled_output = self._bottle(output) |
scores = self.generator(bottled_output) |
gtruth = target.view(-1) |
loss = self.criterion(scores, gtruth) |
if self.lambda_coverage != 0.0: |
coverage_loss = self._compute_coverage_loss( |
std_attn=std_attn, coverage_attn=coverage_attn) |
loss += coverage_loss |
if self.lambda_align != 0.0: |
if align_head.dtype != loss.dtype: |
align_head = align_head.to(loss.dtype) |
if ref_align.dtype != loss.dtype: |
ref_align = ref_align.to(loss.dtype) |
align_loss = self._compute_alignement_loss( |
align_head=align_head, ref_align=ref_align) |
loss += align_loss |
stats = self._stats(loss.clone(), scores, gtruth) |
return loss, stats |
def _compute_coverage_loss(self, std_attn, coverage_attn): |
covloss = torch.min(std_attn, coverage_attn).sum() |
covloss *= self.lambda_coverage |
return covloss |
def _add_align_shard_state(self, shard_state, batch, range_start, |
range_end, attns): |
attn_align = attns.get("align", None) |
align_idx = batch.align |
assert attns is not None |
assert attn_align is not None, ( |
"lambda_align != 0.0 requires " "alignement attention head" |
) |
assert align_idx is not None, ( |
"lambda_align != 0.0 requires " "provide guided alignement" |
) |
pad_tgt_size, batch_size, _ = batch.tgt.size() |
pad_src_size = batch.src[0].size(0) |
align_matrix_size = [batch_size, pad_tgt_size, pad_src_size] |
ref_align = onmt.utils.make_batch_align_matrix( |
align_idx, align_matrix_size, normalize=True |
) |
shard_state.update( |
{ |
"align_head": attn_align, |
"ref_align": ref_align[:, range_start:range_end, :], |
} |
) |
def _compute_alignement_loss(self, align_head, ref_align): |
"""Compute loss between 2 partial alignment matrix.""" |
align_loss = -align_head.clamp(min=1e-18).log().mul(ref_align).sum() |
align_loss *= self.lambda_align |
return align_loss |
def _make_shard_state(self, batch, output, range_, attns=None): |
range_start = range_[0] + self.tgt_shift_index |
range_end = range_[1] |
shard_state = { |
"output": output, |
"target": batch.tgt[range_start:range_end, :, 0], |
} |
if self.lambda_coverage != 0.0: |
self._add_coverage_shard_state(shard_state, attns) |
if self.lambda_align != 0.0: |
self._add_align_shard_state( |
shard_state, batch, range_start, range_end, attns |
) |
return shard_state |
class NMTLossCompute(CommonLossCompute): |
""" |
Standard NMT Loss Computation. |
""" |
def __init__(self, criterion, generator, normalization="sents", |
lambda_coverage=0.0, lambda_align=0.0): |
super(NMTLossCompute, self).__init__(criterion, generator, |
normalization=normalization, |
lambda_coverage=lambda_coverage, |
lambda_align=lambda_align, |
tgt_shift_index=1) |
class LMLossCompute(CommonLossCompute): |
""" |
Standard LM Loss Computation. |
""" |
def __init__(self, criterion, generator, normalization="sents", |
lambda_coverage=0.0, lambda_align=0.0): |
super(LMLossCompute, self).__init__(criterion, generator, |
normalization=normalization, |
lambda_coverage=lambda_coverage, |
lambda_align=lambda_align, |
tgt_shift_index=0) |
def filter_shard_state(state, shard_size=None): |
for k, v in state.items(): |
if shard_size is None: |
yield k, v |
if v is not None: |
v_split = [] |
if isinstance(v, torch.Tensor): |
for v_chunk in torch.split(v, shard_size): |
v_chunk = v_chunk.data.clone() |
v_chunk.requires_grad = v.requires_grad |
v_split.append(v_chunk) |
yield k, (v, v_split) |
def shards(state, shard_size, eval_only=False): |
""" |
Args: |
state: A dictionary which corresponds to the output of |
*LossCompute._make_shard_state(). The values for |
those keys are Tensor-like or None. |
shard_size: The maximum size of the shards yielded by the model. |
eval_only: If True, only yield the state, nothing else. |
Otherwise, yield shards. |
Yields: |
Each yielded shard is a dict. |
Side effect: |
After the last shard, this function does back-propagation. |
""" |
if eval_only: |
yield filter_shard_state(state) |
else: |
non_none = dict(filter_shard_state(state, shard_size)) |
keys, values = zip(*((k, [v_chunk for v_chunk in v_split]) |
for k, (_, v_split) in non_none.items())) |
for shard_tensors in zip(*values): |
yield dict(zip(keys, shard_tensors)) |
variables = [] |
for k, (v, v_split) in non_none.items(): |
if isinstance(v, torch.Tensor) and state[k].requires_grad: |
variables.extend(zip(torch.split(state[k], shard_size), |
[v_chunk.grad for v_chunk in v_split])) |
inputs, grads = zip(*variables) |
torch.autograd.backward(inputs, grads) |