File size: 5,478 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from base_llm_unit_tests import BaseLLMChatTest
import pytest
import sys
import os


sys.path.insert(
    0, os.path.abspath("../..")
)  # Adds the parent directory to the system path
import litellm
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest


class TestBedrockInvokeClaudeJson(BaseLLMChatTest):
    def get_base_completion_call_args(self) -> dict:
        litellm._turn_on_debug()
        return {
            "model": "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0",
        }

    def test_tool_call_no_arguments(self, tool_call_no_arguments):
        """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
        pass


class TestBedrockInvokeNovaJson(BaseLLMChatTest):
    def get_base_completion_call_args(self) -> dict:
        return {
            "model": "bedrock/invoke/us.amazon.nova-micro-v1:0",
        }

    def test_tool_call_no_arguments(self, tool_call_no_arguments):
        """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833"""
        pass
    
    @pytest.fixture(autouse=True)
    def skip_non_json_tests(self, request):
        if not "json" in request.function.__name__.lower():
            pytest.skip(
                f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'"
            )


def test_nova_invoke_remove_empty_system_messages():
    """Test that _remove_empty_system_messages removes empty system list."""
    input_request = BedrockInvokeNovaRequest(
        messages=[{"content": [{"text": "Hello"}], "role": "user"}],
        system=[],
        inferenceConfig={"temperature": 0.7},
    )

    litellm.AmazonInvokeNovaConfig()._remove_empty_system_messages(input_request)

    assert "system" not in input_request
    assert "messages" in input_request
    assert "inferenceConfig" in input_request


def test_nova_invoke_filter_allowed_fields():
    """
    Test that _filter_allowed_fields only keeps fields defined in BedrockInvokeNovaRequest.

    Nova Invoke does not allow `additionalModelRequestFields` and `additionalModelResponseFieldPaths` in the request body.
    This test ensures that these fields are not included in the request body.
    """
    _input_request = {
        "messages": [{"content": [{"text": "Hello"}], "role": "user"}],
        "system": [{"text": "System prompt"}],
        "inferenceConfig": {"temperature": 0.7},
        "additionalModelRequestFields": {"this": "should be removed"},
        "additionalModelResponseFieldPaths": ["this", "should", "be", "removed"],
    }

    input_request = BedrockInvokeNovaRequest(**_input_request)

    result = litellm.AmazonInvokeNovaConfig()._filter_allowed_fields(input_request)

    assert "additionalModelRequestFields" not in result
    assert "additionalModelResponseFieldPaths" not in result
    assert "messages" in result
    assert "system" in result
    assert "inferenceConfig" in result


def test_nova_invoke_streaming_chunk_parsing():
    """
    Test that the AWSEventStreamDecoder correctly handles Nova's /bedrock/invoke/ streaming format
    where content is nested under 'contentBlockDelta'.
    """
    from litellm.llms.bedrock.chat.invoke_handler import AWSEventStreamDecoder

    # Initialize the decoder with a Nova model
    decoder = AWSEventStreamDecoder(model="bedrock/invoke/us.amazon.nova-micro-v1:0")

    # Test case 1: Text content in contentBlockDelta
    nova_text_chunk = {
        "contentBlockDelta": {
            "delta": {"text": "Hello, how can I help?"},
            "contentBlockIndex": 0,
        }
    }
    result = decoder._chunk_parser(nova_text_chunk)
    assert result.choices[0].delta.content == "Hello, how can I help?"
    assert result.choices[0].index == 0
    assert not result.choices[0].finish_reason
    assert result.choices[0].delta.tool_calls is None

    # Test case 2: Tool use start in contentBlockDelta
    nova_tool_start_chunk = {
        "contentBlockDelta": {
            "start": {"toolUse": {"name": "get_weather", "toolUseId": "tool_1"}},
            "contentBlockIndex": 1,
        }
    }
    result = decoder._chunk_parser(nova_tool_start_chunk)
    assert result.choices[0].delta.content == ""
    assert result.choices[0].index == 1
    assert result.choices[0].delta.tool_calls is not None
    assert result.choices[0].delta.tool_calls[0].type == "function"
    assert result.choices[0].delta.tool_calls[0].function.name == "get_weather"
    assert result.choices[0].delta.tool_calls[0].id == "tool_1"

    # Test case 3: Tool use arguments in contentBlockDelta
    nova_tool_args_chunk = {
        "contentBlockDelta": {
            "delta": {"toolUse": {"input": '{"location": "New York"}'}},
            "contentBlockIndex": 2,
        }
    }
    result = decoder._chunk_parser(nova_tool_args_chunk)
    assert result.choices[0].delta.content == ""
    assert result.choices[0].index == 2
    assert result.choices[0].delta.tool_calls is not None
    assert (
        result.choices[0].delta.tool_calls[0].function.arguments
        == '{"location": "New York"}'
    )

    # Test case 4: Stop reason in contentBlockDelta
    nova_stop_chunk = {
        "contentBlockDelta": {
            "stopReason": "tool_use",
        }
    }
    result = decoder._chunk_parser(nova_stop_chunk)
    print(result)
    assert result.choices[0].finish_reason == "tool_calls"