import asyncio from concurrent.futures import ThreadPoolExecutor from unittest.mock import AsyncMock, MagicMock, Mock from uuid import uuid4 import pytest from openhands.controller.agent import Agent from openhands.controller.agent_controller import AgentController from openhands.controller.state.state import State from openhands.core.config import LLMConfig from openhands.core.config.agent_config import AgentConfig from openhands.core.schema import AgentState from openhands.events import EventSource, EventStream from openhands.events.action import ( AgentDelegateAction, AgentFinishAction, MessageAction, ) from openhands.events.action.agent import RecallAction from openhands.events.event import Event, RecallType from openhands.events.observation.agent import RecallObservation from openhands.events.stream import EventStreamSubscriber from openhands.llm.llm import LLM from openhands.llm.metrics import Metrics from openhands.memory.memory import Memory from openhands.storage.memory import InMemoryFileStore @pytest.fixture def mock_event_stream(): """Creates an event stream in memory.""" sid = f'test-{uuid4()}' file_store = InMemoryFileStore({}) return EventStream(sid=sid, file_store=file_store) @pytest.fixture def mock_parent_agent(): """Creates a mock parent agent for testing delegation.""" agent = MagicMock(spec=Agent) agent.name = 'ParentAgent' agent.llm = MagicMock(spec=LLM) agent.llm.metrics = Metrics() agent.llm.config = LLMConfig() agent.config = AgentConfig() # Add a proper system message mock from openhands.events.action.message import SystemMessageAction system_message = SystemMessageAction(content='Test system message') system_message._source = EventSource.AGENT system_message._id = -1 # Set invalid ID to avoid the ID check agent.get_system_message.return_value = system_message return agent @pytest.fixture def mock_child_agent(): """Creates a mock child agent for testing delegation.""" agent = MagicMock(spec=Agent) agent.name = 'ChildAgent' agent.llm = MagicMock(spec=LLM) agent.llm.metrics = Metrics() agent.llm.config = LLMConfig() agent.config = AgentConfig() # Add a proper system message mock from openhands.events.action.message import SystemMessageAction system_message = SystemMessageAction(content='Test system message') system_message._source = EventSource.AGENT system_message._id = -1 # Set invalid ID to avoid the ID check agent.get_system_message.return_value = system_message return agent @pytest.mark.asyncio async def test_delegation_flow(mock_parent_agent, mock_child_agent, mock_event_stream): """ Test that when the parent agent delegates to a child, the parent's delegate is set, and once the child finishes, the parent is cleaned up properly. """ # Mock the agent class resolution so that AgentController can instantiate mock_child_agent Agent.get_cls = Mock(return_value=lambda llm, config: mock_child_agent) # Create parent controller parent_state = State(max_iterations=10) parent_controller = AgentController( agent=mock_parent_agent, event_stream=mock_event_stream, max_iterations=10, sid='parent', confirmation_mode=False, headless_mode=True, initial_state=parent_state, ) # Setup Memory to catch RecallActions mock_memory = MagicMock(spec=Memory) mock_memory.event_stream = mock_event_stream def on_event(event: Event): if isinstance(event, RecallAction): # create a RecallObservation microagent_observation = RecallObservation( recall_type=RecallType.KNOWLEDGE, content='Found info', ) microagent_observation._cause = event.id # ignore attr-defined warning mock_event_stream.add_event(microagent_observation, EventSource.ENVIRONMENT) mock_memory.on_event = on_event mock_event_stream.subscribe( EventStreamSubscriber.MEMORY, mock_memory.on_event, mock_memory ) # Setup a delegate action from the parent delegate_action = AgentDelegateAction(agent='ChildAgent', inputs={'test': True}) mock_parent_agent.step.return_value = delegate_action # Simulate a user message event to cause parent.step() to run message_action = MessageAction(content='please delegate now') message_action._source = EventSource.USER await parent_controller._on_event(message_action) # Give time for the async step() to execute await asyncio.sleep(1) # Verify that a RecallObservation was added to the event stream events = list(mock_event_stream.get_events()) # SystemMessageAction, RecallAction, AgentChangeState, AgentDelegateAction, SystemMessageAction (for child) assert mock_event_stream.get_latest_event_id() == 5 # a RecallObservation and an AgentDelegateAction should be in the list assert any(isinstance(event, RecallObservation) for event in events) assert any(isinstance(event, AgentDelegateAction) for event in events) # Verify that a delegate agent controller is created assert parent_controller.delegate is not None, ( "Parent's delegate controller was not set." ) # The parent's iteration should have incremented assert parent_controller.state.iteration == 1, ( 'Parent iteration should be incremented after step.' ) # Now simulate that the child increments local iteration and finishes its subtask delegate_controller = parent_controller.delegate delegate_controller.state.iteration = 5 # child had some steps delegate_controller.state.outputs = {'delegate_result': 'done'} # The child is done, so we simulate it finishing: child_finish_action = AgentFinishAction() await delegate_controller._on_event(child_finish_action) await asyncio.sleep(0.5) # Now the parent's delegate is None assert parent_controller.delegate is None, ( 'Parent delegate should be None after child finishes.' ) # Parent's global iteration is updated from the child assert parent_controller.state.iteration == 6, ( "Parent iteration should be the child's iteration + 1 after child is done." ) # Cleanup await parent_controller.close() @pytest.mark.asyncio @pytest.mark.parametrize( 'delegate_state', [ AgentState.RUNNING, AgentState.FINISHED, AgentState.ERROR, AgentState.REJECTED, ], ) async def test_delegate_step_different_states( mock_parent_agent, mock_event_stream, delegate_state ): """Ensure that delegate is closed or remains open based on the delegate's state.""" controller = AgentController( agent=mock_parent_agent, event_stream=mock_event_stream, max_iterations=10, sid='test', confirmation_mode=False, headless_mode=True, ) mock_delegate = AsyncMock() controller.delegate = mock_delegate mock_delegate.state.iteration = 5 mock_delegate.state.outputs = {'result': 'test'} mock_delegate.agent.name = 'TestDelegate' mock_delegate.get_agent_state = Mock(return_value=delegate_state) mock_delegate._step = AsyncMock() mock_delegate.close = AsyncMock() def call_on_event_with_new_loop(): """ In this thread, create and set a fresh event loop, so that the run_until_complete() calls inside controller.on_event(...) find a valid loop. """ loop_in_thread = asyncio.new_event_loop() try: asyncio.set_event_loop(loop_in_thread) msg_action = MessageAction(content='Test message') msg_action._source = EventSource.USER controller.on_event(msg_action) finally: loop_in_thread.close() loop = asyncio.get_running_loop() with ThreadPoolExecutor() as executor: future = loop.run_in_executor(executor, call_on_event_with_new_loop) await future if delegate_state == AgentState.RUNNING: assert controller.delegate is not None assert controller.state.iteration == 0 mock_delegate.close.assert_not_called() else: assert controller.delegate is None assert controller.state.iteration == 5 # The close method is called once in end_delegate assert mock_delegate.close.call_count == 1 await controller.close()