LINC-BIT's picture
Upload 1912 files
b84549f verified
# import os
# import torch
# import pickle
# from .raw_vit import ViT
# def vit_b_16(pretrained_backbone=True):
# vit = ViT(
# image_size = 224,
# patch_size = 16,
# num_classes = 1000,
# dim = 768, # encoder layer/attention input/output size (Hidden Size D in the paper)
# depth = 12,
# heads = 12, # (Heads in the paper)
# dim_head = 64, # attention hidden size (seems be default, never change this)
# mlp_dim = 3072, # mlp layer hidden size (MLP size in the paper)
# dropout = 0.,
# emb_dropout = 0.
# )
# if pretrained_backbone:
# ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'weights/base_p16_224_backbone.pth'))
# vit.load_state_dict(ckpt)
# return vit
# def vit_l_16(pretrained_backbone=True):
# vit = ViT(
# image_size = 224,
# patch_size = 16,
# num_classes = 1000,
# dim = 1024, # encoder layer/attention input/output size (Hidden Size D in the paper)
# depth = 24,
# heads = 16, # (Heads in the paper)
# dim_head = 64, # attention hidden size (seems be default, never change this)
# mlp_dim = 4096, # mlp layer hidden size (MLP size in the paper)
# dropout = 0.,
# emb_dropout = 0.
# )
# if pretrained_backbone:
# # https://huggingface.co/timm/vit_large_patch16_224.augreg_in21k_ft_in1k
# ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'weights/pytorch_model.bin'))
# # ckpt = pickle.load(f)
# # print(ckpt)
# # exit()
# # ckpt = torch.load(os.path.join(os.path.dirname(__file__), 'weights/large_p16_224_backbone.pth'))
# vit.load_state_dict(ckpt)
# # pass
# return vit
# def vit_h_16():
# return ViT(
# image_size = 224,
# patch_size = 16,
# num_classes = 1000,
# dim = 1280, # encoder layer/attention input/output size (Hidden Size D in the paper)
# depth = 32,
# heads = 16, # (Heads in the paper)
# dim_head = 64, # attention hidden size (seems be default, never change this)
# mlp_dim = 5120, # mlp layer hidden size (MLP size in the paper)
# dropout = 0.,
# emb_dropout = 0.
# )