test3 / litellm /litellm_core_utils /realtime_streaming.py
DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
import asyncio
import concurrent.futures
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
from litellm.types.llms.openai import (
OpenAIRealtimeEvents,
OpenAIRealtimeOutputItemDone,
OpenAIRealtimeResponseDelta,
OpenAIRealtimeStreamResponseBaseObject,
OpenAIRealtimeStreamSessionEvents,
)
from litellm.types.realtime import ALL_DELTA_TYPES
from .litellm_logging import Logging as LiteLLMLogging
if TYPE_CHECKING:
from websockets.asyncio.client import ClientConnection
CLIENT_CONNECTION_CLASS = ClientConnection
else:
CLIENT_CONNECTION_CLASS = Any
# Create a thread pool with a maximum of 10 threads
executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
DefaultLoggedRealTimeEventTypes = [
"session.created",
"response.create",
"response.done",
]
class RealTimeStreaming:
def __init__(
self,
websocket: Any,
backend_ws: CLIENT_CONNECTION_CLASS,
logging_obj: LiteLLMLogging,
provider_config: Optional[BaseRealtimeConfig] = None,
model: str = "",
):
self.websocket = websocket
self.backend_ws = backend_ws
self.logging_obj = logging_obj
self.messages: List[OpenAIRealtimeEvents] = []
self.input_message: Dict = {}
_logged_real_time_event_types = litellm.logged_real_time_event_types
if _logged_real_time_event_types is None:
_logged_real_time_event_types = DefaultLoggedRealTimeEventTypes
self.logged_real_time_event_types = _logged_real_time_event_types
self.provider_config = provider_config
self.model = model
self.current_delta_chunks: Optional[List[OpenAIRealtimeResponseDelta]] = None
self.current_output_item_id: Optional[str] = None
self.current_response_id: Optional[str] = None
self.current_conversation_id: Optional[str] = None
self.current_item_chunks: Optional[List[OpenAIRealtimeOutputItemDone]] = None
self.current_delta_type: Optional[ALL_DELTA_TYPES] = None
self.session_configuration_request: Optional[str] = None
def _should_store_message(
self,
message_obj: Union[dict, OpenAIRealtimeEvents],
) -> bool:
_msg_type = message_obj["type"] if "type" in message_obj else None
if self.logged_real_time_event_types == "*":
return True
if _msg_type and _msg_type in self.logged_real_time_event_types:
return True
return False
def store_message(self, message: Union[str, bytes, OpenAIRealtimeEvents]):
"""Store message in list"""
if isinstance(message, bytes):
message = message.decode("utf-8")
if isinstance(message, dict):
message_obj = message
else:
message_obj = json.loads(message)
try:
if (
not isinstance(message, dict)
or message_obj.get("type") == "session.created"
or message_obj.get("type") == "session.updated"
):
message_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore
elif not isinstance(message, dict):
message_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore
except Exception as e:
verbose_logger.debug(f"Error parsing message for logging: {e}")
raise e
if self._should_store_message(message_obj):
self.messages.append(message_obj)
def store_input(self, message: dict):
"""Store input message"""
self.input_message = message
if self.logging_obj:
self.logging_obj.pre_call(input=message, api_key="")
async def log_messages(self):
"""Log messages in list"""
if self.logging_obj:
## ASYNC LOGGING
# Create an event loop for the new thread
asyncio.create_task(self.logging_obj.async_success_handler(self.messages))
## SYNC LOGGING
executor.submit(self.logging_obj.success_handler(self.messages))
async def backend_to_client_send_messages(self):
import websockets
try:
while True:
try:
raw_response = await self.backend_ws.recv(
decode=False
) # improves performance
except TypeError:
raw_response = await self.backend_ws.recv() # type: ignore[assignment]
if self.provider_config:
returned_object = self.provider_config.transform_realtime_response(
raw_response,
self.model,
self.logging_obj,
realtime_response_transform_input={
"session_configuration_request": self.session_configuration_request,
"current_output_item_id": self.current_output_item_id,
"current_response_id": self.current_response_id,
"current_delta_chunks": self.current_delta_chunks,
"current_conversation_id": self.current_conversation_id,
"current_item_chunks": self.current_item_chunks,
"current_delta_type": self.current_delta_type,
},
)
transformed_response = returned_object["response"]
self.current_output_item_id = returned_object[
"current_output_item_id"
]
self.current_response_id = returned_object["current_response_id"]
self.current_delta_chunks = returned_object["current_delta_chunks"]
self.current_conversation_id = returned_object[
"current_conversation_id"
]
self.current_item_chunks = returned_object["current_item_chunks"]
self.current_delta_type = returned_object["current_delta_type"]
self.session_configuration_request = returned_object[
"session_configuration_request"
]
if isinstance(transformed_response, list):
for event in transformed_response:
event_str = json.dumps(event)
## LOGGING
self.store_message(event_str)
await self.websocket.send_text(event_str)
else:
event_str = json.dumps(transformed_response)
## LOGGING
self.store_message(event_str)
await self.websocket.send_text(event_str)
else:
## LOGGING
self.store_message(raw_response)
await self.websocket.send_text(raw_response)
except websockets.exceptions.ConnectionClosed as e: # type: ignore
verbose_logger.exception(
f"Connection closed in backend to client send messages - {e}"
)
except Exception as e:
verbose_logger.exception(f"Error in backend to client send messages: {e}")
finally:
await self.log_messages()
async def client_ack_messages(self):
try:
while True:
message = await self.websocket.receive_text()
## LOGGING
self.store_input(message=message)
## FORWARD TO BACKEND
if self.provider_config:
message = self.provider_config.transform_realtime_request(
message, self.model
)
for msg in message:
await self.backend_ws.send(msg)
else:
await self.backend_ws.send(message)
except Exception as e:
verbose_logger.debug(f"Error in client ack messages: {e}")
async def bidirectional_forward(self):
forward_task = asyncio.create_task(self.backend_to_client_send_messages())
try:
await self.client_ack_messages()
except self.websocket.exceptions.ConnectionClosed: # type: ignore
verbose_logger.debug("Connection closed")
forward_task.cancel()
finally:
if not forward_task.done():
forward_task.cancel()
try:
await forward_task
except asyncio.CancelledError:
pass