Spaces:
Build error
Build error
import asyncio | |
from unittest.mock import MagicMock, call, patch | |
import pytest | |
from prompt_toolkit.formatted_text import HTML | |
from prompt_toolkit.keys import Keys | |
from openhands.cli.tui import process_agent_pause | |
from openhands.core.schema import AgentState | |
from openhands.events import EventSource | |
from openhands.events.action import ChangeAgentStateAction | |
from openhands.events.observation import AgentStateChangedObservation | |
class TestProcessAgentPause: | |
async def test_process_agent_pause_ctrl_p(self, mock_print, mock_create_input): | |
"""Test that process_agent_pause sets the done event when Ctrl+P is pressed.""" | |
# Create the done event | |
done = asyncio.Event() | |
# Set up the mock input | |
mock_input = MagicMock() | |
mock_create_input.return_value = mock_input | |
# Mock the context managers | |
mock_raw_mode = MagicMock() | |
mock_input.raw_mode.return_value = mock_raw_mode | |
mock_raw_mode.__enter__ = MagicMock() | |
mock_raw_mode.__exit__ = MagicMock() | |
mock_attach = MagicMock() | |
mock_input.attach.return_value = mock_attach | |
mock_attach.__enter__ = MagicMock() | |
mock_attach.__exit__ = MagicMock() | |
# Capture the keys_ready function | |
keys_ready_func = None | |
def fake_attach(callback): | |
nonlocal keys_ready_func | |
keys_ready_func = callback | |
return mock_attach | |
mock_input.attach.side_effect = fake_attach | |
# Create a task to run process_agent_pause | |
task = asyncio.create_task(process_agent_pause(done, event_stream=MagicMock())) | |
# Give it a moment to start and capture the callback | |
await asyncio.sleep(0.1) | |
# Make sure we captured the callback | |
assert keys_ready_func is not None | |
# Create a key press that simulates Ctrl+P | |
key_press = MagicMock() | |
key_press.key = Keys.ControlP | |
mock_input.read_keys.return_value = [key_press] | |
# Manually call the callback to simulate key press | |
keys_ready_func() | |
# Verify done was set | |
assert done.is_set() | |
# Verify print was called with the pause message | |
assert mock_print.call_count == 2 | |
assert mock_print.call_args_list[0] == call('') | |
# Check that the second call contains the pause message HTML | |
second_call = mock_print.call_args_list[1][0][0] | |
assert isinstance(second_call, HTML) | |
assert 'Pausing the agent' in str(second_call) | |
# Cancel the task | |
task.cancel() | |
try: | |
await task | |
except asyncio.CancelledError: | |
pass | |
class TestCliPauseResumeInRunSession: | |
async def test_on_event_async_pause_processing(self): | |
"""Test that on_event_async processes the pause event when is_paused is set.""" | |
# Create a mock event | |
event = MagicMock() | |
# Create mock dependencies | |
event_stream = MagicMock() | |
is_paused = asyncio.Event() | |
reload_microagents = False | |
config = MagicMock() | |
# Patch the display_event function | |
with ( | |
patch('openhands.cli.main.display_event') as mock_display_event, | |
patch('openhands.cli.main.update_usage_metrics') as mock_update_metrics, | |
): | |
# Create a closure to capture the current context | |
async def test_func(): | |
# Set the pause event | |
is_paused.set() | |
# Create a context similar to run_session to call on_event_async | |
# We're creating a function that mimics the environment of on_event_async | |
async def on_event_async_test(event): | |
nonlocal reload_microagents, is_paused | |
mock_display_event(event, config) | |
mock_update_metrics(event, usage_metrics=MagicMock()) | |
# Pause the agent if the pause event is set (through Ctrl-P) | |
if is_paused.is_set(): | |
event_stream.add_event( | |
ChangeAgentStateAction(AgentState.PAUSED), | |
EventSource.USER, | |
) | |
# The pause event is not cleared here because we want to simulate | |
# the PAUSED event processing in a future event | |
# Call on_event_async_test | |
await on_event_async_test(event) | |
# Check that event_stream.add_event was called with the correct action | |
event_stream.add_event.assert_called_once() | |
args, kwargs = event_stream.add_event.call_args | |
action, source = args | |
assert isinstance(action, ChangeAgentStateAction) | |
assert action.agent_state == AgentState.PAUSED | |
assert source == EventSource.USER | |
# Check that is_paused is still set (will be cleared when PAUSED state is processed) | |
assert is_paused.is_set() | |
# Run the test function | |
await test_func() | |
async def test_awaiting_user_input_paused_skip(self): | |
"""Test that when is_paused is set, awaiting user input events do not trigger prompting.""" | |
# Create a mock event with AgentStateChangedObservation | |
event = MagicMock() | |
event.observation = AgentStateChangedObservation( | |
agent_state=AgentState.AWAITING_USER_INPUT, content='Agent awaiting input' | |
) | |
# Create mock dependencies | |
is_paused = asyncio.Event() | |
reload_microagents = False | |
# Mock function that would be called if code reaches that point | |
mock_prompt_task = MagicMock() | |
# Create a closure to capture the current context | |
async def test_func(): | |
# Set the pause event | |
is_paused.set() | |
# Create a context similar to run_session to call on_event_async | |
async def on_event_async_test(event): | |
nonlocal reload_microagents, is_paused | |
if isinstance(event.observation, AgentStateChangedObservation): | |
if event.observation.agent_state in [ | |
AgentState.AWAITING_USER_INPUT, | |
AgentState.FINISHED, | |
]: | |
# If the agent is paused, do not prompt for input | |
if is_paused.is_set(): | |
return | |
# This code should not be reached if is_paused is set | |
mock_prompt_task() | |
# Call on_event_async_test | |
await on_event_async_test(event) | |
# Verify that mock_prompt_task was not called | |
mock_prompt_task.assert_not_called() | |
# Run the test | |
await test_func() | |
async def test_awaiting_confirmation_paused_skip(self): | |
"""Test that when is_paused is set, awaiting confirmation events do not trigger prompting.""" | |
# Create a mock event with AgentStateChangedObservation | |
event = MagicMock() | |
event.observation = AgentStateChangedObservation( | |
agent_state=AgentState.AWAITING_USER_CONFIRMATION, | |
content='Agent awaiting confirmation', | |
) | |
# Create mock dependencies | |
is_paused = asyncio.Event() | |
# Mock function that would be called if code reaches that point | |
mock_confirmation = MagicMock() | |
# Create a closure to capture the current context | |
async def test_func(): | |
# Set the pause event | |
is_paused.set() | |
# Create a context similar to run_session to call on_event_async | |
async def on_event_async_test(event): | |
nonlocal is_paused | |
if isinstance(event.observation, AgentStateChangedObservation): | |
if ( | |
event.observation.agent_state | |
== AgentState.AWAITING_USER_CONFIRMATION | |
): | |
if is_paused.is_set(): | |
return | |
# This code should not be reached if is_paused is set | |
mock_confirmation() | |
# Call on_event_async_test | |
await on_event_async_test(event) | |
# Verify that confirmation function was not called | |
mock_confirmation.assert_not_called() | |
# Run the test | |
await test_func() | |
class TestCliCommandsPauseResume: | |
async def test_handle_commands_resume(self, mock_handle_resume): | |
"""Test that the handle_commands function properly calls handle_resume_command.""" | |
# Import here to avoid circular imports in test | |
from openhands.cli.commands import handle_commands | |
# Create mocks | |
message = '/resume' | |
event_stream = MagicMock() | |
usage_metrics = MagicMock() | |
sid = 'test-session-id' | |
config = MagicMock() | |
current_dir = '/test/dir' | |
settings_store = MagicMock() | |
# Mock return value | |
mock_handle_resume.return_value = (False, False) | |
# Call handle_commands | |
close_repl, reload_microagents, new_session_requested = await handle_commands( | |
message, | |
event_stream, | |
usage_metrics, | |
sid, | |
config, | |
current_dir, | |
settings_store, | |
) | |
# Check that handle_resume_command was called with correct args | |
mock_handle_resume.assert_called_once_with(event_stream) | |
# Check the return values | |
assert close_repl is False | |
assert reload_microagents is False | |
assert new_session_requested is False | |
class TestAgentStatePauseResume: | |
async def test_agent_running_enables_pause( | |
self, mock_process_agent_pause, mock_display_message | |
): | |
"""Test that when the agent is running, pause functionality is enabled.""" | |
# Create a mock event and event stream | |
event = MagicMock() | |
event.observation = AgentStateChangedObservation( | |
agent_state=AgentState.RUNNING, content='Agent is running' | |
) | |
event_stream = MagicMock() | |
# Create mock dependencies | |
is_paused = asyncio.Event() | |
loop = MagicMock() | |
reload_microagents = False | |
# Create a closure to capture the current context | |
async def test_func(): | |
# Create a context similar to run_session to call on_event_async | |
async def on_event_async_test(event): | |
nonlocal reload_microagents | |
if isinstance(event.observation, AgentStateChangedObservation): | |
if event.observation.agent_state == AgentState.RUNNING: | |
mock_display_message() | |
loop.create_task( | |
mock_process_agent_pause(is_paused, event_stream) | |
) | |
# Call on_event_async_test | |
await on_event_async_test(event) | |
# Check that display_agent_running_message was called | |
mock_display_message.assert_called_once() | |
# Check that loop.create_task was called | |
loop.create_task.assert_called_once() | |
# Run the test function | |
await test_func() | |
async def test_pause_event_changes_agent_state( | |
self, mock_update_metrics, mock_display_event | |
): | |
"""Test that when is_paused is set, a PAUSED state change event is added to the stream.""" | |
# Create mock dependencies | |
event = MagicMock() | |
event_stream = MagicMock() | |
is_paused = asyncio.Event() | |
config = MagicMock() | |
reload_microagents = False | |
# Set the pause event | |
is_paused.set() | |
# Create a closure to capture the current context | |
async def test_func(): | |
# Create a context similar to run_session to call on_event_async | |
async def on_event_async_test(event): | |
nonlocal reload_microagents | |
mock_display_event(event, config) | |
mock_update_metrics(event, MagicMock()) | |
# Pause the agent if the pause event is set (through Ctrl-P) | |
if is_paused.is_set(): | |
event_stream.add_event( | |
ChangeAgentStateAction(AgentState.PAUSED), | |
EventSource.USER, | |
) | |
is_paused.clear() | |
# Call the function | |
await on_event_async_test(event) | |
# Check that the event_stream.add_event was called with the correct action | |
event_stream.add_event.assert_called_once() | |
args, kwargs = event_stream.add_event.call_args | |
action, source = args | |
assert isinstance(action, ChangeAgentStateAction) | |
assert action.agent_state == AgentState.PAUSED | |
assert source == EventSource.USER | |
# Check that is_paused was cleared | |
assert not is_paused.is_set() | |
# Run the test | |
await test_func() | |
async def test_paused_agent_awaits_input(self): | |
"""Test that when the agent is paused, it awaits user input.""" | |
# Create mock dependencies | |
event = MagicMock() | |
# AgentStateChangedObservation requires a content parameter | |
event.observation = AgentStateChangedObservation( | |
agent_state=AgentState.PAUSED, content='Agent state changed to PAUSED' | |
) | |
is_paused = asyncio.Event() | |
# Mock function that would be called for prompting | |
mock_prompt_task = MagicMock() | |
# Create a closure to capture the current context | |
async def test_func(): | |
# Create a simplified version of on_event_async | |
async def on_event_async_test(event): | |
nonlocal is_paused | |
if isinstance(event.observation, AgentStateChangedObservation): | |
if event.observation.agent_state == AgentState.PAUSED: | |
is_paused.clear() # Revert the event state before prompting for user input | |
mock_prompt_task(event.observation.agent_state) | |
# Set is_paused to test that it gets cleared | |
is_paused.set() | |
# Call the function | |
await on_event_async_test(event) | |
# Check that is_paused was cleared | |
assert not is_paused.is_set() | |
# Check that prompt task was called with the correct state | |
mock_prompt_task.assert_called_once_with(AgentState.PAUSED) | |
# Run the test | |
await test_func() | |