herrius's picture
Upload 259 files
32b542e
from uniperceiver.utils.registry import Registry
from torch import ModuleDict
ENCODER_REGISTRY = Registry("ENCODER")
ENCODER_REGISTRY.__doc__ = """
Registry for encoder
"""
def build_encoder(cfg):
encoder = ENCODER_REGISTRY.get(cfg.MODEL.ENCODER)(cfg) if len(cfg.MODEL.ENCODER) > 0 else None
return encoder
def build_unfused_encoders(cfg):
from uniperceiver.config import CfgNode
encoder_dict = {}
for config in cfg.ENCODERS:
tmg_config = CfgNode(config)
encoder = ENCODER_REGISTRY.get(
tmg_config.TYPE)(tmg_config, cfg) if len(tmg_config.TYPE) > 0 else None
encoder_dict[tmg_config.NAME] = encoder
return encoder_dict
def add_encoder_config(cfg, tmp_cfg):
if len(tmp_cfg.MODEL.ENCODER) > 0:
ENCODER_REGISTRY.get(tmp_cfg.MODEL.ENCODER).add_config(cfg)