RSPrompter / mmpl /models /necks /transformer_neck.py
KyanChen's picture
Upload 159 files
1c3eb47
raw
history blame
3.09 kB
import copy
import torch
import torch.nn as nn
from mmpl.registry import MODELS
from mmengine.model import BaseModule
from mmcv.cnn.bricks.transformer import build_transformer_layer
@MODELS.register_module()
class TransformerEncoderNeck(BaseModule):
"""Global Average Pooling neck.
Note that we use `view` to remove extra channel after pooling. We do not
use `squeeze` as it will also remove the batch dimension when the tensor
has a batch dimension of size 1, which can lead to unexpected errors.
Args:
dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}.
Default: 2
"""
def __init__(self,
model_dim,
with_pe=True,
max_position_embeddings=24,
with_cls_token=True,
num_encoder_layers=3
):
super(TransformerEncoderNeck, self).__init__()
self.embed_dims = model_dim
self.with_cls_token = with_cls_token
self.with_pe = with_pe
if self.with_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
if self.with_pe:
self.pe = nn.Parameter(torch.zeros(1, max_position_embeddings, self.embed_dims))
mlp_ratio = 4
embed_dims = model_dim
transformer_layer = dict(
type='BaseTransformerLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=8,
attn_drop=0.1,
proj_drop=0.1,
dropout_layer=dict(type='Dropout', drop_prob=0.1)
),
],
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=embed_dims * mlp_ratio,
num_fcs=2,
act_cfg=dict(type='GELU'),
ffn_drop=0.1,
add_identity=True),
operation_order=('norm', 'self_attn', 'norm', 'ffn'),
norm_cfg=dict(type='LN'),
batch_first=True
)
self.layers = nn.ModuleList()
transformer_layers = [
copy.deepcopy(transformer_layer) for _ in range(num_encoder_layers)
]
for i in range(num_encoder_layers):
self.layers.append(build_transformer_layer(transformer_layers[i]))
self.embed_dims = self.layers[0].embed_dims
self.pre_norm = self.layers[0].pre_norm
def init_weights(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, x):
B = x.shape[0]
if self.with_cls_token:
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
if self.with_pe:
x = x + self.pe[:, :x.shape[1], :]
for layer in self.layers:
x = layer(x)
if self.with_cls_token:
return x[:, 0], x
return None, x