import importlib from functools import partial from pathlib import Path from typing import Any, Callable, Dict, List CATEGORIES = [ "prompt", "llm", "node", "worker", "tool", "encoder", "connector", "component", ] class Registry: """Class for module registration and retrieval.""" def __init__(self): # Initializes a mapping for different categories of modules. self.mapping = {key: {} for key in CATEGORIES} def __getattr__(self, name: str) -> Callable: if name.startswith(("register_", "get_")): prefix, category = name.split("_", 1) if category in CATEGORIES: if prefix == "register": return partial(self.register, category) elif prefix == "get": return partial(self.get, category) raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{name}'" ) def _register(self, category: str, name: str = None): """ Registers a module under a specific category. :param category: The category to register the module under. :param name: The name to register the module as. """ def wrap(module): nonlocal name name = name or module.__name__ if name in self.mapping[category]: raise ValueError( f"Module {name} [{self.mapping[category].get(name)}] already registered in category {category}. Please use a different class name." ) self.mapping.setdefault(category, {})[name] = module return module return wrap def _get(self, category: str, name: str): """ Retrieves a module from a specified category. :param category: The category to search in. :param name: The name of the module to retrieve. :raises KeyError: If the module is not found. """ try: return self.mapping[category][name] except KeyError: raise KeyError(f"Module {name} not found in category {category}") def register(self, category: str, name: str = None): """ Registers a module under a general category. :param category: The category to register the module under. :param name: Optional name to register the module as. """ return self._register(category, name) def get(self, category: str, name: str): """ Retrieves a module from a general category. :param category: The category to search in. :param name: The name of the module to retrieve. """ return self._get(category, name) def import_module(self, project_path: List[str] | str = None): """Import modules from default paths and optional project paths. Args: project_path: Optional path or list of paths to import modules from """ # Handle default paths root_path = Path(__file__).parents[1] default_path = [ root_path.joinpath("models"), root_path.joinpath("tool_system"), root_path.joinpath("services"), root_path.joinpath("memories"), root_path.joinpath("advanced_components"), root_path.joinpath("clients"), ] for path in default_path: for module in path.rglob("*.[ps][yo]"): if module.name == "workflow.py": continue module = str(module) if "__init__" in module or "base.py" in module or "entry.py" in module: continue module = "omagent_core" + module.rsplit("omagent_core", 1)[1].rsplit( ".", 1 )[0].replace("/", ".") importlib.import_module(module) # Handle project paths if project_path: if isinstance(project_path, (str, Path)): project_path = [project_path] for path in project_path: path = Path(path).absolute() project_root = path.parent for module in path.rglob("*.[ps][yo]"): module = str(module) if "__init__" in module: continue module = ( module.replace(str(project_root) + "/", "") .rsplit(".", 1)[0] .replace("/", ".") ) importlib.import_module(module) # Instantiate registry registry = Registry() if __name__ == "__main__": @registry.register_node() class TestNode: name: "TestNode" print(registry.get_node("TestNode"))