|
import warnings |
|
|
|
from mmcv.utils import Registry, build_from_cfg |
|
from torch import nn |
|
|
|
BACKBONES = Registry('backbone') |
|
NECKS = Registry('neck') |
|
HEADS = Registry('head') |
|
LOSSES = Registry('loss') |
|
SEGMENTORS = Registry('segmentor') |
|
|
|
|
|
def build(cfg, registry, default_args=None): |
|
"""Build a module. |
|
|
|
Args: |
|
cfg (dict, list[dict]): The config of modules, is is either a dict |
|
or a list of configs. |
|
registry (:obj:`Registry`): A registry the module belongs to. |
|
default_args (dict, optional): Default arguments to build the module. |
|
Defaults to None. |
|
|
|
Returns: |
|
nn.Module: A built nn module. |
|
""" |
|
|
|
if isinstance(cfg, list): |
|
modules = [ |
|
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg |
|
] |
|
return nn.Sequential(*modules) |
|
else: |
|
return build_from_cfg(cfg, registry, default_args) |
|
|
|
|
|
def build_backbone(cfg): |
|
"""Build backbone.""" |
|
return build(cfg, BACKBONES) |
|
|
|
|
|
def build_neck(cfg): |
|
"""Build neck.""" |
|
return build(cfg, NECKS) |
|
|
|
|
|
def build_head(cfg): |
|
"""Build head.""" |
|
return build(cfg, HEADS) |
|
|
|
|
|
def build_loss(cfg): |
|
"""Build loss.""" |
|
return build(cfg, LOSSES) |
|
|
|
|
|
def build_segmentor(cfg, train_cfg=None, test_cfg=None): |
|
"""Build segmentor.""" |
|
if train_cfg is not None or test_cfg is not None: |
|
warnings.warn( |
|
'train_cfg and test_cfg is deprecated, ' |
|
'please specify them in model', UserWarning) |
|
assert cfg.get('train_cfg') is None or train_cfg is None, \ |
|
'train_cfg specified in both outer field and model field ' |
|
assert cfg.get('test_cfg') is None or test_cfg is None, \ |
|
'test_cfg specified in both outer field and model field ' |
|
return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) |
|
|