|
""" |
|
VPT Script ver: Oct 17th 14:30 |
|
|
|
based on |
|
timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from timm.models.vision_transformer import VisionTransformer, PatchEmbed |
|
|
|
|
|
class VPT_ViT(VisionTransformer): |
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, |
|
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., |
|
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, Prompt_Token_num=1, |
|
VPT_type="Shallow", basic_state_dict=None): |
|
|
|
|
|
super().__init__(img_size=img_size, patch_size=patch_size, in_chans=in_chans, num_classes=num_classes, |
|
embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, |
|
drop_path_rate=drop_path_rate, embed_layer=embed_layer, |
|
norm_layer=norm_layer, act_layer=act_layer) |
|
|
|
|
|
if basic_state_dict is not None: |
|
self.load_state_dict(basic_state_dict, False) |
|
|
|
self.VPT_type = VPT_type |
|
if VPT_type == "Deep": |
|
self.Prompt_Tokens = nn.Parameter(torch.zeros(depth, Prompt_Token_num, embed_dim)) |
|
else: |
|
self.Prompt_Tokens = nn.Parameter(torch.zeros(1, Prompt_Token_num, embed_dim)) |
|
|
|
def New_CLS_head(self, new_classes=15): |
|
if new_classes != 0: |
|
self.head = nn.Linear(self.embed_dim, new_classes) |
|
else: |
|
self.head = nn.Identity() |
|
|
|
def Freeze(self): |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
self.Prompt_Tokens.requires_grad = True |
|
try: |
|
for param in self.head.parameters(): |
|
param.requires_grad = True |
|
except: |
|
pass |
|
|
|
def UnFreeze(self): |
|
for param in self.parameters(): |
|
param.requires_grad = True |
|
|
|
def obtain_prompt(self): |
|
prompt_state_dict = {'head': self.head.state_dict(), |
|
'Prompt_Tokens': self.Prompt_Tokens} |
|
|
|
return prompt_state_dict |
|
|
|
def load_prompt(self, prompt_state_dict): |
|
try: |
|
self.head.load_state_dict(prompt_state_dict['head'], False) |
|
except: |
|
print('head not match, so skip head') |
|
else: |
|
print('prompt head match') |
|
|
|
if self.Prompt_Tokens.shape == prompt_state_dict['Prompt_Tokens'].shape: |
|
|
|
|
|
Prompt_Tokens = nn.Parameter(prompt_state_dict['Prompt_Tokens'].cpu()) |
|
Prompt_Tokens.to(torch.device(self.Prompt_Tokens.device)) |
|
|
|
self.Prompt_Tokens = Prompt_Tokens |
|
|
|
else: |
|
print('\n !!! cannot load prompt') |
|
print('shape of model req prompt', self.Prompt_Tokens.shape) |
|
print('shape of model given prompt', prompt_state_dict['Prompt_Tokens'].shape) |
|
print('') |
|
|
|
def forward_features(self, x): |
|
x = self.patch_embed(x) |
|
|
|
cls_token = self.cls_token.expand(x.shape[0], -1, -1) |
|
|
|
|
|
x = torch.cat((cls_token, x), dim=1) |
|
x = self.pos_drop(x + self.pos_embed) |
|
|
|
if self.VPT_type == "Deep": |
|
|
|
Prompt_Token_num = self.Prompt_Tokens.shape[1] |
|
|
|
for i in range(len(self.blocks)): |
|
|
|
Prompt_Tokens = self.Prompt_Tokens[i].unsqueeze(0) |
|
|
|
x = torch.cat((x, Prompt_Tokens.expand(x.shape[0], -1, -1)), dim=1) |
|
num_tokens = x.shape[1] |
|
|
|
x = self.blocks[i](x)[:, :num_tokens - Prompt_Token_num] |
|
|
|
else: |
|
Prompt_Token_num = self.Prompt_Tokens.shape[1] |
|
|
|
|
|
Prompt_Tokens = self.Prompt_Tokens.expand(x.shape[0], -1, -1) |
|
x = torch.cat((x, Prompt_Tokens), dim=1) |
|
num_tokens = x.shape[1] |
|
|
|
x = self.blocks(x)[:, :num_tokens - Prompt_Token_num] |
|
|
|
x = self.norm(x) |
|
return x |
|
|
|
def forward(self, x): |
|
|
|
x = self.forward_features(x) |
|
|
|
|
|
try: |
|
x = self.pre_logits(x[:, 0, :]) |
|
except: |
|
x = self.fc_norm(x[:, 0, :]) |
|
else: |
|
pass |
|
x = self.head(x) |
|
return x |
|
|