sunshangquan
commit from ssq
f518bf0
import importlib
from os import path as osp
from basicsr.utils import scandir
# automatically scan and import arch modules
# scan all the files under the 'archs' folder and collect files ending with
# '_arch.py'
arch_folder = osp.dirname(osp.abspath(__file__))
arch_filenames = [
osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
if v.endswith('_arch.py')
]
# import all the arch modules
_arch_modules = [
importlib.import_module(f'basicsr.models.archs.{file_name}')
for file_name in arch_filenames
]
def dynamic_instantiation(modules, cls_type, opt):
"""Dynamically instantiate class.
Args:
modules (list[importlib modules]): List of modules from importlib
files.
cls_type (str): Class type.
opt (dict): Class initialization kwargs.
Returns:
class: Instantiated class.
"""
for module in modules:
cls_ = getattr(module, cls_type, None)
if cls_ is not None:
break
if cls_ is None:
raise ValueError(f'{cls_type} is not found.')
return cls_(**opt)
def define_network(opt):
network_type = opt.pop('type')
net = dynamic_instantiation(_arch_modules, network_type, opt)
return net