Spaces:
Runtime error
Runtime error
File size: 3,383 Bytes
2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
from isegm.inference.transforms import ZoomIn
from isegm.model.is_hrnet_model import HRNetModel
from .base import BasePredictor
from .brs import (FeatureBRSPredictor, HRNetFeatureBRSPredictor,
InputBRSPredictor)
from .brs_functors import InputOptimizer, ScaleBiasOptimizer
def get_predictor(
net,
brs_mode,
device,
prob_thresh=0.49,
with_flip=True,
zoom_in_params=dict(),
predictor_params=None,
brs_opt_func_params=None,
lbfgs_params=None,
):
lbfgs_params_ = {
"m": 20,
"factr": 0,
"pgtol": 1e-8,
"maxfun": 20,
}
predictor_params_ = {"optimize_after_n_clicks": 1}
if zoom_in_params is not None:
zoom_in = ZoomIn(**zoom_in_params)
else:
zoom_in = None
if lbfgs_params is not None:
lbfgs_params_.update(lbfgs_params)
lbfgs_params_["maxiter"] = 2 * lbfgs_params_["maxfun"]
if brs_opt_func_params is None:
brs_opt_func_params = dict()
if isinstance(net, (list, tuple)):
assert brs_mode == "NoBRS", "Multi-stage models support only NoBRS mode."
if brs_mode == "NoBRS":
if predictor_params is not None:
predictor_params_.update(predictor_params)
predictor = BasePredictor(
net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_
)
elif brs_mode.startswith("f-BRS"):
predictor_params_.update(
{
"net_clicks_limit": 8,
}
)
if predictor_params is not None:
predictor_params_.update(predictor_params)
insertion_mode = {
"f-BRS-A": "after_c4",
"f-BRS-B": "after_aspp",
"f-BRS-C": "after_deeplab",
}[brs_mode]
opt_functor = ScaleBiasOptimizer(
prob_thresh=prob_thresh,
with_flip=with_flip,
optimizer_params=lbfgs_params_,
**brs_opt_func_params
)
if isinstance(net, HRNetModel):
FeaturePredictor = HRNetFeatureBRSPredictor
insertion_mode = {"after_c4": "A", "after_aspp": "A", "after_deeplab": "C"}[
insertion_mode
]
else:
FeaturePredictor = FeatureBRSPredictor
predictor = FeaturePredictor(
net,
device,
opt_functor=opt_functor,
with_flip=with_flip,
insertion_mode=insertion_mode,
zoom_in=zoom_in,
**predictor_params_
)
elif brs_mode == "RGB-BRS" or brs_mode == "DistMap-BRS":
use_dmaps = brs_mode == "DistMap-BRS"
predictor_params_.update(
{
"net_clicks_limit": 5,
}
)
if predictor_params is not None:
predictor_params_.update(predictor_params)
opt_functor = InputOptimizer(
prob_thresh=prob_thresh,
with_flip=with_flip,
optimizer_params=lbfgs_params_,
**brs_opt_func_params
)
predictor = InputBRSPredictor(
net,
device,
optimize_target="dmaps" if use_dmaps else "rgb",
opt_functor=opt_functor,
with_flip=with_flip,
zoom_in=zoom_in,
**predictor_params_
)
else:
raise NotImplementedError
return predictor
|