DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
"""
Test for anthropic_endpoints/endpoints.py, focusing on handling dictionary objects in streaming responses
"""
import json
import unittest
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from litellm.proxy.anthropic_endpoints.endpoints import async_data_generator_anthropic
class TestAnthropicEndpoints(unittest.TestCase):
@patch("litellm.litellm_core_utils.safe_json_dumps.safe_dumps")
@pytest.mark.asyncio
async def test_async_data_generator_anthropic_dict_handling(self, mock_safe_dumps):
"""Test async_data_generator_anthropic handles dictionary chunks properly"""
# Setup
mock_response = AsyncMock()
mock_response.__aiter__.return_value = [
{"type": "message_start", "message": {"id": "msg_123"}},
"text chunk data",
{"type": "content_block_delta", "delta": {"text": "more data"}},
"text chunk data again",
]
mock_user_api_key_dict = MagicMock()
mock_request_data = {}
mock_proxy_logging_obj = MagicMock()
mock_proxy_logging_obj.async_post_call_streaming_hook = AsyncMock(
side_effect=lambda **kwargs: kwargs["response"]
)
# Configure safe_dumps to return a properly formatted JSON string
mock_safe_dumps.side_effect = lambda chunk: json.dumps(chunk)
# Execute
result = [
chunk
async for chunk in async_data_generator_anthropic(
response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
request_data=mock_request_data,
proxy_logging_obj=mock_proxy_logging_obj,
)
]
# Verify
expected_result = [
'data: {"type": "message_start", "message": {"id": "msg_123"}}\n\n',
"text chunk data",
'data: {"type": "content_block_delta", "delta": {"text": "more data"}}\n\n',
"text chunk data again",
]
self.assertEqual(result, expected_result)
# Assert safe_dumps was called for dictionary objects
mock_safe_dumps.assert_any_call(
{"type": "message_start", "message": {"id": "msg_123"}}
)
mock_safe_dumps.assert_any_call(
{"type": "content_block_delta", "delta": {"text": "more data"}}
)
assert (
mock_safe_dumps.call_count == 2
) # Called twice, once for each dict object
if __name__ == "__main__":
unittest.main()