import importlib.util import os import sys import warnings from pathlib import Path def import_model_module(file_path: os.PathLike): module_name = str(Path(file_path).relative_to(os.getcwd())).replace(os.path.sep, ".") spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) sys.modules[module_name] = module spec.loader.exec_module(module) return module models = {} file_dir = Path(__file__).parents[0] for path in file_dir.glob("*"): if path.is_dir(): model_file_path = path / "model.py" if not model_file_path.exists(): continue module = import_model_module(model_file_path) name_key = "name" get_model_key = "get_model" name = getattr(module, name_key, None) get_model = getattr(module, get_model_key, None) def check_attr_exists(attr_name, attr): if attr is None: warnings.warn( f"Module {model_file_path} should define attribute '{attr_name}'" ) return False return True def check_attr_type(attr_name, attr, type): if isinstance(attr, type): return True warnings.warn( f"'{attr_name}' should be of type {type}, but it is of type {type(attr)}" ) return False def check_attr_callable(attr_name, attr): if callable(attr): return True warnings.warn(f"'{attr_name}' should be callable") return False if not check_attr_exists(name_key, name): continue if not check_attr_exists(get_model_key, get_model): continue if not check_attr_type(name_key, name, str): continue if not check_attr_callable(get_model_key, get_model): continue models[name] = get_model def get_model(name: str): if name not in models: raise KeyError(f"No model with name {name}") return models[name]() def get_all_model_names(): return list(models.keys())