Tianyinus's picture
init submit
edcf5ee verified
"""
get model func Script ver: Dec 5th 14:20
"""
import os
import sys
sys.path.append(os.path.realpath('.'))
import torch
import torch.nn as nn
from torchvision import models
from Backbone import ResHybrid
# get model
def get_model(num_classes=1000, edge_size=224, model_idx=None, drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0,
pretrained_backbone=True, use_cls_token=True, use_pos_embedding=True, use_att_module='SimAM'):
"""
:param num_classes: classification required number of your dataset
:param edge_size: the input edge size of the dataloder
:param model_idx: the model we are going to use. by the format of Model_size_other_info
:param drop_rate: The dropout layer's probility of proposed models
:param attn_drop_rate: The dropout layer(right after the MHSA block or MHGA block)'s probility of proposed models
:param drop_path_rate: The probility of stochastic depth
:param pretrained_backbone: The backbone CNN is initiate randomly or by its official Pretrained models
:param use_cls_token: To use the class token
:param use_pos_embedding: To use the positional enbedding
:param use_att_module: To use which attention module in the FGD Focus block
:return: prepared model
"""
if model_idx[0:5] == 'ViT_h':
# Transfer learning for ViT
import timm
from pprint import pprint
model_names = timm.list_models('*vit*')
pprint(model_names)
if edge_size == 224:
model = timm.create_model('vit_huge_patch14_224_in21k', pretrained=pretrained_backbone, num_classes=num_classes)
else:
print('not a avaliable image size with', model_idx)
elif model_idx[0:5] == 'ViT_l':
# Transfer learning for ViT
import timm
from pprint import pprint
model_names = timm.list_models('*vit*')
pprint(model_names)
if edge_size == 224:
model = timm.create_model('vit_large_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
elif edge_size == 384:
model = timm.create_model('vit_large_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
else:
print('not a avaliable image size with', model_idx)
elif model_idx[0:5] == 'ViT_s':
# Transfer learning for ViT
import timm
from pprint import pprint
model_names = timm.list_models('*vit*')
pprint(model_names)
if edge_size == 224:
model = timm.create_model('vit_small_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
elif edge_size == 384:
model = timm.create_model('vit_small_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
else:
print('not a avaliable image size with', model_idx)
elif model_idx[0:5] == 'ViT_t':
# Transfer learning for ViT
import timm
from pprint import pprint
model_names = timm.list_models('*vit*')
pprint(model_names)
if edge_size == 224:
model = timm.create_model('vit_tiny_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
elif edge_size == 384:
model = timm.create_model('vit_tiny_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
else:
print('not a avaliable image size with', model_idx)
elif model_idx[0:5] == 'ViT_b' or model_idx[0:3] == 'ViT': # vit_base
# Transfer learning for ViT
import timm
from pprint import pprint
model_names = timm.list_models('*vit*')
pprint(model_names)
if edge_size == 224:
model = timm.create_model('vit_base_patch16_224', pretrained=pretrained_backbone, num_classes=num_classes)
elif edge_size == 384:
model = timm.create_model('vit_base_patch16_384', pretrained=pretrained_backbone, num_classes=num_classes)
else:
print('not a avaliable image size with', model_idx)
elif model_idx[0:3] == 'vgg':
# Transfer learning for vgg16_bn
import timm
from pprint import pprint
model_names = timm.list_models('*vgg*')
pprint(model_names)
if model_idx[0:8] == 'vgg16_bn':
model = timm.create_model('vgg16_bn', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:5] == 'vgg16':
model = timm.create_model('vgg16', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:8] == 'vgg19_bn':
model = timm.create_model('vgg19_bn', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:5] == 'vgg19':
model = timm.create_model('vgg19', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:4] == 'deit': # Transfer learning for DeiT
import timm
from pprint import pprint
model_names = timm.list_models('*deit*')
pprint(model_names)
if edge_size == 384:
model = timm.create_model('deit_base_patch16_384', pretrained=pretrained_backbone, num_classes=2)
elif edge_size == 224:
model = timm.create_model('deit_base_patch16_224', pretrained=pretrained_backbone, num_classes=2)
else:
pass
elif model_idx[0:5] == 'twins': # Transfer learning for twins
import timm
from pprint import pprint
model_names = timm.list_models('*twins*')
pprint(model_names)
model = timm.create_model('twins_pcpvt_base', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:5] == 'pit_b' and edge_size == 224: # Transfer learning for PiT
import timm
from pprint import pprint
model_names = timm.list_models('*pit*')
pprint(model_names)
model = timm.create_model('pit_b_224', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:5] == 'gcvit' and edge_size == 224: # Transfer learning for gcvit
import timm
from pprint import pprint
model_names = timm.list_models('*gcvit*')
pprint(model_names)
model = timm.create_model('gcvit_base', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:6] == 'xcit_s': # Transfer learning for XCiT
import timm
from pprint import pprint
model_names = timm.list_models('*xcit*')
pprint(model_names)
if edge_size == 384:
model = timm.create_model('xcit_small_12_p16_384_dist', pretrained=pretrained_backbone,
num_classes=num_classes)
elif edge_size == 224:
model = timm.create_model('xcit_small_12_p16_224_dist', pretrained=pretrained_backbone,
num_classes=num_classes)
else:
pass
elif model_idx[0:6] == 'xcit_m': # Transfer learning for XCiT
import timm
from pprint import pprint
model_names = timm.list_models('*xcit*')
pprint(model_names)
if edge_size == 384:
model = timm.create_model('xcit_medium_24_p16_384_dist', pretrained=pretrained_backbone,
num_classes=num_classes)
elif edge_size == 224:
model = timm.create_model('xcit_medium_24_p16_224_dist', pretrained=pretrained_backbone,
num_classes=num_classes)
else:
pass
elif model_idx[0:6] == 'mvitv2': # Transfer learning for MViT v2 small fixme bug in model!
import timm
from pprint import pprint
model_names = timm.list_models('*mvitv2*')
pprint(model_names)
model = timm.create_model('mvitv2_small_cls', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:6] == 'convit' and edge_size == 224: # Transfer learning for ConViT fixme bug in model!
import timm
from pprint import pprint
model_names = timm.list_models('*convit*')
pprint(model_names)
model = timm.create_model('convit_base', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:6] == 'ResNet': # Transfer learning for the ResNets
if model_idx[0:8] == 'ResNet34':
model = models.resnet34(pretrained=pretrained_backbone)
elif model_idx[0:8] == 'ResNet50':
model = models.resnet50(pretrained=pretrained_backbone)
elif model_idx[0:9] == 'ResNet101':
model = models.resnet101(pretrained=pretrained_backbone)
else:
print('this model is not defined in get model')
return -1
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
elif model_idx[0:6] == 'Backbone': # ours: MSHT
# NOTICE: HERE 'pretrained' controls only The backbone CNN is initiate randomly
# or by its official Pretrained models
model = ResHybrid.create_model(model_idx, edge_size, pretrained=pretrained_backbone, num_classes=num_classes,
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate, use_cls_token=use_cls_token,
use_pos_embedding=use_pos_embedding, use_att_module=use_att_module)
elif model_idx[0:7] == 'bot_256' and edge_size == 256: # Model: BoT
import timm
from pprint import pprint
model_names = timm.list_models('*bot*')
pprint(model_names)
# NOTICE: we find no weight for BoT in timm
# ['botnet26t_256', 'botnet50ts_256', 'eca_botnext26ts_256']
model = timm.create_model('botnet26t_256', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:8] == 'densenet': # Transfer learning for densenet
import timm
from pprint import pprint
model_names = timm.list_models('*densenet*')
pprint(model_names)
model = timm.create_model('densenet121', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:8] == 'xception': # Transfer learning for Xception
import timm
from pprint import pprint
model_names = timm.list_models('*xception*')
pprint(model_names)
model = timm.create_model('xception', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:9] == 'pvt_v2_b0': # Transfer learning for PVT v2 (todo not okey with torch summary)
import timm
from pprint import pprint
model_names = timm.list_models('*pvt_v2*')
pprint(model_names)
model = timm.create_model('pvt_v2_b0', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:9] == 'visformer' and edge_size == 224: # Transfer learning for Visformer
import timm
from pprint import pprint
model_names = timm.list_models('*visformer*')
pprint(model_names)
model = timm.create_model('visformer_small', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:9] == 'conformer': # Transfer learning for Conformer base
from Backbone.counterpart_models import conformer
embed_dim = 576
channel_ratio = 6
if pretrained_backbone:
model = conformer.Conformer(num_classes=1000, patch_size=16, channel_ratio=channel_ratio,
embed_dim=embed_dim, depth=12, num_heads=9, mlp_ratio=4, qkv_bias=True)
# this is the related path to <code>, not <Backbone>
save_model_path = '../saved_models/Conformer_base_patch16.pth' # fixme model is downloaded at this path
# downloaded from official model state at https://github.com/pengzhiliang/Conformer
model.load_state_dict(torch.load(save_model_path), False)
model.trans_cls_head = nn.Linear(embed_dim, num_classes)
model.conv_cls_head = nn.Linear(int(256 * channel_ratio), num_classes)
model.cls_head = nn.Linear(int(2 * num_classes), num_classes)
else:
model = conformer.Conformer(num_classes=num_classes, patch_size=16, channel_ratio=channel_ratio,
embed_dim=embed_dim, depth=12, num_heads=9, mlp_ratio=4, qkv_bias=True)
elif model_idx[0:9] == 'coat_mini' and edge_size == 224: # Transfer learning for coat_mini
import timm
from pprint import pprint
model_names = timm.list_models('*coat*')
pprint(model_names)
model = timm.create_model('coat_mini', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:10] == 'swin_b_384' and edge_size == 384: # Transfer learning for Swin Transformer (swin_b_384)
import timm
from pprint import pprint
model_names = timm.list_models('*swin*')
pprint(model_names) # swin_base_patch4_window12_384 swin_base_patch4_window12_384_in22k
model = timm.create_model('swin_base_patch4_window12_384', pretrained=pretrained_backbone,
num_classes=num_classes)
elif model_idx[0:10] == 'swin_b_224' and edge_size == 224: # Transfer learning for Swin Transformer (swin_b_384)
import timm
from pprint import pprint
model_names = timm.list_models('*swin*')
pprint(model_names) # swin_base_patch4_window7_224 swin_base_patch4_window7_224_in22k
model = timm.create_model('swin_base_patch4_window7_224', pretrained=pretrained_backbone,
num_classes=num_classes)
elif model_idx[0:11] == 'mobilenetv3': # Transfer learning for mobilenetv3
import timm
from pprint import pprint
model_names = timm.list_models('*mobilenet*')
pprint(model_names)
model = timm.create_model('mobilenetv3_large_100', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:11] == 'mobilevit_s': # Transfer learning for mobilevit_s
import timm
from pprint import pprint
model_names = timm.list_models('*mobilevit*')
pprint(model_names)
model = timm.create_model('mobilevit_s', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:11] == 'inceptionv3': # Transfer learning for Inception v3
import timm
from pprint import pprint
model_names = timm.list_models('*inception*')
pprint(model_names)
model = timm.create_model('inception_v3', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:12] == 'cross_former' and edge_size == 224: # Transfer learning for crossformer base
from Backbone.counterpart_models import crossformer
backbone = crossformer.CrossFormer(img_size=edge_size,
patch_size=[4, 8, 16, 32],
in_chans=3,
num_classes=0, # get backbone only
embed_dim=96,
depths=[2, 2, 18, 2],
num_heads=[3, 6, 12, 24],
group_size=[7, 7, 7, 7],
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
drop_path_rate=0.3,
ape=False,
patch_norm=True,
use_checkpoint=False,
merge_size=[[2, 4], [2, 4], [2, 4]], )
if pretrained_backbone:
save_model_path = '../saved_models/crossformer-b.pth' # fixme model is downloaded at this path
# downloaded from official model state at https://github.com/cheerss/CrossFormer
backbone.load_state_dict(torch.load(save_model_path)['model'], False)
model = crossformer.cross_former_cls_head_warp(backbone, num_classes)
elif model_idx[0:13] == 'crossvit_base': # Transfer learning for crossvit_base (todo not okey with torch summary)
import timm
from pprint import pprint
model_names = timm.list_models('*crossvit_base*')
pprint(model_names)
model = timm.create_model('crossvit_base_240', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:14] == 'efficientnet_b': # Transfer learning for efficientnet_b3,4
import timm
from pprint import pprint
model_names = timm.list_models('*efficientnet*')
pprint(model_names)
model = timm.create_model(model_idx[0:15], pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:14] == 'ResN50_ViT_384': # ResNet+ViT融合模型384
import timm
from pprint import pprint
model_names = timm.list_models('*vit_base_resnet*')
pprint(model_names)
model = timm.create_model('vit_base_resnet50_384', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:15] == 'coat_lite_small' and edge_size == 224: # Transfer learning for coat_lite_small
import timm
from pprint import pprint
model_names = timm.list_models('*coat*')
pprint(model_names)
model = timm.create_model('coat_lite_small', pretrained=pretrained_backbone, num_classes=num_classes)
elif model_idx[0:17] == 'efficientformer_l' and edge_size == 224: # Transfer learning for efficientnet_b3,4
import timm
from pprint import pprint
model_names = timm.list_models('*efficientformer*')
pprint(model_names)
model = timm.create_model(model_idx[0:18], pretrained=pretrained_backbone, num_classes=num_classes)
else:
print('\nThe model', model_idx, 'with the edge size of', edge_size)
print("is not defined in the script!!", '\n')
return -1
try:
img = torch.randn(1, 3, edge_size, edge_size)
preds = model(img) # (1, class_number)
print('test model output:', preds)
except:
print("Problem exist in the model defining process!!")
return -1
else:
print('model is ready now!')
return model