Spaces:
Configuration error
Configuration error
File size: 8,335 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
import asyncio
import contextlib
import typing
from typing import Callable, Dict, Union
import aiohttp
import aiohttp.client_exceptions
import aiohttp.http_exceptions
import httpx
from aiohttp.client import ClientResponse, ClientSession
from litellm._logging import verbose_logger
AIOHTTP_EXC_MAP: Dict = {
# Order matters here, most specific exception first
# Timeout related exceptions
aiohttp.ServerTimeoutError: httpx.TimeoutException,
aiohttp.ConnectionTimeoutError: httpx.ConnectTimeout,
aiohttp.SocketTimeoutError: httpx.ReadTimeout,
# Proxy related exceptions
aiohttp.ClientProxyConnectionError: httpx.ProxyError,
# SSL related exceptions
aiohttp.ClientConnectorCertificateError: httpx.ProtocolError,
aiohttp.ClientSSLError: httpx.ProtocolError,
aiohttp.ServerFingerprintMismatch: httpx.ProtocolError,
# Network related exceptions
aiohttp.ClientConnectorError: httpx.ConnectError,
aiohttp.ClientOSError: httpx.ConnectError,
aiohttp.ClientPayloadError: httpx.ReadError,
# Connection disconnection exceptions
aiohttp.ServerDisconnectedError: httpx.ReadError,
# Response related exceptions
aiohttp.ClientConnectionError: httpx.NetworkError,
aiohttp.ClientPayloadError: httpx.ReadError,
aiohttp.ContentTypeError: httpx.ReadError,
aiohttp.TooManyRedirects: httpx.TooManyRedirects,
# URL related exceptions
aiohttp.InvalidURL: httpx.InvalidURL,
# Base exceptions
aiohttp.ClientError: httpx.RequestError,
}
# Add client_exceptions module exceptions
try:
import aiohttp.client_exceptions
AIOHTTP_EXC_MAP[aiohttp.client_exceptions.ClientPayloadError] = httpx.ReadError
except ImportError:
pass
@contextlib.contextmanager
def map_aiohttp_exceptions() -> typing.Iterator[None]:
try:
yield
except Exception as exc:
mapped_exc = None
for from_exc, to_exc in AIOHTTP_EXC_MAP.items():
if not isinstance(exc, from_exc): # type: ignore
continue
if mapped_exc is None or issubclass(to_exc, mapped_exc):
mapped_exc = to_exc
if mapped_exc is None: # pragma: no cover
raise
message = str(exc)
raise mapped_exc(message) from exc
class AiohttpResponseStream(httpx.AsyncByteStream):
CHUNK_SIZE = 1024 * 16
def __init__(self, aiohttp_response: ClientResponse) -> None:
self._aiohttp_response = aiohttp_response
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
try:
async for chunk in self._aiohttp_response.content.iter_chunked(
self.CHUNK_SIZE
):
yield chunk
except (
aiohttp.ClientPayloadError,
aiohttp.client_exceptions.ClientPayloadError,
) as e:
# Handle incomplete transfers more gracefully
# Log the error but don't re-raise if we've already yielded some data
verbose_logger.debug(f"Transfer incomplete, but continuing: {e}")
# If the error is due to incomplete transfer encoding, we can still
# return what we've received so far, similar to how httpx handles it
return
except aiohttp.http_exceptions.TransferEncodingError as e:
# Handle transfer encoding errors gracefully
verbose_logger.debug(f"Transfer encoding error, but continuing: {e}")
return
except Exception:
# For other exceptions, use the normal mapping
with map_aiohttp_exceptions():
raise
async def aclose(self) -> None:
with map_aiohttp_exceptions():
await self._aiohttp_response.__aexit__(None, None, None)
class AiohttpTransport(httpx.AsyncBaseTransport):
def __init__(
self, client: Union[ClientSession, Callable[[], ClientSession]]
) -> None:
self.client = client
async def aclose(self) -> None:
if isinstance(self.client, ClientSession):
await self.client.close()
class LiteLLMAiohttpTransport(AiohttpTransport):
"""
LiteLLM wrapper around AiohttpTransport to handle %-encodings in URLs
and event loop lifecycle issues in CI/CD environments
Credit to: https://github.com/karpetrosyan/httpx-aiohttp for this implementation
"""
def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]):
self.client = client
super().__init__(client=client)
# Store the client factory for recreating sessions when needed
if callable(client):
self._client_factory = client
def _get_valid_client_session(self) -> ClientSession:
"""
Helper to get a valid ClientSession for the current event loop.
This handles the case where the session was created in a different
event loop that may have been closed (common in CI/CD environments).
"""
from aiohttp.client import ClientSession
# If we don't have a client or it's not a ClientSession, create one
if not isinstance(self.client, ClientSession):
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
return self.client
# Check if the existing session is still valid for the current event loop
try:
session_loop = getattr(self.client, "_loop", None)
current_loop = asyncio.get_running_loop()
# If session is from a different or closed loop, recreate it
if (
session_loop is None
or session_loop != current_loop
or session_loop.is_closed()
):
# Clean up the old session
try:
# Note: not awaiting close() here as it might be from a different loop
# The session will be garbage collected
pass
except Exception as e:
verbose_logger.debug(f"Error closing old session: {e}")
pass
# Create a new session in the current event loop
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
except (RuntimeError, AttributeError):
# If we can't check the loop or session is invalid, recreate it
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
return self.client
async def handle_async_request(
self,
request: httpx.Request,
) -> httpx.Response:
from aiohttp import ClientTimeout
from yarl import URL as YarlURL
timeout = request.extensions.get("timeout", {})
sni_hostname = request.extensions.get("sni_hostname")
# Use helper to ensure we have a valid session for the current event loop
client_session = self._get_valid_client_session()
with map_aiohttp_exceptions():
try:
data = request.content
except httpx.RequestNotRead:
data = request.stream # type: ignore
request.headers.pop("transfer-encoding", None) # handled by aiohttp
response = await client_session.request(
method=request.method,
url=YarlURL(str(request.url), encoded=True),
headers=request.headers,
data=data,
allow_redirects=False,
auto_decompress=False,
timeout=ClientTimeout(
sock_connect=timeout.get("connect"),
sock_read=timeout.get("read"),
connect=timeout.get("pool"),
),
server_hostname=sni_hostname,
).__aenter__()
return httpx.Response(
status_code=response.status,
headers=response.headers,
content=AiohttpResponseStream(response),
request=request,
)
|