Tianyinus's picture
init submit
edcf5ee verified
"""
MSHT
Models ver: OCT 27th 20:00 official release
by the authors, check our github page:
https://github.com/sagizty/Multi-Stage-Hybrid-Transformer
ResNet stages' feature map
# input = 3, 384, 384
torch.Size([1, 256, 96, 96])
torch.Size([1, 512, 48, 48])
torch.Size([1, 1024, 24, 24])
torch.Size([1, 2048, 12, 12])
torch.Size([1, 1000])
# input = 3, 224, 224
torch.Size([1, 256, 56, 56])
torch.Size([1, 512, 28, 28])
torch.Size([1, 1024, 14, 14])
torch.Size([1, 2048, 7, 7])
torch.Size([1, 1000])
ref
https://note.youdao.com/ynoteshare1/index.html?id=5a7dbe1a71713c317062ddeedd97d98e&type=note
"""
import torch
from torch import nn
from functools import partial
from torchsummary import summary
import os
from Backbone import Transformer_blocks
# ResNet Bottleneck_block_constructor
class Bottleneck_block_constructor(nn.Module):
extention = 4
# 定义初始化的网络和参数
def __init__(self, inplane, midplane, stride, downsample=None):
super(Bottleneck_block_constructor, self).__init__()
outplane = midplane * self.extention
self.conv1 = nn.Conv2d(inplane, midplane, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(midplane)
self.conv2 = nn.Conv2d(midplane, midplane, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(midplane)
self.conv3 = nn.Conv2d(midplane, outplane, kernel_size=1, stride=1, bias=False)
self.bn3 = nn.BatchNorm2d(midplane * self.extention)
self.relu = nn.ReLU(inplace=False)
self.downsample = downsample
self.stride = stride
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.relu(self.bn3(self.conv3(out)))
if self.downsample is not None:
residual = self.downsample(x)
else:
residual = x
out += residual
out = self.relu(out)
return out
# Hybrid_backbone of ResNets
class Hybrid_backbone_4(nn.Module):
def __init__(self, block_constructor, bottleneck_channels_setting=None, identity_layers_setting=None,
stage_stride_setting=None, fc_num_classes=None, feature_idx=None):
if bottleneck_channels_setting is None:
bottleneck_channels_setting = [64, 128, 256, 512]
if identity_layers_setting is None:
identity_layers_setting = [3, 4, 6, 3]
if stage_stride_setting is None:
stage_stride_setting = [1, 2, 2, 2]
self.inplane = 64
self.fc_num_classes = fc_num_classes
self.feature_idx = feature_idx
super(Hybrid_backbone_4, self).__init__()
self.block_constructor = block_constructor # Bottleneck_block_constructor
self.bcs = bottleneck_channels_setting # [64, 128, 256, 512]
self.ils = identity_layers_setting # [3, 4, 6, 3]
self.sss = stage_stride_setting # [1, 2, 2, 2]
# stem
# alter the RGB pic chanel to match inplane
self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplane)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
# ResNet stages
self.layer1 = self.make_stage_layer(self.block_constructor, self.bcs[0], self.ils[0], self.sss[0])
self.layer2 = self.make_stage_layer(self.block_constructor, self.bcs[1], self.ils[1], self.sss[1])
self.layer3 = self.make_stage_layer(self.block_constructor, self.bcs[2], self.ils[2], self.sss[2])
self.layer4 = self.make_stage_layer(self.block_constructor, self.bcs[3], self.ils[3], self.sss[3])
# cls head
if self.fc_num_classes is not None:
self.avgpool = nn.AvgPool2d(7)
self.fc = nn.Linear(512 * self.block_constructor.extention, fc_num_classes)
def forward(self, x):
# stem
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
stem_out = self.maxpool(out)
# Resnet block of 4 stages
stage1_out = self.layer1(stem_out)
stage2_out = self.layer2(stage1_out)
stage3_out = self.layer3(stage2_out)
stage4_out = self.layer4(stage3_out)
if self.fc_num_classes is not None:
# connect to cls head mlp if asked
fc_out = self.avgpool(stage4_out)
fc_out = torch.flatten(fc_out, 1)
fc_out = self.fc(fc_out)
# get what we need for different usage
if self.feature_idx == 'stages':
if self.fc_num_classes is not None:
return stage1_out, stage2_out, stage3_out, stage4_out, fc_out
else:
return stage1_out, stage2_out, stage3_out, stage4_out
elif self.feature_idx == 'features':
if self.fc_num_classes is not None:
return stem_out, stage1_out, stage2_out, stage3_out, stage4_out, fc_out
else:
return stem_out, stage1_out, stage2_out, stage3_out, stage4_out
else: # self.feature_idx is None
if self.fc_num_classes is not None:
return fc_out
else:
return stage4_out
def make_stage_layer(self, block_constractor, midplane, block_num, stride=1):
"""
block:
midplane:usually = output chanel/4
block_num:
stride:stride of the ResNet Conv Block
"""
block_list = []
outplane = midplane * block_constractor.extention # extention
if stride != 1 or self.inplane != outplane:
downsample = nn.Sequential(
nn.Conv2d(self.inplane, outplane, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(midplane * block_constractor.extention)
)
else:
downsample = None
# Conv Block
conv_block = block_constractor(self.inplane, midplane, stride=stride, downsample=downsample)
block_list.append(conv_block)
self.inplane = outplane # update inplane for the next stage
# Identity Block
for i in range(1, block_num):
block_list.append(block_constractor(self.inplane, midplane, stride=1, downsample=None))
return nn.Sequential(*block_list) # stack blocks
class Hybrid_backbone_3(nn.Module): # 3 stages version
def __init__(self, block_constructor, bottleneck_channels_setting=None, identity_layers_setting=None,
stage_stride_setting=None, fc_num_classes=None, feature_idx=None):
if bottleneck_channels_setting is None:
bottleneck_channels_setting = [64, 128, 256]
if identity_layers_setting is None:
identity_layers_setting = [3, 4, 6]
if stage_stride_setting is None:
stage_stride_setting = [1, 2, 2]
self.inplane = 64
self.fc_num_classes = fc_num_classes
self.feature_idx = feature_idx
super(Hybrid_backbone_3, self).__init__()
self.block_constructor = block_constructor # Bottleneck_block_constructor
self.bcs = bottleneck_channels_setting # [64, 128, 256]
self.ils = identity_layers_setting # [3, 4, 6]
self.sss = stage_stride_setting # [1, 2, 2]
# stem
self.conv1 = nn.Conv2d(3, self.inplane, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplane)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, padding=1, stride=2)
# ResNet 3 stages
self.layer1 = self.make_stage_layer(self.block_constructor, self.bcs[0], self.ils[0], self.sss[0])
self.layer2 = self.make_stage_layer(self.block_constructor, self.bcs[1], self.ils[1], self.sss[1])
self.layer3 = self.make_stage_layer(self.block_constructor, self.bcs[2], self.ils[2], self.sss[2])
if self.fc_num_classes is not None:
self.avgpool = nn.AvgPool2d(24) # 224-14 384-24
self.fc = nn.Linear(self.bcs[-1] * self.block_constructor.extention, fc_num_classes)
def forward(self, x):
# stem:conv+bn+relu+maxpool
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
stem_out = self.maxpool(out)
# Resnet 3 stages
stage1_out = self.layer1(stem_out)
stage2_out = self.layer2(stage1_out)
stage3_out = self.layer3(stage2_out)
if self.fc_num_classes is not None:
fc_out = self.avgpool(stage3_out)
fc_out = torch.flatten(fc_out, 1)
fc_out = self.fc(fc_out)
if self.feature_idx == 'stages':
if self.fc_num_classes is not None:
return stage1_out, stage2_out, stage3_out, fc_out
else:
return stage1_out, stage2_out, stage3_out
elif self.feature_idx == 'features':
if self.fc_num_classes is not None:
return stem_out, stage1_out, stage2_out, stage3_out, fc_out
else:
return stem_out, stage1_out, stage2_out, stage3_out
else: # self.feature_idx is None
if self.fc_num_classes is not None:
return fc_out
else:
return stage3_out
def make_stage_layer(self, block_constractor, midplane, block_num, stride=1):
"""
block:
midplane:
block_num:
stride:
"""
block_list = []
outplane = midplane * block_constractor.extention # extention
if stride != 1 or self.inplane != outplane:
downsample = nn.Sequential(
nn.Conv2d(self.inplane, outplane, stride=stride, kernel_size=1, bias=False),
nn.BatchNorm2d(midplane * block_constractor.extention)
)
else:
downsample = None
# Conv Block
conv_block = block_constractor(self.inplane, midplane, stride=stride, downsample=downsample)
block_list.append(conv_block)
self.inplane = outplane
# Identity Block
for i in range(1, block_num):
block_list.append(block_constractor(self.inplane, midplane, stride=1, downsample=None))
return nn.Sequential(*block_list)
def Hybrid_a(backbone, img_size=224, patch_size=1, in_chans=3, num_classes=1000, embed_dim=768, depth=8,
num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=None, act_layer=None):
# directly stack CNNs and Transformer blocks
embed_layer = partial(Transformer_blocks.Hybrid_feature_map_Embed, backbone=backbone)
Hybrid_model = Transformer_blocks.VisionTransformer(img_size, patch_size, in_chans, num_classes, embed_dim, depth,
num_heads, mlp_ratio, qkv_bias, representation_size,
drop_rate, attn_drop_rate, drop_path_rate, embed_layer,
norm_layer, act_layer)
return Hybrid_model
def create_model(model_idx, edge_size, pretrained=True, num_classes=2, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., use_cls_token=True, use_pos_embedding=True, use_att_module='SimAM'):
"""
get one of MSHT models
:param model_idx: the model we are going to use. by the format of Model_size_other_info
:param edge_size: the input edge size of the dataloder
:param pretrained: The backbone CNN is initiate randomly or by its official Pretrained models
:param num_classes: classification required number of your dataset
:param drop_rate: The dropout layer's probility of proposed models
:param attn_drop_rate: The dropout layer(right after the MHSA block or MHGA block)'s probility of proposed models
:param drop_path_rate: The probility of stochastic depth
:param use_cls_token: To use the class token
:param use_pos_embedding: To use the positional enbedding
:param use_att_module: To use which attention module in the FGD Focus block
# use_att_module in ['SimAM', 'CBAM', 'SE'] different attention module we applied in the ablation study
:return: prepared model
"""
if pretrained:
from torchvision import models
backbone_weights = models.resnet50(pretrained=True).state_dict()
# True for pretrained Resnet50 model, False will randomly initiate
else:
backbone_weights = None
if model_idx[0:11] == 'Hybrid1_224' and edge_size == 224: # ablation study: no focus depth=8 edge_size == 224
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
bottleneck_channels_setting=[64, 128, 256, 512],
identity_layers_setting=[3, 4, 6, 3],
stage_stride_setting=[1, 2, 2, 2],
fc_num_classes=None,
feature_idx=None)
if pretrained:
try:
backbone.load_state_dict(backbone_weights, False)
except:
print("backbone not loaded")
else:
print("backbone loaded")
model = Hybrid_a(backbone, img_size=edge_size, patch_size=1, in_chans=3, num_classes=num_classes, embed_dim=768,
depth=8, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
norm_layer=None, act_layer=None)
elif model_idx[0:11] == 'Hybrid1_384' and edge_size == 384: # ablation study: no focus depth=8 edge_size == 384
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
bottleneck_channels_setting=[64, 128, 256, 512],
identity_layers_setting=[3, 4, 6, 3],
stage_stride_setting=[1, 2, 2, 2],
fc_num_classes=None,
feature_idx=None)
if pretrained:
try:
backbone.load_state_dict(backbone_weights, False)
except:
print("backbone not loaded")
else:
print("backbone loaded")
model = Hybrid_a(backbone, img_size=edge_size, patch_size=1, in_chans=3, num_classes=num_classes, embed_dim=768,
depth=8, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
norm_layer=None, act_layer=None)
elif model_idx[0:11] == 'Hybrid2_224' and edge_size == 224: # Proposed model ablation study: edge_size==224
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
bottleneck_channels_setting=[64, 128, 256, 512],
identity_layers_setting=[3, 4, 6, 3],
stage_stride_setting=[1, 2, 2, 2],
fc_num_classes=None,
feature_idx='stages')
if pretrained:
try:
backbone.load_state_dict(backbone_weights, False)
except:
print("backbone not loaded")
else:
print("backbone loaded")
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
use_cls_token=use_cls_token,
use_pos_embedding=use_pos_embedding,
use_att_module=use_att_module,
stage_size=(56, 28, 14, 7),
stage_dim=[256, 512, 1024, 2048])
elif model_idx[0:11] == 'Hybrid2_384' and edge_size == 384: # Proposed model 384 !!!
backbone = Hybrid_backbone_4(block_constructor=Bottleneck_block_constructor,
bottleneck_channels_setting=[64, 128, 256, 512],
identity_layers_setting=[3, 4, 6, 3],
stage_stride_setting=[1, 2, 2, 2],
fc_num_classes=None,
feature_idx='stages')
if pretrained:
try:
backbone.load_state_dict(backbone_weights, False)
except:
print("backbone not loaded")
else:
print("backbone loaded")
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
use_cls_token=use_cls_token,
use_pos_embedding=use_pos_embedding,
use_att_module=use_att_module,
stage_size=(96, 48, 24, 12),
stage_dim=[256, 512, 1024, 2048])
elif model_idx[0:11] == 'Hybrid3_224' and edge_size == 224: # Proposed model ablation study: edge_size==224
backbone = Hybrid_backbone_3(block_constructor=Bottleneck_block_constructor,
bottleneck_channels_setting=[64, 128, 256],
identity_layers_setting=[3, 4, 6],
stage_stride_setting=[1, 2, 2],
fc_num_classes=None,
feature_idx='stages')
if pretrained:
try:
backbone.load_state_dict(backbone_weights, False)
except:
print("backbone not loaded")
else:
print("backbone loaded")
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
use_cls_token=use_cls_token,
use_pos_embedding=use_pos_embedding,
use_att_module=use_att_module,
stage_size=(56, 28, 14),
stage_dim=[256, 512, 1024])
elif model_idx[0:11] == 'Hybrid3_384' and edge_size == 384: # Proposed model 384 !!!
backbone = Hybrid_backbone_3(block_constructor=Bottleneck_block_constructor,
bottleneck_channels_setting=[64, 128, 256],
identity_layers_setting=[3, 4, 6],
stage_stride_setting=[1, 2, 2],
fc_num_classes=None,
feature_idx='stages')
if pretrained:
try:
backbone.load_state_dict(backbone_weights, False)
except:
print("backbone not loaded")
else:
print("backbone loaded")
model = Transformer_blocks.Stage_wise_hybrid_Transformer(backbone, num_classes=num_classes,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
use_cls_token=use_cls_token,
use_pos_embedding=use_pos_embedding,
use_att_module=use_att_module,
stage_size=(96, 48, 24),
stage_dim=[256, 512, 1024])
else:
print('not a valid hybrid model')
return -1
return model