Spaces:
Configuration error
Configuration error
File size: 8,882 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
import asyncio
import concurrent.futures
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
from litellm.types.llms.openai import (
OpenAIRealtimeEvents,
OpenAIRealtimeOutputItemDone,
OpenAIRealtimeResponseDelta,
OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents,
)
from litellm.types.realtime import ALL_DELTA_TYPES
from .litellm_logging import Logging as LiteLLMLogging
if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
CLIENT_CONNECTION_CLASS = ClientConnection
else:
CLIENT_CONNECTION_CLASS = Any
# Create a thread pool with a maximum of 10 threads
executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
DefaultLoggedRealTimeEventTypes = [
"session.created",
"response.create",
"response.done",
]
class RealTimeStreaming:
def __init__(
self,
websocket: Any,
backend_ws: CLIENT_CONNECTION_CLASS,
logging_obj: LiteLLMLogging,
provider_config: Optional[BaseRealtimeConfig] = None,
model: str = "",
):
self.websocket = websocket
self.backend_ws = backend_ws
self.logging_obj = logging_obj
self.messages: List[OpenAIRealtimeEvents] = []
self.input_message: Dict = {}
_logged_real_time_event_types = litellm.logged_real_time_event_types
if _logged_real_time_event_types is None:
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes
self.logged_real_time_event_types = _logged_real_time_event_types
self.provider_config = provider_config
self.model = model
self.current_delta_chunks: Optional[List[OpenAIRealtimeResponseDelta]] = None
self.current_output_item_id: Optional[str] = None
self.current_response_id: Optional[str] = None
self.current_conversation_id: Optional[str] = None
self.current_item_chunks: Optional[List[OpenAIRealtimeOutputItemDone]] = None
self.current_delta_type: Optional[ALL_DELTA_TYPES] = None
self.session_configuration_request: Optional[str] = None
def _should_store_message(
self,
message_obj: Union[dict, OpenAIRealtimeEvents],
) -> bool:
_msg_type = message_obj["type"] if "type" in message_obj else None
if self.logged_real_time_event_types == "*":
return True
if _msg_type and _msg_type in self.logged_real_time_event_types:
return True
return False
def store_message(self, message: Union[str, bytes, OpenAIRealtimeEvents]):
"""Store message in list"""
if isinstance(message, bytes):
message = message.decode("utf-8")
if isinstance(message, dict):
message_obj = message
else:
message_obj = json.loads(message)
try:
if (
not isinstance(message, dict)
or message_obj.get("type") == "session.created"
or message_obj.get("type") == "session.updated"
):
message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore
elif not isinstance(message, dict):
message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore
except Exception as e:
verbose_logger.debug(f"Error parsing message for logging: {e}")
raise e
if self._should_store_message(message_obj):
self.messages.append(message_obj)
def store_input(self, message: dict):
"""Store input message"""
self.input_message = message
if self.logging_obj:
self.logging_obj.pre_call(input=message, api_key="")
async def log_messages(self):
"""Log messages in list"""
if self.logging_obj:
## ASYNC LOGGING
# Create an event loop for the new thread
asyncio.create_task(self.logging_obj.async_success_handler(self.messages))
## SYNC LOGGING
executor.submit(self.logging_obj.success_handler(self.messages))
async def backend_to_client_send_messages(self):
import websockets
try:
while True:
try:
raw_response = await self.backend_ws.recv(
decode=False
) # improves performance
except TypeError:
raw_response = await self.backend_ws.recv() # type: ignore[assignment]
if self.provider_config:
returned_object = self.provider_config.transform_realtime_response(
raw_response,
self.model,
self.logging_obj,
realtime_response_transform_input={
"session_configuration_request": self.session_configuration_request,
"current_output_item_id": self.current_output_item_id,
"current_response_id": self.current_response_id,
"current_delta_chunks": self.current_delta_chunks,
"current_conversation_id": self.current_conversation_id,
"current_item_chunks": self.current_item_chunks,
"current_delta_type": self.current_delta_type,
},
)
transformed_response = returned_object["response"]
self.current_output_item_id = returned_object[
"current_output_item_id"
]
self.current_response_id = returned_object["current_response_id"]
self.current_delta_chunks = returned_object["current_delta_chunks"]
self.current_conversation_id = returned_object[
"current_conversation_id"
]
self.current_item_chunks = returned_object["current_item_chunks"]
self.current_delta_type = returned_object["current_delta_type"]
self.session_configuration_request = returned_object[
"session_configuration_request"
]
if isinstance(transformed_response, list):
for event in transformed_response:
event_str = json.dumps(event)
## LOGGING
self.store_message(event_str)
await self.websocket.send_text(event_str)
else:
event_str = json.dumps(transformed_response)
## LOGGING
self.store_message(event_str)
await self.websocket.send_text(event_str)
else:
## LOGGING
self.store_message(raw_response)
await self.websocket.send_text(raw_response)
except websockets.exceptions.ConnectionClosed as e: # type: ignore
verbose_logger.exception(
f"Connection closed in backend to client send messages - {e}"
)
except Exception as e:
verbose_logger.exception(f"Error in backend to client send messages: {e}")
finally:
await self.log_messages()
async def client_ack_messages(self):
try:
while True:
message = await self.websocket.receive_text()
## LOGGING
self.store_input(message=message)
## FORWARD TO BACKEND
if self.provider_config:
message = self.provider_config.transform_realtime_request(
message, self.model
)
for msg in message:
await self.backend_ws.send(msg)
else:
await self.backend_ws.send(message)
except Exception as e:
verbose_logger.debug(f"Error in client ack messages: {e}")
async def bidirectional_forward(self):
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
try:
await self.client_ack_messages()
except self.websocket.exceptions.ConnectionClosed: # type: ignore
verbose_logger.debug("Connection closed")
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass
|