Spaces:
Running
Running
from typing import Optional, Union, Any | |
import types, time, json | |
import httpx | |
from .base import BaseLLM | |
from litellm.utils import ( | |
ModelResponse, | |
Choices, | |
Message, | |
CustomStreamWrapper, | |
convert_to_model_response_object, | |
Usage, | |
) | |
from typing import Callable, Optional | |
import aiohttp, requests | |
import litellm | |
from .prompt_templates.factory import prompt_factory, custom_prompt | |
from openai import OpenAI, AsyncOpenAI | |
class OpenAIError(Exception): | |
def __init__( | |
self, | |
status_code, | |
message, | |
request: Optional[httpx.Request] = None, | |
response: Optional[httpx.Response] = None, | |
): | |
self.status_code = status_code | |
self.message = message | |
if request: | |
self.request = request | |
else: | |
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1") | |
if response: | |
self.response = response | |
else: | |
self.response = httpx.Response( | |
status_code=status_code, request=self.request | |
) | |
super().__init__( | |
self.message | |
) # Call the base class constructor with the parameters it needs | |
class OpenAIConfig: | |
""" | |
Reference: https://platform.openai.com/docs/api-reference/chat/create | |
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters: | |
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition. | |
- `function_call` (string or object): This optional parameter controls how the model calls functions. | |
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs. | |
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. | |
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. | |
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message. | |
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics. | |
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. | |
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. | |
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. | |
""" | |
frequency_penalty: Optional[int] = None | |
function_call: Optional[Union[str, dict]] = None | |
functions: Optional[list] = None | |
logit_bias: Optional[dict] = None | |
max_tokens: Optional[int] = None | |
n: Optional[int] = None | |
presence_penalty: Optional[int] = None | |
stop: Optional[Union[str, list]] = None | |
temperature: Optional[int] = None | |
top_p: Optional[int] = None | |
def __init__( | |
self, | |
frequency_penalty: Optional[int] = None, | |
function_call: Optional[Union[str, dict]] = None, | |
functions: Optional[list] = None, | |
logit_bias: Optional[dict] = None, | |
max_tokens: Optional[int] = None, | |
n: Optional[int] = None, | |
presence_penalty: Optional[int] = None, | |
stop: Optional[Union[str, list]] = None, | |
temperature: Optional[int] = None, | |
top_p: Optional[int] = None, | |
) -> None: | |
locals_ = locals() | |
for key, value in locals_.items(): | |
if key != "self" and value is not None: | |
setattr(self.__class__, key, value) | |
def get_config(cls): | |
return { | |
k: v | |
for k, v in cls.__dict__.items() | |
if not k.startswith("__") | |
and not isinstance( | |
v, | |
( | |
types.FunctionType, | |
types.BuiltinFunctionType, | |
classmethod, | |
staticmethod, | |
), | |
) | |
and v is not None | |
} | |
class OpenAITextCompletionConfig: | |
""" | |
Reference: https://platform.openai.com/docs/api-reference/completions/create | |
The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters: | |
- `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token. | |
- `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion. | |
- `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line. | |
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. | |
- `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens. | |
- `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion. | |
- `n` (integer or null): This optional parameter sets how many completions to generate for each prompt. | |
- `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics. | |
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. | |
- `suffix` (string or null): Defines the suffix that comes after a completion of inserted text. | |
- `temperature` (number or null): This optional parameter defines the sampling temperature to use. | |
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. | |
""" | |
best_of: Optional[int] = None | |
echo: Optional[bool] = None | |
frequency_penalty: Optional[int] = None | |
logit_bias: Optional[dict] = None | |
logprobs: Optional[int] = None | |
max_tokens: Optional[int] = None | |
n: Optional[int] = None | |
presence_penalty: Optional[int] = None | |
stop: Optional[Union[str, list]] = None | |
suffix: Optional[str] = None | |
temperature: Optional[float] = None | |
top_p: Optional[float] = None | |
def __init__( | |
self, | |
best_of: Optional[int] = None, | |
echo: Optional[bool] = None, | |
frequency_penalty: Optional[int] = None, | |
logit_bias: Optional[dict] = None, | |
logprobs: Optional[int] = None, | |
max_tokens: Optional[int] = None, | |
n: Optional[int] = None, | |
presence_penalty: Optional[int] = None, | |
stop: Optional[Union[str, list]] = None, | |
suffix: Optional[str] = None, | |
temperature: Optional[float] = None, | |
top_p: Optional[float] = None, | |
) -> None: | |
locals_ = locals() | |
for key, value in locals_.items(): | |
if key != "self" and value is not None: | |
setattr(self.__class__, key, value) | |
def get_config(cls): | |
return { | |
k: v | |
for k, v in cls.__dict__.items() | |
if not k.startswith("__") | |
and not isinstance( | |
v, | |
( | |
types.FunctionType, | |
types.BuiltinFunctionType, | |
classmethod, | |
staticmethod, | |
), | |
) | |
and v is not None | |
} | |
class OpenAIChatCompletion(BaseLLM): | |
def __init__(self) -> None: | |
super().__init__() | |
def completion( | |
self, | |
model_response: ModelResponse, | |
timeout: float, | |
model: Optional[str] = None, | |
messages: Optional[list] = None, | |
print_verbose: Optional[Callable] = None, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
acompletion: bool = False, | |
logging_obj=None, | |
optional_params=None, | |
litellm_params=None, | |
logger_fn=None, | |
headers: Optional[dict] = None, | |
custom_prompt_dict: dict = {}, | |
client=None, | |
): | |
super().completion() | |
exception_mapping_worked = False | |
try: | |
if headers: | |
optional_params["extra_headers"] = headers | |
if model is None or messages is None: | |
raise OpenAIError(status_code=422, message=f"Missing model or messages") | |
if not isinstance(timeout, float): | |
raise OpenAIError( | |
status_code=422, message=f"Timeout needs to be a float" | |
) | |
for _ in range( | |
2 | |
): # if call fails due to alternating messages, retry with reformatted message | |
data = {"model": model, "messages": messages, **optional_params} | |
try: | |
max_retries = data.pop("max_retries", 2) | |
if acompletion is True: | |
if optional_params.get("stream", False): | |
return self.async_streaming( | |
logging_obj=logging_obj, | |
headers=headers, | |
data=data, | |
model=model, | |
api_base=api_base, | |
api_key=api_key, | |
timeout=timeout, | |
client=client, | |
max_retries=max_retries, | |
) | |
else: | |
return self.acompletion( | |
data=data, | |
headers=headers, | |
logging_obj=logging_obj, | |
model_response=model_response, | |
api_base=api_base, | |
api_key=api_key, | |
timeout=timeout, | |
client=client, | |
max_retries=max_retries, | |
) | |
elif optional_params.get("stream", False): | |
return self.streaming( | |
logging_obj=logging_obj, | |
headers=headers, | |
data=data, | |
model=model, | |
api_base=api_base, | |
api_key=api_key, | |
timeout=timeout, | |
client=client, | |
max_retries=max_retries, | |
) | |
else: | |
if not isinstance(max_retries, int): | |
raise OpenAIError( | |
status_code=422, message="max retries must be an int" | |
) | |
if client is None: | |
openai_client = OpenAI( | |
api_key=api_key, | |
base_url=api_base, | |
http_client=litellm.client_session, | |
timeout=timeout, | |
max_retries=max_retries, | |
) | |
else: | |
openai_client = client | |
## LOGGING | |
logging_obj.pre_call( | |
input=messages, | |
api_key=openai_client.api_key, | |
additional_args={ | |
"headers": headers, | |
"api_base": openai_client._base_url._uri_reference, | |
"acompletion": acompletion, | |
"complete_input_dict": data, | |
}, | |
) | |
response = openai_client.chat.completions.create(**data, timeout=timeout) # type: ignore | |
stringified_response = response.model_dump() | |
logging_obj.post_call( | |
input=messages, | |
api_key=api_key, | |
original_response=stringified_response, | |
additional_args={"complete_input_dict": data}, | |
) | |
return convert_to_model_response_object( | |
response_object=stringified_response, | |
model_response_object=model_response, | |
) | |
except Exception as e: | |
if "Conversation roles must alternate user/assistant" in str( | |
e | |
) or "user and assistant roles should be alternating" in str(e): | |
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility | |
new_messages = [] | |
for i in range(len(messages) - 1): | |
new_messages.append(messages[i]) | |
if messages[i]["role"] == messages[i + 1]["role"]: | |
if messages[i]["role"] == "user": | |
new_messages.append( | |
{"role": "assistant", "content": ""} | |
) | |
else: | |
new_messages.append({"role": "user", "content": ""}) | |
new_messages.append(messages[-1]) | |
messages = new_messages | |
elif "Last message must have role `user`" in str(e): | |
new_messages = messages | |
new_messages.append({"role": "user", "content": ""}) | |
messages = new_messages | |
else: | |
raise e | |
except OpenAIError as e: | |
exception_mapping_worked = True | |
raise e | |
except Exception as e: | |
if hasattr(e, "status_code"): | |
raise OpenAIError(status_code=e.status_code, message=str(e)) | |
else: | |
raise OpenAIError(status_code=500, message=str(e)) | |
async def acompletion( | |
self, | |
data: dict, | |
model_response: ModelResponse, | |
timeout: float, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
client=None, | |
max_retries=None, | |
logging_obj=None, | |
headers=None, | |
): | |
response = None | |
try: | |
if client is None: | |
openai_aclient = AsyncOpenAI( | |
api_key=api_key, | |
base_url=api_base, | |
http_client=litellm.aclient_session, | |
timeout=timeout, | |
max_retries=max_retries, | |
) | |
else: | |
openai_aclient = client | |
## LOGGING | |
logging_obj.pre_call( | |
input=data["messages"], | |
api_key=openai_aclient.api_key, | |
additional_args={ | |
"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, | |
"api_base": openai_aclient._base_url._uri_reference, | |
"acompletion": True, | |
"complete_input_dict": data, | |
}, | |
) | |
response = await openai_aclient.chat.completions.create( | |
**data, timeout=timeout | |
) | |
stringified_response = response.model_dump() | |
logging_obj.post_call( | |
input=data["messages"], | |
api_key=api_key, | |
original_response=stringified_response, | |
additional_args={"complete_input_dict": data}, | |
) | |
return convert_to_model_response_object( | |
response_object=stringified_response, | |
model_response_object=model_response, | |
) | |
except Exception as e: | |
raise e | |
def streaming( | |
self, | |
logging_obj, | |
timeout: float, | |
data: dict, | |
model: str, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
client=None, | |
max_retries=None, | |
headers=None, | |
): | |
if client is None: | |
openai_client = OpenAI( | |
api_key=api_key, | |
base_url=api_base, | |
http_client=litellm.client_session, | |
timeout=timeout, | |
max_retries=max_retries, | |
) | |
else: | |
openai_client = client | |
## LOGGING | |
logging_obj.pre_call( | |
input=data["messages"], | |
api_key=api_key, | |
additional_args={ | |
"headers": headers, | |
"api_base": api_base, | |
"acompletion": False, | |
"complete_input_dict": data, | |
}, | |
) | |
response = openai_client.chat.completions.create(**data, timeout=timeout) | |
streamwrapper = CustomStreamWrapper( | |
completion_stream=response, | |
model=model, | |
custom_llm_provider="openai", | |
logging_obj=logging_obj, | |
) | |
return streamwrapper | |
async def async_streaming( | |
self, | |
logging_obj, | |
timeout: float, | |
data: dict, | |
model: str, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
client=None, | |
max_retries=None, | |
headers=None, | |
): | |
response = None | |
try: | |
if client is None: | |
openai_aclient = AsyncOpenAI( | |
api_key=api_key, | |
base_url=api_base, | |
http_client=litellm.aclient_session, | |
timeout=timeout, | |
max_retries=max_retries, | |
) | |
else: | |
openai_aclient = client | |
## LOGGING | |
logging_obj.pre_call( | |
input=data["messages"], | |
api_key=api_key, | |
additional_args={ | |
"headers": headers, | |
"api_base": api_base, | |
"acompletion": True, | |
"complete_input_dict": data, | |
}, | |
) | |
response = await openai_aclient.chat.completions.create( | |
**data, timeout=timeout | |
) | |
streamwrapper = CustomStreamWrapper( | |
completion_stream=response, | |
model=model, | |
custom_llm_provider="openai", | |
logging_obj=logging_obj, | |
) | |
return streamwrapper | |
except ( | |
Exception | |
) as e: # need to exception handle here. async exceptions don't get caught in sync functions. | |
if response is not None and hasattr(response, "text"): | |
raise OpenAIError( | |
status_code=500, | |
message=f"{str(e)}\n\nOriginal Response: {response.text}", | |
) | |
else: | |
if type(e).__name__ == "ReadTimeout": | |
raise OpenAIError(status_code=408, message=f"{type(e).__name__}") | |
elif hasattr(e, "status_code"): | |
raise OpenAIError(status_code=e.status_code, message=str(e)) | |
else: | |
raise OpenAIError(status_code=500, message=f"{str(e)}") | |
async def aembedding( | |
self, | |
input: list, | |
data: dict, | |
model_response: ModelResponse, | |
timeout: float, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
client=None, | |
max_retries=None, | |
logging_obj=None, | |
): | |
response = None | |
try: | |
if client is None: | |
openai_aclient = AsyncOpenAI( | |
api_key=api_key, | |
base_url=api_base, | |
http_client=litellm.aclient_session, | |
timeout=timeout, | |
max_retries=max_retries, | |
) | |
else: | |
openai_aclient = client | |
response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore | |
stringified_response = response.model_dump() | |
## LOGGING | |
logging_obj.post_call( | |
input=input, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data}, | |
original_response=stringified_response, | |
) | |
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="embedding") # type: ignore | |
except Exception as e: | |
## LOGGING | |
logging_obj.post_call( | |
input=input, | |
api_key=api_key, | |
original_response=str(e), | |
) | |
raise e | |
def embedding( | |
self, | |
model: str, | |
input: list, | |
timeout: float, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
model_response: Optional[litellm.utils.EmbeddingResponse] = None, | |
logging_obj=None, | |
optional_params=None, | |
client=None, | |
aembedding=None, | |
): | |
super().embedding() | |
exception_mapping_worked = False | |
try: | |
model = model | |
data = {"model": model, "input": input, **optional_params} | |
max_retries = data.pop("max_retries", 2) | |
if not isinstance(max_retries, int): | |
raise OpenAIError(status_code=422, message="max retries must be an int") | |
## LOGGING | |
logging_obj.pre_call( | |
input=input, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data, "api_base": api_base}, | |
) | |
if aembedding == True: | |
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore | |
return response | |
if client is None: | |
openai_client = OpenAI( | |
api_key=api_key, | |
base_url=api_base, | |
http_client=litellm.client_session, | |
timeout=timeout, | |
max_retries=max_retries, | |
) | |
else: | |
openai_client = client | |
## COMPLETION CALL | |
response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore | |
## LOGGING | |
logging_obj.post_call( | |
input=input, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data}, | |
original_response=response, | |
) | |
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore | |
except OpenAIError as e: | |
exception_mapping_worked = True | |
raise e | |
except Exception as e: | |
if hasattr(e, "status_code"): | |
raise OpenAIError(status_code=e.status_code, message=str(e)) | |
else: | |
raise OpenAIError(status_code=500, message=str(e)) | |
async def aimage_generation( | |
self, | |
prompt: str, | |
data: dict, | |
model_response: ModelResponse, | |
timeout: float, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
client=None, | |
max_retries=None, | |
logging_obj=None, | |
): | |
response = None | |
try: | |
if client is None: | |
openai_aclient = AsyncOpenAI( | |
api_key=api_key, | |
base_url=api_base, | |
http_client=litellm.aclient_session, | |
timeout=timeout, | |
max_retries=max_retries, | |
) | |
else: | |
openai_aclient = client | |
response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore | |
stringified_response = response.model_dump() | |
## LOGGING | |
logging_obj.post_call( | |
input=prompt, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data}, | |
original_response=stringified_response, | |
) | |
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore | |
except Exception as e: | |
## LOGGING | |
logging_obj.post_call( | |
input=input, | |
api_key=api_key, | |
original_response=str(e), | |
) | |
raise e | |
def image_generation( | |
self, | |
model: Optional[str], | |
prompt: str, | |
timeout: float, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
model_response: Optional[litellm.utils.ImageResponse] = None, | |
logging_obj=None, | |
optional_params=None, | |
client=None, | |
aimg_generation=None, | |
): | |
exception_mapping_worked = False | |
try: | |
model = model | |
data = {"model": model, "prompt": prompt, **optional_params} | |
max_retries = data.pop("max_retries", 2) | |
if not isinstance(max_retries, int): | |
raise OpenAIError(status_code=422, message="max retries must be an int") | |
if aimg_generation == True: | |
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore | |
return response | |
if client is None: | |
openai_client = OpenAI( | |
api_key=api_key, | |
base_url=api_base, | |
http_client=litellm.client_session, | |
timeout=timeout, | |
max_retries=max_retries, | |
) | |
else: | |
openai_client = client | |
## LOGGING | |
logging_obj.pre_call( | |
input=prompt, | |
api_key=openai_client.api_key, | |
additional_args={ | |
"headers": {"Authorization": f"Bearer {openai_client.api_key}"}, | |
"api_base": openai_client._base_url._uri_reference, | |
"acompletion": True, | |
"complete_input_dict": data, | |
}, | |
) | |
## COMPLETION CALL | |
response = openai_client.images.generate(**data, timeout=timeout) # type: ignore | |
## LOGGING | |
logging_obj.post_call( | |
input=input, | |
api_key=api_key, | |
additional_args={"complete_input_dict": data}, | |
original_response=response, | |
) | |
# return response | |
return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore | |
except OpenAIError as e: | |
exception_mapping_worked = True | |
raise e | |
except Exception as e: | |
if hasattr(e, "status_code"): | |
raise OpenAIError(status_code=e.status_code, message=str(e)) | |
else: | |
raise OpenAIError(status_code=500, message=str(e)) | |
async def ahealth_check( | |
self, | |
model: Optional[str], | |
api_key: str, | |
timeout: float, | |
mode: str, | |
messages: Optional[list] = None, | |
input: Optional[list] = None, | |
prompt: Optional[str] = None, | |
): | |
client = AsyncOpenAI(api_key=api_key, timeout=timeout) | |
if model is None and mode != "image_generation": | |
raise Exception("model is not set") | |
completion = None | |
if mode == "completion": | |
completion = await client.completions.with_raw_response.create( | |
model=model, # type: ignore | |
prompt=prompt, # type: ignore | |
) | |
elif mode == "chat": | |
if messages is None: | |
raise Exception("messages is not set") | |
completion = await client.chat.completions.with_raw_response.create( | |
model=model, # type: ignore | |
messages=messages, # type: ignore | |
) | |
elif mode == "embedding": | |
if input is None: | |
raise Exception("input is not set") | |
completion = await client.embeddings.with_raw_response.create( | |
model=model, # type: ignore | |
input=input, # type: ignore | |
) | |
elif mode == "image_generation": | |
if prompt is None: | |
raise Exception("prompt is not set") | |
completion = await client.images.with_raw_response.generate( | |
model=model, # type: ignore | |
prompt=prompt, # type: ignore | |
) | |
else: | |
raise Exception("mode not set") | |
response = {} | |
if completion is None or not hasattr(completion, "headers"): | |
raise Exception("invalid completion response") | |
if ( | |
completion.headers.get("x-ratelimit-remaining-requests", None) is not None | |
): # not provided for dall-e requests | |
response["x-ratelimit-remaining-requests"] = completion.headers[ | |
"x-ratelimit-remaining-requests" | |
] | |
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None: | |
response["x-ratelimit-remaining-tokens"] = completion.headers[ | |
"x-ratelimit-remaining-tokens" | |
] | |
return response | |
class OpenAITextCompletion(BaseLLM): | |
_client_session: httpx.Client | |
def __init__(self) -> None: | |
super().__init__() | |
self._client_session = self.create_client_session() | |
def validate_environment(self, api_key): | |
headers = { | |
"content-type": "application/json", | |
} | |
if api_key: | |
headers["Authorization"] = f"Bearer {api_key}" | |
return headers | |
def convert_to_model_response_object( | |
self, | |
response_object: Optional[dict] = None, | |
model_response_object: Optional[ModelResponse] = None, | |
): | |
try: | |
## RESPONSE OBJECT | |
if response_object is None or model_response_object is None: | |
raise ValueError("Error in response object format") | |
choice_list = [] | |
for idx, choice in enumerate(response_object["choices"]): | |
message = Message(content=choice["text"], role="assistant") | |
choice = Choices( | |
finish_reason=choice["finish_reason"], index=idx, message=message | |
) | |
choice_list.append(choice) | |
model_response_object.choices = choice_list | |
if "usage" in response_object: | |
model_response_object.usage = response_object["usage"] | |
if "id" in response_object: | |
model_response_object.id = response_object["id"] | |
if "model" in response_object: | |
model_response_object.model = response_object["model"] | |
model_response_object._hidden_params[ | |
"original_response" | |
] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response | |
return model_response_object | |
except Exception as e: | |
raise e | |
def completion( | |
self, | |
model_response: ModelResponse, | |
api_key: str, | |
model: str, | |
messages: list, | |
timeout: float, | |
print_verbose: Optional[Callable] = None, | |
api_base: Optional[str] = None, | |
logging_obj=None, | |
acompletion: bool = False, | |
optional_params=None, | |
litellm_params=None, | |
logger_fn=None, | |
headers: Optional[dict] = None, | |
): | |
super().completion() | |
exception_mapping_worked = False | |
try: | |
if headers is None: | |
headers = self.validate_environment(api_key=api_key) | |
if model is None or messages is None: | |
raise OpenAIError(status_code=422, message=f"Missing model or messages") | |
api_base = f"{api_base}/completions" | |
if ( | |
len(messages) > 0 | |
and "content" in messages[0] | |
and type(messages[0]["content"]) == list | |
): | |
prompt = messages[0]["content"] | |
else: | |
prompt = " ".join([message["content"] for message in messages]) # type: ignore | |
# don't send max retries to the api, if set | |
optional_params.pop("max_retries", None) | |
data = {"model": model, "prompt": prompt, **optional_params} | |
## LOGGING | |
logging_obj.pre_call( | |
input=messages, | |
api_key=api_key, | |
additional_args={ | |
"headers": headers, | |
"api_base": api_base, | |
"complete_input_dict": data, | |
}, | |
) | |
if acompletion == True: | |
if optional_params.get("stream", False): | |
return self.async_streaming( | |
logging_obj=logging_obj, | |
api_base=api_base, | |
data=data, | |
headers=headers, | |
model_response=model_response, | |
model=model, | |
timeout=timeout, | |
) | |
else: | |
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore | |
elif optional_params.get("stream", False): | |
return self.streaming( | |
logging_obj=logging_obj, | |
api_base=api_base, | |
data=data, | |
headers=headers, | |
model_response=model_response, | |
model=model, | |
timeout=timeout, | |
) | |
else: | |
response = httpx.post( | |
url=f"{api_base}", json=data, headers=headers, timeout=timeout | |
) | |
if response.status_code != 200: | |
raise OpenAIError( | |
status_code=response.status_code, message=response.text | |
) | |
## LOGGING | |
logging_obj.post_call( | |
input=prompt, | |
api_key=api_key, | |
original_response=response, | |
additional_args={ | |
"headers": headers, | |
"api_base": api_base, | |
}, | |
) | |
## RESPONSE OBJECT | |
return self.convert_to_model_response_object( | |
response_object=response.json(), | |
model_response_object=model_response, | |
) | |
except Exception as e: | |
raise e | |
async def acompletion( | |
self, | |
logging_obj, | |
api_base: str, | |
data: dict, | |
headers: dict, | |
model_response: ModelResponse, | |
prompt: str, | |
api_key: str, | |
model: str, | |
timeout: float, | |
): | |
async with httpx.AsyncClient(timeout=timeout) as client: | |
try: | |
response = await client.post( | |
api_base, | |
json=data, | |
headers=headers, | |
timeout=litellm.request_timeout, | |
) | |
response_json = response.json() | |
if response.status_code != 200: | |
raise OpenAIError( | |
status_code=response.status_code, message=response.text | |
) | |
## LOGGING | |
logging_obj.post_call( | |
input=prompt, | |
api_key=api_key, | |
original_response=response, | |
additional_args={ | |
"headers": headers, | |
"api_base": api_base, | |
}, | |
) | |
## RESPONSE OBJECT | |
return self.convert_to_model_response_object( | |
response_object=response_json, model_response_object=model_response | |
) | |
except Exception as e: | |
raise e | |
def streaming( | |
self, | |
logging_obj, | |
api_base: str, | |
data: dict, | |
headers: dict, | |
model_response: ModelResponse, | |
model: str, | |
timeout: float, | |
): | |
with httpx.stream( | |
url=f"{api_base}", | |
json=data, | |
headers=headers, | |
method="POST", | |
timeout=timeout, | |
) as response: | |
if response.status_code != 200: | |
raise OpenAIError( | |
status_code=response.status_code, message=response.text | |
) | |
streamwrapper = CustomStreamWrapper( | |
completion_stream=response.iter_lines(), | |
model=model, | |
custom_llm_provider="text-completion-openai", | |
logging_obj=logging_obj, | |
) | |
for transformed_chunk in streamwrapper: | |
yield transformed_chunk | |
async def async_streaming( | |
self, | |
logging_obj, | |
api_base: str, | |
data: dict, | |
headers: dict, | |
model_response: ModelResponse, | |
model: str, | |
timeout: float, | |
): | |
client = httpx.AsyncClient() | |
async with client.stream( | |
url=f"{api_base}", | |
json=data, | |
headers=headers, | |
method="POST", | |
timeout=timeout, | |
) as response: | |
try: | |
if response.status_code != 200: | |
raise OpenAIError( | |
status_code=response.status_code, message=response.text | |
) | |
streamwrapper = CustomStreamWrapper( | |
completion_stream=response.aiter_lines(), | |
model=model, | |
custom_llm_provider="text-completion-openai", | |
logging_obj=logging_obj, | |
) | |
async for transformed_chunk in streamwrapper: | |
yield transformed_chunk | |
except Exception as e: | |
raise e | |