from uniperceiver.utils.registry import Registry from torch import ModuleDict ENCODER_REGISTRY = Registry("ENCODER") ENCODER_REGISTRY.__doc__ = """ Registry for encoder """ def build_encoder(cfg): encoder = ENCODER_REGISTRY.get(cfg.MODEL.ENCODER)(cfg) if len(cfg.MODEL.ENCODER) > 0 else None return encoder def build_unfused_encoders(cfg): from uniperceiver.config import CfgNode encoder_dict = {} for config in cfg.ENCODERS: tmg_config = CfgNode(config) encoder = ENCODER_REGISTRY.get( tmg_config.TYPE)(tmg_config, cfg) if len(tmg_config.TYPE) > 0 else None encoder_dict[tmg_config.NAME] = encoder return encoder_dict def add_encoder_config(cfg, tmp_cfg): if len(tmp_cfg.MODEL.ENCODER) > 0: ENCODER_REGISTRY.get(tmp_cfg.MODEL.ENCODER).add_config(cfg)