Spaces:
Sleeping
Sleeping
""" | |
This file contains the calling Azure OpenAI's `/openai/realtime` endpoint. | |
This requires websockets, and is currently only supported on LiteLLM Proxy. | |
""" | |
from typing import Any, Optional | |
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging | |
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming | |
from ..azure import AzureChatCompletion | |
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01" | |
async def forward_messages(client_ws: Any, backend_ws: Any): | |
import websockets | |
try: | |
while True: | |
message = await backend_ws.recv() | |
await client_ws.send_text(message) | |
except websockets.exceptions.ConnectionClosed: # type: ignore | |
pass | |
class AzureOpenAIRealtime(AzureChatCompletion): | |
def _construct_url(self, api_base: str, model: str, api_version: str) -> str: | |
""" | |
Example output: | |
"wss://my-endpoint-sweden-berri992.openai.azure.com/openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview"; | |
""" | |
api_base = api_base.replace("https://", "wss://") | |
return ( | |
f"{api_base}/openai/realtime?api-version={api_version}&deployment={model}" | |
) | |
async def async_realtime( | |
self, | |
model: str, | |
websocket: Any, | |
api_base: Optional[str] = None, | |
api_key: Optional[str] = None, | |
api_version: Optional[str] = None, | |
azure_ad_token: Optional[str] = None, | |
client: Optional[Any] = None, | |
logging_obj: Optional[LiteLLMLogging] = None, | |
timeout: Optional[float] = None, | |
): | |
import websockets | |
if api_base is None: | |
raise ValueError("api_base is required for Azure OpenAI calls") | |
if api_version is None: | |
raise ValueError("api_version is required for Azure OpenAI calls") | |
url = self._construct_url(api_base, model, api_version) | |
try: | |
async with websockets.connect( # type: ignore | |
url, | |
extra_headers={ | |
"api-key": api_key, # type: ignore | |
}, | |
) as backend_ws: | |
realtime_streaming = RealTimeStreaming( | |
websocket, backend_ws, logging_obj | |
) | |
await realtime_streaming.bidirectional_forward() | |
except websockets.exceptions.InvalidStatusCode as e: # type: ignore | |
await websocket.close(code=e.status_code, reason=str(e)) | |
except Exception: | |
pass | |