test3 / litellm /llms /custom_httpx /aiohttp_transport.py
DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
import asyncio
import contextlib
import typing
from typing import Callable, Dict, Union
import aiohttp
import aiohttp.client_exceptions
import aiohttp.http_exceptions
import httpx
from aiohttp.client import ClientResponse, ClientSession
from litellm._logging import verbose_logger
AIOHTTP_EXC_MAP: Dict = {
# Order matters here, most specific exception first
# Timeout related exceptions
aiohttp.ServerTimeoutError: httpx.TimeoutException,
aiohttp.ConnectionTimeoutError: httpx.ConnectTimeout,
aiohttp.SocketTimeoutError: httpx.ReadTimeout,
# Proxy related exceptions
aiohttp.ClientProxyConnectionError: httpx.ProxyError,
# SSL related exceptions
aiohttp.ClientConnectorCertificateError: httpx.ProtocolError,
aiohttp.ClientSSLError: httpx.ProtocolError,
aiohttp.ServerFingerprintMismatch: httpx.ProtocolError,
# Network related exceptions
aiohttp.ClientConnectorError: httpx.ConnectError,
aiohttp.ClientOSError: httpx.ConnectError,
aiohttp.ClientPayloadError: httpx.ReadError,
# Connection disconnection exceptions
aiohttp.ServerDisconnectedError: httpx.ReadError,
# Response related exceptions
aiohttp.ClientConnectionError: httpx.NetworkError,
aiohttp.ClientPayloadError: httpx.ReadError,
aiohttp.ContentTypeError: httpx.ReadError,
aiohttp.TooManyRedirects: httpx.TooManyRedirects,
# URL related exceptions
aiohttp.InvalidURL: httpx.InvalidURL,
# Base exceptions
aiohttp.ClientError: httpx.RequestError,
}
# Add client_exceptions module exceptions
try:
import aiohttp.client_exceptions
AIOHTTP_EXC_MAP[aiohttp.client_exceptions.ClientPayloadError] = httpx.ReadError
except ImportError:
pass
@contextlib.contextmanager
def map_aiohttp_exceptions() -> typing.Iterator[None]:
try:
yield
except Exception as exc:
mapped_exc = None
for from_exc, to_exc in AIOHTTP_EXC_MAP.items():
if not isinstance(exc, from_exc): # type: ignore
continue
if mapped_exc is None or issubclass(to_exc, mapped_exc):
mapped_exc = to_exc
if mapped_exc is None: # pragma: no cover
raise
message = str(exc)
raise mapped_exc(message) from exc
class AiohttpResponseStream(httpx.AsyncByteStream):
CHUNK_SIZE = 1024 * 16
def __init__(self, aiohttp_response: ClientResponse) -> None:
self._aiohttp_response = aiohttp_response
async def __aiter__(self) -> typing.AsyncIterator[bytes]:
try:
async for chunk in self._aiohttp_response.content.iter_chunked(
self.CHUNK_SIZE
):
yield chunk
except (
aiohttp.ClientPayloadError,
aiohttp.client_exceptions.ClientPayloadError,
) as e:
# Handle incomplete transfers more gracefully
# Log the error but don't re-raise if we've already yielded some data
verbose_logger.debug(f"Transfer incomplete, but continuing: {e}")
# If the error is due to incomplete transfer encoding, we can still
# return what we've received so far, similar to how httpx handles it
return
except aiohttp.http_exceptions.TransferEncodingError as e:
# Handle transfer encoding errors gracefully
verbose_logger.debug(f"Transfer encoding error, but continuing: {e}")
return
except Exception:
# For other exceptions, use the normal mapping
with map_aiohttp_exceptions():
raise
async def aclose(self) -> None:
with map_aiohttp_exceptions():
await self._aiohttp_response.__aexit__(None, None, None)
class AiohttpTransport(httpx.AsyncBaseTransport):
def __init__(
self, client: Union[ClientSession, Callable[[], ClientSession]]
) -> None:
self.client = client
async def aclose(self) -> None:
if isinstance(self.client, ClientSession):
await self.client.close()
class LiteLLMAiohttpTransport(AiohttpTransport):
"""
LiteLLM wrapper around AiohttpTransport to handle %-encodings in URLs
and event loop lifecycle issues in CI/CD environments
Credit to: https://github.com/karpetrosyan/httpx-aiohttp for this implementation
"""
def __init__(self, client: Union[ClientSession, Callable[[], ClientSession]]):
self.client = client
super().__init__(client=client)
# Store the client factory for recreating sessions when needed
if callable(client):
self._client_factory = client
def _get_valid_client_session(self) -> ClientSession:
"""
Helper to get a valid ClientSession for the current event loop.
This handles the case where the session was created in a different
event loop that may have been closed (common in CI/CD environments).
"""
from aiohttp.client import ClientSession
# If we don't have a client or it's not a ClientSession, create one
if not isinstance(self.client, ClientSession):
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
return self.client
# Check if the existing session is still valid for the current event loop
try:
session_loop = getattr(self.client, "_loop", None)
current_loop = asyncio.get_running_loop()
# If session is from a different or closed loop, recreate it
if (
session_loop is None
or session_loop != current_loop
or session_loop.is_closed()
):
# Clean up the old session
try:
# Note: not awaiting close() here as it might be from a different loop
# The session will be garbage collected
pass
except Exception as e:
verbose_logger.debug(f"Error closing old session: {e}")
pass
# Create a new session in the current event loop
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
except (RuntimeError, AttributeError):
# If we can't check the loop or session is invalid, recreate it
if hasattr(self, "_client_factory") and callable(self._client_factory):
self.client = self._client_factory()
else:
self.client = ClientSession()
return self.client
async def handle_async_request(
self,
request: httpx.Request,
) -> httpx.Response:
from aiohttp import ClientTimeout
from yarl import URL as YarlURL
timeout = request.extensions.get("timeout", {})
sni_hostname = request.extensions.get("sni_hostname")
# Use helper to ensure we have a valid session for the current event loop
client_session = self._get_valid_client_session()
with map_aiohttp_exceptions():
try:
data = request.content
except httpx.RequestNotRead:
data = request.stream # type: ignore
request.headers.pop("transfer-encoding", None) # handled by aiohttp
response = await client_session.request(
method=request.method,
url=YarlURL(str(request.url), encoded=True),
headers=request.headers,
data=data,
allow_redirects=False,
auto_decompress=False,
timeout=ClientTimeout(
sock_connect=timeout.get("connect"),
sock_read=timeout.get("read"),
connect=timeout.get("pool"),
),
server_hostname=sni_hostname,
).__aenter__()
return httpx.Response(
status_code=response.status,
headers=response.headers,
content=AiohttpResponseStream(response),
request=request,
)