Spaces:
Runtime error
Runtime error
from .base import BasePredictor | |
from .brs import InputBRSPredictor, FeatureBRSPredictor, HRNetFeatureBRSPredictor | |
from .brs_functors import InputOptimizer, ScaleBiasOptimizer | |
from isegm.inference.transforms import ZoomIn | |
from isegm.model.is_hrnet_model import HRNetModel | |
def get_predictor(net, brs_mode, device, | |
prob_thresh=0.49, | |
with_flip=True, | |
zoom_in_params=dict(), | |
predictor_params=None, | |
brs_opt_func_params=None, | |
lbfgs_params=None): | |
lbfgs_params_ = { | |
'm': 20, | |
'factr': 0, | |
'pgtol': 1e-8, | |
'maxfun': 20, | |
} | |
predictor_params_ = { | |
'optimize_after_n_clicks': 1 | |
} | |
if zoom_in_params is not None: | |
zoom_in = ZoomIn(**zoom_in_params) | |
else: | |
zoom_in = None | |
if lbfgs_params is not None: | |
lbfgs_params_.update(lbfgs_params) | |
lbfgs_params_['maxiter'] = 2 * lbfgs_params_['maxfun'] | |
if brs_opt_func_params is None: | |
brs_opt_func_params = dict() | |
if isinstance(net, (list, tuple)): | |
assert brs_mode == 'NoBRS', "Multi-stage models support only NoBRS mode." | |
if brs_mode == 'NoBRS': | |
if predictor_params is not None: | |
predictor_params_.update(predictor_params) | |
predictor = BasePredictor(net, device, zoom_in=zoom_in, with_flip=with_flip, **predictor_params_) | |
elif brs_mode.startswith('f-BRS'): | |
predictor_params_.update({ | |
'net_clicks_limit': 8, | |
}) | |
if predictor_params is not None: | |
predictor_params_.update(predictor_params) | |
insertion_mode = { | |
'f-BRS-A': 'after_c4', | |
'f-BRS-B': 'after_aspp', | |
'f-BRS-C': 'after_deeplab' | |
}[brs_mode] | |
opt_functor = ScaleBiasOptimizer(prob_thresh=prob_thresh, | |
with_flip=with_flip, | |
optimizer_params=lbfgs_params_, | |
**brs_opt_func_params) | |
if isinstance(net, HRNetModel): | |
FeaturePredictor = HRNetFeatureBRSPredictor | |
insertion_mode = {'after_c4': 'A', 'after_aspp': 'A', 'after_deeplab': 'C'}[insertion_mode] | |
else: | |
FeaturePredictor = FeatureBRSPredictor | |
predictor = FeaturePredictor(net, device, | |
opt_functor=opt_functor, | |
with_flip=with_flip, | |
insertion_mode=insertion_mode, | |
zoom_in=zoom_in, | |
**predictor_params_) | |
elif brs_mode == 'RGB-BRS' or brs_mode == 'DistMap-BRS': | |
use_dmaps = brs_mode == 'DistMap-BRS' | |
predictor_params_.update({ | |
'net_clicks_limit': 5, | |
}) | |
if predictor_params is not None: | |
predictor_params_.update(predictor_params) | |
opt_functor = InputOptimizer(prob_thresh=prob_thresh, | |
with_flip=with_flip, | |
optimizer_params=lbfgs_params_, | |
**brs_opt_func_params) | |
predictor = InputBRSPredictor(net, device, | |
optimize_target='dmaps' if use_dmaps else 'rgb', | |
opt_functor=opt_functor, | |
with_flip=with_flip, | |
zoom_in=zoom_in, | |
**predictor_params_) | |
else: | |
raise NotImplementedError | |
return predictor | |