File size: 3,383 Bytes
2cdd41c
 
 
1615d09
 
 
 
 
2cdd41c
1615d09
 
 
 
 
 
 
 
 
 
 
2cdd41c
1615d09
 
 
 
2cdd41c
 
1615d09
2cdd41c
 
 
 
 
 
 
 
1615d09
2cdd41c
 
 
 
 
1615d09
2cdd41c
1615d09
2cdd41c
 
1615d09
 
 
 
 
 
 
 
 
2cdd41c
 
 
 
1615d09
 
 
2cdd41c
 
1615d09
 
 
 
 
 
2cdd41c
 
 
1615d09
 
 
2cdd41c
 
 
1615d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
 
 
1615d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from isegm.inference.transforms import ZoomIn
from isegm.model.is_hrnet_model import HRNetModel

from .base import BasePredictor
from .brs import (FeatureBRSPredictor, HRNetFeatureBRSPredictor,
                  InputBRSPredictor)
from .brs_functors import InputOptimizer, ScaleBiasOptimizer


def get_predictor(
    net,
    brs_mode,
    device,
    prob_thresh=0.49,
    with_flip=True,
    zoom_in_params=dict(),
    predictor_params=None,
    brs_opt_func_params=None,
    lbfgs_params=None,
):
    lbfgs_params_ = {
        "m": 20,
        "factr": 0,
        "pgtol": 1e-8,
        "maxfun": 20,
    }

    predictor_params_ = {"optimize_after_n_clicks": 1}

    if zoom_in_params is not None:
        zoom_in = ZoomIn(**zoom_in_params)
    else:
        zoom_in = None

    if lbfgs_params is not None:
        lbfgs_params_.update(lbfgs_params)
    lbfgs_params_["maxiter"] = 2 * lbfgs_params_["maxfun"]

    if brs_opt_func_params is None:
        brs_opt_func_params = dict()

    if isinstance(net, (list, tuple)):
        assert brs_mode == "NoBRS", "Multi-stage models support only NoBRS mode."

    if brs_mode == "NoBRS":
        if predictor_params is not None:
            predictor_params_.update(predictor_params)
        predictor = BasePredictor(
            net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_
        )
    elif brs_mode.startswith("f-BRS"):
        predictor_params_.update(
            {
                "net_clicks_limit": 8,
            }
        )
        if predictor_params is not None:
            predictor_params_.update(predictor_params)

        insertion_mode = {
            "f-BRS-A": "after_c4",
            "f-BRS-B": "after_aspp",
            "f-BRS-C": "after_deeplab",
        }[brs_mode]

        opt_functor = ScaleBiasOptimizer(
            prob_thresh=prob_thresh,
            with_flip=with_flip,
            optimizer_params=lbfgs_params_,
            **brs_opt_func_params
        )

        if isinstance(net, HRNetModel):
            FeaturePredictor = HRNetFeatureBRSPredictor
            insertion_mode = {"after_c4": "A", "after_aspp": "A", "after_deeplab": "C"}[
                insertion_mode
            ]
        else:
            FeaturePredictor = FeatureBRSPredictor

        predictor = FeaturePredictor(
            net,
            device,
            opt_functor=opt_functor,
            with_flip=with_flip,
            insertion_mode=insertion_mode,
            zoom_in=zoom_in,
            **predictor_params_
        )
    elif brs_mode == "RGB-BRS" or brs_mode == "DistMap-BRS":
        use_dmaps = brs_mode == "DistMap-BRS"

        predictor_params_.update(
            {
                "net_clicks_limit": 5,
            }
        )
        if predictor_params is not None:
            predictor_params_.update(predictor_params)

        opt_functor = InputOptimizer(
            prob_thresh=prob_thresh,
            with_flip=with_flip,
            optimizer_params=lbfgs_params_,
            **brs_opt_func_params
        )

        predictor = InputBRSPredictor(
            net,
            device,
            optimize_target="dmaps" if use_dmaps else "rgb",
            opt_functor=opt_functor,
            with_flip=with_flip,
            zoom_in=zoom_in,
            **predictor_params_
        )
    else:
        raise NotImplementedError

    return predictor