File size: 2,324 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# import os
# import torch
# import pickle

# from .raw_vit import ViT


# def vit_b_16(pretrained_backbone=True):
#     vit = ViT(
#         image_size = 224,
#         patch_size = 16,
#         num_classes = 1000,
#         dim = 768, # encoder layer/attention input/output size (Hidden Size D in the paper)
#         depth = 12,
#         heads = 12, # (Heads in the paper)
#         dim_head = 64, # attention hidden size (seems be default, never change this)
#         mlp_dim = 3072, # mlp layer hidden size (MLP size in the paper)
#         dropout = 0.,
#         emb_dropout = 0.
#     )
    
#     if pretrained_backbone:
#         ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'weights/base_p16_224_backbone.pth'))
#         vit.load_state_dict(ckpt)
#     return vit
    
    
# def vit_l_16(pretrained_backbone=True):
#     vit =  ViT(
#         image_size = 224,
#         patch_size = 16,
#         num_classes = 1000,
#         dim = 1024, # encoder layer/attention input/output size (Hidden Size D in the paper)
#         depth = 24,
#         heads = 16, # (Heads in the paper)
#         dim_head = 64, # attention hidden size (seems be default, never change this)
#         mlp_dim = 4096, # mlp layer hidden size (MLP size in the paper)
#         dropout = 0.,
#         emb_dropout = 0.
#     )
    
#     if pretrained_backbone:
#         # https://huggingface.co/timm/vit_large_patch16_224.augreg_in21k_ft_in1k
#         ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'weights/pytorch_model.bin'))
#             # ckpt = pickle.load(f)
#         # print(ckpt)
#         # exit()
#         # ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'weights/large_p16_224_backbone.pth'))
#         vit.load_state_dict(ckpt)
#         # pass
#     return vit
    
    
    
# def vit_h_16():
#     return ViT(
#         image_size = 224,
#         patch_size = 16,
#         num_classes = 1000,
#         dim = 1280, # encoder layer/attention input/output size (Hidden Size D in the paper)
#         depth = 32,
#         heads = 16, # (Heads in the paper)
#         dim_head = 64, # attention hidden size (seems be default, never change this)
#         mlp_dim = 5120, # mlp layer hidden size (MLP size in the paper)
#         dropout = 0.,
#         emb_dropout = 0.
#     )