Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from timm.models.vision_transformer import default_cfgs | |
from timm.models.helpers import load_pretrained, load_custom_pretrained | |
from src.models.vit.utils import checkpoint_filter_fn | |
from src.models.vit.vit import VisionTransformer | |
def create_vit(model_cfg): | |
model_cfg = model_cfg.copy() | |
backbone = model_cfg.pop("backbone") | |
model_cfg.pop("normalization") | |
model_cfg["n_cls"] = 1000 | |
mlp_expansion_ratio = 4 | |
model_cfg["d_ff"] = mlp_expansion_ratio * model_cfg["d_model"] | |
if backbone in default_cfgs: | |
default_cfg = default_cfgs[backbone] | |
else: | |
default_cfg = dict( | |
pretrained=False, | |
num_classes=1000, | |
drop_rate=0.0, | |
drop_path_rate=0.0, | |
drop_block_rate=None, | |
) | |
default_cfg["input_size"] = ( | |
3, | |
model_cfg["image_size"][0], | |
model_cfg["image_size"][1], | |
) | |
model = VisionTransformer(**model_cfg) | |
if backbone == "vit_base_patch8_384": | |
path = os.path.expandvars("$TORCH_HOME/hub/checkpoints/vit_base_patch8_384.pth") | |
state_dict = torch.load(path, map_location="cpu") | |
filtered_dict = checkpoint_filter_fn(state_dict, model) | |
model.load_state_dict(filtered_dict, strict=True) | |
elif "deit" in backbone: | |
load_pretrained(model, default_cfg, filter_fn=checkpoint_filter_fn) | |
else: | |
load_custom_pretrained(model, default_cfg) | |
return model | |