SwinTExCo / src /models /vit /factory.py
duongttr's picture
Upload folder using huggingface_hub
62ef5f4
import os
import torch
from timm.models.vision_transformer import default_cfgs
from timm.models.helpers import load_pretrained, load_custom_pretrained
from src.models.vit.utils import checkpoint_filter_fn
from src.models.vit.vit import VisionTransformer
def create_vit(model_cfg):
model_cfg = model_cfg.copy()
backbone = model_cfg.pop("backbone")
model_cfg.pop("normalization")
model_cfg["n_cls"] = 1000
mlp_expansion_ratio = 4
model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"]
if backbone in default_cfgs:
default_cfg = default_cfgs[backbone]
else:
default_cfg = dict(
pretrained=False,
num_classes=1000,
drop_rate=0.0,
drop_path_rate=0.0,
drop_block_rate=None,
)
default_cfg["input_size"] = (
3,
model_cfg["image_size"][0],
model_cfg["image_size"][1],
)
model = VisionTransformer(**model_cfg)
if backbone == "vit_base_patch8_384":
path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth")
state_dict = torch.load(path, map_location="cpu")
filtered_dict = checkpoint_filter_fn(state_dict, model)
model.load_state_dict(filtered_dict, strict=True)
elif "deit" in backbone:
load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn)
else:
load_custom_pretrained(model, default_cfg)
return model