from unittest.mock import AsyncMock, MagicMock, patch import pytest from openhands.controller.agent import Agent from openhands.controller.agent_controller import AgentController from openhands.controller.state.state import State from openhands.core.config import LLMConfig, OpenHandsConfig from openhands.core.config.agent_config import AgentConfig from openhands.events import EventStream, EventStreamSubscriber from openhands.llm import LLM from openhands.llm.metrics import Metrics from openhands.memory.memory import Memory from openhands.runtime.impl.action_execution.action_execution_client import ( ActionExecutionClient, ) from openhands.server.session.agent_session import AgentSession from openhands.storage.memory import InMemoryFileStore @pytest.fixture def mock_agent(): """Create a properly configured mock agent with all required nested attributes""" # Create the base mocks agent = MagicMock(spec=Agent) llm = MagicMock(spec=LLM) metrics = MagicMock(spec=Metrics) llm_config = MagicMock(spec=LLMConfig) agent_config = MagicMock(spec=AgentConfig) # Configure the LLM config llm_config.model = 'test-model' llm_config.base_url = 'http://test' llm_config.max_message_chars = 1000 # Configure the agent config agent_config.disabled_microagents = [] agent_config.enable_mcp = True # Set up the chain of mocks llm.metrics = metrics llm.config = llm_config agent.llm = llm agent.name = 'test-agent' agent.sandbox_plugins = [] agent.config = agent_config agent.prompt_manager = MagicMock() return agent @pytest.mark.asyncio async def test_agent_session_start_with_no_state(mock_agent): """Test that AgentSession.start() works correctly when there's no state to restore""" # Setup file_store = InMemoryFileStore({}) session = AgentSession( sid='test-session', file_store=file_store, ) # Create a mock runtime and set it up mock_runtime = MagicMock(spec=ActionExecutionClient) # Mock the runtime creation to set up the runtime attribute async def mock_create_runtime(*args, **kwargs): session.runtime = mock_runtime return True session._create_runtime = AsyncMock(side_effect=mock_create_runtime) # Create a mock EventStream with no events mock_event_stream = MagicMock(spec=EventStream) mock_event_stream.get_events.return_value = [] mock_event_stream.subscribe = MagicMock() mock_event_stream.get_latest_event_id.return_value = 0 # Inject the mock event stream into the session session.event_stream = mock_event_stream # Create a spy on set_initial_state class SpyAgentController(AgentController): set_initial_state_call_count = 0 test_initial_state = None def set_initial_state(self, *args, state=None, **kwargs): self.set_initial_state_call_count += 1 self.test_initial_state = state super().set_initial_state(*args, state=state, **kwargs) # Create a real Memory instance with the mock event stream memory = Memory(event_stream=mock_event_stream, sid='test-session') memory.microagents_dir = 'test-dir' # Patch AgentController and State.restore_from_session to fail; patch Memory in AgentSession with ( patch( 'openhands.server.session.agent_session.AgentController', SpyAgentController ), patch( 'openhands.server.session.agent_session.EventStream', return_value=mock_event_stream, ), patch( 'openhands.controller.state.state.State.restore_from_session', side_effect=Exception('No state found'), ), patch('openhands.server.session.agent_session.Memory', return_value=memory), ): await session.start( runtime_name='test-runtime', config=OpenHandsConfig(), agent=mock_agent, max_iterations=10, ) # Verify EventStream.subscribe was called with correct parameters mock_event_stream.subscribe.assert_any_call( EventStreamSubscriber.AGENT_CONTROLLER, session.controller.on_event, session.controller.id, ) mock_event_stream.subscribe.assert_any_call( EventStreamSubscriber.MEMORY, session.memory.on_event, session.controller.id, ) # Verify set_initial_state was called once with None as state assert session.controller.set_initial_state_call_count == 1 assert session.controller.test_initial_state is None assert session.controller.state.max_iterations == 10 assert session.controller.agent.name == 'test-agent' assert session.controller.state.start_id == 0 assert session.controller.state.end_id == -1 @pytest.mark.asyncio async def test_agent_session_start_with_restored_state(mock_agent): """Test that AgentSession.start() works correctly when there's a state to restore""" # Setup file_store = InMemoryFileStore({}) session = AgentSession( sid='test-session', file_store=file_store, ) # Create a mock runtime and set it up mock_runtime = MagicMock(spec=ActionExecutionClient) # Mock the runtime creation to set up the runtime attribute async def mock_create_runtime(*args, **kwargs): session.runtime = mock_runtime return True session._create_runtime = AsyncMock(side_effect=mock_create_runtime) # Create a mock EventStream with some events mock_event_stream = MagicMock(spec=EventStream) mock_event_stream.get_events.return_value = [] mock_event_stream.subscribe = MagicMock() mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist # Inject the mock event stream into the session session.event_stream = mock_event_stream # Create a mock restored state mock_restored_state = MagicMock(spec=State) mock_restored_state.start_id = -1 mock_restored_state.end_id = -1 mock_restored_state.max_iterations = 5 # Create a spy on set_initial_state by subclassing AgentController class SpyAgentController(AgentController): set_initial_state_call_count = 0 test_initial_state = None def set_initial_state(self, *args, state=None, **kwargs): self.set_initial_state_call_count += 1 self.test_initial_state = state super().set_initial_state(*args, state=state, **kwargs) # create a mock Memory mock_memory = MagicMock(spec=Memory) # Patch AgentController and State.restore_from_session to succeed, patch Memory in AgentSession with ( patch( 'openhands.server.session.agent_session.AgentController', SpyAgentController ), patch( 'openhands.server.session.agent_session.EventStream', return_value=mock_event_stream, ), patch( 'openhands.controller.state.state.State.restore_from_session', return_value=mock_restored_state, ), patch('openhands.server.session.agent_session.Memory', mock_memory), ): await session.start( runtime_name='test-runtime', config=OpenHandsConfig(), agent=mock_agent, max_iterations=10, ) # Verify set_initial_state was called once with the restored state assert session.controller.set_initial_state_call_count == 1 # Verify EventStream.subscribe was called with correct parameters mock_event_stream.subscribe.assert_called_with( EventStreamSubscriber.AGENT_CONTROLLER, session.controller.on_event, session.controller.id, ) assert session.controller.test_initial_state is mock_restored_state assert session.controller.state is mock_restored_state assert session.controller.state.max_iterations == 5 assert session.controller.state.start_id == 0 assert session.controller.state.end_id == -1