Spaces:
Configuration error
Configuration error
File size: 5,759 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
"""Abstraction function for OpenAI's realtime API"""
from typing import Any, Optional, cast
import litellm
from litellm import get_llm_provider
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
from litellm.utils import ProviderConfigManager
from ..litellm_core_utils.get_litellm_params import get_litellm_params
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
from ..llms.openai.realtime.handler import OpenAIRealtime
from ..utils import client as wrapper_client
azure_realtime = AzureOpenAIRealtime()
openai_realtime = OpenAIRealtime()
base_llm_http_handler = BaseLLMHTTPHandler()
@wrapper_client
async def _arealtime(
model: str,
websocket: Any, # fastapi websocket
api_base: Optional[str] = None,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
azure_ad_token: Optional[str] = None,
client: Optional[Any] = None,
timeout: Optional[float] = None,
**kwargs,
):
"""
Private function to handle the realtime API call.
For PROXY use only.
"""
headers = cast(Optional[dict], kwargs.get("headers"))
extra_headers = cast(Optional[dict], kwargs.get("extra_headers"))
if headers is None:
headers = {}
if extra_headers is not None:
headers.update(extra_headers)
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
user = kwargs.get("user", None)
litellm_params = GenericLiteLLMParams(**kwargs)
litellm_params_dict = get_litellm_params(**kwargs)
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
model=model,
api_base=api_base,
api_key=api_key,
)
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params={},
litellm_params=litellm_params_dict,
custom_llm_provider=_custom_llm_provider,
)
provider_config: Optional[BaseRealtimeConfig] = None
if _custom_llm_provider in LlmProviders._member_map_.values():
provider_config = ProviderConfigManager.get_provider_realtime_config(
model=model,
provider=LlmProviders(_custom_llm_provider),
)
if provider_config is not None:
await base_llm_http_handler.async_realtime(
model=model,
websocket=websocket,
logging_obj=litellm_logging_obj,
provider_config=provider_config,
api_base=api_base,
api_key=api_key,
client=client,
timeout=timeout,
headers=headers,
)
elif _custom_llm_provider == "azure":
api_base = (
dynamic_api_base
or litellm_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
# set API KEY
api_key = (
dynamic_api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("AZURE_API_KEY")
)
await azure_realtime.async_realtime(
model=model,
websocket=websocket,
api_base=api_base,
api_key=api_key,
api_version="2024-10-01-preview",
azure_ad_token=None,
client=None,
timeout=timeout,
logging_obj=litellm_logging_obj,
)
elif _custom_llm_provider == "openai":
api_base = (
dynamic_api_base
or litellm_params.api_base
or litellm.api_base
or "https://api.openai.com/"
)
# set API KEY
api_key = (
dynamic_api_key
or litellm.api_key
or litellm.openai_key
or get_secret_str("OPENAI_API_KEY")
)
await openai_realtime.async_realtime(
model=model,
websocket=websocket,
logging_obj=litellm_logging_obj,
api_base=api_base,
api_key=api_key,
client=None,
timeout=timeout,
)
else:
raise ValueError(f"Unsupported model: {model}")
async def _realtime_health_check(
model: str,
custom_llm_provider: str,
api_key: Optional[str],
api_base: Optional[str] = None,
api_version: Optional[str] = None,
):
"""
Health check for realtime API - tries connection to the realtime API websocket
Args:
model: str - model name
api_base: str - api base
api_version: Optional[str] - api version
api_key: str - api key
custom_llm_provider: str - custom llm provider
Returns:
bool - True if connection is successful, False otherwise
Raises:
Exception - if the connection is not successful
"""
import websockets
url: Optional[str] = None
if custom_llm_provider == "azure":
url = azure_realtime._construct_url(
api_base=api_base or "",
model=model,
api_version=api_version or "2024-10-01-preview",
)
elif custom_llm_provider == "openai":
url = openai_realtime._construct_url(
api_base=api_base or "https://api.openai.com/", model=model
)
else:
raise ValueError(f"Unsupported model: {model}")
async with websockets.connect( # type: ignore
url,
extra_headers={
"api-key": api_key, # type: ignore
},
):
return True
|