Spaces:
Runtime error
Runtime error
File size: 1,745 Bytes
1c3eb47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import torch
import torch.nn as nn
from mmpl.registry import MODELS
@MODELS.register_module()
class TransformerEDecoderNeck(nn.Module):
"""Global Average Pooling neck.
Note that we use `view` to remove extra channel after pooling. We do not
use `squeeze` as it will also remove the batch dimension when the tensor
has a batch dimension of size 1, which can lead to unexpected errors.
Args:
dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}.
Default: 2
"""
def __init__(self, model_dim, num_encoder_layers=3):
super(TransformerEDecoderNeck, self).__init__()
self.embed_dims = model_dim
self.with_cls_token = True
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
self.transformer_encoder_decoder = nn.Transformer(
d_model=model_dim, num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_encoder_layers, dim_feedforward=model_dim * 2,
batch_first=True,
dropout=0.1
)
self.out_linear_layer = nn.Sequential(
nn.Linear(model_dim, model_dim // 2),
nn.LeakyReLU(),
nn.Linear(model_dim // 2, model_dim)
)
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, inputs):
B = inputs.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, inputs), dim=1)
x = self.transformer_encoder_decoder(inputs, x)
if self.with_cls_token:
x = x[:, 0]
residual = self.out_linear_layer(x)
x = x + residual
return x
|