Spaces:
Sleeping
Sleeping
import enum | |
from typing import Any, List, Optional, Tuple, cast | |
from urllib.parse import urlparse | |
import httpx | |
from httpx import Response | |
import litellm | |
from litellm._logging import verbose_logger | |
from litellm.litellm_core_utils.prompt_templates.common_utils import ( | |
_audio_or_image_in_message_content, | |
convert_content_list_to_str, | |
) | |
from litellm.llms.base_llm.chat.transformation import LiteLLMLoggingObj | |
from litellm.llms.openai.common_utils import drop_params_from_unprocessable_entity_error | |
from litellm.llms.openai.openai import OpenAIConfig | |
from litellm.secret_managers.main import get_secret_str | |
from litellm.types.llms.openai import AllMessageValues | |
from litellm.types.utils import ModelResponse, ProviderField | |
from litellm.utils import _add_path_to_api_base, supports_tool_choice | |
class AzureFoundryErrorStrings(str, enum.Enum): | |
SET_EXTRA_PARAMETERS_TO_PASS_THROUGH = "Set extra-parameters to 'pass-through'" | |
class AzureAIStudioConfig(OpenAIConfig): | |
def get_supported_openai_params(self, model: str) -> List: | |
model_supports_tool_choice = True # azure ai supports this by default | |
if not supports_tool_choice(model=f"azure_ai/{model}"): | |
model_supports_tool_choice = False | |
supported_params = super().get_supported_openai_params(model) | |
if not model_supports_tool_choice: | |
filtered_supported_params = [] | |
for param in supported_params: | |
if param != "tool_choice": | |
filtered_supported_params.append(param) | |
return filtered_supported_params | |
return supported_params | |
def validate_environment( | |
self, | |
headers: dict, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
) -> dict: | |
if api_base and self._should_use_api_key_header(api_base): | |
headers["api-key"] = api_key | |
else: | |
headers["Authorization"] = f"Bearer {api_key}" | |
return headers | |
def _should_use_api_key_header(self, api_base: str) -> bool: | |
""" | |
Returns True if the request should use `api-key` header for authentication. | |
""" | |
parsed_url = urlparse(api_base) | |
host = parsed_url.hostname | |
if host and ( | |
host.endswith(".services.ai.azure.com") | |
or host.endswith(".openai.azure.com") | |
): | |
return True | |
return False | |
def get_complete_url( | |
self, | |
api_base: Optional[str], | |
api_key: Optional[str], | |
model: str, | |
optional_params: dict, | |
litellm_params: dict, | |
stream: Optional[bool] = None, | |
) -> str: | |
""" | |
Constructs a complete URL for the API request. | |
Args: | |
- api_base: Base URL, e.g., | |
"https://litellm8397336933.services.ai.azure.com" | |
OR | |
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" | |
- model: Model name. | |
- optional_params: Additional query parameters, including "api_version". | |
- stream: If streaming is required (optional). | |
Returns: | |
- A complete URL string, e.g., | |
"https://litellm8397336933.services.ai.azure.com/models/chat/completions?api-version=2024-05-01-preview" | |
""" | |
if api_base is None: | |
raise ValueError( | |
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`" | |
) | |
original_url = httpx.URL(api_base) | |
# Extract api_version or use default | |
api_version = cast(Optional[str], litellm_params.get("api_version")) | |
# Create a new dictionary with existing params | |
query_params = dict(original_url.params) | |
# Add api_version if needed | |
if "api-version" not in query_params and api_version: | |
query_params["api-version"] = api_version | |
# Add the path to the base URL | |
if "services.ai.azure.com" in api_base: | |
new_url = _add_path_to_api_base( | |
api_base=api_base, ending_path="/models/chat/completions" | |
) | |
else: | |
new_url = _add_path_to_api_base( | |
api_base=api_base, ending_path="/chat/completions" | |
) | |
# Use the new query_params dictionary | |
final_url = httpx.URL(new_url).copy_with(params=query_params) | |
return str(final_url) | |
def get_required_params(self) -> List[ProviderField]: | |
"""For a given provider, return it's required fields with a description""" | |
return [ | |
ProviderField( | |
field_name="api_key", | |
field_type="string", | |
field_description="Your Azure AI Studio API Key.", | |
field_value="zEJ...", | |
), | |
ProviderField( | |
field_name="api_base", | |
field_type="string", | |
field_description="Your Azure AI Studio API Base.", | |
field_value="https://Mistral-serverless.", | |
), | |
] | |
def _transform_messages( | |
self, | |
messages: List[AllMessageValues], | |
model: str, | |
) -> List: | |
""" | |
- Azure AI Studio doesn't support content as a list. This handles: | |
1. Transforms list content to a string. | |
2. If message contains an image or audio, send as is (user-intended) | |
""" | |
for message in messages: | |
# Do nothing if the message contains an image or audio | |
if _audio_or_image_in_message_content(message): | |
continue | |
texts = convert_content_list_to_str(message=message) | |
if texts: | |
message["content"] = texts | |
return messages | |
def _is_azure_openai_model(self, model: str, api_base: Optional[str]) -> bool: | |
try: | |
if "/" in model: | |
model = model.split("/", 1)[1] | |
if ( | |
model in litellm.open_ai_chat_completion_models | |
or model in litellm.open_ai_text_completion_models | |
or model in litellm.open_ai_embedding_models | |
): | |
return True | |
except Exception: | |
return False | |
return False | |
def _get_openai_compatible_provider_info( | |
self, | |
model: str, | |
api_base: Optional[str], | |
api_key: Optional[str], | |
custom_llm_provider: str, | |
) -> Tuple[Optional[str], Optional[str], str]: | |
api_base = api_base or get_secret_str("AZURE_AI_API_BASE") | |
dynamic_api_key = api_key or get_secret_str("AZURE_AI_API_KEY") | |
if self._is_azure_openai_model(model=model, api_base=api_base): | |
verbose_logger.debug( | |
"Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format( | |
model | |
) | |
) | |
custom_llm_provider = "azure" | |
return api_base, dynamic_api_key, custom_llm_provider | |
def transform_request( | |
self, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
headers: dict, | |
) -> dict: | |
extra_body = optional_params.pop("extra_body", {}) | |
if extra_body and isinstance(extra_body, dict): | |
optional_params.update(extra_body) | |
optional_params.pop("max_retries", None) | |
return super().transform_request( | |
model, messages, optional_params, litellm_params, headers | |
) | |
def transform_response( | |
self, | |
model: str, | |
raw_response: Response, | |
model_response: ModelResponse, | |
logging_obj: LiteLLMLoggingObj, | |
request_data: dict, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
encoding: Any, | |
api_key: Optional[str] = None, | |
json_mode: Optional[bool] = None, | |
) -> ModelResponse: | |
model_response.model = f"azure_ai/{model}" | |
return super().transform_response( | |
model=model, | |
raw_response=raw_response, | |
model_response=model_response, | |
logging_obj=logging_obj, | |
request_data=request_data, | |
messages=messages, | |
optional_params=optional_params, | |
litellm_params=litellm_params, | |
encoding=encoding, | |
api_key=api_key, | |
json_mode=json_mode, | |
) | |
def should_retry_llm_api_inside_llm_translation_on_http_error( | |
self, e: httpx.HTTPStatusError, litellm_params: dict | |
) -> bool: | |
should_drop_params = litellm_params.get("drop_params") or litellm.drop_params | |
error_text = e.response.text | |
if should_drop_params and "Extra inputs are not permitted" in error_text: | |
return True | |
elif ( | |
"unknown field: parameter index is not a valid field" in error_text | |
): # remove index from tool calls | |
return True | |
elif ( | |
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value | |
in error_text | |
): # remove extra-parameters from tool calls | |
return True | |
return super().should_retry_llm_api_inside_llm_translation_on_http_error( | |
e=e, litellm_params=litellm_params | |
) | |
def max_retry_on_unprocessable_entity_error(self) -> int: | |
return 2 | |
def transform_request_on_unprocessable_entity_error( | |
self, e: httpx.HTTPStatusError, request_data: dict | |
) -> dict: | |
_messages = cast(Optional[List[AllMessageValues]], request_data.get("messages")) | |
if ( | |
"unknown field: parameter index is not a valid field" in e.response.text | |
and _messages is not None | |
): | |
litellm.remove_index_from_tool_calls( | |
messages=_messages, | |
) | |
elif ( | |
AzureFoundryErrorStrings.SET_EXTRA_PARAMETERS_TO_PASS_THROUGH.value | |
in e.response.text | |
): | |
request_data = self._drop_extra_params_from_request_data( | |
request_data, e.response.text | |
) | |
data = drop_params_from_unprocessable_entity_error(e=e, data=request_data) | |
return data | |
def _drop_extra_params_from_request_data( | |
self, request_data: dict, error_text: str | |
) -> dict: | |
params_to_drop = self._extract_params_to_drop_from_error_text(error_text) | |
if params_to_drop: | |
for param in params_to_drop: | |
if param in request_data: | |
request_data.pop(param, None) | |
return request_data | |
def _extract_params_to_drop_from_error_text( | |
self, error_text: str | |
) -> Optional[List[str]]: | |
""" | |
Error text looks like this" | |
"Extra parameters ['stream_options', 'extra-parameters'] are not allowed when extra-parameters is not set or set to be 'error'. | |
""" | |
import re | |
# Extract parameters within square brackets | |
match = re.search(r"\[(.*?)\]", error_text) | |
if not match: | |
return [] | |
# Parse the extracted string into a list of parameter names | |
params_str = match.group(1) | |
params = [] | |
for param in params_str.split(","): | |
# Clean up the parameter name (remove quotes, spaces) | |
clean_param = param.strip().strip("'").strip('"') | |
if clean_param: | |
params.append(clean_param) | |
return params | |