from pathlib import Path from typing import Dict, List, Optional, Type from threading import Thread from omagent_core.engine.configuration.aaas_config import AaasConfig import yaml from omagent_core.engine.configuration.configuration import (TEMPLATE_CONFIG, Configuration) from omagent_core.engine.configuration.aaas_config import AAAS_TEMPLATE_CONFIG from omagent_core.utils.registry import registry from pydantic import BaseModel import os class Container: def __init__(self): self._connectors: Dict[str, BaseModel] = {} self._components: Dict[str, BaseModel] = {} self._stm_name: Optional[str] = None self._ltm_name: Optional[str] = None self._callback_name: Optional[str] = None self._input_name: Optional[str] = None self.conductor_config = Configuration() self.aaas_config = AaasConfig() def register_connector( self, connector: Type[BaseModel], name: str = None, overwrite: bool = False, **kwargs, ) -> None: """Register a connector""" if name is None: name = connector.__name__ if name not in self._connectors or overwrite: self._connectors[name] = connector(**kwargs) def get_connector(self, name: str) -> BaseModel: if name not in self._connectors: raise KeyError(f"There is no connector named '{name}' in container.") return self._connectors[name] def register_component( self, component: str | Type[BaseModel], name: str = None, config: dict = {}, overwrite: bool = False, ) -> None: """Generic component registration method Args: component: Component name or class key: The key to save and retrieve component config: Component configuration target_dict: Target dictionary to store component instances component_category: One of the register mapping types, should be provided if component is a string """ if isinstance(component, str): component_cls = registry.get_component(component) component_name = component if not component_cls: raise ValueError(f"{component} not found in registry") elif isinstance(component, type) and issubclass(component, BaseModel): component_cls = component component_name = component.__name__ else: raise ValueError(f"Invalid component type: {type(component)}") if ( name in self._components or component_name in self._components ) and not overwrite: return name or component_name required_connectors = self._get_required_connectors(component_cls) if required_connectors: for connector, cls_name in required_connectors: if connector not in self._connectors.keys(): connector_cls = registry.get_connector(cls_name) self.register_connector(connector_cls, connector) config[connector] = self._connectors[connector] self._components[name or component_name] = component_cls(**config) return name or component_name def get_component(self, component_name: str) -> BaseModel: if component_name not in self._components: raise KeyError( f"There is no component named '{component_name}' in container. You need to register it first." ) return self._components[component_name] def _get_required_connectors(self, cls: Type[BaseModel]) -> List[str]: required_connectors = [] for field_name, field in cls.model_fields.items(): if isinstance(field.annotation, type) and issubclass( field.annotation, BaseModel ): required_connectors.append([field_name, field.annotation.__name__]) return required_connectors @property def components(self) -> Dict[str, BaseModel]: return self._components def register_stm( self, stm: str | Type[BaseModel], name: str = None, config: dict = {}, overwrite: bool = False, ): if os.getenv("OMAGENT_MODE") == "lite": name = "SharedMemSTM" name = self.register_component(stm, name, config, overwrite) self._stm_name = name @property def stm(self) -> BaseModel: if self._stm_name is None: if os.getenv("OMAGENT_MODE") == "lite": self.register_stm("SharedMemSTM") self._stm_name = "SharedMemSTM" else: raise ValueError( "STM component is not registered. Please use register_stm to register." ) return self.get_component(self._stm_name) def register_ltm( self, ltm: str | Type[BaseModel], name: str = None, config: dict = {}, overwrite: bool = False, ): name = self.register_component(ltm, name, config, overwrite) self._ltm_name = name @property def ltm(self) -> BaseModel: if self._ltm_name is None: raise ValueError( "LTM component is not registered. Please use register_ltm to register." ) return self.get_component(self._ltm_name) def register_callback( self, callback: str | Type[BaseModel], name: str = None, config: dict = {}, overwrite: bool = False, ): name = self.register_component(callback, name, config, overwrite) self._callback_name = name @property def callback(self) -> BaseModel: if self._callback_name is None: raise ValueError( "Callback component is not registered. Please use register_callback to register." ) return self.get_component(self._callback_name) def register_input( self, input: str | Type[BaseModel], name: str = None, config: dict = {}, overwrite: bool = False, ): name = self.register_component(input, name, config, overwrite) self._input_name = name @property def input(self) -> BaseModel: if self._input_name is None: raise ValueError( "Input component is not registered. Please use register_input to register." ) return self.get_component(self._input_name) def compile_config( self, output_path: Path, description: bool = True, env_var: bool = True ) -> None: if (output_path / "container.yaml").exists(): print("container.yaml already exists, skip compiling") config = yaml.load( open(output_path / "container.yaml", "r"), Loader=yaml.FullLoader ) return config config = { "conductor_config": TEMPLATE_CONFIG, "aaas_config": AAAS_TEMPLATE_CONFIG, "connectors": {}, "components": {}, } exclude_fields = [ "_parent", "component_stm", "component_ltm", "component_callback", "component_input", ] for name, connector in self._connectors.items(): config["connectors"][name] = connector.__class__.get_config_template( description=description, env_var=env_var, exclude_fields=exclude_fields ) exclude_fields.extend(self._connectors.keys()) for name, component in self._components.items(): config["components"][name] = component.__class__.get_config_template( description=description, env_var=env_var, exclude_fields=exclude_fields ) with open(output_path / "container.yaml", "w") as f: f.write(yaml.dump(config, sort_keys=False, allow_unicode=True)) return config def from_config(self, config_data: dict | str | Path) -> None: """Update container from configuration Args: config_data: The dict including connectors and components configurations """ def clean_config_dict(config_dict: dict) -> dict: """Recursively clean up the configuration dictionary, removing all 'description' and 'env_var' keys""" cleaned = {} for key, value in config_dict.items(): if isinstance(value, dict): if "value" in value: cleaned[key] = value["value"] else: cleaned[key] = clean_config_dict(value) else: cleaned[key] = value return cleaned if isinstance(config_data, str | Path): if not Path(config_data).exists(): if os.getenv("OMAGENT_MODE") == "lite": return else: raise FileNotFoundError(f"Config file not found: {config_data}") config_data = yaml.load(open(config_data, "r"), Loader=yaml.FullLoader) config_data = clean_config_dict(config_data) if "conductor_config" in config_data: self.conductor_config = Configuration(**config_data["conductor_config"]) if "aaas_config" in config_data: self.aaas_config = AaasConfig(**config_data["aaas_config"]) # connectors if "connectors" in config_data: for name, config in config_data["connectors"].items(): connector_cls = registry.get_connector(config.pop("name")) if connector_cls: self.register_connector( name=name, connector=connector_cls, overwrite=True, **config ) # components if "components" in config_data: for name, config in config_data["components"].items(): self.register_component( component=config.pop("name"), name=name, config=config, overwrite=True, ) self.check_connection() def check_connection(self): if os.getenv("OMAGENT_MODE") == "lite": return for name, connector in self._connectors.items(): try: connector.check_connection() except Exception as e: raise ConnectionError( f"Connection to {name} failed. Please check your connector config in container.yaml. \n Error Message: {e}" ) try: from omagent_core.engine.orkes.orkes_workflow_client import \ OrkesWorkflowClient conductor_client = OrkesWorkflowClient(self.conductor_config) conductor_client.check_connection() except Exception as e: raise ConnectionError( f"Connection to Conductor failed. Please check your conductor config in container.yaml. \n Error Message: {e}" ) print("--------------------------------") print("All connections passed the connection check") print("--------------------------------") container = Container()