File size: 4,796 Bytes
1b7e88c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import importlib
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List

CATEGORIES = [
    "prompt",
    "llm",
    "node",
    "worker",
    "tool",
    "encoder",
    "connector",
    "component",
]


class Registry:
    """Class for module registration and retrieval."""

    def __init__(self):
        # Initializes a mapping for different categories of modules.
        self.mapping = {key: {} for key in CATEGORIES}

    def __getattr__(self, name: str) -> Callable:
        if name.startswith(("register_", "get_")):
            prefix, category = name.split("_", 1)
            if category in CATEGORIES:
                if prefix == "register":
                    return partial(self.register, category)
                elif prefix == "get":
                    return partial(self.get, category)
        raise AttributeError(
            f"'{self.__class__.__name__}' object has no attribute '{name}'"
        )

    def _register(self, category: str, name: str = None):
        """
        Registers a module under a specific category.

        :param category: The category to register the module under.
        :param name: The name to register the module as.
        """

        def wrap(module):
            nonlocal name
            name = name or module.__name__
            if name in self.mapping[category]:
                raise ValueError(
                    f"Module {name} [{self.mapping[category].get(name)}] already registered in category {category}. Please use a different class name."
                )
            self.mapping.setdefault(category, {})[name] = module
            return module

        return wrap

    def _get(self, category: str, name: str):
        """
        Retrieves a module from a specified category.

        :param category: The category to search in.
        :param name: The name of the module to retrieve.
        :raises KeyError: If the module is not found.
        """
        try:
            return self.mapping[category][name]
        except KeyError:
            raise KeyError(f"Module {name} not found in category {category}")

    def register(self, category: str, name: str = None):
        """
        Registers a module under a general category.

        :param category: The category to register the module under.
        :param name: Optional name to register the module as.
        """
        return self._register(category, name)

    def get(self, category: str, name: str):
        """
        Retrieves a module from a general category.

        :param category: The category to search in.
        :param name: The name of the module to retrieve.
        """
        return self._get(category, name)

    def import_module(self, project_path: List[str] | str = None):
        """Import modules from default paths and optional project paths.

        Args:
            project_path: Optional path or list of paths to import modules from
        """
        # Handle default paths
        root_path = Path(__file__).parents[1]
        default_path = [
            root_path.joinpath("models"),
            root_path.joinpath("tool_system"),
            root_path.joinpath("services"),
            root_path.joinpath("memories"),
            root_path.joinpath("advanced_components"),
            root_path.joinpath("clients"),
        ]

        for path in default_path:
            for module in path.rglob("*.[ps][yo]"):
                if module.name == "workflow.py":
                    continue
                module = str(module)
                if "__init__" in module or "base.py" in module or "entry.py" in module:
                    continue
                module = "omagent_core" + module.rsplit("omagent_core", 1)[1].rsplit(
                    ".", 1
                )[0].replace("/", ".")
                importlib.import_module(module)

        # Handle project paths
        if project_path:
            if isinstance(project_path, (str, Path)):
                project_path = [project_path]

            for path in project_path:
                path = Path(path).absolute()
                project_root = path.parent
                for module in path.rglob("*.[ps][yo]"):
                    module = str(module)
                    if "__init__" in module:
                        continue
                    module = (
                        module.replace(str(project_root) + "/", "")
                        .rsplit(".", 1)[0]
                        .replace("/", ".")
                    )
                    importlib.import_module(module)


# Instantiate registry
registry = Registry()


if __name__ == "__main__":

    @registry.register_node()
    class TestNode:
        name: "TestNode"

    print(registry.get_node("TestNode"))