from typing import TYPE_CHECKING, Any, List, Literal, Optional, Union from httpx import Headers, Response from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ModelResponse from ..common_utils import PredibaseError if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj LiteLLMLoggingObj = _LiteLLMLoggingObj else: LiteLLMLoggingObj = Any class PredibaseConfig(BaseConfig): """ Reference: https://docs.predibase.com/user-guide/inference/rest_api """ adapter_id: Optional[str] = None adapter_source: Optional[Literal["pbase", "hub", "s3"]] = None best_of: Optional[int] = None decoder_input_details: Optional[bool] = None details: bool = True # enables returning logprobs + best of max_new_tokens: int = ( 256 # openai default - requests hang if max_new_tokens not given ) repetition_penalty: Optional[float] = None return_full_text: Optional[bool] = ( False # by default don't return the input as part of the output ) seed: Optional[int] = None stop: Optional[List[str]] = None temperature: Optional[float] = None top_k: Optional[int] = None top_p: Optional[int] = None truncate: Optional[int] = None typical_p: Optional[float] = None watermark: Optional[bool] = None def __init__( self, best_of: Optional[int] = None, decoder_input_details: Optional[bool] = None, details: Optional[bool] = None, max_new_tokens: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = None, seed: Optional[int] = None, stop: Optional[List[str]] = None, temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[int] = None, truncate: Optional[int] = None, typical_p: Optional[float] = None, watermark: Optional[bool] = None, ) -> None: locals_ = locals() for key, value in locals_.items(): if key != "self" and value is not None: setattr(self.__class__, key, value) @classmethod def get_config(cls): return super().get_config() def get_supported_openai_params(self, model: str): return [ "stream", "temperature", "max_completion_tokens", "max_tokens", "top_p", "stop", "n", "response_format", ] def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, ) -> dict: for param, value in non_default_params.items(): # temperature, top_p, n, stream, stop, max_tokens, n, presence_penalty default to None if param == "temperature": if value == 0.0 or value == 0: # hugging face exception raised when temp==0 # Failed: Error occurred: HuggingfaceException - Input validation error: `temperature` must be strictly positive value = 0.01 optional_params["temperature"] = value if param == "top_p": optional_params["top_p"] = value if param == "n": optional_params["best_of"] = value optional_params["do_sample"] = ( True # Need to sample if you want best of for hf inference endpoints ) if param == "stream": optional_params["stream"] = value if param == "stop": optional_params["stop"] = value if param == "max_tokens" or param == "max_completion_tokens": # HF TGI raises the following exception when max_new_tokens==0 # Failed: Error occurred: HuggingfaceException - Input validation error: `max_new_tokens` must be strictly positive if value == 0: value = 1 optional_params["max_new_tokens"] = value if param == "echo": # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_generation.decoder_input_details # Return the decoder input token logprobs and ids. You must set details=True as well for it to be taken into account. Defaults to False optional_params["decoder_input_details"] = True if param == "response_format": optional_params["response_format"] = value return optional_params def transform_response( self, model: str, raw_response: Response, model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, request_data: dict, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, encoding: str, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: raise NotImplementedError( "Predibase transformation currently done in handler.py. Need to migrate to this file." ) def transform_request( self, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, headers: dict, ) -> dict: raise NotImplementedError( "Predibase transformation currently done in handler.py. Need to migrate to this file." ) def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, Headers] ) -> BaseLLMException: return PredibaseError( status_code=status_code, message=error_message, headers=headers ) def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: if api_key is None: raise ValueError( "Missing Predibase API Key - A call is being made to predibase but no key is set either in the environment variables or via params" ) default_headers = { "content-type": "application/json", "Authorization": "Bearer {}".format(api_key), } if headers is not None and isinstance(headers, dict): headers = {**default_headers, **headers} return headers