Spaces:
Runtime error
Runtime error
from isegm.inference.transforms import ZoomIn | |
from isegm.model.is_hrnet_model import HRNetModel | |
from .base import BasePredictor | |
from .brs import (FeatureBRSPredictor, HRNetFeatureBRSPredictor, | |
InputBRSPredictor) | |
from .brs_functors import InputOptimizer, ScaleBiasOptimizer | |
def get_predictor( | |
net, | |
brs_mode, | |
device, | |
prob_thresh=0.49, | |
with_flip=True, | |
zoom_in_params=dict(), | |
predictor_params=None, | |
brs_opt_func_params=None, | |
lbfgs_params=None, | |
): | |
lbfgs_params_ = { | |
"m": 20, | |
"factr": 0, | |
"pgtol": 1e-8, | |
"maxfun": 20, | |
} | |
predictor_params_ = {"optimize_after_n_clicks": 1} | |
if zoom_in_params is not None: | |
zoom_in = ZoomIn(**zoom_in_params) | |
else: | |
zoom_in = None | |
if lbfgs_params is not None: | |
lbfgs_params_.update(lbfgs_params) | |
lbfgs_params_["maxiter"] = 2 * lbfgs_params_["maxfun"] | |
if brs_opt_func_params is None: | |
brs_opt_func_params = dict() | |
if isinstance(net, (list, tuple)): | |
assert brs_mode == "NoBRS", "Multi-stage models support only NoBRS mode." | |
if brs_mode == "NoBRS": | |
if predictor_params is not None: | |
predictor_params_.update(predictor_params) | |
predictor = BasePredictor( | |
net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_ | |
) | |
elif brs_mode.startswith("f-BRS"): | |
predictor_params_.update( | |
{ | |
"net_clicks_limit": 8, | |
} | |
) | |
if predictor_params is not None: | |
predictor_params_.update(predictor_params) | |
insertion_mode = { | |
"f-BRS-A": "after_c4", | |
"f-BRS-B": "after_aspp", | |
"f-BRS-C": "after_deeplab", | |
}[brs_mode] | |
opt_functor = ScaleBiasOptimizer( | |
prob_thresh=prob_thresh, | |
with_flip=with_flip, | |
optimizer_params=lbfgs_params_, | |
**brs_opt_func_params | |
) | |
if isinstance(net, HRNetModel): | |
FeaturePredictor = HRNetFeatureBRSPredictor | |
insertion_mode = {"after_c4": "A", "after_aspp": "A", "after_deeplab": "C"}[ | |
insertion_mode | |
] | |
else: | |
FeaturePredictor = FeatureBRSPredictor | |
predictor = FeaturePredictor( | |
net, | |
device, | |
opt_functor=opt_functor, | |
with_flip=with_flip, | |
insertion_mode=insertion_mode, | |
zoom_in=zoom_in, | |
**predictor_params_ | |
) | |
elif brs_mode == "RGB-BRS" or brs_mode == "DistMap-BRS": | |
use_dmaps = brs_mode == "DistMap-BRS" | |
predictor_params_.update( | |
{ | |
"net_clicks_limit": 5, | |
} | |
) | |
if predictor_params is not None: | |
predictor_params_.update(predictor_params) | |
opt_functor = InputOptimizer( | |
prob_thresh=prob_thresh, | |
with_flip=with_flip, | |
optimizer_params=lbfgs_params_, | |
**brs_opt_func_params | |
) | |
predictor = InputBRSPredictor( | |
net, | |
device, | |
optimize_target="dmaps" if use_dmaps else "rgb", | |
opt_functor=opt_functor, | |
with_flip=with_flip, | |
zoom_in=zoom_in, | |
**predictor_params_ | |
) | |
else: | |
raise NotImplementedError | |
return predictor | |