File size: 3,246 Bytes
edcf5ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
build_promptmodel   Script  ver: Oct 17th 14:20

"""

try:
    from Backbone.VPT_structure import *
except:
    from Backbone.VPT_structure import *


def build_promptmodel(num_classes=1000, edge_size=224, model_idx='ViT', patch_size=16,
                      Prompt_Token_num=20, VPT_type="Deep", prompt_state_dict=None, base_state_dict='timm'):
    """
    following the https://github.com/sagizty/VPT
    this build the VPT (prompt version of ViT), with additional prompt tokens,
    each layer the information become [B, N_patch + N_prompt, Dim]

    During training only the prompt tokens and the head layer are
    set to be learnable while the rest of Transformer layers are frozen

    # VPT_type = "Shallow" / "Deep"
        - Shallow: concatenate N_prompt of prompt tokens before the first Transformer Encoder block,
                each layer the information become [B, N_patch + N_prompt, Dim]
        - Deep: concatenate N_prompt of prompt tokens to each Transformer Encoder block,
                this will replace the output prompt tokens learnt form previous encoder.
    """

    if model_idx[0:3] == 'ViT':

        if base_state_dict is None:
            basic_state_dict = None

        elif type(base_state_dict) == str:
            if base_state_dict == 'timm':
                # ViT_Prompt
                import timm
                # from pprint import pprint
                # model_names = timm.list_models('*vit*')
                # pprint(model_names)

                basic_model = timm.create_model('vit_base_patch' + str(patch_size) + '_' + str(edge_size),
                                                pretrained=True)
                basic_state_dict = basic_model.state_dict()
                print('in prompt model building, timm ViT loaded for base_state_dict')

            else:
                basic_state_dict = None
                print('in prompt model building, no vaild str for base_state_dict')

        else:  # state dict: collections.OrderedDict
            basic_state_dict = base_state_dict
            print('in prompt model building, a .pth base_state_dict loaded')

        model = VPT_ViT(img_size=edge_size, patch_size=patch_size, Prompt_Token_num=Prompt_Token_num,
                        VPT_type=VPT_type, basic_state_dict=basic_state_dict)

        model.New_CLS_head(num_classes)

        if prompt_state_dict is not None:
            try:
                model.load_prompt(prompt_state_dict)
            except:
                print('erro in .pth prompt_state_dict')
            else:
                print('in prompt model building, a .pth prompt_state_dict loaded')

        model.Freeze()
    else:
        print("The model is not difined in the Prompt script!!")
        return -1

    try:
        img = torch.randn(1, 3, edge_size, edge_size)
        preds = model(img)  # (1, class_number)
        print('Build VPT model with in/out shape: ', img.shape, ' -> ', preds.shape)

    except:
        print("Problem exist in the model defining process!!")
        return -1
    else:
        print('model is ready now!')
        return model


if __name__ == '__main__':
    model = build_promptmodel(prompt_state_dict=None, base_state_dict='timm', num_classes=0)