import os import torch from timm.models.vision_transformer import default_cfgs from timm.models.helpers import load_pretrained, load_custom_pretrained from src.models.vit.utils import checkpoint_filter_fn from src.models.vit.vit import VisionTransformer def create_vit(model_cfg): model_cfg = model_cfg.copy() backbone = model_cfg.pop("backbone") model_cfg.pop("normalization") model_cfg["n_cls"] = 1000 mlp_expansion_ratio = 4 model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"] if backbone in default_cfgs: default_cfg = default_cfgs[backbone] else: default_cfg = dict( pretrained=False, num_classes=1000, drop_rate=0.0, drop_path_rate=0.0, drop_block_rate=None, ) default_cfg["input_size"] = ( 3, model_cfg["image_size"][0], model_cfg["image_size"][1], ) model = VisionTransformer(**model_cfg) if backbone == "vit_base_patch8_384": path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth") state_dict = torch.load(path, map_location="cpu") filtered_dict = checkpoint_filter_fn(state_dict, model) model.load_state_dict(filtered_dict, strict=True) elif "deit" in backbone: load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn) else: load_custom_pretrained(model, default_cfg) return model