File size: 4,738 Bytes
edcf5ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""
VPT Script ver: Oct 17th 14:30
based on
timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
"""
import torch
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer, PatchEmbed
class VPT_ViT(VisionTransformer):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, Prompt_Token_num=1,
VPT_type="Shallow", basic_state_dict=None):
# Recreate ViT
super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes,
embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate, embed_layer=embed_layer,
norm_layer=norm_layer, act_layer=act_layer)
# load basic state_dict
if basic_state_dict is not None:
self.load_state_dict(basic_state_dict, False)
self.VPT_type = VPT_type
if VPT_type == "Deep":
self.Prompt_Tokens = nn.Parameter(torch.zeros(depth, Prompt_Token_num, embed_dim))
else: # "Shallow"
self.Prompt_Tokens = nn.Parameter(torch.zeros(1, Prompt_Token_num, embed_dim))
def New_CLS_head(self, new_classes=15):
if new_classes != 0:
self.head = nn.Linear(self.embed_dim, new_classes)
else:
self.head = nn.Identity()
def Freeze(self):
for param in self.parameters():
param.requires_grad = False
self.Prompt_Tokens.requires_grad = True
try:
for param in self.head.parameters():
param.requires_grad = True
except:
pass
def UnFreeze(self):
for param in self.parameters():
param.requires_grad = True
def obtain_prompt(self):
prompt_state_dict = {'head': self.head.state_dict(),
'Prompt_Tokens': self.Prompt_Tokens}
# print(prompt_state_dict)
return prompt_state_dict
def load_prompt(self, prompt_state_dict):
try:
self.head.load_state_dict(prompt_state_dict['head'], False)
except:
print('head not match, so skip head')
else:
print('prompt head match')
if self.Prompt_Tokens.shape == prompt_state_dict['Prompt_Tokens'].shape:
# device check
Prompt_Tokens = nn.Parameter(prompt_state_dict['Prompt_Tokens'].cpu())
Prompt_Tokens.to(torch.device(self.Prompt_Tokens.device))
self.Prompt_Tokens = Prompt_Tokens
else:
print('\n !!! cannot load prompt')
print('shape of model req prompt', self.Prompt_Tokens.shape)
print('shape of model given prompt', prompt_state_dict['Prompt_Tokens'].shape)
print('')
def forward_features(self, x):
x = self.patch_embed(x)
# print(x.shape,self.pos_embed.shape)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
# concatenate CLS token
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
if self.VPT_type == "Deep":
Prompt_Token_num = self.Prompt_Tokens.shape[1]
for i in range(len(self.blocks)):
# concatenate Prompt_Tokens
Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0)
# firstly concatenate
x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1)
num_tokens = x.shape[1]
# lastly remove, a genius trick
x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num]
else: # self.VPT_type == "Shallow"
Prompt_Token_num = self.Prompt_Tokens.shape[1]
# concatenate Prompt_Tokens
Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1)
x = torch.cat((x, Prompt_Tokens), dim=1)
num_tokens = x.shape[1]
# Sequntially procees
x = self.blocks(x)[:, :num_tokens - Prompt_Token_num]
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
# use cls token for cls head
try:
x = self.pre_logits(x[:, 0, :])
except:
x = self.fc_norm(x[:, 0, :])
else:
pass
x = self.head(x)
return x
|