File size: 5,236 Bytes
29f689c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""This code is refer from:
https://github.com/jjwei66/BUSNet
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from .nrtr_decoder import PositionalEncoding, TransformerBlock
from .abinet_decoder import _get_mask, _get_length


class BUSDecoder(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 nhead=8,
                 num_layers=4,
                 dim_feedforward=2048,
                 dropout=0.1,
                 max_length=25,
                 ignore_index=100,
                 pretraining=False,
                 detach=True):
        super().__init__()
        d_model = in_channels
        self.ignore_index = ignore_index
        self.pretraining = pretraining
        self.d_model = d_model
        self.detach = detach
        self.max_length = max_length + 1  # additional stop token
        self.out_channels = out_channels
        # --------------------------------------------------------------------------
        # decoder specifics
        self.proj = nn.Linear(out_channels, d_model, False)
        self.token_encoder = PositionalEncoding(dropout=0.1,
                                                dim=d_model,
                                                max_len=self.max_length)
        self.pos_encoder = PositionalEncoding(dropout=0.1,
                                              dim=d_model,
                                              max_len=self.max_length)

        self.decoder = nn.ModuleList([
            TransformerBlock(
                d_model=d_model,
                nhead=nhead,
                dim_feedforward=dim_feedforward,
                attention_dropout_rate=dropout,
                residual_dropout_rate=dropout,
                with_self_attn=False,
                with_cross_attn=True,
            ) for i in range(num_layers)
        ])

        v_mask = torch.empty((1, 1, d_model))
        l_mask = torch.empty((1, 1, d_model))
        self.v_mask = nn.Parameter(v_mask)
        self.l_mask = nn.Parameter(l_mask)
        torch.nn.init.uniform_(self.v_mask, -0.001, 0.001)
        torch.nn.init.uniform_(self.l_mask, -0.001, 0.001)

        v_embeding = torch.empty((1, 1, d_model))
        l_embeding = torch.empty((1, 1, d_model))
        self.v_embeding = nn.Parameter(v_embeding)
        self.l_embeding = nn.Parameter(l_embeding)
        torch.nn.init.uniform_(self.v_embeding, -0.001, 0.001)
        torch.nn.init.uniform_(self.l_embeding, -0.001, 0.001)
        self.cls = nn.Linear(d_model, out_channels)

    def forward_decoder(self, q, x, mask=None):
        for decoder_layer in self.decoder:
            q = decoder_layer(q, x, cross_mask=mask)
        output = q  # (N, T, E)
        logits = self.cls(output)  # (N, T, C)
        return logits

    def forward(self, img_feat, data=None):
        """
        Args:
            tokens: (N, T, C) where T is length, N is batch size and C is classes number
            lengths: (N,)
        """
        img_feat = img_feat + self.v_embeding
        B, L, C = img_feat.shape

        # --------------------------------------------------------------------------
        # decoder procedure
        T = self.max_length
        zeros = img_feat.new_zeros((B, T, C))
        zeros_len = img_feat.new_zeros(B)
        query = self.pos_encoder(zeros)

        # 1. vision decode
        v_embed = torch.cat((img_feat, self.l_mask.repeat(B, T, 1)),
                            dim=1)  # v
        padding_mask = _get_mask(
            self.max_length + zeros_len,
            self.max_length)  # 对tokens长度以外的padding # B, maxlen maxlen
        v_mask = torch.zeros((1, 1, self.max_length, L),
                             device=img_feat.device).tile([B, 1, 1,
                                                           1])  # maxlen L
        mask = torch.cat((v_mask, padding_mask), 3)
        v_logits = self.forward_decoder(query, v_embed, mask=mask)

        # 2. language decode
        if self.training and self.pretraining:
            tgt = torch.where(data[0] == self.ignore_index, 0, data[0])
            tokens = F.one_hot(tgt, num_classes=self.out_channels)
            tokens = tokens.float()
            lengths = data[-1]
        else:
            tokens = torch.softmax(v_logits, dim=-1)
            lengths = _get_length(v_logits)
            tokens = tokens.detach()
        token_embed = self.proj(tokens)  # (N, T, E)
        token_embed = self.token_encoder(token_embed)  # (T, N, E)
        token_embed = token_embed + self.l_embeding

        padding_mask = _get_mask(lengths,
                                 self.max_length)  # 对tokens长度以外的padding
        mask = torch.cat((v_mask, padding_mask), 3)
        l_embed = torch.cat((self.v_mask.repeat(B, L, 1), token_embed), dim=1)
        l_logits = self.forward_decoder(query, l_embed, mask=mask)

        # 3. vision language decode
        vl_embed = torch.cat((img_feat, token_embed), dim=1)
        vl_logits = self.forward_decoder(query, vl_embed, mask=mask)

        if self.training:
            return {'align': [vl_logits], 'lang': l_logits, 'vision': v_logits}
        else:
            return F.softmax(vl_logits, -1)