|
from uniperceiver.utils.registry import Registry |
|
from torch import ModuleDict |
|
|
|
|
|
ENCODER_REGISTRY = Registry("ENCODER") |
|
ENCODER_REGISTRY.__doc__ = """ |
|
Registry for encoder |
|
""" |
|
|
|
def build_encoder(cfg): |
|
encoder = ENCODER_REGISTRY.get(cfg.MODEL.ENCODER)(cfg) if len(cfg.MODEL.ENCODER) > 0 else None |
|
return encoder |
|
|
|
def build_unfused_encoders(cfg): |
|
from uniperceiver.config import CfgNode |
|
encoder_dict = {} |
|
for config in cfg.ENCODERS: |
|
tmg_config = CfgNode(config) |
|
encoder = ENCODER_REGISTRY.get( |
|
tmg_config.TYPE)(tmg_config, cfg) if len(tmg_config.TYPE) > 0 else None |
|
encoder_dict[tmg_config.NAME] = encoder |
|
|
|
return encoder_dict |
|
|
|
|
|
def add_encoder_config(cfg, tmp_cfg): |
|
if len(tmp_cfg.MODEL.ENCODER) > 0: |
|
ENCODER_REGISTRY.get(tmp_cfg.MODEL.ENCODER).add_config(cfg) |