Spaces:
Sleeping
Sleeping
File size: 4,796 Bytes
1b7e88c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import importlib
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List
CATEGORIES = [
"prompt",
"llm",
"node",
"worker",
"tool",
"encoder",
"connector",
"component",
]
class Registry:
"""Class for module registration and retrieval."""
def __init__(self):
# Initializes a mapping for different categories of modules.
self.mapping = {key: {} for key in CATEGORIES}
def __getattr__(self, name: str) -> Callable:
if name.startswith(("register_", "get_")):
prefix, category = name.split("_", 1)
if category in CATEGORIES:
if prefix == "register":
return partial(self.register, category)
elif prefix == "get":
return partial(self.get, category)
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
def _register(self, category: str, name: str = None):
"""
Registers a module under a specific category.
:param category: The category to register the module under.
:param name: The name to register the module as.
"""
def wrap(module):
nonlocal name
name = name or module.__name__
if name in self.mapping[category]:
raise ValueError(
f"Module {name} [{self.mapping[category].get(name)}] already registered in category {category}. Please use a different class name."
)
self.mapping.setdefault(category, {})[name] = module
return module
return wrap
def _get(self, category: str, name: str):
"""
Retrieves a module from a specified category.
:param category: The category to search in.
:param name: The name of the module to retrieve.
:raises KeyError: If the module is not found.
"""
try:
return self.mapping[category][name]
except KeyError:
raise KeyError(f"Module {name} not found in category {category}")
def register(self, category: str, name: str = None):
"""
Registers a module under a general category.
:param category: The category to register the module under.
:param name: Optional name to register the module as.
"""
return self._register(category, name)
def get(self, category: str, name: str):
"""
Retrieves a module from a general category.
:param category: The category to search in.
:param name: The name of the module to retrieve.
"""
return self._get(category, name)
def import_module(self, project_path: List[str] | str = None):
"""Import modules from default paths and optional project paths.
Args:
project_path: Optional path or list of paths to import modules from
"""
# Handle default paths
root_path = Path(__file__).parents[1]
default_path = [
root_path.joinpath("models"),
root_path.joinpath("tool_system"),
root_path.joinpath("services"),
root_path.joinpath("memories"),
root_path.joinpath("advanced_components"),
root_path.joinpath("clients"),
]
for path in default_path:
for module in path.rglob("*.[ps][yo]"):
if module.name == "workflow.py":
continue
module = str(module)
if "__init__" in module or "base.py" in module or "entry.py" in module:
continue
module = "omagent_core" + module.rsplit("omagent_core", 1)[1].rsplit(
".", 1
)[0].replace("/", ".")
importlib.import_module(module)
# Handle project paths
if project_path:
if isinstance(project_path, (str, Path)):
project_path = [project_path]
for path in project_path:
path = Path(path).absolute()
project_root = path.parent
for module in path.rglob("*.[ps][yo]"):
module = str(module)
if "__init__" in module:
continue
module = (
module.replace(str(project_root) + "/", "")
.rsplit(".", 1)[0]
.replace("/", ".")
)
importlib.import_module(module)
# Instantiate registry
registry = Registry()
if __name__ == "__main__":
@registry.register_node()
class TestNode:
name: "TestNode"
print(registry.get_node("TestNode"))
|