|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Factory to provide model configs.""" |
|
|
|
from official.modeling.hyperparams import params_dict |
|
from official.vision.detection.configs import maskrcnn_config |
|
from official.vision.detection.configs import retinanet_config |
|
from official.vision.detection.configs import shapemask_config |
|
|
|
|
|
def config_generator(model): |
|
"""Model function generator.""" |
|
if model == 'retinanet': |
|
default_config = retinanet_config.RETINANET_CFG |
|
restrictions = retinanet_config.RETINANET_RESTRICTIONS |
|
elif model == 'mask_rcnn': |
|
default_config = maskrcnn_config.MASKRCNN_CFG |
|
restrictions = maskrcnn_config.MASKRCNN_RESTRICTIONS |
|
elif model == 'shapemask': |
|
default_config = shapemask_config.SHAPEMASK_CFG |
|
restrictions = shapemask_config.SHAPEMASK_RESTRICTIONS |
|
else: |
|
raise ValueError('Model %s is not supported.' % model) |
|
|
|
return params_dict.ParamsDict(default_config, restrictions) |
|
|