"""This code is refer from: https://github.com/jjwei66/BUSNet """ import torch import torch.nn as nn import torch.nn.functional as F from .nrtr_decoder import PositionalEncoding, TransformerBlock from .abinet_decoder import _get_mask, _get_length class BUSDecoder(nn.Module): def __init__(self, in_channels, out_channels, nhead=8, num_layers=4, dim_feedforward=2048, dropout=0.1, max_length=25, ignore_index=100, pretraining=False, detach=True): super().__init__() d_model = in_channels self.ignore_index = ignore_index self.pretraining = pretraining self.d_model = d_model self.detach = detach self.max_length = max_length + 1 # additional stop token self.out_channels = out_channels # -------------------------------------------------------------------------- # decoder specifics self.proj = nn.Linear(out_channels, d_model, False) self.token_encoder = PositionalEncoding(dropout=0.1, dim=d_model, max_len=self.max_length) self.pos_encoder = PositionalEncoding(dropout=0.1, dim=d_model, max_len=self.max_length) self.decoder = nn.ModuleList([ TransformerBlock( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, attention_dropout_rate=dropout, residual_dropout_rate=dropout, with_self_attn=False, with_cross_attn=True, ) for i in range(num_layers) ]) v_mask = torch.empty((1, 1, d_model)) l_mask = torch.empty((1, 1, d_model)) self.v_mask = nn.Parameter(v_mask) self.l_mask = nn.Parameter(l_mask) torch.nn.init.uniform_(self.v_mask, -0.001, 0.001) torch.nn.init.uniform_(self.l_mask, -0.001, 0.001) v_embeding = torch.empty((1, 1, d_model)) l_embeding = torch.empty((1, 1, d_model)) self.v_embeding = nn.Parameter(v_embeding) self.l_embeding = nn.Parameter(l_embeding) torch.nn.init.uniform_(self.v_embeding, -0.001, 0.001) torch.nn.init.uniform_(self.l_embeding, -0.001, 0.001) self.cls = nn.Linear(d_model, out_channels) def forward_decoder(self, q, x, mask=None): for decoder_layer in self.decoder: q = decoder_layer(q, x, cross_mask=mask) output = q # (N, T, E) logits = self.cls(output) # (N, T, C) return logits def forward(self, img_feat, data=None): """ Args: tokens: (N, T, C) where T is length, N is batch size and C is classes number lengths: (N,) """ img_feat = img_feat + self.v_embeding B, L, C = img_feat.shape # -------------------------------------------------------------------------- # decoder procedure T = self.max_length zeros = img_feat.new_zeros((B, T, C)) zeros_len = img_feat.new_zeros(B) query = self.pos_encoder(zeros) # 1. vision decode v_embed = torch.cat((img_feat, self.l_mask.repeat(B, T, 1)), dim=1) # v padding_mask = _get_mask( self.max_length + zeros_len, self.max_length) # 对tokens长度以外的padding # B, maxlen maxlen v_mask = torch.zeros((1, 1, self.max_length, L), device=img_feat.device).tile([B, 1, 1, 1]) # maxlen L mask = torch.cat((v_mask, padding_mask), 3) v_logits = self.forward_decoder(query, v_embed, mask=mask) # 2. language decode if self.training and self.pretraining: tgt = torch.where(data[0] == self.ignore_index, 0, data[0]) tokens = F.one_hot(tgt, num_classes=self.out_channels) tokens = tokens.float() lengths = data[-1] else: tokens = torch.softmax(v_logits, dim=-1) lengths = _get_length(v_logits) tokens = tokens.detach() token_embed = self.proj(tokens) # (N, T, E) token_embed = self.token_encoder(token_embed) # (T, N, E) token_embed = token_embed + self.l_embeding padding_mask = _get_mask(lengths, self.max_length) # 对tokens长度以外的padding mask = torch.cat((v_mask, padding_mask), 3) l_embed = torch.cat((self.v_mask.repeat(B, L, 1), token_embed), dim=1) l_logits = self.forward_decoder(query, l_embed, mask=mask) # 3. vision language decode vl_embed = torch.cat((img_feat, token_embed), dim=1) vl_logits = self.forward_decoder(query, vl_embed, mask=mask) if self.training: return {'align': [vl_logits], 'lang': l_logits, 'vision': v_logits} else: return F.softmax(vl_logits, -1)