Spaces:
Runtime error
Runtime error
import copy | |
import torch | |
import torch.nn as nn | |
from mmpl.registry import MODELS | |
from mmengine.model import BaseModule | |
from mmcv.cnn.bricks.transformer import build_transformer_layer | |
class TransformerEncoderNeck(BaseModule): | |
"""Global Average Pooling neck. | |
Note that we use `view` to remove extra channel after pooling. We do not | |
use `squeeze` as it will also remove the batch dimension when the tensor | |
has a batch dimension of size 1, which can lead to unexpected errors. | |
Args: | |
dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}. | |
Default: 2 | |
""" | |
def __init__(self, | |
model_dim, | |
with_pe=True, | |
max_position_embeddings=24, | |
with_cls_token=True, | |
num_encoder_layers=3 | |
): | |
super(TransformerEncoderNeck, self).__init__() | |
self.embed_dims = model_dim | |
self.with_cls_token = with_cls_token | |
self.with_pe = with_pe | |
if self.with_cls_token: | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims)) | |
if self.with_pe: | |
self.pe = nn.Parameter(torch.zeros(1, max_position_embeddings, self.embed_dims)) | |
mlp_ratio = 4 | |
embed_dims = model_dim | |
transformer_layer = dict( | |
type='BaseTransformerLayer', | |
attn_cfgs=[ | |
dict( | |
type='MultiheadAttention', | |
embed_dims=embed_dims, | |
num_heads=8, | |
attn_drop=0.1, | |
proj_drop=0.1, | |
dropout_layer=dict(type='Dropout', drop_prob=0.1) | |
), | |
], | |
ffn_cfgs=dict( | |
type='FFN', | |
embed_dims=embed_dims, | |
feedforward_channels=embed_dims * mlp_ratio, | |
num_fcs=2, | |
act_cfg=dict(type='GELU'), | |
ffn_drop=0.1, | |
add_identity=True), | |
operation_order=('norm', 'self_attn', 'norm', 'ffn'), | |
norm_cfg=dict(type='LN'), | |
batch_first=True | |
) | |
self.layers = nn.ModuleList() | |
transformer_layers = [ | |
copy.deepcopy(transformer_layer) for _ in range(num_encoder_layers) | |
] | |
for i in range(num_encoder_layers): | |
self.layers.append(build_transformer_layer(transformer_layers[i])) | |
self.embed_dims = self.layers[0].embed_dims | |
self.pre_norm = self.layers[0].pre_norm | |
def init_weights(self): | |
for p in self.parameters(): | |
if p.dim() > 1: | |
nn.init.xavier_uniform_(p) | |
def forward(self, x): | |
B = x.shape[0] | |
if self.with_cls_token: | |
cls_tokens = self.cls_token.expand(B, -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
if self.with_pe: | |
x = x + self.pe[:, :x.shape[1], :] | |
for layer in self.layers: | |
x = layer(x) | |
if self.with_cls_token: | |
return x[:, 0], x | |
return None, x | |