ovsam / app /models /transformer_neck.py
Haobo Yuan
Add model
9cc3eb2
from functools import partial
from typing import Tuple, List, Optional
import torch
from torch import Tensor, nn
from mmengine.model import BaseModule, normal_init
from mmdet.registry import MODELS
from mmdet.models.layers import PatchEmbed
from ext.meta.sam_meta import checkpoint_dict
from ext.sam.common import LayerNorm2d
from ext.sam.image_encoder import Block
from utils.load_checkpoint import load_checkpoint_with_prefix
@MODELS.register_module()
class MultiLayerTransformerNeck(BaseModule):
STRIDE = 16
def __init__(
self,
input_size: Tuple[int, int],
in_channels: List[int],
embed_channels: int,
out_channels: int,
layer_ids: Tuple[int] = (0, 1, 2, 3),
strides: Tuple[int] = (4, 8, 16, 32),
embedding_path: Optional[str] = None,
fix=False,
init_cfg=None
) -> None:
super().__init__(init_cfg=None)
self.transformer_size = (input_size[0] // self.STRIDE, input_size[1] // self.STRIDE)
self.layer_ids = layer_ids
self.patch_embeds = nn.ModuleList()
for idx, in_ch in enumerate(in_channels):
if idx in layer_ids:
if strides[idx] > self.STRIDE:
patch_embed = PatchEmbed(
conv_type=nn.ConvTranspose2d,
in_channels=in_ch,
embed_dims=embed_channels,
kernel_size=strides[idx] // self.STRIDE,
stride=strides[idx] // self.STRIDE,
input_size=(input_size[0] // strides[idx], input_size[1] // strides[idx])
)
else:
patch_embed = PatchEmbed(
in_channels=in_ch,
embed_dims=embed_channels,
kernel_size=self.STRIDE // strides[idx],
stride=self.STRIDE // strides[idx],
input_size=(input_size[0] // strides[idx], input_size[1] // strides[idx])
)
self.patch_embeds.append(patch_embed)
else:
self.patch_embeds.append(nn.Identity())
if embedding_path is not None:
assert embedding_path.startswith('sam_')
embedding_ckpt = embedding_path.split('_', maxsplit=1)[1]
path = checkpoint_dict[embedding_ckpt]
state_dict = load_checkpoint_with_prefix(path, prefix='image_encoder')
pos_embed = state_dict['pos_embed']
else:
# For loading from checkpoint
pos_embed = torch.zeros(1, input_size[0] // self.STRIDE, input_size[1] // self.STRIDE, embed_channels)
self.register_buffer('pos_embed', pos_embed)
self.level_encoding = nn.Embedding(len(layer_ids), embed_channels)
depth = 5
global_attn_indexes = [4]
window_size = 14
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block(
dim=embed_channels,
num_heads=16,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
use_rel_pos=True,
rel_pos_zero_init=True,
window_size=window_size if i not in global_attn_indexes else 0,
input_size=self.transformer_size,
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv2d(
embed_channels,
out_channels,
kernel_size=1,
bias=False,
),
LayerNorm2d(out_channels),
nn.Conv2d(
out_channels,
out_channels,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm2d(out_channels),
)
self.fix = fix
if self.fix:
self.train(mode=False)
for name, param in self.named_parameters():
param.requires_grad = False
if init_cfg is not None:
assert init_cfg['type'] == 'Pretrained'
checkpoint_path = init_cfg['checkpoint']
state_dict = load_checkpoint_with_prefix(checkpoint_path, prefix=init_cfg['prefix'])
self.load_state_dict(state_dict, strict=True)
self._is_init = True
def init_weights(self):
normal_init(self.level_encoding, mean=0, std=1)
def train(self: torch.nn.Module, mode: bool = True) -> torch.nn.Module:
if not isinstance(mode, bool):
raise ValueError("training mode is expected to be boolean")
if self.fix:
super().train(mode=False)
else:
super().train(mode=mode)
return self
def forward(self, inputs: Tuple[Tensor]) -> Tensor:
input_embeddings = []
level_cnt = 0
for idx, feat in enumerate(inputs):
if idx not in self.layer_ids:
continue
feat, size = self.patch_embeds[idx](feat)
feat = feat.unflatten(1, size)
feat = feat + self.level_encoding.weight[level_cnt]
input_embeddings.append(feat)
level_cnt += 1
feat = sum(input_embeddings)
feat = feat + self.pos_embed
for block in self.blocks:
feat = block(feat)
feat = feat.permute(0, 3, 1, 2).contiguous()
feat = self.neck(feat)
return feat