|
from typing import Any, List, Optional, Tuple, cast |
|
|
|
import httpx |
|
from httpx import Response |
|
|
|
import litellm |
|
from litellm._logging import verbose_logger |
|
from litellm.litellm_core_utils.prompt_templates.common_utils import ( |
|
_audio_or_image_in_message_content, |
|
convert_content_list_to_str, |
|
) |
|
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj |
|
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error |
|
from litellm.llms.openai.openai import OpenAIConfig |
|
from litellm.secret_managers.main import get_secret_str |
|
from litellm.types.llms.openai import AllMessageValues |
|
from litellm.types.utils import ModelResponse, ProviderField |
|
from litellm.utils import _add_path_to_api_base |
|
|
|
|
|
class AzureAIStudioConfig(OpenAIConfig): |
|
def validate_environment( |
|
self, |
|
headers: dict, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
api_key: Optional[str] = None, |
|
api_base: Optional[str] = None, |
|
) -> dict: |
|
if api_base and "services.ai.azure.com" in api_base: |
|
headers["api-key"] = api_key |
|
else: |
|
headers["Authorization"] = f"Bearer {api_key}" |
|
|
|
return headers |
|
|
|
def get_complete_url( |
|
self, |
|
api_base: str, |
|
model: str, |
|
optional_params: dict, |
|
stream: Optional[bool] = None, |
|
) -> str: |
|
""" |
|
Constructs a complete URL for the API request. |
|
|
|
Args: |
|
- api_base: Base URL, e.g., |
|
"https://litellm8397336933.services.ai.azure.com" |
|
OR |
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" |
|
- model: Model name. |
|
- optional_params: Additional query parameters, including "api_version". |
|
- stream: If streaming is required (optional). |
|
|
|
Returns: |
|
- A complete URL string, e.g., |
|
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" |
|
""" |
|
original_url = httpx.URL(api_base) |
|
|
|
|
|
api_version = cast(Optional[str], optional_params.get("api_version")) |
|
|
|
|
|
if "api-version" not in original_url.params and api_version: |
|
|
|
original_url.params["api-version"] = api_version |
|
|
|
|
|
if "services.ai.azure.com" in api_base: |
|
new_url = _add_path_to_api_base( |
|
api_base=api_base, ending_path="/models/chat/completions" |
|
) |
|
else: |
|
new_url = _add_path_to_api_base( |
|
api_base=api_base, ending_path="/chat/completions" |
|
) |
|
|
|
|
|
query_params = original_url.params |
|
final_url = httpx.URL(new_url).copy_with(params=query_params) |
|
|
|
return str(final_url) |
|
|
|
def get_required_params(self) -> List[ProviderField]: |
|
"""For a given provider, return it's required fields with a description""" |
|
return [ |
|
ProviderField( |
|
field_name="api_key", |
|
field_type="string", |
|
field_description="Your Azure AI Studio API Key.", |
|
field_value="zEJ...", |
|
), |
|
ProviderField( |
|
field_name="api_base", |
|
field_type="string", |
|
field_description="Your Azure AI Studio API Base.", |
|
field_value="https://Mistral-serverless.", |
|
), |
|
] |
|
|
|
def _transform_messages( |
|
self, |
|
messages: List[AllMessageValues], |
|
model: str, |
|
) -> List: |
|
""" |
|
- Azure AI Studio doesn't support content as a list. This handles: |
|
1. Transforms list content to a string. |
|
2. If message contains an image or audio, send as is (user-intended) |
|
""" |
|
for message in messages: |
|
|
|
|
|
if _audio_or_image_in_message_content(message): |
|
continue |
|
|
|
texts = convert_content_list_to_str(message=message) |
|
if texts: |
|
message["content"] = texts |
|
return messages |
|
|
|
def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool: |
|
try: |
|
if "/" in model: |
|
model = model.split("/", 1)[1] |
|
if ( |
|
model in litellm.open_ai_chat_completion_models |
|
or model in litellm.open_ai_text_completion_models |
|
or model in litellm.open_ai_embedding_models |
|
): |
|
return True |
|
|
|
except Exception: |
|
return False |
|
return False |
|
|
|
def _get_openai_compatible_provider_info( |
|
self, |
|
model: str, |
|
api_base: Optional[str], |
|
api_key: Optional[str], |
|
custom_llm_provider: str, |
|
) -> Tuple[Optional[str], Optional[str], str]: |
|
api_base = api_base or get_secret_str("AZURE_AI_API_BASE") |
|
dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY") |
|
|
|
if self._is_azure_openai_model(model=model, api_base=api_base): |
|
verbose_logger.debug( |
|
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format( |
|
model |
|
) |
|
) |
|
custom_llm_provider = "azure" |
|
return api_base, dynamic_api_key, custom_llm_provider |
|
|
|
def transform_request( |
|
self, |
|
model: str, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
headers: dict, |
|
) -> dict: |
|
extra_body = optional_params.pop("extra_body", {}) |
|
if extra_body and isinstance(extra_body, dict): |
|
optional_params.update(extra_body) |
|
optional_params.pop("max_retries", None) |
|
return super().transform_request( |
|
model, messages, optional_params, litellm_params, headers |
|
) |
|
|
|
def transform_response( |
|
self, |
|
model: str, |
|
raw_response: Response, |
|
model_response: ModelResponse, |
|
logging_obj: LiteLLMLoggingObj, |
|
request_data: dict, |
|
messages: List[AllMessageValues], |
|
optional_params: dict, |
|
litellm_params: dict, |
|
encoding: Any, |
|
api_key: Optional[str] = None, |
|
json_mode: Optional[bool] = None, |
|
) -> ModelResponse: |
|
model_response.model = f"azure_ai/{model}" |
|
return super().transform_response( |
|
model=model, |
|
raw_response=raw_response, |
|
model_response=model_response, |
|
logging_obj=logging_obj, |
|
request_data=request_data, |
|
messages=messages, |
|
optional_params=optional_params, |
|
litellm_params=litellm_params, |
|
encoding=encoding, |
|
api_key=api_key, |
|
json_mode=json_mode, |
|
) |
|
|
|
def should_retry_llm_api_inside_llm_translation_on_http_error( |
|
self, e: httpx.HTTPStatusError, litellm_params: dict |
|
) -> bool: |
|
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params |
|
error_text = e.response.text |
|
if should_drop_params and "Extra inputs are not permitted" in error_text: |
|
return True |
|
elif ( |
|
"unknown field: parameter index is not a valid field" in error_text |
|
): |
|
return True |
|
return super().should_retry_llm_api_inside_llm_translation_on_http_error( |
|
e=e, litellm_params=litellm_params |
|
) |
|
|
|
@property |
|
def max_retry_on_unprocessable_entity_error(self) -> int: |
|
return 2 |
|
|
|
def transform_request_on_unprocessable_entity_error( |
|
self, e: httpx.HTTPStatusError, request_data: dict |
|
) -> dict: |
|
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages")) |
|
if ( |
|
"unknown field: parameter index is not a valid field" in e.response.text |
|
and _messages is not None |
|
): |
|
litellm.remove_index_from_tool_calls( |
|
messages=_messages, |
|
) |
|
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data) |
|
return data |
|
|