import torch import torch.nn as nn import torch.nn.functional as F from openrec.modeling.decoders.nrtr_decoder import Embeddings, PositionalEncoding, TransformerBlock # , Beam from openrec.modeling.decoders.visionlan_decoder import Transformer_Encoder def generate_square_subsequent_mask(sz): r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = (mask.float().masked_fill(mask == 0, float('-inf')).masked_fill( mask == 1, float(0.0))) return mask class SEM_Pre(nn.Module): def __init__( self, d_model=512, dst_vocab_size=40, residual_dropout_rate=0.1, ): super(SEM_Pre, self).__init__() self.embedding = Embeddings(d_model=d_model, vocab=dst_vocab_size) self.positional_encoding = PositionalEncoding( dropout=residual_dropout_rate, dim=d_model, ) def forward(self, tgt): tgt = self.embedding(tgt) tgt = self.positional_encoding(tgt) tgt_mask = generate_square_subsequent_mask(tgt.shape[1]).to(tgt.device) return tgt, tgt_mask class POS_Pre(nn.Module): def __init__( self, d_model=512, ): super(POS_Pre, self).__init__() self.pos_encoding = PositionalEncoding( dropout=0.1, dim=d_model, ) self.linear1 = nn.Linear(d_model, d_model) self.linear2 = nn.Linear(d_model, d_model) self.norm2 = nn.LayerNorm(d_model) def forward(self, tgt): pos = tgt.new_zeros(*tgt.shape) pos = self.pos_encoding(pos) pos2 = self.linear2(F.relu(self.linear1(pos))) pos = self.norm2(pos + pos2) return pos class DSF(nn.Module): def __init__(self, d_model, fusion_num): super(DSF, self).__init__() self.w_att = nn.Linear(fusion_num * d_model, d_model) def forward(self, l_feature, v_feature): """ Args: l_feature: (N, T, E) where T is length, N is batch size and d is dim of model v_feature: (N, T, E) shape the same as l_feature l_lengths: (N,) v_lengths: (N,) """ f = torch.cat((l_feature, v_feature), dim=2) f_att = torch.sigmoid(self.w_att(f)) output = f_att * v_feature + (1 - f_att) * l_feature return output class MDCDP(nn.Module): r""" Multi-Domain CharacterDistance Perception """ def __init__(self, d_model, n_head, d_inner, num_layers): super(MDCDP, self).__init__() self.num_layers = num_layers # step 1 SAE self.layers_pos = nn.ModuleList([ TransformerBlock(d_model, n_head, d_inner) for _ in range(num_layers) ]) # step 2 CBI: self.layers2 = nn.ModuleList([ TransformerBlock( d_model, n_head, d_inner, with_self_attn=False, with_cross_attn=True, ) for _ in range(num_layers) ]) self.layers3 = nn.ModuleList([ TransformerBlock( d_model, n_head, d_inner, with_self_attn=False, with_cross_attn=True, ) for _ in range(num_layers) ]) # step 3 :DSF self.dynamic_shared_fusion = DSF(d_model, 2) def forward( self, sem, vis, pos, tgt_mask=None, memory_mask=None, ): for i in range(self.num_layers): # ----------step 1 -----------: SAE: Self-Attention Enhancement pos = self.layers_pos[i](pos, self_mask=tgt_mask) # ----------step 2 -----------: CBI: Cross-Branch Interaction # CBI-V pos_vis = self.layers2[i]( pos, vis, cross_mask=memory_mask, ) # CBI-S pos_sem = self.layers3[i]( pos, sem, cross_mask=tgt_mask, ) # ----------step 3 -----------: DSF: Dynamic Shared Fusion pos = self.dynamic_shared_fusion(pos_vis, pos_sem) output = pos return output class ConvBnRelu(nn.Module): # adapt padding for kernel_size change def __init__( self, in_channels, out_channels, kernel_size, conv=nn.Conv2d, stride=2, inplace=True, ): super().__init__() p_size = [int(k // 2) for k in kernel_size] # p_size = int(kernel_size//2) self.conv = conv( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=p_size, ) self.bn = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=inplace) def forward(self, x): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class CDistNetDecoder(nn.Module): def __init__(self, in_channels, out_channels, n_head=None, num_encoder_blocks=3, num_decoder_blocks=3, beam_size=0, max_len=25, residual_dropout_rate=0.1, add_conv=False, **kwargs): super(CDistNetDecoder, self).__init__() dst_vocab_size = out_channels self.ignore_index = dst_vocab_size - 1 self.bos = dst_vocab_size - 2 self.eos = 0 self.beam_size = beam_size self.max_len = max_len self.add_conv = add_conv d_model = in_channels dim_feedforward = d_model * 4 n_head = n_head if n_head is not None else d_model // 32 if add_conv: self.convbnrelu = ConvBnRelu( in_channels=in_channels, out_channels=in_channels, kernel_size=(1, 3), stride=(1, 2), ) if num_encoder_blocks > 0: self.positional_encoding = PositionalEncoding( dropout=0.1, dim=d_model, ) self.trans_encoder = Transformer_Encoder( n_layers=num_encoder_blocks, n_head=n_head, d_model=d_model, d_inner=dim_feedforward, ) else: self.trans_encoder = None self.semantic_branch = SEM_Pre( d_model=d_model, dst_vocab_size=dst_vocab_size, residual_dropout_rate=residual_dropout_rate, ) self.positional_branch = POS_Pre(d_model=d_model) self.mdcdp = MDCDP(d_model, n_head, dim_feedforward // 2, num_decoder_blocks) self._reset_parameters() self.tgt_word_prj = nn.Linear( d_model, dst_vocab_size - 2, bias=False) # We don't predict nor self.tgt_word_prj.weight.data.normal_(mean=0.0, std=d_model**-0.5) def forward(self, x, data=None): if self.add_conv: x = self.convbnrelu(x) # x = rearrange(x, "b c h w -> b (w h) c") x = x.flatten(2).transpose(1, 2) if self.trans_encoder is not None: x = self.positional_encoding(x) vis_feat = self.trans_encoder(x, src_mask=None) else: vis_feat = x if self.training: max_len = data[1].max() tgt = data[0][:, :1 + max_len] res = self.forward_train(vis_feat, tgt) else: if self.beam_size > 0: res = self.forward_beam(vis_feat) else: res = self.forward_test(vis_feat) return res def forward_train(self, vis_feat, tgt): sem_feat, sem_mask = self.semantic_branch(tgt) pos_feat = self.positional_branch(sem_feat) output = self.mdcdp( sem_feat, vis_feat, pos_feat, tgt_mask=sem_mask, memory_mask=None, ) logit = self.tgt_word_prj(output) return logit def forward_test(self, vis_feat): bs = vis_feat.size(0) dec_seq = torch.full( (bs, self.max_len + 1), self.ignore_index, dtype=torch.int64, device=vis_feat.device, ) dec_seq[:, 0] = self.bos logits = [] for len_dec_seq in range(0, self.max_len): sem_feat, sem_mask = self.semantic_branch(dec_seq[:, :len_dec_seq + 1]) pos_feat = self.positional_branch(sem_feat) output = self.mdcdp( sem_feat, vis_feat, pos_feat, tgt_mask=sem_mask, memory_mask=None, ) dec_output = output[:, -1:, :] word_prob = F.softmax(self.tgt_word_prj(dec_output), dim=-1) logits.append(word_prob) if len_dec_seq < self.max_len: # greedy decode. add the next token index to the target input dec_seq[:, len_dec_seq + 1] = word_prob.squeeze(1).argmax(-1) # Efficient batch decoding: If all output words have at least one EOS token, end decoding. if (dec_seq == self.eos).any(dim=-1).all(): break logits = torch.cat(logits, dim=1) return logits def forward_beam(self, x): """Translation work in one batch.""" # to do def _reset_parameters(self): r"""Initiate parameters in the transformer model.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p)