|
"""
|
|
MSHT
|
|
|
|
Models ver: OCT 27th 20:00 official release
|
|
|
|
by the authors, check our github page:
|
|
https://github.com/sagizty/Multi-Stage-Hybrid-Transformer
|
|
|
|
|
|
ResNet stages' feature map
|
|
|
|
# input = 3, 384, 384
|
|
torch.Size([1, 256, 96, 96])
|
|
torch.Size([1, 512, 48, 48])
|
|
torch.Size([1, 1024, 24, 24])
|
|
torch.Size([1, 2048, 12, 12])
|
|
torch.Size([1, 1000])
|
|
|
|
# input = 3, 224, 224
|
|
torch.Size([1, 256, 56, 56])
|
|
torch.Size([1, 512, 28, 28])
|
|
torch.Size([1, 1024, 14, 14])
|
|
torch.Size([1, 2048, 7, 7])
|
|
torch.Size([1, 1000])
|
|
|
|
ref
|
|
https://note.youdao.com/ynoteshare1/index.html?id=5a7dbe1a71713c317062ddeedd97d98e&type=note
|
|
"""
|
|
import torch
|
|
from torch import nn
|
|
from functools import partial
|
|
from torchsummary import summary
|
|
import os
|
|
from Backbone import Transformer_blocks
|
|
|
|
|
|
|
|
class Bottleneck_block_constructor(nn.Module):
|
|
|
|
extention = 4
|
|
|
|
|
|
def __init__(self, inplane, midplane, stride, downsample=None):
|
|
super(Bottleneck_block_constructor, self).__init__()
|
|
|
|
outplane = midplane * self.extention
|
|
|
|
self.conv1 = nn.Conv2d(inplane, midplane, kernel_size=1, stride=stride, bias=False)
|
|
self.bn1 = nn.BatchNorm2d(midplane)
|
|
|
|
self.conv2 = nn.Conv2d(midplane, midplane, kernel_size=3, stride=1, padding=1, bias=False)
|
|
self.bn2 = nn.BatchNorm2d(midplane)
|
|
|
|
self.conv3 = nn.Conv2d(midplane, outplane, kernel_size=1, stride=1, bias=False)
|
|
self.bn3 = nn.BatchNorm2d(midplane * self.extention)
|
|
|
|
self.relu = nn.ReLU(inplace=False)
|
|
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
|
|
def forward(self, x):
|
|
|
|
out = self.relu(self.bn1(self.conv1(x)))
|
|
out = self.relu(self.bn2(self.conv2(out)))
|
|
out = self.relu(self.bn3(self.conv3(out)))
|
|
|
|
if self.downsample is not None:
|
|
residual = self.downsample(x)
|
|
else:
|
|
residual = x
|
|
|
|
out += residual
|
|
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
|
|
|
|
class Hybrid_backbone_4(nn.Module):
|
|
|
|
def __init__(self, block_constructor, bottleneck_channels_setting=None, identity_layers_setting=None,
|
|
stage_stride_setting=None, fc_num_classes=None, feature_idx=None):
|
|
|
|
if bottleneck_channels_setting is None:
|
|
bottleneck_channels_setting = [64, 128, 256, 512]
|
|
if identity_layers_setting is None:
|
|
identity_layers_setting = [3, 4, 6, 3]
|
|
if stage_stride_setting is None:
|
|
stage_stride_setting = [1, 2, 2, 2]
|
|
|
|
self.inplane = 64
|
|
self.fc_num_classes = fc_num_classes
|
|
self.feature_idx = feature_idx
|
|
|
|
super(Hybrid_backbone_4, self).__init__()
|
|
|
|
self.block_constructor = block_constructor
|
|
self.bcs = bottleneck_channels_setting
|
|
self.ils = identity_layers_setting
|
|
self.sss = stage_stride_setting
|
|
|
|
|
|
|
|
self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=7, stride=2, padding=3, bias=False)
|
|
self.bn1 = nn.BatchNorm2d(self.inplane)
|
|
self.relu = nn.ReLU()
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
|
|
|
|
|
|
self.layer1 = self.make_stage_layer(self.block_constructor, self.bcs[0], self.ils[0], self.sss[0])
|
|
self.layer2 = self.make_stage_layer(self.block_constructor, self.bcs[1], self.ils[1], self.sss[1])
|
|
self.layer3 = self.make_stage_layer(self.block_constructor, self.bcs[2], self.ils[2], self.sss[2])
|
|
self.layer4 = self.make_stage_layer(self.block_constructor, self.bcs[3], self.ils[3], self.sss[3])
|
|
|
|
|
|
if self.fc_num_classes is not None:
|
|
self.avgpool = nn.AvgPool2d(7)
|
|
self.fc = nn.Linear(512 * self.block_constructor.extention, fc_num_classes)
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
stem_out = self.maxpool(out)
|
|
|
|
|
|
stage1_out = self.layer1(stem_out)
|
|
stage2_out = self.layer2(stage1_out)
|
|
stage3_out = self.layer3(stage2_out)
|
|
stage4_out = self.layer4(stage3_out)
|
|
|
|
if self.fc_num_classes is not None:
|
|
|
|
fc_out = self.avgpool(stage4_out)
|
|
fc_out = torch.flatten(fc_out, 1)
|
|
fc_out = self.fc(fc_out)
|
|
|
|
|
|
if self.feature_idx == 'stages':
|
|
if self.fc_num_classes is not None:
|
|
return stage1_out, stage2_out, stage3_out, stage4_out, fc_out
|
|
else:
|
|
return stage1_out, stage2_out, stage3_out, stage4_out
|
|
elif self.feature_idx == 'features':
|
|
if self.fc_num_classes is not None:
|
|
return stem_out, stage1_out, stage2_out, stage3_out, stage4_out, fc_out
|
|
else:
|
|
return stem_out, stage1_out, stage2_out, stage3_out, stage4_out
|
|
else:
|
|
if self.fc_num_classes is not None:
|
|
return fc_out
|
|
else:
|
|
return stage4_out
|
|
|
|
def make_stage_layer(self, block_constractor, midplane, block_num, stride=1):
|
|
"""
|
|
block:
|
|
midplane:usually = output chanel/4
|
|
block_num:
|
|
stride:stride of the ResNet Conv Block
|
|
"""
|
|
|
|
block_list = []
|
|
|
|
outplane = midplane * block_constractor.extention
|
|
|
|
if stride != 1 or self.inplane != outplane:
|
|
downsample = nn.Sequential(
|
|
nn.Conv2d(self.inplane, outplane, stride=stride, kernel_size=1, bias=False),
|
|
nn.BatchNorm2d(midplane * block_constractor.extention)
|
|
)
|
|
else:
|
|
downsample = None
|
|
|
|
|
|
conv_block = block_constractor(self.inplane, midplane, stride=stride, downsample=downsample)
|
|
block_list.append(conv_block)
|
|
|
|
self.inplane = outplane
|
|
|
|
|
|
for i in range(1, block_num):
|
|
block_list.append(block_constractor(self.inplane, midplane, stride=1, downsample=None))
|
|
|
|
return nn.Sequential(*block_list)
|
|
|
|
|
|
class Hybrid_backbone_3(nn.Module):
|
|
|
|
def __init__(self, block_constructor, bottleneck_channels_setting=None, identity_layers_setting=None,
|
|
stage_stride_setting=None, fc_num_classes=None, feature_idx=None):
|
|
|
|
if bottleneck_channels_setting is None:
|
|
bottleneck_channels_setting = [64, 128, 256]
|
|
if identity_layers_setting is None:
|
|
identity_layers_setting = [3, 4, 6]
|
|
if stage_stride_setting is None:
|
|
stage_stride_setting = [1, 2, 2]
|
|
|
|
self.inplane = 64
|
|
self.fc_num_classes = fc_num_classes
|
|
self.feature_idx = feature_idx
|
|
|
|
super(Hybrid_backbone_3, self).__init__()
|
|
|
|
self.block_constructor = block_constructor
|
|
self.bcs = bottleneck_channels_setting
|
|
self.ils = identity_layers_setting
|
|
self.sss = stage_stride_setting
|
|
|
|
|
|
self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=7, stride=2, padding=3, bias=False)
|
|
self.bn1 = nn.BatchNorm2d(self.inplane)
|
|
self.relu = nn.ReLU()
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
|
|
|
|
|
|
self.layer1 = self.make_stage_layer(self.block_constructor, self.bcs[0], self.ils[0], self.sss[0])
|
|
self.layer2 = self.make_stage_layer(self.block_constructor, self.bcs[1], self.ils[1], self.sss[1])
|
|
self.layer3 = self.make_stage_layer(self.block_constructor, self.bcs[2], self.ils[2], self.sss[2])
|
|
|
|
if self.fc_num_classes is not None:
|
|
self.avgpool = nn.AvgPool2d(24)
|
|
self.fc = nn.Linear(self.bcs[-1] * self.block_constructor.extention, fc_num_classes)
|
|
|
|
def forward(self, x):
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
stem_out = self.maxpool(out)
|
|
|
|
|
|
stage1_out = self.layer1(stem_out)
|
|
stage2_out = self.layer2(stage1_out)
|
|
stage3_out = self.layer3(stage2_out)
|
|
|
|
if self.fc_num_classes is not None:
|
|
fc_out = self.avgpool(stage3_out)
|
|
fc_out = torch.flatten(fc_out, 1)
|
|
fc_out = self.fc(fc_out)
|
|
|
|
if self.feature_idx == 'stages':
|
|
if self.fc_num_classes is not None:
|
|
return stage1_out, stage2_out, stage3_out, fc_out
|
|
else:
|
|
return stage1_out, stage2_out, stage3_out
|
|
elif self.feature_idx == 'features':
|
|
if self.fc_num_classes is not None:
|
|
return stem_out, stage1_out, stage2_out, stage3_out, fc_out
|
|
else:
|
|
return stem_out, stage1_out, stage2_out, stage3_out
|
|
else:
|
|
if self.fc_num_classes is not None:
|
|
return fc_out
|
|
else:
|
|
return stage3_out
|
|
|
|
def make_stage_layer(self, block_constractor, midplane, block_num, stride=1):
|
|
"""
|
|
block:
|
|
midplane:
|
|
block_num:
|
|
stride:
|
|
"""
|
|
|
|
block_list = []
|
|
|
|
outplane = midplane * block_constractor.extention
|
|
|
|
if stride != 1 or self.inplane != outplane:
|
|
downsample = nn.Sequential(
|
|
nn.Conv2d(self.inplane, outplane, stride=stride, kernel_size=1, bias=False),
|
|
nn.BatchNorm2d(midplane * block_constractor.extention)
|
|
)
|
|
else:
|
|
downsample = None
|
|
|
|
|
|
conv_block = block_constractor(self.inplane, midplane, stride=stride, downsample=downsample)
|
|
block_list.append(conv_block)
|
|
|
|
self.inplane = outplane
|
|
|
|
|
|
for i in range(1, block_num):
|
|
block_list.append(block_constractor(self.inplane, midplane, stride=1, downsample=None))
|
|
|
|
return nn.Sequential(*block_list)
|
|
|
|
|
|
def Hybrid_a(backbone, img_size=224, patch_size=1, in_chans=3, num_classes=1000, embed_dim=768, depth=8,
|
|
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, drop_rate=0., attn_drop_rate=0.,
|
|
drop_path_rate=0., norm_layer=None, act_layer=None):
|
|
|
|
embed_layer = partial(Transformer_blocks.Hybrid_feature_map_Embed, backbone=backbone)
|
|
|
|
Hybrid_model = Transformer_blocks.VisionTransformer(img_size, patch_size, in_chans, num_classes, embed_dim, depth,
|
|
num_heads, mlp_ratio, qkv_bias, representation_size,
|
|
drop_rate, attn_drop_rate, drop_path_rate, embed_layer,
|
|
norm_layer, act_layer)
|
|
|
|
return Hybrid_model
|
|
|
|
|
|
def create_model(model_idx, edge_size, pretrained=True, num_classes=2, drop_rate=0., attn_drop_rate=0.,
|
|
drop_path_rate=0., use_cls_token=True, use_pos_embedding=True, use_att_module='SimAM'):
|
|
"""
|
|
get one of MSHT models
|
|
|
|
:param model_idx: the model we are going to use. by the format of Model_size_other_info
|
|
:param edge_size: the input edge size of the dataloder
|
|
:param pretrained: The backbone CNN is initiate randomly or by its official Pretrained models
|
|
:param num_classes: classification required number of your dataset
|
|
|
|
:param drop_rate: The dropout layer's probility of proposed models
|
|
:param attn_drop_rate: The dropout layer(right after the MHSA block or MHGA block)'s probility of proposed models
|
|
:param drop_path_rate: The probility of stochastic depth
|
|
|
|
:param use_cls_token: To use the class token
|
|
:param use_pos_embedding: To use the positional enbedding
|
|
:param use_att_module: To use which attention module in the FGD Focus block
|
|
# use_att_module in ['SimAM', 'CBAM', 'SE'] different attention module we applied in the ablation study
|
|
|
|
:return: prepared model
|
|
"""
|
|
|
|
if pretrained:
|
|
from torchvision import models
|
|
backbone_weights = models.resnet50(pretrained=True).state_dict()
|
|
|
|
else:
|
|
backbone_weights = None
|
|
|
|
if model_idx[0:11] == 'Hybrid1_224' and edge_size == 224:
|
|
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
|
|
bottleneck_channels_setting=[64, 128, 256, 512],
|
|
identity_layers_setting=[3, 4, 6, 3],
|
|
stage_stride_setting=[1, 2, 2, 2],
|
|
fc_num_classes=None,
|
|
feature_idx=None)
|
|
|
|
if pretrained:
|
|
try:
|
|
backbone.load_state_dict(backbone_weights, False)
|
|
except:
|
|
print("backbone not loaded")
|
|
else:
|
|
print("backbone loaded")
|
|
|
|
model = Hybrid_a(backbone, img_size=edge_size, patch_size=1, in_chans=3, num_classes=num_classes, embed_dim=768,
|
|
depth=8, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
|
|
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
|
|
norm_layer=None, act_layer=None)
|
|
|
|
elif model_idx[0:11] == 'Hybrid1_384' and edge_size == 384:
|
|
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
|
|
bottleneck_channels_setting=[64, 128, 256, 512],
|
|
identity_layers_setting=[3, 4, 6, 3],
|
|
stage_stride_setting=[1, 2, 2, 2],
|
|
fc_num_classes=None,
|
|
feature_idx=None)
|
|
|
|
if pretrained:
|
|
try:
|
|
backbone.load_state_dict(backbone_weights, False)
|
|
except:
|
|
print("backbone not loaded")
|
|
else:
|
|
print("backbone loaded")
|
|
|
|
model = Hybrid_a(backbone, img_size=edge_size, patch_size=1, in_chans=3, num_classes=num_classes, embed_dim=768,
|
|
depth=8, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
|
|
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
|
|
norm_layer=None, act_layer=None)
|
|
|
|
elif model_idx[0:11] == 'Hybrid2_224' and edge_size == 224:
|
|
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
|
|
bottleneck_channels_setting=[64, 128, 256, 512],
|
|
identity_layers_setting=[3, 4, 6, 3],
|
|
stage_stride_setting=[1, 2, 2, 2],
|
|
fc_num_classes=None,
|
|
feature_idx='stages')
|
|
if pretrained:
|
|
try:
|
|
backbone.load_state_dict(backbone_weights, False)
|
|
except:
|
|
print("backbone not loaded")
|
|
else:
|
|
print("backbone loaded")
|
|
|
|
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
|
|
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
|
drop_path_rate=drop_path_rate,
|
|
use_cls_token=use_cls_token,
|
|
use_pos_embedding=use_pos_embedding,
|
|
use_att_module=use_att_module,
|
|
stage_size=(56, 28, 14, 7),
|
|
stage_dim=[256, 512, 1024, 2048])
|
|
|
|
elif model_idx[0:11] == 'Hybrid2_384' and edge_size == 384:
|
|
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
|
|
bottleneck_channels_setting=[64, 128, 256, 512],
|
|
identity_layers_setting=[3, 4, 6, 3],
|
|
stage_stride_setting=[1, 2, 2, 2],
|
|
fc_num_classes=None,
|
|
feature_idx='stages')
|
|
if pretrained:
|
|
try:
|
|
backbone.load_state_dict(backbone_weights, False)
|
|
except:
|
|
print("backbone not loaded")
|
|
else:
|
|
print("backbone loaded")
|
|
|
|
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
|
|
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
|
drop_path_rate=drop_path_rate,
|
|
use_cls_token=use_cls_token,
|
|
use_pos_embedding=use_pos_embedding,
|
|
use_att_module=use_att_module,
|
|
stage_size=(96, 48, 24, 12),
|
|
stage_dim=[256, 512, 1024, 2048])
|
|
|
|
elif model_idx[0:11] == 'Hybrid3_224' and edge_size == 224:
|
|
backbone = Hybrid_backbone_3(block_constructor=Bottleneck_block_constructor,
|
|
bottleneck_channels_setting=[64, 128, 256],
|
|
identity_layers_setting=[3, 4, 6],
|
|
stage_stride_setting=[1, 2, 2],
|
|
fc_num_classes=None,
|
|
feature_idx='stages')
|
|
if pretrained:
|
|
try:
|
|
backbone.load_state_dict(backbone_weights, False)
|
|
except:
|
|
print("backbone not loaded")
|
|
else:
|
|
print("backbone loaded")
|
|
|
|
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
|
|
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
|
drop_path_rate=drop_path_rate,
|
|
use_cls_token=use_cls_token,
|
|
use_pos_embedding=use_pos_embedding,
|
|
use_att_module=use_att_module,
|
|
stage_size=(56, 28, 14),
|
|
stage_dim=[256, 512, 1024])
|
|
|
|
elif model_idx[0:11] == 'Hybrid3_384' and edge_size == 384:
|
|
backbone = Hybrid_backbone_3(block_constructor=Bottleneck_block_constructor,
|
|
bottleneck_channels_setting=[64, 128, 256],
|
|
identity_layers_setting=[3, 4, 6],
|
|
stage_stride_setting=[1, 2, 2],
|
|
fc_num_classes=None,
|
|
feature_idx='stages')
|
|
if pretrained:
|
|
try:
|
|
backbone.load_state_dict(backbone_weights, False)
|
|
except:
|
|
print("backbone not loaded")
|
|
else:
|
|
print("backbone loaded")
|
|
|
|
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
|
|
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
|
|
drop_path_rate=drop_path_rate,
|
|
use_cls_token=use_cls_token,
|
|
use_pos_embedding=use_pos_embedding,
|
|
use_att_module=use_att_module,
|
|
stage_size=(96, 48, 24),
|
|
stage_dim=[256, 512, 1024])
|
|
|
|
else:
|
|
print('not a valid hybrid model')
|
|
return -1
|
|
|
|
return model
|
|
|