Spaces:
Configuration error
Configuration error
import asyncio | |
import contextlib | |
import typing | |
from typing import Callable, Dict, Union | |
import aiohttp | |
import aiohttp.client_exceptions | |
import aiohttp.http_exceptions | |
import httpx | |
from aiohttp.client import ClientResponse, ClientSession | |
from litellm._logging import verbose_logger | |
AIOHTTP_EXC_MAP: Dict = { | |
# Order matters here, most specific exception first | |
# Timeout related exceptions | |
aiohttp.ServerTimeoutError: httpx.TimeoutException, | |
aiohttp.ConnectionTimeoutError: httpx.ConnectTimeout, | |
aiohttp.SocketTimeoutError: httpx.ReadTimeout, | |
# Proxy related exceptions | |
aiohttp.ClientProxyConnectionError: httpx.ProxyError, | |
# SSL related exceptions | |
aiohttp.ClientConnectorCertificateError: httpx.ProtocolError, | |
aiohttp.ClientSSLError: httpx.ProtocolError, | |
aiohttp.ServerFingerprintMismatch: httpx.ProtocolError, | |
# Network related exceptions | |
aiohttp.ClientConnectorError: httpx.ConnectError, | |
aiohttp.ClientOSError: httpx.ConnectError, | |
aiohttp.ClientPayloadError: httpx.ReadError, | |
# Connection disconnection exceptions | |
aiohttp.ServerDisconnectedError: httpx.ReadError, | |
# Response related exceptions | |
aiohttp.ClientConnectionError: httpx.NetworkError, | |
aiohttp.ClientPayloadError: httpx.ReadError, | |
aiohttp.ContentTypeError: httpx.ReadError, | |
aiohttp.TooManyRedirects: httpx.TooManyRedirects, | |
# URL related exceptions | |
aiohttp.InvalidURL: httpx.InvalidURL, | |
# Base exceptions | |
aiohttp.ClientError: httpx.RequestError, | |
} | |
# Add client_exceptions module exceptions | |
try: | |
import aiohttp.client_exceptions | |
AIOHTTP_EXC_MAP[aiohttp.client_exceptions.ClientPayloadError] = httpx.ReadError | |
except ImportError: | |
pass | |
def map_aiohttp_exceptions() -> typing.Iterator[None]: | |
try: | |
yield | |
except Exception as exc: | |
mapped_exc = None | |
for from_exc, to_exc in AIOHTTP_EXC_MAP.items(): | |
if not isinstance(exc, from_exc): # type: ignore | |
continue | |
if mapped_exc is None or issubclass(to_exc, mapped_exc): | |
mapped_exc = to_exc | |
if mapped_exc is None: # pragma: no cover | |
raise | |
message = str(exc) | |
raise mapped_exc(message) from exc | |
class AiohttpResponseStream(httpx.AsyncByteStream): | |
CHUNK_SIZE = 1024 * 16 | |
def __init__(self, aiohttp_response: ClientResponse) -> None: | |
self._aiohttp_response = aiohttp_response | |
async def __aiter__(self) -> typing.AsyncIterator[bytes]: | |
try: | |
async for chunk in self._aiohttp_response.content.iter_chunked( | |
self.CHUNK_SIZE | |
): | |
yield chunk | |
except ( | |
aiohttp.ClientPayloadError, | |
aiohttp.client_exceptions.ClientPayloadError, | |
) as e: | |
# Handle incomplete transfers more gracefully | |
# Log the error but don't re-raise if we've already yielded some data | |
verbose_logger.debug(f"Transfer incomplete, but continuing: {e}") | |
# If the error is due to incomplete transfer encoding, we can still | |
# return what we've received so far, similar to how httpx handles it | |
return | |
except aiohttp.http_exceptions.TransferEncodingError as e: | |
# Handle transfer encoding errors gracefully | |
verbose_logger.debug(f"Transfer encoding error, but continuing: {e}") | |
return | |
except Exception: | |
# For other exceptions, use the normal mapping | |
with map_aiohttp_exceptions(): | |
raise | |
async def aclose(self) -> None: | |
with map_aiohttp_exceptions(): | |
await self._aiohttp_response.__aexit__(None, None, None) | |
class AiohttpTransport(httpx.AsyncBaseTransport): | |
def __init__( | |
self, client: Union[ClientSession, Callable[[], ClientSession]] | |
) -> None: | |
self.client = client | |
async def aclose(self) -> None: | |
if isinstance(self.client, ClientSession): | |
await self.client.close() | |
class LiteLLMAiohttpTransport(AiohttpTransport): | |
""" | |
LiteLLM wrapper around AiohttpTransport to handle %-encodings in URLs | |
and event loop lifecycle issues in CI/CD environments | |
Credit to: https://github.com/karpetrosyan/httpx-aiohttp for this implementation | |
""" | |
def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]): | |
self.client = client | |
super().__init__(client=client) | |
# Store the client factory for recreating sessions when needed | |
if callable(client): | |
self._client_factory = client | |
def _get_valid_client_session(self) -> ClientSession: | |
""" | |
Helper to get a valid ClientSession for the current event loop. | |
This handles the case where the session was created in a different | |
event loop that may have been closed (common in CI/CD environments). | |
""" | |
from aiohttp.client import ClientSession | |
# If we don't have a client or it's not a ClientSession, create one | |
if not isinstance(self.client, ClientSession): | |
if hasattr(self, "_client_factory") and callable(self._client_factory): | |
self.client = self._client_factory() | |
else: | |
self.client = ClientSession() | |
return self.client | |
# Check if the existing session is still valid for the current event loop | |
try: | |
session_loop = getattr(self.client, "_loop", None) | |
current_loop = asyncio.get_running_loop() | |
# If session is from a different or closed loop, recreate it | |
if ( | |
session_loop is None | |
or session_loop != current_loop | |
or session_loop.is_closed() | |
): | |
# Clean up the old session | |
try: | |
# Note: not awaiting close() here as it might be from a different loop | |
# The session will be garbage collected | |
pass | |
except Exception as e: | |
verbose_logger.debug(f"Error closing old session: {e}") | |
pass | |
# Create a new session in the current event loop | |
if hasattr(self, "_client_factory") and callable(self._client_factory): | |
self.client = self._client_factory() | |
else: | |
self.client = ClientSession() | |
except (RuntimeError, AttributeError): | |
# If we can't check the loop or session is invalid, recreate it | |
if hasattr(self, "_client_factory") and callable(self._client_factory): | |
self.client = self._client_factory() | |
else: | |
self.client = ClientSession() | |
return self.client | |
async def handle_async_request( | |
self, | |
request: httpx.Request, | |
) -> httpx.Response: | |
from aiohttp import ClientTimeout | |
from yarl import URL as YarlURL | |
timeout = request.extensions.get("timeout", {}) | |
sni_hostname = request.extensions.get("sni_hostname") | |
# Use helper to ensure we have a valid session for the current event loop | |
client_session = self._get_valid_client_session() | |
with map_aiohttp_exceptions(): | |
try: | |
data = request.content | |
except httpx.RequestNotRead: | |
data = request.stream # type: ignore | |
request.headers.pop("transfer-encoding", None) # handled by aiohttp | |
response = await client_session.request( | |
method=request.method, | |
url=YarlURL(str(request.url), encoded=True), | |
headers=request.headers, | |
data=data, | |
allow_redirects=False, | |
auto_decompress=False, | |
timeout=ClientTimeout( | |
sock_connect=timeout.get("connect"), | |
sock_read=timeout.get("read"), | |
connect=timeout.get("pool"), | |
), | |
server_hostname=sni_hostname, | |
).__aenter__() | |
return httpx.Response( | |
status_code=response.status, | |
headers=response.headers, | |
content=AiohttpResponseStream(response), | |
request=request, | |
) | |