File size: 1,745 Bytes
1c3eb47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn

from mmpl.registry import MODELS


@MODELS.register_module()
class TransformerEDecoderNeck(nn.Module):
    """Global Average Pooling neck.

    Note that we use `view` to remove extra channel after pooling. We do not
    use `squeeze` as it will also remove the batch dimension when the tensor
    has a batch dimension of size 1, which can lead to unexpected errors.

    Args:
        dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}.
            Default: 2
    """

    def __init__(self, model_dim, num_encoder_layers=3):
        super(TransformerEDecoderNeck, self).__init__()
        self.embed_dims = model_dim
        self.with_cls_token = True
        self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))

        self.transformer_encoder_decoder = nn.Transformer(
            d_model=model_dim, num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_encoder_layers, dim_feedforward=model_dim * 2,
            batch_first=True,
            dropout=0.1
        )
        self.out_linear_layer = nn.Sequential(
            nn.Linear(model_dim, model_dim // 2),
            nn.LeakyReLU(),
            nn.Linear(model_dim // 2, model_dim)
        )

    def init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, inputs):
        B = inputs.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, inputs), dim=1)
        x = self.transformer_encoder_decoder(inputs, x)
        if self.with_cls_token:
            x = x[:, 0]

        residual = self.out_linear_layer(x)
        x = x + residual

        return x