import torch from torch import nn from torch.nn import functional as F from torch.nn.init import ones_, trunc_normal_, zeros_ from .nrtr_decoder import TransformerBlock, Embeddings class CPA(nn.Module): def __init__(self, dim, max_len=25): super(CPA, self).__init__() self.fc1 = nn.Linear(dim, dim) self.fc2 = nn.Linear(dim, dim) self.fc3 = nn.Linear(dim, dim) self.pos_embed = nn.Parameter(torch.zeros([1, max_len + 1, dim], dtype=torch.float32), requires_grad=True) trunc_normal_(self.pos_embed, std=0.02) def forward(self, feat): # feat: B, L, Dim feat = feat.mean(1).unsqueeze(1) # B, 1, Dim x = self.fc1(feat) + self.pos_embed # B max_len dim x = F.softmax(self.fc2(F.tanh(x)), -1) # B max_len dim x = self.fc3(feat * x) # B max_len dim return x class ARDecoder(nn.Module): def __init__( self, in_channels, out_channels, nhead=None, num_decoder_layers=6, max_len=25, attention_dropout_rate=0.0, residual_dropout_rate=0.1, scale_embedding=True, ): super(ARDecoder, self).__init__() self.out_channels = out_channels self.ignore_index = out_channels - 1 self.bos = out_channels - 2 self.eos = 0 self.max_len = max_len d_model = in_channels dim_feedforward = d_model * 4 nhead = nhead if nhead is not None else d_model // 32 self.embedding = Embeddings( d_model=d_model, vocab=self.out_channels, padding_idx=0, scale_embedding=scale_embedding, ) self.pos_embed = nn.Parameter(torch.zeros([1, max_len + 1, d_model], dtype=torch.float32), requires_grad=True) trunc_normal_(self.pos_embed, std=0.02) self.decoder = nn.ModuleList([ TransformerBlock( d_model, nhead, dim_feedforward, attention_dropout_rate, residual_dropout_rate, with_self_attn=True, with_cross_attn=False, ) for i in range(num_decoder_layers) ]) self.tgt_word_prj = nn.Linear(d_model, self.out_channels - 2, bias=False) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward_train(self, src, tgt): tgt = tgt[:, :-1] tgt = self.embedding( tgt) + src[:, :tgt.shape[1]] + self.pos_embed[:, :tgt.shape[1]] tgt_mask = self.generate_square_subsequent_mask( tgt.shape[1], device=src.get_device()) for decoder_layer in self.decoder: tgt = decoder_layer(tgt, self_mask=tgt_mask) output = tgt logit = self.tgt_word_prj(output) return logit def forward(self, src, data=None): if self.training: max_len = data[1].max() tgt = data[0][:, :2 + max_len] res = self.forward_train(src, tgt) else: res = self.forward_test(src) return res def forward_test(self, src): bs = src.shape[0] src = src + self.pos_embed dec_seq = torch.full((bs, self.max_len + 1), self.ignore_index, dtype=torch.int64, device=src.get_device()) dec_seq[:, 0] = self.bos logits = [] for len_dec_seq in range(0, self.max_len): dec_seq_embed = self.embedding( dec_seq[:, :len_dec_seq + 1]) # N dim 26+10 # 012 a dec_seq_embed = dec_seq_embed + src[:, :len_dec_seq + 1] tgt_mask = self.generate_square_subsequent_mask( dec_seq_embed.shape[1], src.get_device()) tgt = dec_seq_embed # bs, 3, dim #bos, a, b, c, ... eos for decoder_layer in self.decoder: tgt = decoder_layer(tgt, self_mask=tgt_mask) dec_output = tgt dec_output = dec_output[:, -1:, :] word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1) logits.append(word_prob) if len_dec_seq < self.max_len: # greedy decode. add the next token index to the target input dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1) # Efficient batch decoding: If all output words have at least one EOS token, end decoding. if (dec_seq == self.eos).any(dim=-1).all(): break logits = torch.cat(logits, dim=1) return logits def generate_square_subsequent_mask(self, sz, device): """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ mask = torch.zeros([sz, sz], dtype=torch.float32) mask_inf = torch.triu( torch.full((sz, sz), dtype=torch.float32, fill_value=-torch.inf), diagonal=1, ) mask = mask + mask_inf return mask.unsqueeze(0).unsqueeze(0).to(device) class OTEDecoder(nn.Module): def __init__(self, in_channels, out_channels, max_len=25, num_heads=None, ar=False, num_decoder_layers=1, **kwargs): super(OTEDecoder, self).__init__() self.out_channels = out_channels - 2 # none + 26 + 10 dim = in_channels self.dim = dim self.max_len = max_len + 1 # max_len + eos self.cpa = CPA(dim=dim, max_len=max_len) self.ar = ar if ar: self.ar_decoder = ARDecoder(in_channels=dim, out_channels=out_channels, nhead=num_heads, num_decoder_layers=num_decoder_layers, max_len=max_len) else: self.fc = nn.Linear(dim, self.out_channels) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: zeros_(m.bias) elif isinstance(m, nn.LayerNorm): zeros_(m.bias) ones_(m.weight) @torch.jit.ignore def no_weight_decay(self): return {'pos_embed'} def forward(self, x, data=None): x = self.cpa(x) if self.ar: return self.ar_decoder(x, data=data) logits = self.fc(x) # B, 26, 37 if self.training: logits = logits[:, :data[1].max() + 1] return logits