aiben / openai_server /test_prompt_caching.py
abugaber's picture
Upload folder using huggingface_hub
3943768 verified
raw
history blame
4.11 kB
import sys
import pytest
from typing import List, Dict
if 'src' not in sys.path:
sys.path.append('src')
from src.gpt_langchain import H2OChatAnthropic3
# Assume the process_messages function is imported from the module where it's defined
process_messages = H2OChatAnthropic3.process_messages
def assert_cache_control_count(messages: List[Dict], expected_count: int):
actual_count = sum(
1 for msg in messages if msg["role"] == "user"
for item in (msg["content"] if isinstance(msg["content"], list) else [msg["content"]])
if isinstance(item, dict) and "cache_control" in item
)
assert actual_count == expected_count, f"Expected {expected_count} cache_control entries, but found {actual_count}"
def test_simple_string_messages():
messages = [
{"role": "user", "content": "Message 1"},
{"role": "assistant", "content": "Response 1"},
{"role": "user", "content": "Message 2"},
{"role": "user", "content": "Message 3"},
{"role": "user", "content": "Message 4"},
{"role": "user", "content": "Message 5"},
]
result = process_messages(messages)
assert len(result) == 6
assert_cache_control_count(result, 3)
assert all("cache_control" in msg["content"][0] for msg in result[-3:] if msg["role"] == "user")
assert "cache_control" not in result[0]["content"][0]
def test_mixed_content_types():
messages = [
{"role": "user", "content": "Text message"},
{"role": "assistant", "content": "Response"},
{"role": "user",
"content": [{"type": "text", "text": "List item 1"}, {"type": "image", "image_url": "example.com/image.jpg"}]},
{"role": "user", "content": "Another text message"},
]
result = process_messages(messages)
assert len(result) == 4
assert_cache_control_count(result, 3)
assert "cache_control" in result[-1]["content"][0]
assert all("cache_control" in item for item in result[-2]["content"])
assert "cache_control" not in result[0]["content"][0]
def test_max_cache_control_limit():
messages = [
{"role": "user", "content": [{"type": "text", "text": "Item 1"}, {"type": "text", "text": "Item 2"}]},
{"role": "user", "content": [{"type": "text", "text": "Item 3"}, {"type": "text", "text": "Item 4"}]},
{"role": "user", "content": "Text message"},
]
result = process_messages(messages)
assert_cache_control_count(result, 3)
assert "cache_control" in result[-1]["content"][0]
assert "cache_control" in result[-2]["content"][1]
assert "cache_control" in result[-2]["content"][0]
assert "cache_control" not in result[0]["content"][0]
def test_empty_list_content():
messages = [
{"role": "user", "content": []},
{"role": "user", "content": "Text message"},
]
result = process_messages(messages)
assert len(result) == 2
assert result[0]["content"] == []
assert "cache_control" in result[1]["content"][0]
def test_preserve_message_order():
messages = [
{"role": "user", "content": "First"},
{"role": "assistant", "content": "Response 1"},
{"role": "user", "content": "Second"},
{"role": "assistant", "content": "Response 2"},
{"role": "user", "content": "Third"},
{"role": "user", "content": "Fourth"},
]
result = process_messages(messages)
user_messages = [msg["content"] for msg in result if msg["role"] == "user"]
assert user_messages == [
[{"type": "text", "text": "First"}],
[{"type": "text", "text": "Second", "cache_control": {"type": "ephemeral"}}],
[{"type": "text", "text": "Third", "cache_control": {"type": "ephemeral"}}],
[{"type": "text", "text": "Fourth", "cache_control": {"type": "ephemeral"}}],
]
assert len(result) == 6 # Ensure all messages are preserved
assert [msg["role"] for msg in result] == ["user", "assistant", "user", "assistant", "user",
"user"] # Ensure order is preserved
if __name__ == "__main__":
pytest.main([__file__])