OpenHands / tests /unit /test_agent_session.py
Backup-bdg's picture
Upload 964 files
51ff9e5 verified
raw
history blame
8.11 kB
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openhands.controller.agent import Agent
from openhands.controller.agent_controller import AgentController
from openhands.controller.state.state import State
from openhands.core.config import LLMConfig, OpenHandsConfig
from openhands.core.config.agent_config import AgentConfig
from openhands.events import EventStream, EventStreamSubscriber
from openhands.llm import LLM
from openhands.llm.metrics import Metrics
from openhands.memory.memory import Memory
from openhands.runtime.impl.action_execution.action_execution_client import (
ActionExecutionClient,
)
from openhands.server.session.agent_session import AgentSession
from openhands.storage.memory import InMemoryFileStore
@pytest.fixture
def mock_agent():
"""Create a properly configured mock agent with all required nested attributes"""
# Create the base mocks
agent = MagicMock(spec=Agent)
llm = MagicMock(spec=LLM)
metrics = MagicMock(spec=Metrics)
llm_config = MagicMock(spec=LLMConfig)
agent_config = MagicMock(spec=AgentConfig)
# Configure the LLM config
llm_config.model = 'test-model'
llm_config.base_url = 'http://test'
llm_config.max_message_chars = 1000
# Configure the agent config
agent_config.disabled_microagents = []
agent_config.enable_mcp = True
# Set up the chain of mocks
llm.metrics = metrics
llm.config = llm_config
agent.llm = llm
agent.name = 'test-agent'
agent.sandbox_plugins = []
agent.config = agent_config
agent.prompt_manager = MagicMock()
return agent
@pytest.mark.asyncio
async def test_agent_session_start_with_no_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's no state to restore"""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with no events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 0
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a spy on set_initial_state
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None
def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# Create a real Memory instance with the mock event stream
memory = Memory(event_stream=mock_event_stream, sid='test-session')
memory.microagents_dir = 'test-dir'
# Patch AgentController and State.restore_from_session to fail; patch Memory in AgentSession
with (
patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
),
patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
),
patch(
'openhands.controller.state.state.State.restore_from_session',
side_effect=Exception('No state found'),
),
patch('openhands.server.session.agent_session.Memory', return_value=memory),
):
await session.start(
runtime_name='test-runtime',
config=OpenHandsConfig(),
agent=mock_agent,
max_iterations=10,
)
# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_any_call(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
mock_event_stream.subscribe.assert_any_call(
EventStreamSubscriber.MEMORY,
session.memory.on_event,
session.controller.id,
)
# Verify set_initial_state was called once with None as state
assert session.controller.set_initial_state_call_count == 1
assert session.controller.test_initial_state is None
assert session.controller.state.max_iterations == 10
assert session.controller.agent.name == 'test-agent'
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1
@pytest.mark.asyncio
async def test_agent_session_start_with_restored_state(mock_agent):
"""Test that AgentSession.start() works correctly when there's a state to restore"""
# Setup
file_store = InMemoryFileStore({})
session = AgentSession(
sid='test-session',
file_store=file_store,
)
# Create a mock runtime and set it up
mock_runtime = MagicMock(spec=ActionExecutionClient)
# Mock the runtime creation to set up the runtime attribute
async def mock_create_runtime(*args, **kwargs):
session.runtime = mock_runtime
return True
session._create_runtime = AsyncMock(side_effect=mock_create_runtime)
# Create a mock EventStream with some events
mock_event_stream = MagicMock(spec=EventStream)
mock_event_stream.get_events.return_value = []
mock_event_stream.subscribe = MagicMock()
mock_event_stream.get_latest_event_id.return_value = 5 # Indicate some events exist
# Inject the mock event stream into the session
session.event_stream = mock_event_stream
# Create a mock restored state
mock_restored_state = MagicMock(spec=State)
mock_restored_state.start_id = -1
mock_restored_state.end_id = -1
mock_restored_state.max_iterations = 5
# Create a spy on set_initial_state by subclassing AgentController
class SpyAgentController(AgentController):
set_initial_state_call_count = 0
test_initial_state = None
def set_initial_state(self, *args, state=None, **kwargs):
self.set_initial_state_call_count += 1
self.test_initial_state = state
super().set_initial_state(*args, state=state, **kwargs)
# create a mock Memory
mock_memory = MagicMock(spec=Memory)
# Patch AgentController and State.restore_from_session to succeed, patch Memory in AgentSession
with (
patch(
'openhands.server.session.agent_session.AgentController', SpyAgentController
),
patch(
'openhands.server.session.agent_session.EventStream',
return_value=mock_event_stream,
),
patch(
'openhands.controller.state.state.State.restore_from_session',
return_value=mock_restored_state,
),
patch('openhands.server.session.agent_session.Memory', mock_memory),
):
await session.start(
runtime_name='test-runtime',
config=OpenHandsConfig(),
agent=mock_agent,
max_iterations=10,
)
# Verify set_initial_state was called once with the restored state
assert session.controller.set_initial_state_call_count == 1
# Verify EventStream.subscribe was called with correct parameters
mock_event_stream.subscribe.assert_called_with(
EventStreamSubscriber.AGENT_CONTROLLER,
session.controller.on_event,
session.controller.id,
)
assert session.controller.test_initial_state is mock_restored_state
assert session.controller.state is mock_restored_state
assert session.controller.state.max_iterations == 5
assert session.controller.state.start_id == 0
assert session.controller.state.end_id == -1