|
|
|
from uniperceiver.utils.registry import Registry |
|
|
|
PREDICTOR_REGISTRY = Registry("PREDICTOR") |
|
PREDICTOR_REGISTRY.__doc__ = """ |
|
Registry for PREDICTOR |
|
""" |
|
|
|
def build_predictor(cfg): |
|
predictor = PREDICTOR_REGISTRY.get(cfg.MODEL.PREDICTOR)(cfg) if len(cfg.MODEL.PREDICTOR) > 0 else None |
|
return predictor |
|
|
|
def build_v_predictor(cfg): |
|
predictor = PREDICTOR_REGISTRY.get(cfg.MODEL.V_PREDICTOR)(cfg) if len(cfg.MODEL.V_PREDICTOR) > 0 else None |
|
return predictor |
|
|
|
def build_predictor_with_name(cfg, name): |
|
predictor = PREDICTOR_REGISTRY.get(name)(cfg) if len(name) > 0 else None |
|
return predictor |
|
|
|
def add_predictor_config(cfg, tmp_cfg): |
|
if len(tmp_cfg.MODEL.PREDICTOR) > 0: |
|
PREDICTOR_REGISTRY.get(tmp_cfg.MODEL.PREDICTOR).add_config(cfg) |