Spaces:
Sleeping
Sleeping
import time | |
from typing import TYPE_CHECKING, Any, List, Optional, Union | |
import httpx | |
from litellm.llms.base_llm.chat.transformation import BaseLLMException | |
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig | |
from litellm.types.llms.openai import AllMessageValues | |
from litellm.types.utils import ModelResponse, Usage | |
from ..common_utils import OobaboogaError | |
if TYPE_CHECKING: | |
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj | |
LoggingClass = LiteLLMLoggingObj | |
else: | |
LoggingClass = Any | |
class OobaboogaConfig(OpenAIGPTConfig): | |
def get_error_class( | |
self, | |
error_message: str, | |
status_code: int, | |
headers: Optional[Union[dict, httpx.Headers]] = None, | |
) -> BaseLLMException: | |
return OobaboogaError( | |
status_code=status_code, message=error_message, headers=headers | |
) | |
def transform_response( | |
self, | |
model: str, | |
raw_response: httpx.Response, | |
model_response: ModelResponse, | |
logging_obj: LoggingClass, | |
request_data: dict, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
encoding: Any, | |
api_key: Optional[str] = None, | |
json_mode: Optional[bool] = None, | |
) -> ModelResponse: | |
## LOGGING | |
logging_obj.post_call( | |
input=messages, | |
api_key=api_key, | |
original_response=raw_response.text, | |
additional_args={"complete_input_dict": request_data}, | |
) | |
## RESPONSE OBJECT | |
try: | |
completion_response = raw_response.json() | |
except Exception: | |
raise OobaboogaError( | |
message=raw_response.text, status_code=raw_response.status_code | |
) | |
if "error" in completion_response: | |
raise OobaboogaError( | |
message=completion_response["error"], | |
status_code=raw_response.status_code, | |
) | |
else: | |
try: | |
model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore | |
except Exception as e: | |
raise OobaboogaError( | |
message=str(e), | |
status_code=raw_response.status_code, | |
) | |
model_response.created = int(time.time()) | |
model_response.model = model | |
usage = Usage( | |
prompt_tokens=completion_response["usage"]["prompt_tokens"], | |
completion_tokens=completion_response["usage"]["completion_tokens"], | |
total_tokens=completion_response["usage"]["total_tokens"], | |
) | |
setattr(model_response, "usage", usage) | |
return model_response | |
def validate_environment( | |
self, | |
headers: dict, | |
model: str, | |
messages: List[AllMessageValues], | |
optional_params: dict, | |
litellm_params: dict, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
) -> dict: | |
headers = { | |
"accept": "application/json", | |
"content-type": "application/json", | |
} | |
if api_key is not None: | |
headers["Authorization"] = f"Token {api_key}" | |
return headers | |