File size: 831 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from uniperceiver.utils.registry import Registry
from torch import ModuleDict


ENCODER_REGISTRY = Registry("ENCODER")
ENCODER_REGISTRY.__doc__ = """
Registry for encoder
"""

def build_encoder(cfg):
    encoder = ENCODER_REGISTRY.get(cfg.MODEL.ENCODER)(cfg) if len(cfg.MODEL.ENCODER) > 0 else None
    return encoder

def build_unfused_encoders(cfg):
    from uniperceiver.config import  CfgNode
    encoder_dict = {}
    for config in cfg.ENCODERS:
        tmg_config = CfgNode(config)
        encoder = ENCODER_REGISTRY.get(
            tmg_config.TYPE)(tmg_config, cfg) if len(tmg_config.TYPE) > 0 else None
        encoder_dict[tmg_config.NAME] = encoder

    return encoder_dict


def add_encoder_config(cfg, tmp_cfg):
    if len(tmp_cfg.MODEL.ENCODER) > 0:
        ENCODER_REGISTRY.get(tmp_cfg.MODEL.ENCODER).add_config(cfg)