|
import importlib |
|
from functools import partial |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, List |
|
|
|
CATEGORIES = [ |
|
"prompt", |
|
"llm", |
|
"node", |
|
"worker", |
|
"tool", |
|
"encoder", |
|
"connector", |
|
"component", |
|
] |
|
|
|
|
|
class Registry: |
|
"""Class for module registration and retrieval.""" |
|
|
|
def __init__(self): |
|
|
|
self.mapping = {key: {} for key in CATEGORIES} |
|
|
|
def __getattr__(self, name: str) -> Callable: |
|
if name.startswith(("register_", "get_")): |
|
prefix, category = name.split("_", 1) |
|
if category in CATEGORIES: |
|
if prefix == "register": |
|
return partial(self.register, category) |
|
elif prefix == "get": |
|
return partial(self.get, category) |
|
raise AttributeError( |
|
f"'{self.__class__.__name__}' object has no attribute '{name}'" |
|
) |
|
|
|
def _register(self, category: str, name: str = None): |
|
""" |
|
Registers a module under a specific category. |
|
|
|
:param category: The category to register the module under. |
|
:param name: The name to register the module as. |
|
""" |
|
|
|
def wrap(module): |
|
nonlocal name |
|
name = name or module.__name__ |
|
if name in self.mapping[category]: |
|
raise ValueError( |
|
f"Module {name} [{self.mapping[category].get(name)}] already registered in category {category}. Please use a different class name." |
|
) |
|
self.mapping.setdefault(category, {})[name] = module |
|
return module |
|
|
|
return wrap |
|
|
|
def _get(self, category: str, name: str): |
|
""" |
|
Retrieves a module from a specified category. |
|
|
|
:param category: The category to search in. |
|
:param name: The name of the module to retrieve. |
|
:raises KeyError: If the module is not found. |
|
""" |
|
try: |
|
return self.mapping[category][name] |
|
except KeyError: |
|
raise KeyError(f"Module {name} not found in category {category}") |
|
|
|
def register(self, category: str, name: str = None): |
|
""" |
|
Registers a module under a general category. |
|
|
|
:param category: The category to register the module under. |
|
:param name: Optional name to register the module as. |
|
""" |
|
return self._register(category, name) |
|
|
|
def get(self, category: str, name: str): |
|
""" |
|
Retrieves a module from a general category. |
|
|
|
:param category: The category to search in. |
|
:param name: The name of the module to retrieve. |
|
""" |
|
return self._get(category, name) |
|
|
|
def import_module(self, project_path: List[str] | str = None): |
|
"""Import modules from default paths and optional project paths. |
|
|
|
Args: |
|
project_path: Optional path or list of paths to import modules from |
|
""" |
|
|
|
root_path = Path(__file__).parents[1] |
|
default_path = [ |
|
root_path.joinpath("models"), |
|
root_path.joinpath("tool_system"), |
|
root_path.joinpath("services"), |
|
root_path.joinpath("memories"), |
|
root_path.joinpath("advanced_components"), |
|
root_path.joinpath("clients"), |
|
] |
|
|
|
for path in default_path: |
|
for module in path.rglob("*.[ps][yo]"): |
|
if module.name == "workflow.py": |
|
continue |
|
module = str(module) |
|
if "__init__" in module or "base.py" in module or "entry.py" in module: |
|
continue |
|
module = "omagent_core" + module.rsplit("omagent_core", 1)[1].rsplit( |
|
".", 1 |
|
)[0].replace("/", ".") |
|
importlib.import_module(module) |
|
|
|
|
|
if project_path: |
|
if isinstance(project_path, (str, Path)): |
|
project_path = [project_path] |
|
|
|
for path in project_path: |
|
path = Path(path).absolute() |
|
project_root = path.parent |
|
for module in path.rglob("*.[ps][yo]"): |
|
module = str(module) |
|
if "__init__" in module: |
|
continue |
|
module = ( |
|
module.replace(str(project_root) + "/", "") |
|
.rsplit(".", 1)[0] |
|
.replace("/", ".") |
|
) |
|
importlib.import_module(module) |
|
|
|
|
|
|
|
registry = Registry() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
@registry.register_node() |
|
class TestNode: |
|
name: "TestNode" |
|
|
|
print(registry.get_node("TestNode")) |
|
|