|
from pathlib import Path |
|
from typing import Dict, List, Optional, Type |
|
from threading import Thread |
|
|
|
from omagent_core.engine.configuration.aaas_config import AaasConfig |
|
import yaml |
|
from omagent_core.engine.configuration.configuration import (TEMPLATE_CONFIG, |
|
Configuration) |
|
from omagent_core.engine.configuration.aaas_config import AAAS_TEMPLATE_CONFIG |
|
from omagent_core.utils.registry import registry |
|
from pydantic import BaseModel |
|
import os |
|
|
|
|
|
class Container: |
|
def __init__(self): |
|
self._connectors: Dict[str, BaseModel] = {} |
|
self._components: Dict[str, BaseModel] = {} |
|
self._stm_name: Optional[str] = None |
|
self._ltm_name: Optional[str] = None |
|
self._callback_name: Optional[str] = None |
|
self._input_name: Optional[str] = None |
|
self.conductor_config = Configuration() |
|
self.aaas_config = AaasConfig() |
|
|
|
def register_connector( |
|
self, |
|
connector: Type[BaseModel], |
|
name: str = None, |
|
overwrite: bool = False, |
|
**kwargs, |
|
) -> None: |
|
"""Register a connector""" |
|
if name is None: |
|
name = connector.__name__ |
|
if name not in self._connectors or overwrite: |
|
self._connectors[name] = connector(**kwargs) |
|
|
|
def get_connector(self, name: str) -> BaseModel: |
|
if name not in self._connectors: |
|
raise KeyError(f"There is no connector named '{name}' in container.") |
|
return self._connectors[name] |
|
|
|
def register_component( |
|
self, |
|
component: str | Type[BaseModel], |
|
name: str = None, |
|
config: dict = {}, |
|
overwrite: bool = False, |
|
) -> None: |
|
"""Generic component registration method |
|
|
|
Args: |
|
component: Component name or class |
|
key: The key to save and retrieve component |
|
config: Component configuration |
|
target_dict: Target dictionary to store component instances |
|
component_category: One of the register mapping types, should be provided if component is a string |
|
""" |
|
if isinstance(component, str): |
|
component_cls = registry.get_component(component) |
|
component_name = component |
|
if not component_cls: |
|
raise ValueError(f"{component} not found in registry") |
|
elif isinstance(component, type) and issubclass(component, BaseModel): |
|
component_cls = component |
|
component_name = component.__name__ |
|
else: |
|
raise ValueError(f"Invalid component type: {type(component)}") |
|
|
|
if ( |
|
name in self._components or component_name in self._components |
|
) and not overwrite: |
|
return name or component_name |
|
|
|
required_connectors = self._get_required_connectors(component_cls) |
|
if required_connectors: |
|
for connector, cls_name in required_connectors: |
|
if connector not in self._connectors.keys(): |
|
connector_cls = registry.get_connector(cls_name) |
|
self.register_connector(connector_cls, connector) |
|
config[connector] = self._connectors[connector] |
|
|
|
self._components[name or component_name] = component_cls(**config) |
|
return name or component_name |
|
|
|
def get_component(self, component_name: str) -> BaseModel: |
|
if component_name not in self._components: |
|
raise KeyError( |
|
f"There is no component named '{component_name}' in container. You need to register it first." |
|
) |
|
return self._components[component_name] |
|
|
|
def _get_required_connectors(self, cls: Type[BaseModel]) -> List[str]: |
|
required_connectors = [] |
|
for field_name, field in cls.model_fields.items(): |
|
if isinstance(field.annotation, type) and issubclass( |
|
field.annotation, BaseModel |
|
): |
|
required_connectors.append([field_name, field.annotation.__name__]) |
|
return required_connectors |
|
|
|
@property |
|
def components(self) -> Dict[str, BaseModel]: |
|
return self._components |
|
|
|
def register_stm( |
|
self, |
|
stm: str | Type[BaseModel], |
|
name: str = None, |
|
config: dict = {}, |
|
overwrite: bool = False, |
|
): |
|
if os.getenv("OMAGENT_MODE") == "lite": |
|
name = "SharedMemSTM" |
|
name = self.register_component(stm, name, config, overwrite) |
|
self._stm_name = name |
|
|
|
@property |
|
def stm(self) -> BaseModel: |
|
if self._stm_name is None: |
|
if os.getenv("OMAGENT_MODE") == "lite": |
|
self.register_stm("SharedMemSTM") |
|
self._stm_name = "SharedMemSTM" |
|
else: |
|
raise ValueError( |
|
"STM component is not registered. Please use register_stm to register." |
|
) |
|
|
|
return self.get_component(self._stm_name) |
|
|
|
def register_ltm( |
|
self, |
|
ltm: str | Type[BaseModel], |
|
name: str = None, |
|
config: dict = {}, |
|
overwrite: bool = False, |
|
): |
|
name = self.register_component(ltm, name, config, overwrite) |
|
self._ltm_name = name |
|
|
|
@property |
|
def ltm(self) -> BaseModel: |
|
if self._ltm_name is None: |
|
raise ValueError( |
|
"LTM component is not registered. Please use register_ltm to register." |
|
) |
|
return self.get_component(self._ltm_name) |
|
|
|
def register_callback( |
|
self, |
|
callback: str | Type[BaseModel], |
|
name: str = None, |
|
config: dict = {}, |
|
overwrite: bool = False, |
|
): |
|
name = self.register_component(callback, name, config, overwrite) |
|
self._callback_name = name |
|
|
|
@property |
|
def callback(self) -> BaseModel: |
|
if self._callback_name is None: |
|
raise ValueError( |
|
"Callback component is not registered. Please use register_callback to register." |
|
) |
|
return self.get_component(self._callback_name) |
|
|
|
def register_input( |
|
self, |
|
input: str | Type[BaseModel], |
|
name: str = None, |
|
config: dict = {}, |
|
overwrite: bool = False, |
|
): |
|
name = self.register_component(input, name, config, overwrite) |
|
self._input_name = name |
|
|
|
@property |
|
def input(self) -> BaseModel: |
|
if self._input_name is None: |
|
raise ValueError( |
|
"Input component is not registered. Please use register_input to register." |
|
) |
|
return self.get_component(self._input_name) |
|
|
|
def compile_config( |
|
self, output_path: Path, description: bool = True, env_var: bool = True |
|
) -> None: |
|
if (output_path / "container.yaml").exists(): |
|
print("container.yaml already exists, skip compiling") |
|
config = yaml.load( |
|
open(output_path / "container.yaml", "r"), Loader=yaml.FullLoader |
|
) |
|
return config |
|
|
|
config = { |
|
"conductor_config": TEMPLATE_CONFIG, |
|
"aaas_config": AAAS_TEMPLATE_CONFIG, |
|
"connectors": {}, |
|
"components": {}, |
|
} |
|
exclude_fields = [ |
|
"_parent", |
|
"component_stm", |
|
"component_ltm", |
|
"component_callback", |
|
"component_input", |
|
] |
|
for name, connector in self._connectors.items(): |
|
config["connectors"][name] = connector.__class__.get_config_template( |
|
description=description, env_var=env_var, exclude_fields=exclude_fields |
|
) |
|
exclude_fields.extend(self._connectors.keys()) |
|
for name, component in self._components.items(): |
|
config["components"][name] = component.__class__.get_config_template( |
|
description=description, env_var=env_var, exclude_fields=exclude_fields |
|
) |
|
|
|
with open(output_path / "container.yaml", "w") as f: |
|
f.write(yaml.dump(config, sort_keys=False, allow_unicode=True)) |
|
|
|
return config |
|
|
|
def from_config(self, config_data: dict | str | Path) -> None: |
|
"""Update container from configuration |
|
|
|
Args: |
|
config_data: The dict including connectors and components configurations |
|
""" |
|
def clean_config_dict(config_dict: dict) -> dict: |
|
"""Recursively clean up the configuration dictionary, removing all 'description' and 'env_var' keys""" |
|
cleaned = {} |
|
for key, value in config_dict.items(): |
|
if isinstance(value, dict): |
|
if "value" in value: |
|
cleaned[key] = value["value"] |
|
else: |
|
cleaned[key] = clean_config_dict(value) |
|
else: |
|
cleaned[key] = value |
|
return cleaned |
|
|
|
if isinstance(config_data, str | Path): |
|
if not Path(config_data).exists(): |
|
if os.getenv("OMAGENT_MODE") == "lite": |
|
return |
|
else: |
|
raise FileNotFoundError(f"Config file not found: {config_data}") |
|
config_data = yaml.load(open(config_data, "r"), Loader=yaml.FullLoader) |
|
config_data = clean_config_dict(config_data) |
|
|
|
if "conductor_config" in config_data: |
|
self.conductor_config = Configuration(**config_data["conductor_config"]) |
|
if "aaas_config" in config_data: |
|
self.aaas_config = AaasConfig(**config_data["aaas_config"]) |
|
|
|
|
|
if "connectors" in config_data: |
|
for name, config in config_data["connectors"].items(): |
|
connector_cls = registry.get_connector(config.pop("name")) |
|
if connector_cls: |
|
self.register_connector( |
|
name=name, connector=connector_cls, overwrite=True, **config |
|
) |
|
|
|
|
|
if "components" in config_data: |
|
for name, config in config_data["components"].items(): |
|
self.register_component( |
|
component=config.pop("name"), |
|
name=name, |
|
config=config, |
|
overwrite=True, |
|
) |
|
|
|
self.check_connection() |
|
|
|
def check_connection(self): |
|
if os.getenv("OMAGENT_MODE") == "lite": |
|
return |
|
|
|
for name, connector in self._connectors.items(): |
|
try: |
|
connector.check_connection() |
|
except Exception as e: |
|
raise ConnectionError( |
|
f"Connection to {name} failed. Please check your connector config in container.yaml. \n Error Message: {e}" |
|
) |
|
|
|
try: |
|
from omagent_core.engine.orkes.orkes_workflow_client import \ |
|
OrkesWorkflowClient |
|
|
|
conductor_client = OrkesWorkflowClient(self.conductor_config) |
|
conductor_client.check_connection() |
|
except Exception as e: |
|
raise ConnectionError( |
|
f"Connection to Conductor failed. Please check your conductor config in container.yaml. \n Error Message: {e}" |
|
) |
|
|
|
print("--------------------------------") |
|
print("All connections passed the connection check") |
|
print("--------------------------------") |
|
|
|
|
|
container = Container() |
|
|