|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
import torch |
|
from torch import Tensor |
|
import torch.nn as nn |
|
|
|
from examples.simultaneous_translation.utils.p_choose_strategy import ( |
|
learnable_p_choose, |
|
waitk_p_choose |
|
) |
|
|
|
from examples.simultaneous_translation.utils.monotonic_attention import ( |
|
expected_alignment_from_p_choose, |
|
expected_soft_attention, |
|
mass_preservation, |
|
) |
|
from fairseq.modules import MultiheadAttention |
|
|
|
from . import register_monotonic_attention |
|
from typing import Dict, Optional |
|
|
|
|
|
@register_monotonic_attention("hard_aligned") |
|
class MonotonicAttention(MultiheadAttention): |
|
""" |
|
Abstract class of monotonic attentions |
|
""" |
|
k_in_proj: Dict[str, nn.Linear] |
|
q_in_proj: Dict[str, nn.Linear] |
|
|
|
def __init__(self, args): |
|
super().__init__( |
|
embed_dim=args.decoder_embed_dim, |
|
num_heads=args.decoder_attention_heads, |
|
kdim=getattr(args, "encoder_embed_dim", None), |
|
vdim=getattr(args, "encoder_embed_dim", None), |
|
dropout=args.attention_dropout, |
|
encoder_decoder_attention=True, |
|
) |
|
|
|
self.soft_attention = False |
|
|
|
self.eps = getattr(args, "attention_eps", True) |
|
self.mass_preservation = getattr(args, "mass_preservation", True) |
|
|
|
self.noise_type = args.noise_type |
|
self.noise_mean = args.noise_mean |
|
self.noise_var = args.noise_var |
|
|
|
self.energy_bias_init = args.energy_bias_init |
|
self.energy_bias = ( |
|
nn.Parameter(self.energy_bias_init * torch.ones([1])) |
|
if args.energy_bias is True |
|
else 0 |
|
) |
|
|
|
self.k_in_proj = {"monotonic": self.k_proj} |
|
self.q_in_proj = {"monotonic": self.q_proj} |
|
self.chunk_size = None |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
|
|
parser.add_argument('--no-mass-preservation', action="store_false", |
|
dest="mass_preservation", |
|
help='Do not stay on the last token when decoding') |
|
parser.add_argument('--mass-preservation', action="store_true", |
|
dest="mass_preservation", |
|
help='Stay on the last token when decoding') |
|
parser.set_defaults(mass_preservation=True) |
|
parser.add_argument('--noise-var', type=float, default=1.0, |
|
help='Variance of discretness noise') |
|
parser.add_argument('--noise-mean', type=float, default=0.0, |
|
help='Mean of discretness noise') |
|
parser.add_argument('--noise-type', type=str, default="flat", |
|
help='Type of discretness noise') |
|
parser.add_argument('--energy-bias', action="store_true", |
|
default=False, |
|
help='Bias for energy') |
|
parser.add_argument('--energy-bias-init', type=float, default=-2.0, |
|
help='Initial value of the bias for energy') |
|
parser.add_argument('--attention-eps', type=float, default=1e-6, |
|
help='Epsilon when calculating expected attention') |
|
|
|
def energy_from_qk( |
|
self, |
|
query: Tensor, |
|
key: Tensor, |
|
energy_type: str, |
|
key_padding_mask: Optional[Tensor] = None, |
|
bias: int = 0 |
|
): |
|
""" |
|
Compute energy from query and key |
|
q_func_value is a tuple looks like |
|
(q_proj_func, q_tensor) |
|
q_tensor size: bsz, tgt_len, emb_dim |
|
k_tensor size: bsz, src_len, emb_dim |
|
key_padding_mask size: bsz, src_len |
|
attn_mask: bsz, src_len |
|
""" |
|
|
|
length, bsz, _ = query.size() |
|
q = self.q_in_proj[energy_type].forward(query) |
|
q = ( |
|
q.contiguous() |
|
.view(length, bsz * self.num_heads, self.head_dim) |
|
.transpose(0, 1) |
|
) |
|
q = q * self.scaling |
|
length, bsz, _ = key.size() |
|
k = self.k_in_proj[energy_type].forward(key) |
|
k = ( |
|
k.contiguous() |
|
.view(length, bsz * self.num_heads, self.head_dim) |
|
.transpose(0, 1) |
|
) |
|
|
|
energy = torch.bmm(q, k.transpose(1, 2)) + bias |
|
|
|
if key_padding_mask is not None: |
|
energy = energy.masked_fill( |
|
key_padding_mask.unsqueeze(1).to(torch.bool), |
|
- float("inf") |
|
) |
|
|
|
return energy |
|
|
|
def p_choose_from_qk(self, query, key, key_padding_mask, incremental_states=None): |
|
monotonic_energy = self.energy_from_qk( |
|
query, |
|
key, |
|
"monotonic", |
|
key_padding_mask=key_padding_mask, |
|
bias=self.energy_bias, |
|
) |
|
|
|
p_choose = learnable_p_choose( |
|
monotonic_energy, |
|
self.noise_mean, |
|
self.noise_var, |
|
self.training |
|
) |
|
return p_choose |
|
|
|
def p_choose(self, query, key, key_padding_mask, incremental_states=None): |
|
return self.p_choose_from_qk(self, query, key, key_padding_mask) |
|
|
|
def monotonic_attention_process_infer( |
|
self, |
|
query: Optional[Tensor], |
|
key: Optional[Tensor], |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
|
): |
|
""" |
|
Monotonic attention at inference time |
|
Notice that this function is designed for simuleval not sequence_generator |
|
""" |
|
assert query is not None |
|
assert key is not None |
|
|
|
if query.size(1) != 1: |
|
raise RuntimeError( |
|
"Simultaneous translation models don't support batch decoding." |
|
) |
|
|
|
p_choose = self.p_choose( |
|
query, key, None, incremental_state |
|
).squeeze(1) |
|
|
|
|
|
src_len = key.size(0) |
|
|
|
max_steps = src_len - 1 if self.mass_preservation else src_len |
|
monotonic_cache = self._get_monotonic_buffer(incremental_state) |
|
|
|
monotonic_step = monotonic_cache.get( |
|
'head_step', |
|
p_choose.new_zeros(1, self.num_heads).long() |
|
) |
|
assert monotonic_step is not None |
|
finish_read = monotonic_step.eq(max_steps) |
|
p_choose_i = torch.tensor(1) |
|
|
|
while finish_read.sum().item() < self.num_heads: |
|
|
|
|
|
|
|
p_choose_i = ( |
|
p_choose.gather( |
|
1, |
|
monotonic_step |
|
.clamp(0, src_len - 1), |
|
) |
|
) |
|
|
|
read_one_step = ( |
|
(p_choose_i < 0.5) |
|
.type_as(monotonic_step) |
|
.masked_fill(finish_read, 0) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
monotonic_step += read_one_step |
|
|
|
finish_read = monotonic_step.eq(max_steps) | (read_one_step == 0) |
|
|
|
|
|
p_choose_i = ( |
|
p_choose.gather( |
|
1, |
|
monotonic_step |
|
.clamp(0, src_len - 1), |
|
) |
|
) |
|
|
|
monotonic_cache["head_step"] = monotonic_step |
|
|
|
monotonic_cache["head_read"] = ( |
|
monotonic_step.eq(max_steps) & (p_choose_i < 0.5) |
|
) |
|
self._set_monotonic_buffer(incremental_state, monotonic_cache) |
|
|
|
|
|
alpha = ( |
|
p_choose |
|
.new_zeros([self.num_heads, src_len]) |
|
.scatter( |
|
1, |
|
(monotonic_step) |
|
.view(self.num_heads, 1).clamp(0, src_len - 1), |
|
1 |
|
) |
|
) |
|
|
|
if not self.mass_preservation: |
|
alpha = alpha.masked_fill( |
|
(monotonic_step == max_steps) |
|
.view(self.num_heads, 1), |
|
0 |
|
) |
|
|
|
|
|
if self.soft_attention: |
|
monotonic_step = monotonic_step.t() |
|
beta_mask = torch.arange(src_len).expand_as(alpha).gt(monotonic_step).unsqueeze(1) |
|
|
|
soft_energy = self.energy_from_qk( |
|
query, |
|
key, |
|
"soft" |
|
) |
|
beta = torch.nn.functional.softmax( |
|
soft_energy.masked_fill(beta_mask, -float("inf")), dim=-1 |
|
) |
|
|
|
beta = beta.masked_fill(monotonic_step.eq(0).unsqueeze(1), 0) |
|
else: |
|
|
|
beta = alpha |
|
|
|
return p_choose, alpha, beta |
|
|
|
def monotonic_attention_process_train( |
|
self, |
|
query: Optional[Tensor], |
|
key: Optional[Tensor], |
|
key_padding_mask: Optional[Tensor] = None, |
|
): |
|
""" |
|
Calculating monotonic attention process for training |
|
Including: |
|
stepwise probability: p_choose |
|
expected hard alignment: alpha |
|
expected soft attention: beta |
|
""" |
|
assert query is not None |
|
assert key is not None |
|
|
|
|
|
p_choose = self.p_choose_from_qk(query, key, key_padding_mask) |
|
|
|
|
|
alpha = expected_alignment_from_p_choose( |
|
p_choose, |
|
key_padding_mask, |
|
eps=self.eps, |
|
) |
|
|
|
if self.mass_preservation: |
|
alpha = mass_preservation( |
|
alpha, key_padding_mask |
|
) |
|
|
|
|
|
if self.soft_attention: |
|
soft_energy = self.energy_from_qk( |
|
query, |
|
key, |
|
"soft", |
|
key_padding_mask=None, |
|
) |
|
|
|
beta = expected_soft_attention( |
|
alpha, |
|
soft_energy, |
|
padding_mask=key_padding_mask, |
|
chunk_size=self.chunk_size, |
|
eps=self.eps, |
|
) |
|
else: |
|
beta = alpha |
|
soft_energy = alpha |
|
|
|
return p_choose, alpha, beta, soft_energy |
|
|
|
def forward( |
|
self, |
|
query: Optional[Tensor], |
|
key: Optional[Tensor], |
|
value: Optional[Tensor], |
|
key_padding_mask: Optional[Tensor] = None, |
|
attn_mask: Optional[Tensor] = None, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
need_weights: bool = True, static_kv: bool = False, need_head_weights: bool = False, |
|
): |
|
""" |
|
query: tgt_len, bsz, embed_dim |
|
key: src_len, bsz, embed_dim |
|
value: src_len, bsz, embed_dim |
|
""" |
|
|
|
assert attn_mask is None |
|
assert query is not None |
|
assert key is not None |
|
assert value is not None |
|
|
|
tgt_len, bsz, embed_dim = query.size() |
|
src_len = value.size(0) |
|
|
|
if key_padding_mask is not None: |
|
assert not key_padding_mask[:, 0].any(), ( |
|
"Only right padding is supported." |
|
) |
|
key_padding_mask = ( |
|
key_padding_mask |
|
.unsqueeze(1) |
|
.expand([bsz, self.num_heads, src_len]) |
|
.contiguous() |
|
.view(-1, src_len) |
|
) |
|
|
|
if incremental_state is not None: |
|
|
|
( |
|
p_choose, alpha, beta |
|
) = self.monotonic_attention_process_infer( |
|
query, key, incremental_state |
|
) |
|
soft_energy = beta |
|
else: |
|
|
|
( |
|
p_choose, alpha, beta, soft_energy |
|
) = self.monotonic_attention_process_train( |
|
query, key, key_padding_mask |
|
) |
|
|
|
v = self.v_proj(value) |
|
length, bsz, _ = v.size() |
|
v = ( |
|
v.contiguous() |
|
.view(length, bsz * self.num_heads, self.head_dim) |
|
.transpose(0, 1) |
|
) |
|
|
|
attn = torch.bmm(beta.type_as(v), v) |
|
|
|
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) |
|
|
|
attn = self.out_proj(attn) |
|
|
|
p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len) |
|
alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len) |
|
beta = beta.view(bsz, self.num_heads, tgt_len, src_len) |
|
|
|
return attn, { |
|
"p_choose": p_choose, |
|
"alpha": alpha, |
|
"beta": beta, |
|
} |
|
|
|
def _get_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]): |
|
maybe_incremental_state = self.get_incremental_state( |
|
incremental_state, |
|
'monotonic', |
|
) |
|
if maybe_incremental_state is None: |
|
typed_empty_dict: Dict[str, Optional[Tensor]] = {} |
|
return typed_empty_dict |
|
else: |
|
return maybe_incremental_state |
|
|
|
def _set_monotonic_buffer(self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], buffer: Dict[str, Optional[Tensor]]): |
|
self.set_incremental_state( |
|
incremental_state, |
|
'monotonic', |
|
buffer, |
|
) |
|
|
|
|
|
@register_monotonic_attention("infinite_lookback") |
|
class MonotonicInfiniteLookbackAttention( |
|
MonotonicAttention |
|
): |
|
def __init__(self, args): |
|
super().__init__(args) |
|
self.soft_attention = True |
|
self.init_soft_attention() |
|
|
|
def init_soft_attention(self): |
|
self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True) |
|
self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True) |
|
self.k_in_proj["soft"] = self.k_proj_soft |
|
self.q_in_proj["soft"] = self.q_proj_soft |
|
|
|
if self.qkv_same_dim: |
|
|
|
|
|
nn.init.xavier_uniform_( |
|
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2) |
|
) |
|
nn.init.xavier_uniform_( |
|
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2) |
|
) |
|
else: |
|
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight) |
|
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight) |
|
|
|
|
|
@register_monotonic_attention("waitk") |
|
class WaitKAttention( |
|
MonotonicInfiniteLookbackAttention |
|
): |
|
""" |
|
STACL: Simultaneous Translation with Implicit Anticipation and |
|
Controllable Latency using Prefix-to-Prefix Framework |
|
https://www.aclweb.org/anthology/P19-1289/ |
|
""" |
|
def __init__(self, args): |
|
super().__init__(args) |
|
self.q_in_proj["soft"] = self.q_in_proj["monotonic"] |
|
self.k_in_proj["soft"] = self.k_in_proj["monotonic"] |
|
|
|
self.waitk_lagging = args.waitk_lagging |
|
assert self.waitk_lagging > 0, ( |
|
f"Lagging has to been larger than 0, get {self.waitk_lagging}." |
|
) |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
super( |
|
MonotonicInfiniteLookbackAttention, |
|
MonotonicInfiniteLookbackAttention |
|
).add_args(parser) |
|
|
|
parser.add_argument( |
|
"--waitk-lagging", type=int, required=True, help="Wait K lagging" |
|
) |
|
|
|
def p_choose_from_qk( |
|
self, |
|
query: Optional[Tensor], |
|
key: Optional[Tensor], |
|
key_padding_mask: Optional[Tensor] = None, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
): |
|
assert query is not None |
|
assert key is not None |
|
|
|
p_choose = waitk_p_choose( |
|
tgt_len=query.size(0), |
|
src_len=key.size(0), |
|
bsz=query.size(1) * self.num_heads, |
|
waitk_lagging=self.waitk_lagging, |
|
key_padding_mask=key_padding_mask, |
|
incremental_state=incremental_state, |
|
) |
|
|
|
return p_choose.to(query) |
|
|
|
|
|
@register_monotonic_attention("chunkwise") |
|
class ChunkwiseAttention( |
|
MonotonicInfiniteLookbackAttention |
|
): |
|
def __init__(self, args): |
|
super().__init__(args) |
|
self.chunk_size = args.mocha_chunk_size |
|
assert self.chunk_size > 1 |
|
|
|
@staticmethod |
|
def add_args(parser): |
|
super( |
|
MonotonicInfiniteLookbackAttention |
|
).add_args(parser) |
|
|
|
parser.add_argument( |
|
"--mocha-chunk-size", type=int, |
|
required=True, help="Mocha chunk size" |
|
) |
|
|