File size: 7,219 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import ones_, trunc_normal_, zeros_

from .nrtr_decoder import TransformerBlock, Embeddings


class CPA(nn.Module):

    def __init__(self, dim, max_len=25):
        super(CPA, self).__init__()

        self.fc1 = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, dim)
        self.fc3 = nn.Linear(dim, dim)
        self.pos_embed = nn.Parameter(torch.zeros([1, max_len + 1, dim],
                                                  dtype=torch.float32),
                                      requires_grad=True)
        trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, feat):
        # feat: B, L, Dim
        feat = feat.mean(1).unsqueeze(1)  # B, 1, Dim
        x = self.fc1(feat) + self.pos_embed  # B max_len dim
        x = F.softmax(self.fc2(F.tanh(x)), -1)  # B max_len dim
        x = self.fc3(feat * x)  # B max_len dim
        return x


class ARDecoder(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        nhead=None,
        num_decoder_layers=6,
        max_len=25,
        attention_dropout_rate=0.0,
        residual_dropout_rate=0.1,
        scale_embedding=True,
    ):
        super(ARDecoder, self).__init__()
        self.out_channels = out_channels
        self.ignore_index = out_channels - 1
        self.bos = out_channels - 2
        self.eos = 0
        self.max_len = max_len
        d_model = in_channels
        dim_feedforward = d_model * 4
        nhead = nhead if nhead is not None else d_model // 32
        self.embedding = Embeddings(
            d_model=d_model,
            vocab=self.out_channels,
            padding_idx=0,
            scale_embedding=scale_embedding,
        )
        self.pos_embed = nn.Parameter(torch.zeros([1, max_len + 1, d_model],
                                                  dtype=torch.float32),
                                      requires_grad=True)
        trunc_normal_(self.pos_embed, std=0.02)
        self.decoder = nn.ModuleList([
            TransformerBlock(
                d_model,
                nhead,
                dim_feedforward,
                attention_dropout_rate,
                residual_dropout_rate,
                with_self_attn=True,
                with_cross_attn=False,
            ) for i in range(num_decoder_layers)
        ])

        self.tgt_word_prj = nn.Linear(d_model,
                                      self.out_channels - 2,
                                      bias=False)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward_train(self, src, tgt):
        tgt = tgt[:, :-1]

        tgt = self.embedding(
            tgt) + src[:, :tgt.shape[1]] + self.pos_embed[:, :tgt.shape[1]]
        tgt_mask = self.generate_square_subsequent_mask(
            tgt.shape[1], device=src.get_device())

        for decoder_layer in self.decoder:
            tgt = decoder_layer(tgt, self_mask=tgt_mask)
        output = tgt
        logit = self.tgt_word_prj(output)
        return logit

    def forward(self, src, data=None):

        if self.training:
            max_len = data[1].max()
            tgt = data[0][:, :2 + max_len]
            res = self.forward_train(src, tgt)
        else:
            res = self.forward_test(src)
        return res

    def forward_test(self, src):
        bs = src.shape[0]
        src = src + self.pos_embed
        dec_seq = torch.full((bs, self.max_len + 1),
                             self.ignore_index,
                             dtype=torch.int64,
                             device=src.get_device())
        dec_seq[:, 0] = self.bos
        logits = []
        for len_dec_seq in range(0, self.max_len):
            dec_seq_embed = self.embedding(
                dec_seq[:, :len_dec_seq + 1])  # N dim 26+10 # </s>  012 a
            dec_seq_embed = dec_seq_embed + src[:, :len_dec_seq + 1]
            tgt_mask = self.generate_square_subsequent_mask(
                dec_seq_embed.shape[1], src.get_device())
            tgt = dec_seq_embed  # bs, 3, dim #bos, a, b, c, ... eos
            for decoder_layer in self.decoder:
                tgt = decoder_layer(tgt, self_mask=tgt_mask)
            dec_output = tgt
            dec_output = dec_output[:, -1:, :]
            word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1)
            logits.append(word_prob)
            if len_dec_seq < self.max_len:
                # greedy decode. add the next token index to the target input
                dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1)
                # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
                if (dec_seq == self.eos).any(dim=-1).all():
                    break
        logits = torch.cat(logits, dim=1)
        return logits

    def generate_square_subsequent_mask(self, sz, device):
        """Generate a square mask for the sequence.

        The masked positions are filled with float('-inf'). Unmasked positions
        are filled with float(0.0).
        """
        mask = torch.zeros([sz, sz], dtype=torch.float32)
        mask_inf = torch.triu(
            torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf),
            diagonal=1,
        )
        mask = mask + mask_inf
        return mask.unsqueeze(0).unsqueeze(0).to(device)


class OTEDecoder(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 max_len=25,
                 num_heads=None,
                 ar=False,
                 num_decoder_layers=1,
                 **kwargs):
        super(OTEDecoder, self).__init__()

        self.out_channels = out_channels - 2  # none + 26 + 10
        dim = in_channels
        self.dim = dim
        self.max_len = max_len + 1  # max_len + eos

        self.cpa = CPA(dim=dim, max_len=max_len)
        self.ar = ar
        if ar:
            self.ar_decoder = ARDecoder(in_channels=dim,
                                        out_channels=out_channels,
                                        nhead=num_heads,
                                        num_decoder_layers=num_decoder_layers,
                                        max_len=max_len)
        else:
            self.fc = nn.Linear(dim, self.out_channels)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            zeros_(m.bias)
            ones_(m.weight)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed'}

    def forward(self, x, data=None):
        x = self.cpa(x)
        if self.ar:
            return self.ar_decoder(x, data=data)
        logits = self.fc(x)  # B, 26, 37
        if self.training:
            logits = logits[:, :data[1].max() + 1]
        return logits