File size: 3,556 Bytes
d5175d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from typing import Optional, Dict
from torch import Tensor
import torch


def waitk(
    query, key, waitk_lagging: int, num_heads: int, key_padding_mask: Optional[Tensor] = None,
    incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None
):
    if incremental_state is not None:
        # Retrieve target length from incremental states
        # For inference the length of query is always 1
        tgt_len = incremental_state["steps"]["tgt"]
        assert tgt_len is not None
        tgt_len = int(tgt_len)
    else:
        tgt_len, bsz, _ = query.size()

    max_src_len, bsz, _ = key.size()

    if max_src_len < waitk_lagging:
        if incremental_state is not None:
            tgt_len = 1
        return query.new_zeros(
            bsz * num_heads, tgt_len, max_src_len
        )

    # Assuming the p_choose looks like this for wait k=3
    # src_len = 6, tgt_len = 5
    #   [0, 0, 1, 0, 0, 0, 0]
    #   [0, 0, 0, 1, 0, 0, 0]
    #   [0, 0, 0, 0, 1, 0, 0]
    #   [0, 0, 0, 0, 0, 1, 0]
    #   [0, 0, 0, 0, 0, 0, 1]
    # linearize the p_choose matrix:
    # [0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0...]
    # The indices of linearized matrix that equals 1 is
    # 2 + 6 * 0
    # 3 + 6 * 1
    # ...
    # n + src_len * n + k - 1 = n * (src_len + 1) + k - 1
    # n from 0 to tgt_len - 1
    #
    # First, generate the indices (activate_indices_offset: bsz, tgt_len)
    # Second, scatter a zeros tensor (bsz, tgt_len * src_len)
    # with activate_indices_offset
    # Third, resize the tensor to (bsz, tgt_len, src_len)

    activate_indices_offset = (
        (
            torch.arange(tgt_len) * (max_src_len + 1)
            + waitk_lagging - 1
        )
        .unsqueeze(0)
        .expand(bsz, tgt_len)
        .to(query)
        .long()
    )

    if key_padding_mask is not None:
        if key_padding_mask[:, 0].any():
            # Left padding
            activate_indices_offset += (
                key_padding_mask.sum(dim=1, keepdim=True)
            )

    # Need to clamp the indices that are too large
    activate_indices_offset = (
        activate_indices_offset
        .clamp(
            0,
            min(
                [
                    tgt_len,
                    max_src_len - waitk_lagging + 1
                ]
            ) * max_src_len - 1
        )
    )

    p_choose = torch.zeros(bsz, tgt_len * max_src_len).to(query)

    p_choose = p_choose.scatter(
        1,
        activate_indices_offset,
        1.0
    ).view(bsz, tgt_len, max_src_len)

    if incremental_state is not None:
        p_choose = p_choose[:, -1:]
        tgt_len = 1

    # Extend to each head
    p_choose = (
        p_choose.contiguous()
        .unsqueeze(1)
        .expand(-1, num_heads, -1, -1)
        .contiguous()
        .view(-1, tgt_len, max_src_len)
    )

    return p_choose


def hard_aligned(q_proj: Optional[Tensor], k_proj: Optional[Tensor], attn_energy, noise_mean: float = 0.0, noise_var: float = 0.0, training: bool = True):
    """
    Calculating step wise prob for reading and writing
    1 to read, 0 to write
    """

    noise = 0
    if training:
        # add noise here to encourage discretness
        noise = (
            torch.normal(noise_mean, noise_var, attn_energy.size())
            .type_as(attn_energy)
            .to(attn_energy.device)
        )

    p_choose = torch.sigmoid(attn_energy + noise)
    _, _, tgt_len, src_len = p_choose.size()

    # p_choose: bsz * self.num_heads, tgt_len, src_len
    return p_choose.view(-1, tgt_len, src_len)