Spaces:
Sleeping
Sleeping
File size: 12,402 Bytes
469eae6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 |
"""
Common base config for all LLM providers
"""
import types
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Iterator,
List,
Optional,
Type,
Union,
cast,
)
import httpx
from pydantic import BaseModel
from litellm.constants import DEFAULT_MAX_TOKENS, RESPONSE_FORMAT_TOOL_NAME
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionToolChoiceFunctionParam,
ChatCompletionToolChoiceObjectParam,
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
from ..base_utils import (
map_developer_role_to_system_role,
type_to_response_format_param,
)
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class BaseLLMException(Exception):
def __init__(
self,
status_code: int,
message: str,
headers: Optional[Union[dict, httpx.Headers]] = None,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
body: Optional[dict] = None,
):
self.status_code = status_code
self.message: str = message
self.headers = headers
if request:
self.request = request
else:
self.request = httpx.Request(
method="POST", url="https://docs.litellm.ai/docs"
)
if response:
self.response = response
else:
self.response = httpx.Response(
status_code=status_code, request=self.request
)
self.body = body
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class BaseConfig(ABC):
def __init__(self):
pass
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_json_schema_from_pydantic_object(
self, response_format: Optional[Union[Type[BaseModel], dict]]
) -> Optional[dict]:
return type_to_response_format_param(response_format=response_format)
def is_thinking_enabled(self, non_default_params: dict) -> bool:
return (
non_default_params.get("thinking", {}).get("type") == "enabled"
or non_default_params.get("reasoning_effort") is not None
)
def update_optional_params_with_thinking_tokens(
self, non_default_params: dict, optional_params: dict
):
"""
Handles scenario where max tokens is not specified. For anthropic models (anthropic api/bedrock/vertex ai), this requires having the max tokens being set and being greater than the thinking token budget.
Checks 'non_default_params' for 'thinking' and 'max_tokens'
if 'thinking' is enabled and 'max_tokens' is not specified, set 'max_tokens' to the thinking token budget + DEFAULT_MAX_TOKENS
"""
is_thinking_enabled = self.is_thinking_enabled(optional_params)
if is_thinking_enabled and "max_tokens" not in non_default_params:
thinking_token_budget = cast(dict, optional_params["thinking"]).get(
"budget_tokens", None
)
if thinking_token_budget is not None:
optional_params["max_tokens"] = (
thinking_token_budget + DEFAULT_MAX_TOKENS
)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Returns True if the model/provider should fake stream
"""
return False
def _add_tools_to_optional_params(self, optional_params: dict, tools: List) -> dict:
"""
Helper util to add tools to optional_params.
"""
if "tools" not in optional_params:
optional_params["tools"] = tools
else:
optional_params["tools"] = [
*optional_params["tools"],
*tools,
]
return optional_params
def translate_developer_role_to_system_role(
self,
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Translate `developer` role to `system` role for non-OpenAI providers.
Overriden by OpenAI/Azure
"""
return map_developer_role_to_system_role(messages=messages)
def should_retry_llm_api_inside_llm_translation_on_http_error(
self, e: httpx.HTTPStatusError, litellm_params: dict
) -> bool:
"""
Returns True if the model/provider should retry the LLM API on UnprocessableEntityError
Overriden by azure ai - where different models support different parameters
"""
return False
def transform_request_on_unprocessable_entity_error(
self, e: httpx.HTTPStatusError, request_data: dict
) -> dict:
"""
Transform the request data on UnprocessableEntityError
"""
return request_data
@property
def max_retry_on_unprocessable_entity_error(self) -> int:
"""
Returns the max retry count for UnprocessableEntityError
Used if `should_retry_llm_api_inside_llm_translation_on_http_error` is True
"""
return 0
@abstractmethod
def get_supported_openai_params(self, model: str) -> list:
pass
def _add_response_format_to_tools(
self,
optional_params: dict,
value: dict,
is_response_format_supported: bool,
enforce_tool_choice: bool = True,
) -> dict:
"""
Follow similar approach to anthropic - translate to a single tool call.
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective.
Add response format to tools
This is used to translate response_format to a tool call, for models/APIs that don't support response_format directly.
"""
json_schema: Optional[dict] = None
if "response_schema" in value:
json_schema = value["response_schema"]
elif "json_schema" in value:
json_schema = value["json_schema"]["schema"]
if json_schema and not is_response_format_supported:
_tool_choice = ChatCompletionToolChoiceObjectParam(
type="function",
function=ChatCompletionToolChoiceFunctionParam(
name=RESPONSE_FORMAT_TOOL_NAME
),
)
_tool = ChatCompletionToolParam(
type="function",
function=ChatCompletionToolParamFunctionChunk(
name=RESPONSE_FORMAT_TOOL_NAME, parameters=json_schema
),
)
optional_params.setdefault("tools", [])
optional_params["tools"].append(_tool)
if enforce_tool_choice:
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
elif is_response_format_supported:
optional_params["response_format"] = value
return optional_params
@abstractmethod
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
pass
@abstractmethod
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
pass
def sign_request(
self,
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> dict:
"""
Some providers like Bedrock require signing the request. The sign request funtion needs access to `request_data` and `complete_url`
Args:
headers: dict
optional_params: dict
request_data: dict - the request body being sent in http request
api_base: str - the complete url being sent in http request
Returns:
dict - the signed headers
Update the headers with the signed headers in this function. The return values will be sent as headers in the http request.
"""
return headers
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
OPTIONAL
Get the complete url for the request
Some providers need `model` in `api_base`
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base
@abstractmethod
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
pass
@abstractmethod
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
pass
@abstractmethod
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
pass
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
pass
def get_async_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional[AsyncHTTPHandler] = None,
json_mode: Optional[bool] = None,
) -> CustomStreamWrapper:
raise NotImplementedError
def get_sync_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
json_mode: Optional[bool] = None,
) -> CustomStreamWrapper:
raise NotImplementedError
@property
def custom_llm_provider(self) -> Optional[str]:
return None
@property
def has_custom_stream_wrapper(self) -> bool:
return False
@property
def supports_stream_param_in_request_body(self) -> bool:
"""
Some providers like Bedrock invoke do not support the stream parameter in the request body.
By default, this is true for almost all providers.
"""
return True
|