|
|
|
import inspect
|
|
import warnings
|
|
from functools import partial
|
|
|
|
from .misc import is_seq_of
|
|
|
|
|
|
def build_from_cfg(cfg, registry, default_args=None):
|
|
"""Build a module from config dict.
|
|
|
|
Args:
|
|
cfg (dict): Config dict. It should at least contain the key "type".
|
|
registry (:obj:`Registry`): The registry to search the type from.
|
|
default_args (dict, optional): Default initialization arguments.
|
|
|
|
Returns:
|
|
object: The constructed object.
|
|
"""
|
|
if not isinstance(cfg, dict):
|
|
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
|
if 'type' not in cfg:
|
|
if default_args is None or 'type' not in default_args:
|
|
raise KeyError(
|
|
'`cfg` or `default_args` must contain the key "type", '
|
|
f'but got {cfg}\n{default_args}')
|
|
if not isinstance(registry, Registry):
|
|
raise TypeError('registry must be an mmcv.Registry object, '
|
|
f'but got {type(registry)}')
|
|
if not (isinstance(default_args, dict) or default_args is None):
|
|
raise TypeError('default_args must be a dict or None, '
|
|
f'but got {type(default_args)}')
|
|
|
|
args = cfg.copy()
|
|
|
|
if default_args is not None:
|
|
for name, value in default_args.items():
|
|
args.setdefault(name, value)
|
|
|
|
obj_type = args.pop('type')
|
|
if isinstance(obj_type, str):
|
|
obj_cls = registry.get(obj_type)
|
|
if obj_cls is None:
|
|
raise KeyError(
|
|
f'{obj_type} is not in the {registry.name} registry')
|
|
elif inspect.isclass(obj_type):
|
|
obj_cls = obj_type
|
|
else:
|
|
raise TypeError(
|
|
f'type must be a str or valid type, but got {type(obj_type)}')
|
|
try:
|
|
return obj_cls(**args)
|
|
except Exception as e:
|
|
|
|
raise type(e)(f'{obj_cls.__name__}: {e}')
|
|
|
|
|
|
class Registry:
|
|
"""A registry to map strings to classes.
|
|
|
|
Registered object could be built from registry.
|
|
Example:
|
|
>>> MODELS = Registry('models')
|
|
>>> @MODELS.register_module()
|
|
>>> class ResNet:
|
|
>>> pass
|
|
>>> resnet = MODELS.build(dict(type='ResNet'))
|
|
|
|
Please refer to
|
|
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
|
advanced usage.
|
|
|
|
Args:
|
|
name (str): Registry name.
|
|
build_func(func, optional): Build function to construct instance from
|
|
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
|
``build_func`` is specified. If ``parent`` is specified and
|
|
``build_func`` is not given, ``build_func`` will be inherited
|
|
from ``parent``. Default: None.
|
|
parent (Registry, optional): Parent registry. The class registered in
|
|
children registry could be built from parent. Default: None.
|
|
scope (str, optional): The scope of registry. It is the key to search
|
|
for children registry. If not specified, scope will be the name of
|
|
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self, name, build_func=None, parent=None, scope=None):
|
|
self._name = name
|
|
self._module_dict = dict()
|
|
self._children = dict()
|
|
self._scope = self.infer_scope() if scope is None else scope
|
|
|
|
|
|
|
|
|
|
|
|
if build_func is None:
|
|
if parent is not None:
|
|
self.build_func = parent.build_func
|
|
else:
|
|
self.build_func = build_from_cfg
|
|
else:
|
|
self.build_func = build_func
|
|
if parent is not None:
|
|
assert isinstance(parent, Registry)
|
|
parent._add_children(self)
|
|
self.parent = parent
|
|
else:
|
|
self.parent = None
|
|
|
|
def __len__(self):
|
|
return len(self._module_dict)
|
|
|
|
def __contains__(self, key):
|
|
return self.get(key) is not None
|
|
|
|
def __repr__(self):
|
|
format_str = self.__class__.__name__ + \
|
|
f'(name={self._name}, ' \
|
|
f'items={self._module_dict})'
|
|
return format_str
|
|
|
|
@staticmethod
|
|
def infer_scope():
|
|
"""Infer the scope of registry.
|
|
|
|
The name of the package where registry is defined will be returned.
|
|
|
|
Example:
|
|
# in mmdet/models/backbone/resnet.py
|
|
>>> MODELS = Registry('models')
|
|
>>> @MODELS.register_module()
|
|
>>> class ResNet:
|
|
>>> pass
|
|
The scope of ``ResNet`` will be ``mmdet``.
|
|
|
|
|
|
Returns:
|
|
scope (str): The inferred scope name.
|
|
"""
|
|
|
|
|
|
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
|
|
split_filename = filename.split('.')
|
|
return split_filename[0]
|
|
|
|
@staticmethod
|
|
def split_scope_key(key):
|
|
"""Split scope and key.
|
|
|
|
The first scope will be split from key.
|
|
|
|
Examples:
|
|
>>> Registry.split_scope_key('mmdet.ResNet')
|
|
'mmdet', 'ResNet'
|
|
>>> Registry.split_scope_key('ResNet')
|
|
None, 'ResNet'
|
|
|
|
Return:
|
|
scope (str, None): The first scope.
|
|
key (str): The remaining key.
|
|
"""
|
|
split_index = key.find('.')
|
|
if split_index != -1:
|
|
return key[:split_index], key[split_index + 1:]
|
|
else:
|
|
return None, key
|
|
|
|
@property
|
|
def name(self):
|
|
return self._name
|
|
|
|
@property
|
|
def scope(self):
|
|
return self._scope
|
|
|
|
@property
|
|
def module_dict(self):
|
|
return self._module_dict
|
|
|
|
@property
|
|
def children(self):
|
|
return self._children
|
|
|
|
def get(self, key):
|
|
"""Get the registry record.
|
|
|
|
Args:
|
|
key (str): The class name in string format.
|
|
|
|
Returns:
|
|
class: The corresponding class.
|
|
"""
|
|
scope, real_key = self.split_scope_key(key)
|
|
if scope is None or scope == self._scope:
|
|
|
|
if real_key in self._module_dict:
|
|
return self._module_dict[real_key]
|
|
else:
|
|
|
|
if scope in self._children:
|
|
return self._children[scope].get(real_key)
|
|
else:
|
|
|
|
parent = self.parent
|
|
while parent.parent is not None:
|
|
parent = parent.parent
|
|
return parent.get(key)
|
|
|
|
def build(self, *args, **kwargs):
|
|
return self.build_func(*args, **kwargs, registry=self)
|
|
|
|
def _add_children(self, registry):
|
|
"""Add children for a registry.
|
|
|
|
The ``registry`` will be added as children based on its scope.
|
|
The parent registry could build objects from children registry.
|
|
|
|
Example:
|
|
>>> models = Registry('models')
|
|
>>> mmdet_models = Registry('models', parent=models)
|
|
>>> @mmdet_models.register_module()
|
|
>>> class ResNet:
|
|
>>> pass
|
|
>>> resnet = models.build(dict(type='mmdet.ResNet'))
|
|
"""
|
|
|
|
assert isinstance(registry, Registry)
|
|
assert registry.scope is not None
|
|
assert registry.scope not in self.children, \
|
|
f'scope {registry.scope} exists in {self.name} registry'
|
|
self.children[registry.scope] = registry
|
|
|
|
def _register_module(self, module_class, module_name=None, force=False):
|
|
if not inspect.isclass(module_class):
|
|
raise TypeError('module must be a class, '
|
|
f'but got {type(module_class)}')
|
|
|
|
if module_name is None:
|
|
module_name = module_class.__name__
|
|
if isinstance(module_name, str):
|
|
module_name = [module_name]
|
|
for name in module_name:
|
|
if not force and name in self._module_dict:
|
|
raise KeyError(f'{name} is already registered '
|
|
f'in {self.name}')
|
|
self._module_dict[name] = module_class
|
|
|
|
def deprecated_register_module(self, cls=None, force=False):
|
|
warnings.warn(
|
|
'The old API of register_module(module, force=False) '
|
|
'is deprecated and will be removed, please use the new API '
|
|
'register_module(name=None, force=False, module=None) instead.')
|
|
if cls is None:
|
|
return partial(self.deprecated_register_module, force=force)
|
|
self._register_module(cls, force=force)
|
|
return cls
|
|
|
|
def register_module(self, name=None, force=False, module=None):
|
|
"""Register a module.
|
|
|
|
A record will be added to `self._module_dict`, whose key is the class
|
|
name or the specified name, and value is the class itself.
|
|
It can be used as a decorator or a normal function.
|
|
|
|
Example:
|
|
>>> backbones = Registry('backbone')
|
|
>>> @backbones.register_module()
|
|
>>> class ResNet:
|
|
>>> pass
|
|
|
|
>>> backbones = Registry('backbone')
|
|
>>> @backbones.register_module(name='mnet')
|
|
>>> class MobileNet:
|
|
>>> pass
|
|
|
|
>>> backbones = Registry('backbone')
|
|
>>> class ResNet:
|
|
>>> pass
|
|
>>> backbones.register_module(ResNet)
|
|
|
|
Args:
|
|
name (str | None): The module name to be registered. If not
|
|
specified, the class name will be used.
|
|
force (bool, optional): Whether to override an existing class with
|
|
the same name. Default: False.
|
|
module (type): Module class to be registered.
|
|
"""
|
|
if not isinstance(force, bool):
|
|
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
|
|
|
|
|
if isinstance(name, type):
|
|
return self.deprecated_register_module(name, force=force)
|
|
|
|
|
|
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
|
|
raise TypeError(
|
|
'name must be either of None, an instance of str or a sequence'
|
|
f' of str, but got {type(name)}')
|
|
|
|
|
|
if module is not None:
|
|
self._register_module(
|
|
module_class=module, module_name=name, force=force)
|
|
return module
|
|
|
|
|
|
def _register(cls):
|
|
self._register_module(
|
|
module_class=cls, module_name=name, force=force)
|
|
return cls
|
|
|
|
return _register
|
|
|