|
from typing import Optional, Dict |
|
from torch import Tensor |
|
import torch |
|
|
|
|
|
def waitk( |
|
query, key, waitk_lagging: int, num_heads: int, key_padding_mask: Optional[Tensor] = None, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None |
|
): |
|
if incremental_state is not None: |
|
|
|
|
|
tgt_len = incremental_state["steps"]["tgt"] |
|
assert tgt_len is not None |
|
tgt_len = int(tgt_len) |
|
else: |
|
tgt_len, bsz, _ = query.size() |
|
|
|
max_src_len, bsz, _ = key.size() |
|
|
|
if max_src_len < waitk_lagging: |
|
if incremental_state is not None: |
|
tgt_len = 1 |
|
return query.new_zeros( |
|
bsz * num_heads, tgt_len, max_src_len |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
activate_indices_offset = ( |
|
( |
|
torch.arange(tgt_len) * (max_src_len + 1) |
|
+ waitk_lagging - 1 |
|
) |
|
.unsqueeze(0) |
|
.expand(bsz, tgt_len) |
|
.to(query) |
|
.long() |
|
) |
|
|
|
if key_padding_mask is not None: |
|
if key_padding_mask[:, 0].any(): |
|
|
|
activate_indices_offset += ( |
|
key_padding_mask.sum(dim=1, keepdim=True) |
|
) |
|
|
|
|
|
activate_indices_offset = ( |
|
activate_indices_offset |
|
.clamp( |
|
0, |
|
min( |
|
[ |
|
tgt_len, |
|
max_src_len - waitk_lagging + 1 |
|
] |
|
) * max_src_len - 1 |
|
) |
|
) |
|
|
|
p_choose = torch.zeros(bsz, tgt_len * max_src_len).to(query) |
|
|
|
p_choose = p_choose.scatter( |
|
1, |
|
activate_indices_offset, |
|
1.0 |
|
).view(bsz, tgt_len, max_src_len) |
|
|
|
if incremental_state is not None: |
|
p_choose = p_choose[:, -1:] |
|
tgt_len = 1 |
|
|
|
|
|
p_choose = ( |
|
p_choose.contiguous() |
|
.unsqueeze(1) |
|
.expand(-1, num_heads, -1, -1) |
|
.contiguous() |
|
.view(-1, tgt_len, max_src_len) |
|
) |
|
|
|
return p_choose |
|
|
|
|
|
def hard_aligned(q_proj: Optional[Tensor], k_proj: Optional[Tensor], attn_energy, noise_mean: float = 0.0, noise_var: float = 0.0, training: bool = True): |
|
""" |
|
Calculating step wise prob for reading and writing |
|
1 to read, 0 to write |
|
""" |
|
|
|
noise = 0 |
|
if training: |
|
|
|
noise = ( |
|
torch.normal(noise_mean, noise_var, attn_energy.size()) |
|
.type_as(attn_energy) |
|
.to(attn_energy.device) |
|
) |
|
|
|
p_choose = torch.sigmoid(attn_energy + noise) |
|
_, _, tgt_len, src_len = p_choose.size() |
|
|
|
|
|
return p_choose.view(-1, tgt_len, src_len) |
|
|