Spaces:
Configuration error
Configuration error
import asyncio | |
import concurrent.futures | |
import json | |
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | |
import litellm | |
from litellm._logging import verbose_logger | |
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig | |
from litellm.types.llms.openai import ( | |
OpenAIRealtimeEvents, | |
OpenAIRealtimeOutputItemDone, | |
OpenAIRealtimeResponseDelta, | |
OpenAIRealtimeStreamResponseBaseObject, | |
OpenAIRealtimeStreamSessionEvents, | |
) | |
from litellm.types.realtime import ALL_DELTA_TYPES | |
from .litellm_logging import Logging as LiteLLMLogging | |
if TYPE_CHECKING: | |
from websockets.asyncio.client import ClientConnection | |
CLIENT_CONNECTION_CLASS = ClientConnection | |
else: | |
CLIENT_CONNECTION_CLASS = Any | |
# Create a thread pool with a maximum of 10 threads | |
executor = concurrent.futures.ThreadPoolExecutor(max_workers=10) | |
DefaultLoggedRealTimeEventTypes = [ | |
"session.created", | |
"response.create", | |
"response.done", | |
] | |
class RealTimeStreaming: | |
def __init__( | |
self, | |
websocket: Any, | |
backend_ws: CLIENT_CONNECTION_CLASS, | |
logging_obj: LiteLLMLogging, | |
provider_config: Optional[BaseRealtimeConfig] = None, | |
model: str = "", | |
): | |
self.websocket = websocket | |
self.backend_ws = backend_ws | |
self.logging_obj = logging_obj | |
self.messages: List[OpenAIRealtimeEvents] = [] | |
self.input_message: Dict = {} | |
_logged_real_time_event_types = litellm.logged_real_time_event_types | |
if _logged_real_time_event_types is None: | |
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes | |
self.logged_real_time_event_types = _logged_real_time_event_types | |
self.provider_config = provider_config | |
self.model = model | |
self.current_delta_chunks: Optional[List[OpenAIRealtimeResponseDelta]] = None | |
self.current_output_item_id: Optional[str] = None | |
self.current_response_id: Optional[str] = None | |
self.current_conversation_id: Optional[str] = None | |
self.current_item_chunks: Optional[List[OpenAIRealtimeOutputItemDone]] = None | |
self.current_delta_type: Optional[ALL_DELTA_TYPES] = None | |
self.session_configuration_request: Optional[str] = None | |
def _should_store_message( | |
self, | |
message_obj: Union[dict, OpenAIRealtimeEvents], | |
) -> bool: | |
_msg_type = message_obj["type"] if "type" in message_obj else None | |
if self.logged_real_time_event_types == "*": | |
return True | |
if _msg_type and _msg_type in self.logged_real_time_event_types: | |
return True | |
return False | |
def store_message(self, message: Union[str, bytes, OpenAIRealtimeEvents]): | |
"""Store message in list""" | |
if isinstance(message, bytes): | |
message = message.decode("utf-8") | |
if isinstance(message, dict): | |
message_obj = message | |
else: | |
message_obj = json.loads(message) | |
try: | |
if ( | |
not isinstance(message, dict) | |
or message_obj.get("type") == "session.created" | |
or message_obj.get("type") == "session.updated" | |
): | |
message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore | |
elif not isinstance(message, dict): | |
message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore | |
except Exception as e: | |
verbose_logger.debug(f"Error parsing message for logging: {e}") | |
raise e | |
if self._should_store_message(message_obj): | |
self.messages.append(message_obj) | |
def store_input(self, message: dict): | |
"""Store input message""" | |
self.input_message = message | |
if self.logging_obj: | |
self.logging_obj.pre_call(input=message, api_key="") | |
async def log_messages(self): | |
"""Log messages in list""" | |
if self.logging_obj: | |
## ASYNC LOGGING | |
# Create an event loop for the new thread | |
asyncio.create_task(self.logging_obj.async_success_handler(self.messages)) | |
## SYNC LOGGING | |
executor.submit(self.logging_obj.success_handler(self.messages)) | |
async def backend_to_client_send_messages(self): | |
import websockets | |
try: | |
while True: | |
try: | |
raw_response = await self.backend_ws.recv( | |
decode=False | |
) # improves performance | |
except TypeError: | |
raw_response = await self.backend_ws.recv() # type: ignore[assignment] | |
if self.provider_config: | |
returned_object = self.provider_config.transform_realtime_response( | |
raw_response, | |
self.model, | |
self.logging_obj, | |
realtime_response_transform_input={ | |
"session_configuration_request": self.session_configuration_request, | |
"current_output_item_id": self.current_output_item_id, | |
"current_response_id": self.current_response_id, | |
"current_delta_chunks": self.current_delta_chunks, | |
"current_conversation_id": self.current_conversation_id, | |
"current_item_chunks": self.current_item_chunks, | |
"current_delta_type": self.current_delta_type, | |
}, | |
) | |
transformed_response = returned_object["response"] | |
self.current_output_item_id = returned_object[ | |
"current_output_item_id" | |
] | |
self.current_response_id = returned_object["current_response_id"] | |
self.current_delta_chunks = returned_object["current_delta_chunks"] | |
self.current_conversation_id = returned_object[ | |
"current_conversation_id" | |
] | |
self.current_item_chunks = returned_object["current_item_chunks"] | |
self.current_delta_type = returned_object["current_delta_type"] | |
self.session_configuration_request = returned_object[ | |
"session_configuration_request" | |
] | |
if isinstance(transformed_response, list): | |
for event in transformed_response: | |
event_str = json.dumps(event) | |
## LOGGING | |
self.store_message(event_str) | |
await self.websocket.send_text(event_str) | |
else: | |
event_str = json.dumps(transformed_response) | |
## LOGGING | |
self.store_message(event_str) | |
await self.websocket.send_text(event_str) | |
else: | |
## LOGGING | |
self.store_message(raw_response) | |
await self.websocket.send_text(raw_response) | |
except websockets.exceptions.ConnectionClosed as e: # type: ignore | |
verbose_logger.exception( | |
f"Connection closed in backend to client send messages - {e}" | |
) | |
except Exception as e: | |
verbose_logger.exception(f"Error in backend to client send messages: {e}") | |
finally: | |
await self.log_messages() | |
async def client_ack_messages(self): | |
try: | |
while True: | |
message = await self.websocket.receive_text() | |
## LOGGING | |
self.store_input(message=message) | |
## FORWARD TO BACKEND | |
if self.provider_config: | |
message = self.provider_config.transform_realtime_request( | |
message, self.model | |
) | |
for msg in message: | |
await self.backend_ws.send(msg) | |
else: | |
await self.backend_ws.send(message) | |
except Exception as e: | |
verbose_logger.debug(f"Error in client ack messages: {e}") | |
async def bidirectional_forward(self): | |
forward_task = asyncio.create_task(self.backend_to_client_send_messages()) | |
try: | |
await self.client_ack_messages() | |
except self.websocket.exceptions.ConnectionClosed: # type: ignore | |
verbose_logger.debug("Connection closed") | |
forward_task.cancel() | |
finally: | |
if not forward_task.done(): | |
forward_task.cancel() | |
try: | |
await forward_task | |
except asyncio.CancelledError: | |
pass | |