# What is this? ## Controller file for Predibase Integration - https://predibase.com/ import json import os import time from functools import partial from typing import Callable, Optional, Union import httpx # type: ignore import litellm import litellm.litellm_core_utils import litellm.litellm_core_utils.litellm_logging from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.prompt_templates.factory import ( custom_prompt, prompt_factory, ) from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, get_async_httpx_client, ) from litellm.types.utils import LiteLLMLoggingBaseClass from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage from ..common_utils import PredibaseError async def make_call( client: AsyncHTTPHandler, api_base: str, headers: dict, data: str, model: str, messages: list, logging_obj, timeout: Optional[Union[float, httpx.Timeout]], ): response = await client.post( api_base, headers=headers, data=data, stream=True, timeout=timeout ) if response.status_code != 200: raise PredibaseError(status_code=response.status_code, message=response.text) completion_stream = response.aiter_lines() # LOGGING logging_obj.post_call( input=messages, api_key="", original_response=completion_stream, # Pass the completion stream for logging additional_args={"complete_input_dict": data}, ) return completion_stream class PredibaseChatCompletion: def __init__(self) -> None: super().__init__() def output_parser(self, generated_text: str): """ Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 """ chat_template_tokens = [ "<|assistant|>", "<|system|>", "<|user|>", "", "", ] for token in chat_template_tokens: if generated_text.strip().startswith(token): generated_text = generated_text.replace(token, "", 1) if generated_text.endswith(token): generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] return generated_text def process_response( # noqa: PLR0915 self, model: str, response: httpx.Response, model_response: ModelResponse, stream: bool, logging_obj: LiteLLMLoggingBaseClass, optional_params: dict, api_key: str, data: Union[dict, str], messages: list, print_verbose, encoding, ) -> ModelResponse: ## LOGGING logging_obj.post_call( input=messages, api_key=api_key, original_response=response.text, additional_args={"complete_input_dict": data}, ) print_verbose(f"raw model_response: {response.text}") ## RESPONSE OBJECT try: completion_response = response.json() except Exception: raise PredibaseError(message=response.text, status_code=422) if "error" in completion_response: raise PredibaseError( message=str(completion_response["error"]), status_code=response.status_code, ) else: if not isinstance(completion_response, dict): raise PredibaseError( status_code=422, message=f"'completion_response' is not a dictionary - {completion_response}", ) elif "generated_text" not in completion_response: raise PredibaseError( status_code=422, message=f"'generated_text' is not a key response dictionary - {completion_response}", ) if len(completion_response["generated_text"]) > 0: model_response.choices[0].message.content = self.output_parser( # type: ignore completion_response["generated_text"] ) ## GETTING LOGPROBS + FINISH REASON if ( "details" in completion_response and "tokens" in completion_response["details"] ): model_response.choices[0].finish_reason = map_finish_reason( completion_response["details"]["finish_reason"] ) sum_logprob = 0 for token in completion_response["details"]["tokens"]: if token["logprob"] is not None: sum_logprob += token["logprob"] setattr( model_response.choices[0].message, # type: ignore "_logprob", sum_logprob, # [TODO] move this to using the actual logprobs ) if "best_of" in optional_params and optional_params["best_of"] > 1: if ( "details" in completion_response and "best_of_sequences" in completion_response["details"] ): choices_list = [] for idx, item in enumerate( completion_response["details"]["best_of_sequences"] ): sum_logprob = 0 for token in item["tokens"]: if token["logprob"] is not None: sum_logprob += token["logprob"] if len(item["generated_text"]) > 0: message_obj = Message( content=self.output_parser(item["generated_text"]), logprobs=sum_logprob, ) else: message_obj = Message(content=None) choice_obj = Choices( finish_reason=map_finish_reason(item["finish_reason"]), index=idx + 1, message=message_obj, ) choices_list.append(choice_obj) model_response.choices.extend(choices_list) ## CALCULATING USAGE prompt_tokens = 0 try: prompt_tokens = litellm.token_counter(messages=messages) except Exception: # this should remain non blocking we should not block a response returning if calculating usage fails pass output_text = model_response["choices"][0]["message"].get("content", "") if output_text is not None and len(output_text) > 0: completion_tokens = 0 try: completion_tokens = len( encoding.encode( model_response["choices"][0]["message"].get("content", "") ) ) ##[TODO] use a model-specific tokenizer except Exception: # this should remain non blocking we should not block a response returning if calculating usage fails pass else: completion_tokens = 0 total_tokens = prompt_tokens + completion_tokens model_response.created = int(time.time()) model_response.model = model usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, ) model_response.usage = usage # type: ignore ## RESPONSE HEADERS predibase_headers = response.headers response_headers = {} for k, v in predibase_headers.items(): if k.startswith("x-"): response_headers["llm_provider-{}".format(k)] = v model_response._hidden_params["additional_headers"] = response_headers return model_response def completion( self, model: str, messages: list, api_base: str, custom_prompt_dict: dict, model_response: ModelResponse, print_verbose: Callable, encoding, api_key: str, logging_obj, optional_params: dict, tenant_id: str, timeout: Union[float, httpx.Timeout], acompletion=None, litellm_params=None, logger_fn=None, headers: dict = {}, ) -> Union[ModelResponse, CustomStreamWrapper]: headers = litellm.PredibaseConfig().validate_environment( api_key=api_key, headers=headers, messages=messages, optional_params=optional_params, model=model, ) completion_url = "" input_text = "" base_url = "https://serving.app.predibase.com" if "https" in model: completion_url = model elif api_base: base_url = api_base elif "PREDIBASE_API_BASE" in os.environ: base_url = os.getenv("PREDIBASE_API_BASE", "") completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}" if optional_params.get("stream", False) is True: completion_url += "/generate_stream" else: completion_url += "/generate" if model in custom_prompt_dict: # check if the model has a registered custom prompt model_prompt_details = custom_prompt_dict[model] prompt = custom_prompt( role_dict=model_prompt_details["roles"], initial_prompt_value=model_prompt_details["initial_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"], messages=messages, ) else: prompt = prompt_factory(model=model, messages=messages) ## Load Config config = litellm.PredibaseConfig.get_config() for k, v in config.items(): if ( k not in optional_params ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in optional_params[k] = v stream = optional_params.pop("stream", False) data = { "inputs": prompt, "parameters": optional_params, } input_text = prompt ## LOGGING logging_obj.pre_call( input=input_text, api_key=api_key, additional_args={ "complete_input_dict": data, "headers": headers, "api_base": completion_url, "acompletion": acompletion, }, ) ## COMPLETION CALL if acompletion is True: ### ASYNC STREAMING if stream is True: return self.async_streaming( model=model, messages=messages, data=data, api_base=completion_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, api_key=api_key, logging_obj=logging_obj, optional_params=optional_params, litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, timeout=timeout, ) # type: ignore else: ### ASYNC COMPLETION return self.async_completion( model=model, messages=messages, data=data, api_base=completion_url, model_response=model_response, print_verbose=print_verbose, encoding=encoding, api_key=api_key, logging_obj=logging_obj, optional_params=optional_params, stream=False, litellm_params=litellm_params, logger_fn=logger_fn, headers=headers, timeout=timeout, ) # type: ignore ### SYNC STREAMING if stream is True: response = litellm.module_level_client.post( completion_url, headers=headers, data=json.dumps(data), stream=stream, timeout=timeout, # type: ignore ) _response = CustomStreamWrapper( response.iter_lines(), model, custom_llm_provider="predibase", logging_obj=logging_obj, ) return _response ### SYNC COMPLETION else: response = litellm.module_level_client.post( url=completion_url, headers=headers, data=json.dumps(data), timeout=timeout, # type: ignore ) return self.process_response( model=model, response=response, model_response=model_response, stream=optional_params.get("stream", False), logging_obj=logging_obj, # type: ignore optional_params=optional_params, api_key=api_key, data=data, messages=messages, print_verbose=print_verbose, encoding=encoding, ) async def async_completion( self, model: str, messages: list, api_base: str, model_response: ModelResponse, print_verbose: Callable, encoding, api_key, logging_obj, stream, data: dict, optional_params: dict, timeout: Union[float, httpx.Timeout], litellm_params=None, logger_fn=None, headers={}, ) -> ModelResponse: async_handler = get_async_httpx_client( llm_provider=litellm.LlmProviders.PREDIBASE, params={"timeout": timeout}, ) try: response = await async_handler.post( api_base, headers=headers, data=json.dumps(data) ) except httpx.HTTPStatusError as e: raise PredibaseError( status_code=e.response.status_code, message="HTTPStatusError - received status_code={}, error_message={}".format( e.response.status_code, e.response.text ), ) except Exception as e: for exception in litellm.LITELLM_EXCEPTION_TYPES: if isinstance(e, exception): raise e raise PredibaseError( status_code=500, message="{}".format(str(e)) ) # don't use verbose_logger.exception, if exception is raised return self.process_response( model=model, response=response, model_response=model_response, stream=stream, logging_obj=logging_obj, api_key=api_key, data=data, messages=messages, print_verbose=print_verbose, optional_params=optional_params, encoding=encoding, ) async def async_streaming( self, model: str, messages: list, api_base: str, model_response: ModelResponse, print_verbose: Callable, encoding, api_key, logging_obj, data: dict, timeout: Union[float, httpx.Timeout], optional_params=None, litellm_params=None, logger_fn=None, headers={}, ) -> CustomStreamWrapper: data["stream"] = True streamwrapper = CustomStreamWrapper( completion_stream=None, make_call=partial( make_call, api_base=api_base, headers=headers, data=json.dumps(data), model=model, messages=messages, logging_obj=logging_obj, timeout=timeout, ), model=model, custom_llm_provider="predibase", logging_obj=logging_obj, ) return streamwrapper def embedding(self, *args, **kwargs): pass