Spaces:
Sleeping
Sleeping
# What is this? | |
## Controller file for Predibase Integration - https://predibase.com/ | |
import json | |
import os | |
import time | |
from functools import partial | |
from typing import Callable, Optional, Union | |
import httpx # type: ignore | |
import litellm | |
import litellm.litellm_core_utils | |
import litellm.litellm_core_utils.litellm_logging | |
from litellm.litellm_core_utils.core_helpers import map_finish_reason | |
from litellm.litellm_core_utils.prompt_templates.factory import ( | |
custom_prompt, | |
prompt_factory, | |
) | |
from litellm.llms.custom_httpx.http_handler import ( | |
AsyncHTTPHandler, | |
get_async_httpx_client, | |
) | |
from litellm.types.utils import LiteLLMLoggingBaseClass | |
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage | |
from ..common_utils import PredibaseError | |
async def make_call( | |
client: AsyncHTTPHandler, | |
api_base: str, | |
headers: dict, | |
data: str, | |
model: str, | |
messages: list, | |
logging_obj, | |
timeout: Optional[Union[float, httpx.Timeout]], | |
): | |
response = await client.post( | |
api_base, headers=headers, data=data, stream=True, timeout=timeout | |
) | |
if response.status_code != 200: | |
raise PredibaseError(status_code=response.status_code, message=response.text) | |
completion_stream = response.aiter_lines() | |
# LOGGING | |
logging_obj.post_call( | |
input=messages, | |
api_key="", | |
original_response=completion_stream, # Pass the completion stream for logging | |
additional_args={"complete_input_dict": data}, | |
) | |
return completion_stream | |
class PredibaseChatCompletion: | |
def __init__(self) -> None: | |
super().__init__() | |
def output_parser(self, generated_text: str): | |
""" | |
Parse the output text to remove any special characters. In our current approach we just check for ChatML tokens. | |
Initial issue that prompted this - https://github.com/BerriAI/litellm/issues/763 | |
""" | |
chat_template_tokens = [ | |
"<|assistant|>", | |
"<|system|>", | |
"<|user|>", | |
"<s>", | |
"</s>", | |
] | |
for token in chat_template_tokens: | |
if generated_text.strip().startswith(token): | |
generated_text = generated_text.replace(token, "", 1) | |
if generated_text.endswith(token): | |
generated_text = generated_text[::-1].replace(token[::-1], "", 1)[::-1] | |
return generated_text | |
def process_response( # noqa: PLR0915 | |
self, | |
model: str, | |
response: httpx.Response, | |
model_response: ModelResponse, | |
stream: bool, | |
logging_obj: LiteLLMLoggingBaseClass, | |
optional_params: dict, | |
api_key: str, | |
data: Union[dict, str], | |
messages: list, | |
print_verbose, | |
encoding, | |
) -> ModelResponse: | |
## LOGGING | |
logging_obj.post_call( | |
input=messages, | |
api_key=api_key, | |
original_response=response.text, | |
additional_args={"complete_input_dict": data}, | |
) | |
print_verbose(f"raw model_response: {response.text}") | |
## RESPONSE OBJECT | |
try: | |
completion_response = response.json() | |
except Exception: | |
raise PredibaseError(message=response.text, status_code=422) | |
if "error" in completion_response: | |
raise PredibaseError( | |
message=str(completion_response["error"]), | |
status_code=response.status_code, | |
) | |
else: | |
if not isinstance(completion_response, dict): | |
raise PredibaseError( | |
status_code=422, | |
message=f"'completion_response' is not a dictionary - {completion_response}", | |
) | |
elif "generated_text" not in completion_response: | |
raise PredibaseError( | |
status_code=422, | |
message=f"'generated_text' is not a key response dictionary - {completion_response}", | |
) | |
if len(completion_response["generated_text"]) > 0: | |
model_response.choices[0].message.content = self.output_parser( # type: ignore | |
completion_response["generated_text"] | |
) | |
## GETTING LOGPROBS + FINISH REASON | |
if ( | |
"details" in completion_response | |
and "tokens" in completion_response["details"] | |
): | |
model_response.choices[0].finish_reason = map_finish_reason( | |
completion_response["details"]["finish_reason"] | |
) | |
sum_logprob = 0 | |
for token in completion_response["details"]["tokens"]: | |
if token["logprob"] is not None: | |
sum_logprob += token["logprob"] | |
setattr( | |
model_response.choices[0].message, # type: ignore | |
"_logprob", | |
sum_logprob, # [TODO] move this to using the actual logprobs | |
) | |
if "best_of" in optional_params and optional_params["best_of"] > 1: | |
if ( | |
"details" in completion_response | |
and "best_of_sequences" in completion_response["details"] | |
): | |
choices_list = [] | |
for idx, item in enumerate( | |
completion_response["details"]["best_of_sequences"] | |
): | |
sum_logprob = 0 | |
for token in item["tokens"]: | |
if token["logprob"] is not None: | |
sum_logprob += token["logprob"] | |
if len(item["generated_text"]) > 0: | |
message_obj = Message( | |
content=self.output_parser(item["generated_text"]), | |
logprobs=sum_logprob, | |
) | |
else: | |
message_obj = Message(content=None) | |
choice_obj = Choices( | |
finish_reason=map_finish_reason(item["finish_reason"]), | |
index=idx + 1, | |
message=message_obj, | |
) | |
choices_list.append(choice_obj) | |
model_response.choices.extend(choices_list) | |
## CALCULATING USAGE | |
prompt_tokens = 0 | |
try: | |
prompt_tokens = litellm.token_counter(messages=messages) | |
except Exception: | |
# this should remain non blocking we should not block a response returning if calculating usage fails | |
pass | |
output_text = model_response["choices"][0]["message"].get("content", "") | |
if output_text is not None and len(output_text) > 0: | |
completion_tokens = 0 | |
try: | |
completion_tokens = len( | |
encoding.encode( | |
model_response["choices"][0]["message"].get("content", "") | |
) | |
) ##[TODO] use a model-specific tokenizer | |
except Exception: | |
# this should remain non blocking we should not block a response returning if calculating usage fails | |
pass | |
else: | |
completion_tokens = 0 | |
total_tokens = prompt_tokens + completion_tokens | |
model_response.created = int(time.time()) | |
model_response.model = model | |
usage = Usage( | |
prompt_tokens=prompt_tokens, | |
completion_tokens=completion_tokens, | |
total_tokens=total_tokens, | |
) | |
model_response.usage = usage # type: ignore | |
## RESPONSE HEADERS | |
predibase_headers = response.headers | |
response_headers = {} | |
for k, v in predibase_headers.items(): | |
if k.startswith("x-"): | |
response_headers["llm_provider-{}".format(k)] = v | |
model_response._hidden_params["additional_headers"] = response_headers | |
return model_response | |
def completion( | |
self, | |
model: str, | |
messages: list, | |
api_base: str, | |
custom_prompt_dict: dict, | |
model_response: ModelResponse, | |
print_verbose: Callable, | |
encoding, | |
api_key: str, | |
logging_obj, | |
optional_params: dict, | |
litellm_params: dict, | |
tenant_id: str, | |
timeout: Union[float, httpx.Timeout], | |
acompletion=None, | |
logger_fn=None, | |
headers: dict = {}, | |
) -> Union[ModelResponse, CustomStreamWrapper]: | |
headers = litellm.PredibaseConfig().validate_environment( | |
api_key=api_key, | |
headers=headers, | |
messages=messages, | |
optional_params=optional_params, | |
model=model, | |
litellm_params=litellm_params, | |
) | |
completion_url = "" | |
input_text = "" | |
base_url = "https://serving.app.predibase.com" | |
if "https" in model: | |
completion_url = model | |
elif api_base: | |
base_url = api_base | |
elif "PREDIBASE_API_BASE" in os.environ: | |
base_url = os.getenv("PREDIBASE_API_BASE", "") | |
completion_url = f"{base_url}/{tenant_id}/deployments/v2/llms/{model}" | |
if optional_params.get("stream", False) is True: | |
completion_url += "/generate_stream" | |
else: | |
completion_url += "/generate" | |
if model in custom_prompt_dict: | |
# check if the model has a registered custom prompt | |
model_prompt_details = custom_prompt_dict[model] | |
prompt = custom_prompt( | |
role_dict=model_prompt_details["roles"], | |
initial_prompt_value=model_prompt_details["initial_prompt_value"], | |
final_prompt_value=model_prompt_details["final_prompt_value"], | |
messages=messages, | |
) | |
else: | |
prompt = prompt_factory(model=model, messages=messages) | |
## Load Config | |
config = litellm.PredibaseConfig.get_config() | |
for k, v in config.items(): | |
if ( | |
k not in optional_params | |
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in | |
optional_params[k] = v | |
stream = optional_params.pop("stream", False) | |
data = { | |
"inputs": prompt, | |
"parameters": optional_params, | |
} | |
input_text = prompt | |
## LOGGING | |
logging_obj.pre_call( | |
input=input_text, | |
api_key=api_key, | |
additional_args={ | |
"complete_input_dict": data, | |
"headers": headers, | |
"api_base": completion_url, | |
"acompletion": acompletion, | |
}, | |
) | |
## COMPLETION CALL | |
if acompletion is True: | |
### ASYNC STREAMING | |
if stream is True: | |
return self.async_streaming( | |
model=model, | |
messages=messages, | |
data=data, | |
api_base=completion_url, | |
model_response=model_response, | |
print_verbose=print_verbose, | |
encoding=encoding, | |
api_key=api_key, | |
logging_obj=logging_obj, | |
optional_params=optional_params, | |
litellm_params=litellm_params, | |
logger_fn=logger_fn, | |
headers=headers, | |
timeout=timeout, | |
) # type: ignore | |
else: | |
### ASYNC COMPLETION | |
return self.async_completion( | |
model=model, | |
messages=messages, | |
data=data, | |
api_base=completion_url, | |
model_response=model_response, | |
print_verbose=print_verbose, | |
encoding=encoding, | |
api_key=api_key, | |
logging_obj=logging_obj, | |
optional_params=optional_params, | |
stream=False, | |
litellm_params=litellm_params, | |
logger_fn=logger_fn, | |
headers=headers, | |
timeout=timeout, | |
) # type: ignore | |
### SYNC STREAMING | |
if stream is True: | |
response = litellm.module_level_client.post( | |
completion_url, | |
headers=headers, | |
data=json.dumps(data), | |
stream=stream, | |
timeout=timeout, # type: ignore | |
) | |
_response = CustomStreamWrapper( | |
response.iter_lines(), | |
model, | |
custom_llm_provider="predibase", | |
logging_obj=logging_obj, | |
) | |
return _response | |
### SYNC COMPLETION | |
else: | |
response = litellm.module_level_client.post( | |
url=completion_url, | |
headers=headers, | |
data=json.dumps(data), | |
timeout=timeout, # type: ignore | |
) | |
return self.process_response( | |
model=model, | |
response=response, | |
model_response=model_response, | |
stream=optional_params.get("stream", False), | |
logging_obj=logging_obj, # type: ignore | |
optional_params=optional_params, | |
api_key=api_key, | |
data=data, | |
messages=messages, | |
print_verbose=print_verbose, | |
encoding=encoding, | |
) | |
async def async_completion( | |
self, | |
model: str, | |
messages: list, | |
api_base: str, | |
model_response: ModelResponse, | |
print_verbose: Callable, | |
encoding, | |
api_key, | |
logging_obj, | |
stream, | |
data: dict, | |
optional_params: dict, | |
timeout: Union[float, httpx.Timeout], | |
litellm_params=None, | |
logger_fn=None, | |
headers={}, | |
) -> ModelResponse: | |
async_handler = get_async_httpx_client( | |
llm_provider=litellm.LlmProviders.PREDIBASE, | |
params={"timeout": timeout}, | |
) | |
try: | |
response = await async_handler.post( | |
api_base, headers=headers, data=json.dumps(data) | |
) | |
except httpx.HTTPStatusError as e: | |
raise PredibaseError( | |
status_code=e.response.status_code, | |
message="HTTPStatusError - received status_code={}, error_message={}".format( | |
e.response.status_code, e.response.text | |
), | |
) | |
except Exception as e: | |
for exception in litellm.LITELLM_EXCEPTION_TYPES: | |
if isinstance(e, exception): | |
raise e | |
raise PredibaseError( | |
status_code=500, message="{}".format(str(e)) | |
) # don't use verbose_logger.exception, if exception is raised | |
return self.process_response( | |
model=model, | |
response=response, | |
model_response=model_response, | |
stream=stream, | |
logging_obj=logging_obj, | |
api_key=api_key, | |
data=data, | |
messages=messages, | |
print_verbose=print_verbose, | |
optional_params=optional_params, | |
encoding=encoding, | |
) | |
async def async_streaming( | |
self, | |
model: str, | |
messages: list, | |
api_base: str, | |
model_response: ModelResponse, | |
print_verbose: Callable, | |
encoding, | |
api_key, | |
logging_obj, | |
data: dict, | |
timeout: Union[float, httpx.Timeout], | |
optional_params=None, | |
litellm_params=None, | |
logger_fn=None, | |
headers={}, | |
) -> CustomStreamWrapper: | |
data["stream"] = True | |
streamwrapper = CustomStreamWrapper( | |
completion_stream=None, | |
make_call=partial( | |
make_call, | |
api_base=api_base, | |
headers=headers, | |
data=json.dumps(data), | |
model=model, | |
messages=messages, | |
logging_obj=logging_obj, | |
timeout=timeout, | |
), | |
model=model, | |
custom_llm_provider="predibase", | |
logging_obj=logging_obj, | |
) | |
return streamwrapper | |
def embedding(self, *args, **kwargs): | |
pass | |