File size: 5,785 Bytes
246d201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from unittest.mock import Mock

import pytest
from litellm import ModelResponse

from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
from openhands.core.config import AgentConfig, LLMConfig
from openhands.events.action import MessageAction
from openhands.llm.llm import LLM


@pytest.fixture
def mock_llm():
    llm = LLM(
        LLMConfig(
            model='claude-3-5-sonnet-20241022',
            api_key='fake',
            caching_prompt=True,
        )
    )
    return llm


@pytest.fixture
def codeact_agent(mock_llm):
    config = AgentConfig()
    agent = CodeActAgent(mock_llm, config)
    return agent


def response_mock(content: str, tool_call_id: str):
    class MockModelResponse:
        def __init__(self, content, tool_call_id):
            self.choices = [
                {
                    'message': {
                        'content': content,
                        'tool_calls': [
                            {
                                'function': {
                                    'id': tool_call_id,
                                    'name': 'execute_bash',
                                    'arguments': '{}',
                                }
                            }
                        ],
                    }
                }
            ]

        def model_dump(self):
            return {'choices': self.choices}

    return ModelResponse(**MockModelResponse(content, tool_call_id).model_dump())


def test_get_messages(codeact_agent: CodeActAgent):
    # Add some events to history
    history = list()
    message_action_1 = MessageAction('Initial user message')
    message_action_1._source = 'user'
    history.append(message_action_1)
    message_action_2 = MessageAction('Sure!')
    message_action_2._source = 'assistant'
    history.append(message_action_2)
    message_action_3 = MessageAction('Hello, agent!')
    message_action_3._source = 'user'
    history.append(message_action_3)
    message_action_4 = MessageAction('Hello, user!')
    message_action_4._source = 'assistant'
    history.append(message_action_4)
    message_action_5 = MessageAction('Laaaaaaaast!')
    message_action_5._source = 'user'
    history.append(message_action_5)

    codeact_agent.reset()
    messages = codeact_agent._get_messages(
        Mock(history=history, max_iterations=5, iteration=0, extra_data={})
    )

    assert (
        len(messages) == 6
    )  # System, initial user + user message, agent message, last user message
    assert messages[0].content[0].cache_prompt  # system message
    assert messages[1].role == 'user'
    assert messages[1].content[0].text.endswith('Initial user message')
    # we add cache breakpoint to the last 3 user messages
    assert messages[1].content[0].cache_prompt

    assert messages[3].role == 'user'
    assert messages[3].content[0].text == ('Hello, agent!')
    assert messages[3].content[0].cache_prompt
    assert messages[4].role == 'assistant'
    assert messages[4].content[0].text == 'Hello, user!'
    assert not messages[4].content[0].cache_prompt
    assert messages[5].role == 'user'
    assert messages[5].content[0].text.startswith('Laaaaaaaast!')
    assert messages[5].content[0].cache_prompt


def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
    history = list()
    # Add multiple user and agent messages
    for i in range(15):
        message_action_user = MessageAction(f'User message {i}')
        message_action_user._source = 'user'
        history.append(message_action_user)
        message_action_agent = MessageAction(f'Agent message {i}')
        message_action_agent._source = 'assistant'
        history.append(message_action_agent)

    codeact_agent.reset()
    messages = codeact_agent._get_messages(
        Mock(history=history, max_iterations=10, iteration=5, extra_data={})
    )

    # Check that only the last two user messages have cache_prompt=True
    cached_user_messages = [
        msg
        for msg in messages
        if msg.role in ('user', 'system') and msg.content[0].cache_prompt
    ]
    assert (
        len(cached_user_messages) == 4
    )  # Including the initial system+user + 2 last user message

    # Verify that these are indeed the last two user messages (from start)
    assert cached_user_messages[0].content[0].text.startswith('You are OpenHands agent')
    assert cached_user_messages[2].content[0].text.startswith('User message 1')
    assert cached_user_messages[3].content[0].text.startswith('User message 1')


def test_prompt_caching_headers(codeact_agent: CodeActAgent):
    history = list()
    # Setup
    msg1 = MessageAction('Hello, agent!')
    msg1._source = 'user'
    history.append(msg1)
    msg2 = MessageAction('Hello, user!')
    msg2._source = 'agent'
    history.append(msg2)

    mock_state = Mock()
    mock_state.history = history
    mock_state.max_iterations = 5
    mock_state.iteration = 0
    mock_state.extra_data = {}

    codeact_agent.reset()

    # Create a mock for litellm_completion
    def check_headers(**kwargs):
        assert 'extra_headers' in kwargs
        assert 'anthropic-beta' in kwargs['extra_headers']
        assert kwargs['extra_headers']['anthropic-beta'] == 'prompt-caching-2024-07-31'
        return ModelResponse(
            choices=[{'message': {'content': 'Hello! How can I assist you today?'}}]
        )

    codeact_agent.llm._completion_unwrapped = check_headers
    result = codeact_agent.step(mock_state)

    # Assert
    assert isinstance(result, MessageAction)
    assert result.content == 'Hello! How can I assist you today?'