# import os | |
# import torch | |
# import pickle | |
# from .raw_vit import ViT | |
# def vit_b_16(pretrained_backbone=True): | |
# vit = ViT( | |
# image_size = 224, | |
# patch_size = 16, | |
# num_classes = 1000, | |
# dim = 768, # encoder layer/attention input/output size (Hidden Size D in the paper) | |
# depth = 12, | |
# heads = 12, # (Heads in the paper) | |
# dim_head = 64, # attention hidden size (seems be default, never change this) | |
# mlp_dim = 3072, # mlp layer hidden size (MLP size in the paper) | |
# dropout = 0., | |
# emb_dropout = 0. | |
# ) | |
# if pretrained_backbone: | |
# ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'weights/base_p16_224_backbone.pth')) | |
# vit.load_state_dict(ckpt) | |
# return vit | |
# def vit_l_16(pretrained_backbone=True): | |
# vit = ViT( | |
# image_size = 224, | |
# patch_size = 16, | |
# num_classes = 1000, | |
# dim = 1024, # encoder layer/attention input/output size (Hidden Size D in the paper) | |
# depth = 24, | |
# heads = 16, # (Heads in the paper) | |
# dim_head = 64, # attention hidden size (seems be default, never change this) | |
# mlp_dim = 4096, # mlp layer hidden size (MLP size in the paper) | |
# dropout = 0., | |
# emb_dropout = 0. | |
# ) | |
# if pretrained_backbone: | |
# # https://huggingface.co/timm/vit_large_patch16_224.augreg_in21k_ft_in1k | |
# ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'weights/pytorch_model.bin')) | |
# # ckpt = pickle.load(f) | |
# # print(ckpt) | |
# # exit() | |
# # ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'weights/large_p16_224_backbone.pth')) | |
# vit.load_state_dict(ckpt) | |
# # pass | |
# return vit | |
# def vit_h_16(): | |
# return ViT( | |
# image_size = 224, | |
# patch_size = 16, | |
# num_classes = 1000, | |
# dim = 1280, # encoder layer/attention input/output size (Hidden Size D in the paper) | |
# depth = 32, | |
# heads = 16, # (Heads in the paper) | |
# dim_head = 64, # attention hidden size (seems be default, never change this) | |
# mlp_dim = 5120, # mlp layer hidden size (MLP size in the paper) | |
# dropout = 0., | |
# emb_dropout = 0. | |
# ) | |