File size: 12,576 Bytes
e3278e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
from typing import TYPE_CHECKING, Any, List, Optional, Union

from httpx._models import Headers, Response

import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
    convert_to_azure_openai_messages,
)
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.utils import ModelResponse
from litellm.utils import supports_response_schema

from ....exceptions import UnsupportedParamsError
from ....types.llms.openai import (
    AllMessageValues,
    ChatCompletionToolChoiceFunctionParam,
    ChatCompletionToolChoiceObjectParam,
    ChatCompletionToolParam,
    ChatCompletionToolParamFunctionChunk,
)
from ...base_llm.chat.transformation import BaseConfig
from ..common_utils import AzureOpenAIError

if TYPE_CHECKING:
    from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj

    LoggingClass = LiteLLMLoggingObj
else:
    LoggingClass = Any


class AzureOpenAIConfig(BaseConfig):
    """
    Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions

    The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. Below are the parameters::

    - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.

    - `function_call` (string or object): This optional parameter controls how the model calls functions.

    - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.

    - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.

    - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.

    - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.

    - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.

    - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.

    - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.

    - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
    """

    def __init__(
        self,
        frequency_penalty: Optional[int] = None,
        function_call: Optional[Union[str, dict]] = None,
        functions: Optional[list] = None,
        logit_bias: Optional[dict] = None,
        max_tokens: Optional[int] = None,
        n: Optional[int] = None,
        presence_penalty: Optional[int] = None,
        stop: Optional[Union[str, list]] = None,
        temperature: Optional[int] = None,
        top_p: Optional[int] = None,
    ) -> None:
        locals_ = locals().copy()
        for key, value in locals_.items():
            if key != "self" and value is not None:
                setattr(self.__class__, key, value)

    @classmethod
    def get_config(cls):
        return super().get_config()

    def get_supported_openai_params(self, model: str) -> List[str]:
        return [
            "temperature",
            "n",
            "stream",
            "stream_options",
            "stop",
            "max_tokens",
            "max_completion_tokens",
            "tools",
            "tool_choice",
            "presence_penalty",
            "frequency_penalty",
            "logit_bias",
            "user",
            "function_call",
            "functions",
            "tools",
            "tool_choice",
            "top_p",
            "logprobs",
            "top_logprobs",
            "response_format",
            "seed",
            "extra_headers",
            "parallel_tool_calls",
        ]

    def _is_response_format_supported_model(self, model: str) -> bool:
        """
        - all 4o models are supported
        - check if 'supports_response_format' is True from get_model_info
        - [TODO] support smart retries for 3.5 models (some supported, some not)
        """
        if "4o" in model:
            return True
        elif supports_response_schema(model):
            return True

        return False

    def map_openai_params(
        self,
        non_default_params: dict,
        optional_params: dict,
        model: str,
        drop_params: bool,
        api_version: str = "",
    ) -> dict:
        supported_openai_params = self.get_supported_openai_params(model)

        api_version_times = api_version.split("-")
        api_version_year = api_version_times[0]
        api_version_month = api_version_times[1]
        api_version_day = api_version_times[2]
        for param, value in non_default_params.items():
            if param == "tool_choice":
                """
                This parameter requires API version 2023-12-01-preview or later

                tool_choice='required' is not supported as of 2024-05-01-preview
                """
                ## check if api version supports this param ##
                if (
                    api_version_year < "2023"
                    or (api_version_year == "2023" and api_version_month < "12")
                    or (
                        api_version_year == "2023"
                        and api_version_month == "12"
                        and api_version_day < "01"
                    )
                ):
                    if litellm.drop_params is True or (
                        drop_params is not None and drop_params is True
                    ):
                        pass
                    else:
                        raise UnsupportedParamsError(
                            status_code=400,
                            message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
                        )
                elif value == "required" and (
                    api_version_year == "2024" and api_version_month <= "05"
                ):  ## check if tool_choice value is supported ##
                    if litellm.drop_params is True or (
                        drop_params is not None and drop_params is True
                    ):
                        pass
                    else:
                        raise UnsupportedParamsError(
                            status_code=400,
                            message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
                        )
                else:
                    optional_params["tool_choice"] = value
            elif param == "response_format" and isinstance(value, dict):
                json_schema: Optional[dict] = None
                schema_name: str = ""
                if "response_schema" in value:
                    json_schema = value["response_schema"]
                    schema_name = "json_tool_call"
                elif "json_schema" in value:
                    json_schema = value["json_schema"]["schema"]
                    schema_name = value["json_schema"]["name"]
                """
                Follow similar approach to anthropic - translate to a single tool call. 

                When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
                - You usually want to provide a single tool
                - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
                - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
                """
                _is_response_format_supported_model = (
                    self._is_response_format_supported_model(model)
                )
                if json_schema is not None and (
                    (api_version_year <= "2024" and api_version_month < "08")
                    or not _is_response_format_supported_model
                ):  # azure api version "2024-08-01-preview" onwards supports 'json_schema' only for gpt-4o/3.5 models

                    _tool_choice = ChatCompletionToolChoiceObjectParam(
                        type="function",
                        function=ChatCompletionToolChoiceFunctionParam(
                            name=schema_name
                        ),
                    )

                    _tool = ChatCompletionToolParam(
                        type="function",
                        function=ChatCompletionToolParamFunctionChunk(
                            name=schema_name, parameters=json_schema
                        ),
                    )

                    optional_params["tools"] = [_tool]
                    optional_params["tool_choice"] = _tool_choice
                    optional_params["json_mode"] = True
                else:
                    optional_params["response_format"] = value
            elif param in supported_openai_params:
                optional_params[param] = value

        return optional_params

    def transform_request(
        self,
        model: str,
        messages: List[AllMessageValues],
        optional_params: dict,
        litellm_params: dict,
        headers: dict,
    ) -> dict:
        messages = convert_to_azure_openai_messages(messages)
        return {
            "model": model,
            "messages": messages,
            **optional_params,
        }

    def transform_response(
        self,
        model: str,
        raw_response: Response,
        model_response: ModelResponse,
        logging_obj: LoggingClass,
        request_data: dict,
        messages: List[AllMessageValues],
        optional_params: dict,
        litellm_params: dict,
        encoding: Any,
        api_key: Optional[str] = None,
        json_mode: Optional[bool] = None,
    ) -> ModelResponse:
        raise NotImplementedError(
            "Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK."
        )

    def get_mapped_special_auth_params(self) -> dict:
        return {"token": "azure_ad_token"}

    def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
        for param, value in non_default_params.items():
            if param == "token":
                optional_params["azure_ad_token"] = value
        return optional_params

    def get_eu_regions(self) -> List[str]:
        """
        Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
        """
        return ["europe", "sweden", "switzerland", "france", "uk"]

    def get_us_regions(self) -> List[str]:
        """
        Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
        """
        return [
            "us",
            "eastus",
            "eastus2",
            "eastus2euap",
            "eastus3",
            "southcentralus",
            "westus",
            "westus2",
            "westus3",
            "westus4",
        ]

    def get_error_class(
        self, error_message: str, status_code: int, headers: Union[dict, Headers]
    ) -> BaseLLMException:
        return AzureOpenAIError(
            message=error_message, status_code=status_code, headers=headers
        )

    def validate_environment(
        self,
        headers: dict,
        model: str,
        messages: List[AllMessageValues],
        optional_params: dict,
        api_key: Optional[str] = None,
        api_base: Optional[str] = None,
    ) -> dict:
        raise NotImplementedError(
            "Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."
        )