|
import inspect
|
|
import platform
|
|
|
|
from .registry import PLUGIN_LAYERS
|
|
|
|
if platform.system() == 'Windows':
|
|
import regex as re
|
|
else:
|
|
import re
|
|
|
|
|
|
def infer_abbr(class_type):
|
|
"""Infer abbreviation from the class name.
|
|
|
|
This method will infer the abbreviation to map class types to
|
|
abbreviations.
|
|
|
|
Rule 1: If the class has the property "abbr", return the property.
|
|
Rule 2: Otherwise, the abbreviation falls back to snake case of class
|
|
name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
|
|
|
|
Args:
|
|
class_type (type): The norm layer type.
|
|
|
|
Returns:
|
|
str: The inferred abbreviation.
|
|
"""
|
|
|
|
def camel2snack(word):
|
|
"""Convert camel case word into snack case.
|
|
|
|
Modified from `inflection lib
|
|
<https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
|
|
|
|
Example::
|
|
|
|
>>> camel2snack("FancyBlock")
|
|
'fancy_block'
|
|
"""
|
|
|
|
word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
|
|
word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
|
|
word = word.replace('-', '_')
|
|
return word.lower()
|
|
|
|
if not inspect.isclass(class_type):
|
|
raise TypeError(
|
|
f'class_type must be a type, but got {type(class_type)}')
|
|
if hasattr(class_type, '_abbr_'):
|
|
return class_type._abbr_
|
|
else:
|
|
return camel2snack(class_type.__name__)
|
|
|
|
|
|
def build_plugin_layer(cfg, postfix='', **kwargs):
|
|
"""Build plugin layer.
|
|
|
|
Args:
|
|
cfg (None or dict): cfg should contain:
|
|
type (str): identify plugin layer type.
|
|
layer args: args needed to instantiate a plugin layer.
|
|
postfix (int, str): appended into norm abbreviation to
|
|
create named layer. Default: ''.
|
|
|
|
Returns:
|
|
tuple[str, nn.Module]:
|
|
name (str): abbreviation + postfix
|
|
layer (nn.Module): created plugin layer
|
|
"""
|
|
if not isinstance(cfg, dict):
|
|
raise TypeError('cfg must be a dict')
|
|
if 'type' not in cfg:
|
|
raise KeyError('the cfg dict must contain the key "type"')
|
|
cfg_ = cfg.copy()
|
|
|
|
layer_type = cfg_.pop('type')
|
|
if layer_type not in PLUGIN_LAYERS:
|
|
raise KeyError(f'Unrecognized plugin type {layer_type}')
|
|
|
|
plugin_layer = PLUGIN_LAYERS.get(layer_type)
|
|
abbr = infer_abbr(plugin_layer)
|
|
|
|
assert isinstance(postfix, (int, str))
|
|
name = abbr + str(postfix)
|
|
|
|
layer = plugin_layer(**kwargs, **cfg_)
|
|
|
|
return name, layer
|
|
|