File size: 936 Bytes
edcf5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import timm
import os
from Backbone.GetPromptModel import build_promptmodel
from pprint import pprint


def get_PuzzleTuning_VPT_model(num_classes=0, edge_size=224, prompt_state_dict=None, base_state_dict='timm'):
    """
    :param num_classes: classification required number of your dataset, 0 for taking the feature
    :param edge_size: the input edge size of the dataloder
    :param model_idx: the model we are going to use. by the format of Model_size_other_info

    :param pretrained_backbone: The backbone CNN is initiate randomly or by its official Pretrained models

    :return: prepared model
    """

    model = build_promptmodel(
        num_classes=0,  # set to feature extractor model, output is CLS token
        edge_size=edge_size, model_idx='ViT', patch_size=16,
        Prompt_Token_num=20, VPT_type="Deep",
        prompt_state_dict=prompt_state_dict,
        base_state_dict=base_state_dict)

    return model