PuzzleTuning_VPT / PuzzleTuning /Get_PuzzleTuning_model.py
Tianyinus's picture
init submit
edcf5ee verified
import timm
import os
from Backbone.GetPromptModel import build_promptmodel
from pprint import pprint
def get_PuzzleTuning_VPT_model(num_classes=0, edge_size=224, prompt_state_dict=None, base_state_dict='timm'):
"""
:param num_classes: classification required number of your dataset, 0 for taking the feature
:param edge_size: the input edge size of the dataloder
:param model_idx: the model we are going to use. by the format of Model_size_other_info
:param pretrained_backbone: The backbone CNN is initiate randomly or by its official Pretrained models
:return: prepared model
"""
model = build_promptmodel(
num_classes=0, # set to feature extractor model, output is CLS token
edge_size=edge_size, model_idx='ViT', patch_size=16,
Prompt_Token_num=20, VPT_type="Deep",
prompt_state_dict=prompt_state_dict,
base_state_dict=base_state_dict)
return model