|
import timm |
|
import os |
|
from Backbone.GetPromptModel import build_promptmodel |
|
from pprint import pprint |
|
|
|
|
|
def get_PuzzleTuning_VPT_model(num_classes=0, edge_size=224, prompt_state_dict=None, base_state_dict='timm'): |
|
""" |
|
:param num_classes: classification required number of your dataset, 0 for taking the feature |
|
:param edge_size: the input edge size of the dataloder |
|
:param model_idx: the model we are going to use. by the format of Model_size_other_info |
|
|
|
:param pretrained_backbone: The backbone CNN is initiate randomly or by its official Pretrained models |
|
|
|
:return: prepared model |
|
""" |
|
|
|
model = build_promptmodel( |
|
num_classes=0, |
|
edge_size=edge_size, model_idx='ViT', patch_size=16, |
|
Prompt_Token_num=20, VPT_type="Deep", |
|
prompt_state_dict=prompt_state_dict, |
|
base_state_dict=base_state_dict) |
|
|
|
return model |
|
|