Spaces:
Configuration error
Configuration error
File size: 7,645 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
"""
Handler for transforming /chat/completions api requests to litellm.responses requests
"""
from typing import TYPE_CHECKING, Any, Coroutine, TypedDict, Union
if TYPE_CHECKING:
from litellm import CustomStreamWrapper, LiteLLMLoggingObj, ModelResponse
class ResponsesToCompletionBridgeHandlerInputKwargs(TypedDict):
model: str
messages: list
optional_params: dict
litellm_params: dict
headers: dict
model_response: "ModelResponse"
logging_obj: "LiteLLMLoggingObj"
custom_llm_provider: str
class ResponsesToCompletionBridgeHandler:
def __init__(self):
from .transformation import LiteLLMResponsesTransformationHandler
super().__init__()
self.transformation_handler = LiteLLMResponsesTransformationHandler()
def validate_input_kwargs(
self, kwargs: dict
) -> ResponsesToCompletionBridgeHandlerInputKwargs:
from litellm import LiteLLMLoggingObj
from litellm.types.utils import ModelResponse
model = kwargs.get("model")
if model is None or not isinstance(model, str):
raise ValueError("model is required")
custom_llm_provider = kwargs.get("custom_llm_provider")
if custom_llm_provider is None or not isinstance(custom_llm_provider, str):
raise ValueError("custom_llm_provider is required")
messages = kwargs.get("messages")
if messages is None or not isinstance(messages, list):
raise ValueError("messages is required")
optional_params = kwargs.get("optional_params")
if optional_params is None or not isinstance(optional_params, dict):
raise ValueError("optional_params is required")
litellm_params = kwargs.get("litellm_params")
if litellm_params is None or not isinstance(litellm_params, dict):
raise ValueError("litellm_params is required")
headers = kwargs.get("headers")
if headers is None or not isinstance(headers, dict):
raise ValueError("headers is required")
model_response = kwargs.get("model_response")
if model_response is None or not isinstance(model_response, ModelResponse):
raise ValueError("model_response is required")
logging_obj = kwargs.get("logging_obj")
if logging_obj is None or not isinstance(logging_obj, LiteLLMLoggingObj):
raise ValueError("logging_obj is required")
return ResponsesToCompletionBridgeHandlerInputKwargs(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
model_response=model_response,
logging_obj=logging_obj,
custom_llm_provider=custom_llm_provider,
)
def completion(
self, *args, **kwargs
) -> Union[
Coroutine[Any, Any, Union["ModelResponse", "CustomStreamWrapper"]],
"ModelResponse",
"CustomStreamWrapper",
]:
if kwargs.get("acompletion") is True:
return self.acompletion(**kwargs)
from litellm import responses
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.types.llms.openai import ResponsesAPIResponse
validated_kwargs = self.validate_input_kwargs(kwargs)
model = validated_kwargs["model"]
messages = validated_kwargs["messages"]
optional_params = validated_kwargs["optional_params"]
litellm_params = validated_kwargs["litellm_params"]
headers = validated_kwargs["headers"]
model_response = validated_kwargs["model_response"]
logging_obj = validated_kwargs["logging_obj"]
custom_llm_provider = validated_kwargs["custom_llm_provider"]
request_data = self.transformation_handler.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
result = responses(
**request_data,
)
if isinstance(result, ResponsesAPIResponse):
return self.transformation_handler.transform_response(
model=model,
raw_response=result,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=kwargs.get("encoding"),
api_key=kwargs.get("api_key"),
json_mode=kwargs.get("json_mode"),
)
else:
completion_stream = self.transformation_handler.get_model_response_iterator(
streaming_response=result, # type: ignore
sync_stream=True,
json_mode=kwargs.get("json_mode"),
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return streamwrapper
async def acompletion(
self, *args, **kwargs
) -> Union["ModelResponse", "CustomStreamWrapper"]:
from litellm import aresponses
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.types.llms.openai import ResponsesAPIResponse
validated_kwargs = self.validate_input_kwargs(kwargs)
model = validated_kwargs["model"]
messages = validated_kwargs["messages"]
optional_params = validated_kwargs["optional_params"]
litellm_params = validated_kwargs["litellm_params"]
headers = validated_kwargs["headers"]
model_response = validated_kwargs["model_response"]
logging_obj = validated_kwargs["logging_obj"]
custom_llm_provider = validated_kwargs["custom_llm_provider"]
request_data = self.transformation_handler.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
result = await aresponses(
**request_data,
aresponses=True,
)
if isinstance(result, ResponsesAPIResponse):
return self.transformation_handler.transform_response(
model=model,
raw_response=result,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=kwargs.get("encoding"),
api_key=kwargs.get("api_key"),
json_mode=kwargs.get("json_mode"),
)
else:
completion_stream = self.transformation_handler.get_model_response_iterator(
streaming_response=result, # type: ignore
sync_stream=False,
json_mode=kwargs.get("json_mode"),
)
streamwrapper = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
)
return streamwrapper
responses_api_bridge = ResponsesToCompletionBridgeHandler()
|