from typing import Optional, Dict from torch import Tensor import torch def waitk( query, key, waitk_lagging: int, num_heads: int, key_padding_mask: Optional[Tensor] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None ): if incremental_state is not None: # Retrieve target length from incremental states # For inference the length of query is always 1 tgt_len = incremental_state["steps"]["tgt"] assert tgt_len is not None tgt_len = int(tgt_len) else: tgt_len, bsz, _ = query.size() max_src_len, bsz, _ = key.size() if max_src_len < waitk_lagging: if incremental_state is not None: tgt_len = 1 return query.new_zeros( bsz * num_heads, tgt_len, max_src_len ) # Assuming the p_choose looks like this for wait k=3 # src_len = 6, tgt_len = 5 # [0, 0, 1, 0, 0, 0, 0] # [0, 0, 0, 1, 0, 0, 0] # [0, 0, 0, 0, 1, 0, 0] # [0, 0, 0, 0, 0, 1, 0] # [0, 0, 0, 0, 0, 0, 1] # linearize the p_choose matrix: # [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0...] # The indices of linearized matrix that equals 1 is # 2 + 6 * 0 # 3 + 6 * 1 # ... # n + src_len * n + k - 1 = n * (src_len + 1) + k - 1 # n from 0 to tgt_len - 1 # # First, generate the indices (activate_indices_offset: bsz, tgt_len) # Second, scatter a zeros tensor (bsz, tgt_len * src_len) # with activate_indices_offset # Third, resize the tensor to (bsz, tgt_len, src_len) activate_indices_offset = ( ( torch.arange(tgt_len) * (max_src_len + 1) + waitk_lagging - 1 ) .unsqueeze(0) .expand(bsz, tgt_len) .to(query) .long() ) if key_padding_mask is not None: if key_padding_mask[:, 0].any(): # Left padding activate_indices_offset += ( key_padding_mask.sum(dim=1, keepdim=True) ) # Need to clamp the indices that are too large activate_indices_offset = ( activate_indices_offset .clamp( 0, min( [ tgt_len, max_src_len - waitk_lagging + 1 ] ) * max_src_len - 1 ) ) p_choose = torch.zeros(bsz, tgt_len * max_src_len).to(query) p_choose = p_choose.scatter( 1, activate_indices_offset, 1.0 ).view(bsz, tgt_len, max_src_len) if incremental_state is not None: p_choose = p_choose[:, -1:] tgt_len = 1 # Extend to each head p_choose = ( p_choose.contiguous() .unsqueeze(1) .expand(-1, num_heads, -1, -1) .contiguous() .view(-1, tgt_len, max_src_len) ) return p_choose def hard_aligned(q_proj: Optional[Tensor], k_proj: Optional[Tensor], attn_energy, noise_mean: float = 0.0, noise_var: float = 0.0, training: bool = True): """ Calculating step wise prob for reading and writing 1 to read, 0 to write """ noise = 0 if training: # add noise here to encourage discretness noise = ( torch.normal(noise_mean, noise_var, attn_energy.size()) .type_as(attn_energy) .to(attn_energy.device) ) p_choose = torch.sigmoid(attn_energy + noise) _, _, tgt_len, src_len = p_choose.size() # p_choose: bsz * self.num_heads, tgt_len, src_len return p_choose.view(-1, tgt_len, src_len)