# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han # International Conference on Computer Vision (ICCV), 2023 from efficientvit.models.efficientvit import ( EfficientViTSeg, efficientvit_seg_b0, efficientvit_seg_b1, efficientvit_seg_b2, efficientvit_seg_b3, efficientvit_seg_l1, efficientvit_seg_l2, ) from efficientvit.models.nn.norm import set_norm_eps from efficientvit.models.utils import load_state_dict_from_file __all__ = ["create_seg_model"] REGISTERED_SEG_MODEL: dict[str, dict[str, str]] = { "cityscapes": { "b0": "assets/checkpoints/seg/cityscapes/b0.pt", "b1": "assets/checkpoints/seg/cityscapes/b1.pt", "b2": "assets/checkpoints/seg/cityscapes/b2.pt", "b3": "assets/checkpoints/seg/cityscapes/b3.pt", ################################################ "l1": "assets/checkpoints/seg/cityscapes/l1.pt", "l2": "assets/checkpoints/seg/cityscapes/l2.pt", }, "ade20k": { "b1": "assets/checkpoints/seg/ade20k/b1.pt", "b2": "assets/checkpoints/seg/ade20k/b2.pt", "b3": "assets/checkpoints/seg/ade20k/b3.pt", ################################################ "l1": "assets/checkpoints/seg/ade20k/l1.pt", "l2": "assets/checkpoints/seg/ade20k/l2.pt", }, } def create_seg_model( name: str, dataset: str, pretrained=True, weight_url: str or None = None, **kwargs ) -> EfficientViTSeg: model_dict = { "b0": efficientvit_seg_b0, "b1": efficientvit_seg_b1, "b2": efficientvit_seg_b2, "b3": efficientvit_seg_b3, ######################### "l1": efficientvit_seg_l1, "l2": efficientvit_seg_l2, } model_id = name.split("-")[0] if model_id not in model_dict: raise ValueError(f"Do not find {name} in the model zoo. List of models: {list(model_dict.keys())}") else: model = model_dict[model_id](dataset=dataset, **kwargs) if model_id in ["l1", "l2"]: set_norm_eps(model, 1e-7) if pretrained: weight_url = weight_url or REGISTERED_SEG_MODEL[dataset].get(name, None) if weight_url is None: raise ValueError(f"Do not find the pretrained weight of {name}.") else: weight = load_state_dict_from_file(weight_url) model.load_state_dict(weight) return model