File size: 5,262 Bytes
447ebeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import datetime
import json
import os
import sys
import unittest
from typing import List, Optional, Tuple
from unittest.mock import ANY, MagicMock, Mock, patch

import httpx
import pytest

sys.path.insert(
    0, os.path.abspath("../..")
)  # Adds the parent directory to the system-path
import litellm
from litellm.integrations.custom_prompt_management import CustomPromptManagement
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams


@pytest.fixture(autouse=True)
def setup_anthropic_api_key(monkeypatch):
    monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-some-key")


class TestCustomPromptManagement(CustomPromptManagement):
    def get_chat_completion_prompt(
        self,
        model: str,
        messages: List[AllMessageValues],
        non_default_params: dict,
        prompt_id: Optional[str],
        prompt_variables: Optional[dict],
        dynamic_callback_params: StandardCallbackDynamicParams,
        prompt_label: Optional[str],
    ) -> Tuple[str, List[AllMessageValues], dict]:
        print(
            "TestCustomPromptManagement: running get_chat_completion_prompt for prompt_id: ",
            prompt_id,
        )
        if prompt_id == "test_prompt_id":
            messages = [
                {"role": "user", "content": "This is the prompt for test_prompt_id"},
            ]
            return model, messages, non_default_params
        elif prompt_id == "prompt_with_variables":
            content = "Hello, {name}! You are {age} years old and live in {city}."
            content_with_variables = content.format(**(prompt_variables or {}))
            messages = [
                {"role": "user", "content": content_with_variables},
            ]
            return model, messages, non_default_params
        else:
            return model, messages, non_default_params


@pytest.mark.asyncio
async def test_custom_prompt_management_with_prompt_id(monkeypatch):
    custom_prompt_management = TestCustomPromptManagement()
    litellm.callbacks = [custom_prompt_management]

    # Mock AsyncHTTPHandler.post method
    client = AsyncHTTPHandler()
    with patch.object(client, "post", return_value=MagicMock()) as mock_post:
        await litellm.acompletion(
            model="anthropic/claude-3-5-sonnet",
            messages=[{"role": "user", "content": "Hello, how are you?"}],
            client=client,
            prompt_id="test_prompt_id",
        )

        mock_post.assert_called_once()
        print(mock_post.call_args.kwargs)
        request_body = mock_post.call_args.kwargs["json"]
        print("request_body: ", json.dumps(request_body, indent=4))

        assert request_body["model"] == "claude-3-5-sonnet"
        # the message gets applied to the prompt from the custom prompt management callback
        assert (
            request_body["messages"][0]["content"][0]["text"]
            == "This is the prompt for test_prompt_id"
        )


@pytest.mark.asyncio
async def test_custom_prompt_management_with_prompt_id_and_prompt_variables():
    custom_prompt_management = TestCustomPromptManagement()
    litellm.callbacks = [custom_prompt_management]

    # Mock AsyncHTTPHandler.post method
    client = AsyncHTTPHandler()
    with patch.object(client, "post", return_value=MagicMock()) as mock_post:
        await litellm.acompletion(
            model="anthropic/claude-3-5-sonnet",
            messages=[],
            client=client,
            prompt_id="prompt_with_variables",
            prompt_variables={"name": "John", "age": 30, "city": "New York"},
        )

        mock_post.assert_called_once()
        print(mock_post.call_args.kwargs)
        request_body = mock_post.call_args.kwargs["json"]
        print("request_body: ", json.dumps(request_body, indent=4))

        assert request_body["model"] == "claude-3-5-sonnet"
        # the message gets applied to the prompt from the custom prompt management callback
        assert (
            request_body["messages"][0]["content"][0]["text"]
            == "Hello, John! You are 30 years old and live in New York."
        )


@pytest.mark.asyncio
async def test_custom_prompt_management_without_prompt_id():
    custom_prompt_management = TestCustomPromptManagement()
    litellm.callbacks = [custom_prompt_management]

    # Mock AsyncHTTPHandler.post method
    client = AsyncHTTPHandler()
    with patch.object(client, "post", return_value=MagicMock()) as mock_post:
        await litellm.acompletion(
            model="anthropic/claude-3-5-sonnet",
            messages=[{"role": "user", "content": "Hello, how are you?"}],
            client=client,
        )

        mock_post.assert_called_once()
        print(mock_post.call_args.kwargs)
        request_body = mock_post.call_args.kwargs["json"]
        print("request_body: ", json.dumps(request_body, indent=4))

        assert request_body["model"] == "claude-3-5-sonnet"
        # the message does not get applied to the prompt from the custom prompt management callback since we did not pass a prompt_id
        assert (
            request_body["messages"][0]["content"][0]["text"] == "Hello, how are you?"
        )