|
import asyncio |
|
import dataclasses |
|
import inspect |
|
import logging |
|
import socket |
|
import time |
|
import traceback |
|
from abc import ABC, abstractmethod |
|
from copy import deepcopy |
|
from typing import Any, Callable, Optional, Union |
|
|
|
from omagent_core.base import BotBase |
|
from omagent_core.engine.automator import utils |
|
from omagent_core.engine.automator.utils import convert_from_dict_or_list |
|
from omagent_core.engine.configuration.configuration import Configuration |
|
from omagent_core.engine.http.api_client import ApiClient |
|
from omagent_core.engine.http.models import TaskExecLog |
|
from omagent_core.engine.http.models.task import Task |
|
from omagent_core.engine.http.models.task_result import TaskResult |
|
from omagent_core.engine.http.models.task_result_status import TaskResultStatus |
|
from omagent_core.engine.worker.exception import NonRetryableException |
|
from pydantic import Field |
|
from typing_extensions import Self |
|
|
|
ExecuteTaskFunction = Callable[[Union[Task, object]], Union[TaskResult, object]] |
|
|
|
logger = logging.getLogger(Configuration.get_logging_formatted_name(__name__)) |
|
|
|
|
|
def is_callable_input_parameter_a_task( |
|
callable: ExecuteTaskFunction, object_type: Any |
|
) -> bool: |
|
parameters = inspect.signature(callable).parameters |
|
if len(parameters) != 1: |
|
return False |
|
parameter = parameters[list(parameters.keys())[0]] |
|
return ( |
|
parameter.annotation == object_type |
|
or parameter.annotation == parameter.empty |
|
or parameter.annotation == object |
|
) |
|
|
|
|
|
def is_callable_return_value_of_type( |
|
callable: ExecuteTaskFunction, object_type: Any |
|
) -> bool: |
|
return_annotation = inspect.signature(callable).return_annotation |
|
return return_annotation == object_type |
|
|
|
|
|
class BaseWorker(BotBase, ABC): |
|
poll_interval: float = Field( |
|
default=100, description="Worker poll interval in millisecond" |
|
) |
|
domain: Optional[str] = Field(default=None, description="The domain of workflow") |
|
concurrency: int = Field(default=5, description="The concurrency of worker") |
|
_task_type: Optional[str] = None |
|
|
|
def model_post_init(self, __context: Any) -> None: |
|
self.task_definition_name = self.id or self.name |
|
self.next_task_index = 0 |
|
self._task_definition_name_cache = None |
|
|
|
self.api_client = ApiClient() |
|
self.worker_id = deepcopy(self.get_identity()) |
|
|
|
self._workflow_instance_id = None |
|
self._task_type = None |
|
for _, attr_value in self.__dict__.items(): |
|
if isinstance(attr_value, BotBase): |
|
attr_value._parent = self |
|
|
|
@property |
|
def task_type(self) -> str: |
|
return self._task_type |
|
|
|
@task_type.setter |
|
def task_type(self, value: str): |
|
self._task_type = value |
|
|
|
@property |
|
def workflow_instance_id(self) -> Optional[str]: |
|
return self._workflow_instance_id |
|
|
|
@workflow_instance_id.setter |
|
def workflow_instance_id(self, value: Optional[str]): |
|
self._workflow_instance_id = value |
|
|
|
@abstractmethod |
|
def _run(self, *args, **kwargs) -> Any: |
|
"""Run the Node.""" |
|
|
|
def __call__(self, *args: Any, **kwds: Any) -> Any: |
|
print ("__call__") |
|
return self._run(*args, **kwds) |
|
|
|
def execute(self, task: Task) -> TaskResult: |
|
task_input = {} |
|
task_output = None |
|
task_result: TaskResult = self.get_task_result_from_task(task) |
|
if task.conversation_info: |
|
self.workflow_instance_id = '|'.join([ |
|
task.workflow_instance_id, |
|
task.conversation_info.get('agentId', ''), |
|
task.conversation_info.get('conversationId', ''), |
|
task.conversation_info.get('chatId', ''), |
|
]) |
|
else: |
|
self.workflow_instance_id = task.workflow_instance_id |
|
|
|
try: |
|
if is_callable_input_parameter_a_task( |
|
callable=self._run, |
|
object_type=Task, |
|
): |
|
task_output = self._run(task) |
|
else: |
|
params = inspect.signature(self._run).parameters |
|
for input_name in params: |
|
typ = params[input_name].annotation |
|
default_value = params[input_name].default |
|
if input_name in task.input_data: |
|
if typ in utils.simple_types: |
|
task_input[input_name] = task.input_data[input_name] |
|
else: |
|
task_input[input_name] = convert_from_dict_or_list( |
|
typ, task.input_data[input_name] |
|
) |
|
else: |
|
if default_value is not inspect.Parameter.empty: |
|
task_input[input_name] = default_value |
|
else: |
|
task_input[input_name] = None |
|
if inspect.iscoroutinefunction(self._run): |
|
try: |
|
loop = asyncio.get_running_loop() |
|
except RuntimeError: |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
|
|
task_output = loop.run_until_complete( |
|
asyncio.gather(self._run(**task_input), return_exceptions=True) |
|
)[0] |
|
else: |
|
task_output = self._run(**task_input) |
|
if type(task_output) == TaskResult: |
|
task_output.task_id = task.task_id |
|
task_output.workflow_instance_id = task.workflow_instance_id |
|
return task_output |
|
else: |
|
task_result.status = TaskResultStatus.COMPLETED |
|
task_result.output_data = task_output |
|
|
|
except NonRetryableException as ne: |
|
task_result.status = TaskResultStatus.FAILED_WITH_TERMINAL_ERROR |
|
if len(ne.args) > 0: |
|
task_result.reason_for_incompletion = ne.args[0] |
|
|
|
except Exception as ne: |
|
logger.error( |
|
f"Error executing task {task.task_def_name} with id {task.task_id}. error = {traceback.format_exc()}" |
|
) |
|
|
|
task_result.logs = [ |
|
TaskExecLog( |
|
traceback.format_exc(), task_result.task_id, int(time.time()) |
|
) |
|
] |
|
task_result.status = TaskResultStatus.FAILED |
|
if len(ne.args) > 0: |
|
task_result.reason_for_incompletion = ne.args[0] |
|
self.workflow_instance_id = None |
|
|
|
if dataclasses.is_dataclass(type(task_result.output_data)): |
|
task_output = dataclasses.asdict(task_result.output_data) |
|
task_result.output_data = task_output |
|
return task_result |
|
if not isinstance(task_result.output_data, dict): |
|
task_output = task_result.output_data |
|
task_result.output_data = self.api_client.sanitize_for_serialization( |
|
task_output |
|
) |
|
if not isinstance(task_result.output_data, dict): |
|
task_result.output_data = {"result": task_result.output_data} |
|
return task_result |
|
|
|
def get_identity(self) -> str: |
|
return self.worker_id if hasattr(self, "worker_id") else socket.gethostname() |
|
|
|
def get_polling_interval_in_seconds(self) -> float: |
|
return self.poll_interval / 1000 |
|
|
|
def get_task_definition_name(self) -> str: |
|
return self.task_definition_name_cache |
|
|
|
@property |
|
def task_definition_names(self): |
|
if isinstance(self.task_definition_name, list): |
|
return self.task_definition_name |
|
else: |
|
return [self.task_definition_name] |
|
|
|
@property |
|
def task_definition_name_cache(self): |
|
if self._task_definition_name_cache is None: |
|
self._task_definition_name_cache = self.compute_task_definition_name() |
|
return self._task_definition_name_cache |
|
|
|
def clear_task_definition_name_cache(self): |
|
self._task_definition_name_cache = None |
|
|
|
def compute_task_definition_name(self): |
|
if isinstance(self.task_definition_name, list): |
|
task_definition_name = self.task_definition_name[self.next_task_index] |
|
self.next_task_index = (self.next_task_index + 1) % len( |
|
self.task_definition_name |
|
) |
|
return task_definition_name |
|
return self.task_definition_name |
|
|
|
def get_task_result_from_task(self, task: Task) -> TaskResult: |
|
return TaskResult( |
|
task_id=task.task_id, |
|
workflow_instance_id=task.workflow_instance_id, |
|
worker_id=self.get_identity(), |
|
biz_meta=task.biz_meta, |
|
callback_url=task.callback_url, |
|
) |
|
|
|
def get_domain(self) -> str: |
|
return self.domain |
|
|
|
def paused(self) -> bool: |
|
return False |
|
|