Tianyinus's picture
init submit
edcf5ee verified
"""
VPT Script ver: Oct 17th 14:30
based on
timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
"""
import torch
import torch.nn as nn
from timm.models.vision_transformer import VisionTransformer, PatchEmbed
class VPT_ViT(VisionTransformer):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, Prompt_Token_num=1,
VPT_type="Shallow", basic_state_dict=None):
# Recreate ViT
super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes,
embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate, embed_layer=embed_layer,
norm_layer=norm_layer, act_layer=act_layer)
# load basic state_dict
if basic_state_dict is not None:
self.load_state_dict(basic_state_dict, False)
self.VPT_type = VPT_type
if VPT_type == "Deep":
self.Prompt_Tokens = nn.Parameter(torch.zeros(depth, Prompt_Token_num, embed_dim))
else: # "Shallow"
self.Prompt_Tokens = nn.Parameter(torch.zeros(1, Prompt_Token_num, embed_dim))
def New_CLS_head(self, new_classes=15):
if new_classes != 0:
self.head = nn.Linear(self.embed_dim, new_classes)
else:
self.head = nn.Identity()
def Freeze(self):
for param in self.parameters():
param.requires_grad = False
self.Prompt_Tokens.requires_grad = True
try:
for param in self.head.parameters():
param.requires_grad = True
except:
pass
def UnFreeze(self):
for param in self.parameters():
param.requires_grad = True
def obtain_prompt(self):
prompt_state_dict = {'head': self.head.state_dict(),
'Prompt_Tokens': self.Prompt_Tokens}
# print(prompt_state_dict)
return prompt_state_dict
def load_prompt(self, prompt_state_dict):
try:
self.head.load_state_dict(prompt_state_dict['head'], False)
except:
print('head not match, so skip head')
else:
print('prompt head match')
if self.Prompt_Tokens.shape == prompt_state_dict['Prompt_Tokens'].shape:
# device check
Prompt_Tokens = nn.Parameter(prompt_state_dict['Prompt_Tokens'].cpu())
Prompt_Tokens.to(torch.device(self.Prompt_Tokens.device))
self.Prompt_Tokens = Prompt_Tokens
else:
print('\n !!! cannot load prompt')
print('shape of model req prompt', self.Prompt_Tokens.shape)
print('shape of model given prompt', prompt_state_dict['Prompt_Tokens'].shape)
print('')
def forward_features(self, x):
x = self.patch_embed(x)
# print(x.shape,self.pos_embed.shape)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
# concatenate CLS token
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
if self.VPT_type == "Deep":
Prompt_Token_num = self.Prompt_Tokens.shape[1]
for i in range(len(self.blocks)):
# concatenate Prompt_Tokens
Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0)
# firstly concatenate
x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1)
num_tokens = x.shape[1]
# lastly remove, a genius trick
x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num]
else: # self.VPT_type == "Shallow"
Prompt_Token_num = self.Prompt_Tokens.shape[1]
# concatenate Prompt_Tokens
Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1)
x = torch.cat((x, Prompt_Tokens), dim=1)
num_tokens = x.shape[1]
# Sequntially procees
x = self.blocks(x)[:, :num_tokens - Prompt_Token_num]
x = self.norm(x)
return x
def forward(self, x):
x = self.forward_features(x)
# use cls token for cls head
try:
x = self.pre_logits(x[:, 0, :])
except:
x = self.fc_norm(x[:, 0, :])
else:
pass
x = self.head(x)
return x