|
from typing import Optional |
|
import torch |
|
from torch import Tensor |
|
|
|
from examples.simultaneous_translation.utils.functions import ( |
|
exclusive_cumprod, |
|
prob_check, |
|
moving_sum, |
|
) |
|
|
|
|
|
def expected_alignment_from_p_choose( |
|
p_choose: Tensor, |
|
padding_mask: Optional[Tensor] = None, |
|
eps: float = 1e-6 |
|
): |
|
""" |
|
Calculating expected alignment for from stepwise probability |
|
|
|
Reference: |
|
Online and Linear-Time Attention by Enforcing Monotonic Alignments |
|
https://arxiv.org/pdf/1704.00784.pdf |
|
|
|
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j} |
|
a_ij = p_ij q_ij |
|
|
|
Parallel solution: |
|
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) |
|
|
|
============================================================ |
|
Expected input size |
|
p_choose: bsz, tgt_len, src_len |
|
""" |
|
prob_check(p_choose) |
|
|
|
|
|
bsz, tgt_len, src_len = p_choose.size() |
|
dtype = p_choose.dtype |
|
|
|
p_choose = p_choose.float() |
|
|
|
if padding_mask is not None: |
|
p_choose = p_choose.masked_fill(padding_mask.unsqueeze(1), 0.0) |
|
|
|
|
|
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=eps) |
|
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, eps, 1.0) |
|
|
|
alpha_0 = p_choose.new_zeros([bsz, 1, src_len]) |
|
alpha_0[:, :, 0] = 1.0 |
|
|
|
previous_alpha = [alpha_0] |
|
|
|
for i in range(tgt_len): |
|
|
|
|
|
|
|
|
|
alpha_i = ( |
|
p_choose[:, i] |
|
* cumprod_1mp[:, i] |
|
* torch.cumsum( |
|
previous_alpha[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1 |
|
) |
|
).clamp(0, 1.0) |
|
|
|
previous_alpha.append(alpha_i.unsqueeze(1)) |
|
|
|
|
|
alpha = torch.cat(previous_alpha[1:], dim=1) |
|
|
|
|
|
alpha = alpha.type(dtype) |
|
|
|
prob_check(alpha) |
|
|
|
return alpha |
|
|
|
|
|
def expected_soft_attention( |
|
alpha: Tensor, |
|
soft_energy: Tensor, |
|
padding_mask: Optional[Tensor] = None, |
|
chunk_size: Optional[int] = None, |
|
eps: float = 1e-10 |
|
): |
|
""" |
|
Function to compute expected soft attention for |
|
monotonic infinite lookback attention from |
|
expected alignment and soft energy. |
|
|
|
Reference: |
|
Monotonic Chunkwise Attention |
|
https://arxiv.org/abs/1712.05382 |
|
|
|
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation |
|
https://arxiv.org/abs/1906.05218 |
|
|
|
alpha: bsz, tgt_len, src_len |
|
soft_energy: bsz, tgt_len, src_len |
|
padding_mask: bsz, src_len |
|
left_padding: bool |
|
""" |
|
if padding_mask is not None: |
|
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0) |
|
soft_energy = soft_energy.masked_fill( |
|
padding_mask.unsqueeze(1), -float("inf") |
|
) |
|
|
|
prob_check(alpha) |
|
|
|
dtype = alpha.dtype |
|
|
|
alpha = alpha.float() |
|
soft_energy = soft_energy.float() |
|
|
|
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0] |
|
exp_soft_energy = torch.exp(soft_energy) + eps |
|
|
|
if chunk_size is not None: |
|
|
|
beta = ( |
|
exp_soft_energy |
|
* moving_sum( |
|
alpha / (eps + moving_sum(exp_soft_energy, chunk_size, 1)), |
|
1, chunk_size |
|
) |
|
) |
|
else: |
|
|
|
|
|
|
|
inner_items = alpha / (eps + torch.cumsum(exp_soft_energy, dim=2)) |
|
|
|
beta = ( |
|
exp_soft_energy |
|
* torch.cumsum(inner_items.flip(dims=[2]), dim=2) |
|
.flip(dims=[2]) |
|
) |
|
|
|
if padding_mask is not None: |
|
beta = beta.masked_fill( |
|
padding_mask.unsqueeze(1).to(torch.bool), 0.0) |
|
|
|
|
|
beta = beta.type(dtype) |
|
|
|
beta = beta.clamp(0, 1) |
|
|
|
prob_check(beta) |
|
|
|
return beta |
|
|
|
|
|
def mass_preservation( |
|
alpha: Tensor, |
|
padding_mask: Optional[Tensor] = None, |
|
left_padding: bool = False |
|
): |
|
""" |
|
Function to compute the mass perservation for alpha. |
|
This means that the residual weights of alpha will be assigned |
|
to the last token. |
|
|
|
Reference: |
|
Monotonic Infinite Lookback Attention for Simultaneous Machine Translation |
|
https://arxiv.org/abs/1906.05218 |
|
|
|
alpha: bsz, tgt_len, src_len |
|
padding_mask: bsz, src_len |
|
left_padding: bool |
|
""" |
|
|
|
prob_check(alpha) |
|
|
|
if padding_mask is not None: |
|
if not left_padding: |
|
assert not padding_mask[:, 0].any(), ( |
|
"Find padding on the beginning of the sequence." |
|
) |
|
alpha = alpha.masked_fill(padding_mask.unsqueeze(1), 0.0) |
|
|
|
if left_padding or padding_mask is None: |
|
residuals = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0, 1) |
|
alpha[:, :, -1] = residuals |
|
else: |
|
|
|
_, tgt_len, src_len = alpha.size() |
|
residuals = 1 - alpha.sum(dim=-1, keepdim=True).clamp(0, 1) |
|
src_lens = src_len - padding_mask.sum(dim=1, keepdim=True) |
|
src_lens = src_lens.expand(-1, tgt_len).contiguous() |
|
|
|
residuals += alpha.gather(2, src_lens.unsqueeze(2) - 1) |
|
alpha = alpha.scatter(2, src_lens.unsqueeze(2) - 1, residuals) |
|
|
|
prob_check(alpha) |
|
|
|
return alpha |
|
|