Spaces:
Sleeping
Sleeping
import inspect | |
import os | |
from collections import OrderedDict | |
from typing import Optional, Iterable, Callable | |
_innest_error = True | |
_DI_ENGINE_REG_TRACE_IS_ON = os.environ.get('DIENGINEREGTRACE', 'OFF').upper() == 'ON' | |
class Registry(dict): | |
""" | |
Overview: | |
A helper class for managing registering modules, it extends a dictionary | |
and provides a register functions. | |
Interfaces: | |
``__init__``, ``register``, ``get``, ``build``, ``query``, ``query_details`` | |
Examples: | |
creeting a registry: | |
>>> some_registry = Registry({"default": default_module}) | |
There're two ways of registering new modules: | |
1): normal way is just calling register function: | |
>>> def foo(): | |
>>> ... | |
some_registry.register("foo_module", foo) | |
2): used as decorator when declaring the module: | |
>>> @some_registry.register("foo_module") | |
>>> @some_registry.register("foo_modeul_nickname") | |
>>> def foo(): | |
>>> ... | |
Access of module is just like using a dictionary, eg: | |
>>> f = some_registry["foo_module"] | |
""" | |
def __init__(self, *args, **kwargs) -> None: | |
""" | |
Overview: | |
Initialize the Registry object. | |
Arguments: | |
- args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ | |
dict. | |
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ | |
dict. | |
""" | |
super(Registry, self).__init__(*args, **kwargs) | |
self.__trace__ = dict() | |
def register( | |
self, | |
module_name: Optional[str] = None, | |
module: Optional[Callable] = None, | |
force_overwrite: bool = False | |
) -> Callable: | |
""" | |
Overview: | |
Register the module. | |
Arguments: | |
- module_name (:obj:`Optional[str]`): The name of the module. | |
- module (:obj:`Optional[Callable]`): The module to be registered. | |
- force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name. | |
""" | |
if _DI_ENGINE_REG_TRACE_IS_ON: | |
frame = inspect.stack()[1][0] | |
info = inspect.getframeinfo(frame) | |
filename = info.filename | |
lineno = info.lineno | |
# used as function call | |
if module is not None: | |
assert module_name is not None | |
Registry._register_generic(self, module_name, module, force_overwrite) | |
if _DI_ENGINE_REG_TRACE_IS_ON: | |
self.__trace__[module_name] = (filename, lineno) | |
return | |
# used as decorator | |
def register_fn(fn: Callable) -> Callable: | |
if module_name is None: | |
name = fn.__name__ | |
else: | |
name = module_name | |
Registry._register_generic(self, name, fn, force_overwrite) | |
if _DI_ENGINE_REG_TRACE_IS_ON: | |
self.__trace__[name] = (filename, lineno) | |
return fn | |
return register_fn | |
def _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) -> None: | |
""" | |
Overview: | |
Register the module. | |
Arguments: | |
- module_dict (:obj:`dict`): The dict to store the module. | |
- module_name (:obj:`str`): The name of the module. | |
- module (:obj:`Callable`): The module to be registered. | |
- force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name. | |
""" | |
if not force_overwrite: | |
assert module_name not in module_dict, module_name | |
module_dict[module_name] = module | |
def get(self, module_name: str) -> Callable: | |
""" | |
Overview: | |
Get the module. | |
Arguments: | |
- module_name (:obj:`str`): The name of the module. | |
""" | |
return self[module_name] | |
def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object: | |
""" | |
Overview: | |
Build the object. | |
Arguments: | |
- obj_type (:obj:`str`): The type of the object. | |
- obj_args (:obj:`Tuple`): The arguments passed to the object. | |
- obj_kwargs (:obj:`Dict`): The keyword arguments passed to the object. | |
""" | |
try: | |
build_fn = self[obj_type] | |
return build_fn(*obj_args, **obj_kwargs) | |
except Exception as e: | |
# get build_fn fail | |
if isinstance(e, KeyError): | |
raise KeyError("not support buildable-object type: {}".format(obj_type)) | |
# build_fn execution fail | |
global _innest_error | |
if _innest_error: | |
argspec = inspect.getfullargspec(build_fn) | |
message = 'for {}(alias={})'.format(build_fn, obj_type) | |
message += '\nExpected args are:{}'.format(argspec) | |
message += '\nGiven args are:{}/{}'.format(argspec, obj_kwargs.keys()) | |
message += '\nGiven args details are:{}/{}'.format(argspec, obj_kwargs) | |
_innest_error = False | |
raise e | |
def query(self) -> Iterable: | |
""" | |
Overview: | |
all registered module names. | |
""" | |
return self.keys() | |
def query_details(self, aliases: Optional[Iterable] = None) -> OrderedDict: | |
""" | |
Overview: | |
Get the details of the registered modules. | |
Arguments: | |
- aliases (:obj:`Optional[Iterable]`): The aliases of the modules. | |
""" | |
assert _DI_ENGINE_REG_TRACE_IS_ON, "please exec 'export DIENGINEREGTRACE=ON' first" | |
if aliases is None: | |
aliases = self.keys() | |
return OrderedDict((alias, self.__trace__[alias]) for alias in aliases) | |