File size: 3,308 Bytes
6e9c433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b8ec7b
6e9c433
 
 
9b8ec7b
9a0bf16
 
 
 
9b8ec7b
6e9c433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from typing import Tuple
import streamlit as st
import os
import torch

from .app_env import SOD_MODEL_TYPE
from .app_utils import count_parameters
from .smultimae_model import RGBDSMultiMAEModel
from .base_model import BaseRGBDModel
from .device import device

from s_multimae.da.dav6 import DataAugmentationV6
from s_multimae.configs.base_config import base_cfg
from s_multimae.configs.experiment_config import arg_cfg
from s_multimae.model_pl import ModelPL

# from spnet_model import SPNetModel


@st.cache_resource
def load_smultimae_model(
    sod_model_config_key: str, top: int
) -> Tuple[BaseRGBDModel, base_cfg]:
    """
    1. Construct model
    2. Load pretrained weights
    3. Load model into device
    """
    cfg = arg_cfg[sod_model_config_key]()

    weights_fname = f"s-multimae-{cfg.experiment_name}-top{top}.pth"
    ckpt_path = os.path.join("weights", weights_fname)
    print(ckpt_path)
    if not os.path.isfile(ckpt_path):
        from huggingface_hub import hf_hub_download

        hf_hub_download(
            repo_id="RGBD-SOD/S-MultiMAE",
            filename=weights_fname,
            local_dir="weights",
        )
    assert os.path.isfile(ckpt_path)

    # sod_model = ModelPL.load_from_checkpoint(
    #     ckpt_path,
    #     cfg=cfg,
    #     map_location=device,
    # )
    sod_model = ModelPL(cfg)
    sod_model.model.load_state_dict(
        torch.load(ckpt_path, map_location="cpu"), strict=False
    )
    da = DataAugmentationV6(cfg)
    return RGBDSMultiMAEModel(cfg, sod_model), cfg, da


# @st.cache_resource
# def load_spnet_model() -> BaseRGBDModel:
#     """
#     1. Construct model
#     2. Load pretrained weights
#     3. Load model into device
#     """
#     sod_model = SPNetModel()
#     return sod_model


# @st.cache_resource
# def load_bbsnet_model() -> BaseRGBDModel:
#     """
#     1. Construct model
#     2. Load pretrained weights
#     3. Load model into device
#     """
#     sod_model = BBSNetModel()
#     return sod_model


def sod_selection_ui() -> BaseRGBDModel:
    sod_model_type = st.selectbox(
        "Choose SOD model",
        (
            SOD_MODEL_TYPE.S_MULTIMAE,
            # SOD_MODEL_TYPE.SPNET,
            # SOD_MODEL_TYPE.BBSNET,
        ),
        key="sod_model_type",
    )

    if sod_model_type == SOD_MODEL_TYPE.S_MULTIMAE:
        d = {
            "S-MultiMAE [ViT-L] Multi-GT": {"top": 1, "cfg": "cfgv4_0_2006"},
            "S-MultiMAE [ViT-B] Multi-GT": {"top": 1, "cfg": "cfgv4_0_2007"},
        }

        sod_model_config_key = st.selectbox(
            "Choose config",
            list(d.keys()),
            key="sod_model_config_key",
        )
        sod_model, cfg, da = load_smultimae_model(
            d[sod_model_config_key]["cfg"], d[sod_model_config_key]["top"]
        )
        # st.text(f"Model description: {cfg.description}")
    # elif sod_model_type == SOD_MODEL_TYPE.SPNET:
    #     sod_model = load_spnet_model()
    #     st.text(f"Model description: SPNet (https://github.com/taozh2017/SPNet)")
    # elif sod_model_type == SOD_MODEL_TYPE.BBSNET:
    #     sod_model = load_bbsnet_model()
    #     st.text(f"Model description: BBSNet (https://github.com/DengPingFan/BBS-Net)")
    st.text(f"Number of parameters {count_parameters(sod_model)}")

    return sod_model, da