DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
"""
PagerDuty Alerting Integration
Handles two types of alerts:
- High LLM API Failure Rate. Configure X fails in Y seconds to trigger an alert.
- High Number of Hanging LLM Requests. Configure X hangs in Y seconds to trigger an alert.
Note: This is a Free feature on the regular litellm docker image.
However, this is under the enterprise license
"""
import asyncio
import os
from datetime import datetime, timedelta, timezone
from typing import List, Literal, Optional, Union
from litellm._logging import verbose_logger
from litellm.caching import DualCache
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.integrations.pagerduty import (
AlertingConfig,
PagerDutyInternalEvent,
PagerDutyPayload,
PagerDutyRequestBody,
)
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingPayloadErrorInformation,
)
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD = 60
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS = 60
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS = 60
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS = 600
class PagerDutyAlerting(SlackAlerting):
"""
Tracks failed requests and hanging requests separately.
If threshold is crossed for either type, triggers a PagerDuty alert.
"""
def __init__(
self, alerting_args: Optional[Union[AlertingConfig, dict]] = None, **kwargs
):
super().__init__()
_api_key = os.getenv("PAGERDUTY_API_KEY")
if not _api_key:
raise ValueError("PAGERDUTY_API_KEY is not set")
self.api_key: str = _api_key
alerting_args = alerting_args or {}
self.pagerduty_alerting_args: AlertingConfig = AlertingConfig(
failure_threshold=alerting_args.get(
"failure_threshold", PAGERDUTY_DEFAULT_FAILURE_THRESHOLD
),
failure_threshold_window_seconds=alerting_args.get(
"failure_threshold_window_seconds",
PAGERDUTY_DEFAULT_FAILURE_THRESHOLD_WINDOW_SECONDS,
),
hanging_threshold_seconds=alerting_args.get(
"hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
),
hanging_threshold_window_seconds=alerting_args.get(
"hanging_threshold_window_seconds",
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS,
),
)
# Separate storage for failures vs. hangs
self._failure_events: List[PagerDutyInternalEvent] = []
self._hanging_events: List[PagerDutyInternalEvent] = []
# ------------------ MAIN LOGIC ------------------ #
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
"""
Record a failure event. Only send an alert to PagerDuty if the
configured *failure* threshold is exceeded in the specified window.
"""
now = datetime.now(timezone.utc)
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
if not standard_logging_payload:
raise ValueError(
"standard_logging_object is required for PagerDutyAlerting"
)
# Extract error details
error_info: Optional[StandardLoggingPayloadErrorInformation] = (
standard_logging_payload.get("error_information") or {}
)
_meta = standard_logging_payload.get("metadata") or {}
self._failure_events.append(
PagerDutyInternalEvent(
failure_event_type="failed_response",
timestamp=now,
error_class=error_info.get("error_class"),
error_code=error_info.get("error_code"),
error_llm_provider=error_info.get("llm_provider"),
user_api_key_hash=_meta.get("user_api_key_hash"),
user_api_key_alias=_meta.get("user_api_key_alias"),
user_api_key_org_id=_meta.get("user_api_key_org_id"),
user_api_key_team_id=_meta.get("user_api_key_team_id"),
user_api_key_user_id=_meta.get("user_api_key_user_id"),
user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
user_api_key_user_email=_meta.get("user_api_key_user_email"),
)
)
# Prune + Possibly alert
window_seconds = self.pagerduty_alerting_args.get(
"failure_threshold_window_seconds", 60
)
threshold = self.pagerduty_alerting_args.get("failure_threshold", 1)
# If threshold is crossed, send PD alert for failures
await self._send_alert_if_thresholds_crossed(
events=self._failure_events,
window_seconds=window_seconds,
threshold=threshold,
alert_prefix="High LLM API Failure Rate",
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Optional[Union[Exception, str, dict]]:
"""
Example of detecting hanging requests by waiting a given threshold.
If the request didn't finish by then, we treat it as 'hanging'.
"""
verbose_logger.info("Inside Proxy Logging Pre-call hook!")
asyncio.create_task(
self.hanging_response_handler(
request_data=data, user_api_key_dict=user_api_key_dict
)
)
return None
async def hanging_response_handler(
self, request_data: Optional[dict], user_api_key_dict: UserAPIKeyAuth
):
"""
Checks if request completed by the time 'hanging_threshold_seconds' elapses.
If not, we classify it as a hanging request.
"""
verbose_logger.debug(
f"Inside Hanging Response Handler!..sleeping for {self.pagerduty_alerting_args.get('hanging_threshold_seconds', PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS)} seconds"
)
await asyncio.sleep(
self.pagerduty_alerting_args.get(
"hanging_threshold_seconds", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
)
)
if await self._request_is_completed(request_data=request_data):
return # It's not hanging if completed
# Otherwise, record it as hanging
self._hanging_events.append(
PagerDutyInternalEvent(
failure_event_type="hanging_response",
timestamp=datetime.now(timezone.utc),
error_class="HangingRequest",
error_code="HangingRequest",
error_llm_provider="HangingRequest",
user_api_key_hash=user_api_key_dict.api_key,
user_api_key_alias=user_api_key_dict.key_alias,
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_team_id=user_api_key_dict.team_id,
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_user_email=user_api_key_dict.user_email,
)
)
# Prune + Possibly alert
window_seconds = self.pagerduty_alerting_args.get(
"hanging_threshold_window_seconds",
PAGERDUTY_DEFAULT_HANGING_THRESHOLD_WINDOW_SECONDS,
)
threshold: int = self.pagerduty_alerting_args.get(
"hanging_threshold_fails", PAGERDUTY_DEFAULT_HANGING_THRESHOLD_SECONDS
)
# If threshold is crossed, send PD alert for hangs
await self._send_alert_if_thresholds_crossed(
events=self._hanging_events,
window_seconds=window_seconds,
threshold=threshold,
alert_prefix="High Number of Hanging LLM Requests",
)
# ------------------ HELPERS ------------------ #
async def _send_alert_if_thresholds_crossed(
self,
events: List[PagerDutyInternalEvent],
window_seconds: int,
threshold: int,
alert_prefix: str,
):
"""
1. Prune old events
2. If threshold is reached, build alert, send to PagerDuty
3. Clear those events
"""
cutoff = datetime.now(timezone.utc) - timedelta(seconds=window_seconds)
pruned = [e for e in events if e.get("timestamp", datetime.min) > cutoff]
# Update the reference list
events.clear()
events.extend(pruned)
# Check threshold
verbose_logger.debug(
f"Have {len(events)} events in the last {window_seconds} seconds. Threshold is {threshold}"
)
if len(events) >= threshold:
# Build short summary of last N events
error_summaries = self._build_error_summaries(events, max_errors=5)
alert_message = (
f"{alert_prefix}: {len(events)} in the last {window_seconds} seconds."
)
custom_details = {"recent_errors": error_summaries}
await self.send_alert_to_pagerduty(
alert_message=alert_message,
custom_details=custom_details,
)
# Clear them after sending an alert, so we don't spam
events.clear()
def _build_error_summaries(
self, events: List[PagerDutyInternalEvent], max_errors: int = 5
) -> List[PagerDutyInternalEvent]:
"""
Build short text summaries for the last `max_errors`.
Example: "ValueError (code: 500, provider: openai)"
"""
recent = events[-max_errors:]
summaries = []
for fe in recent:
# If any of these is None, show "N/A" to avoid messing up the summary string
fe.pop("timestamp")
summaries.append(fe)
return summaries
async def send_alert_to_pagerduty(self, alert_message: str, custom_details: dict):
"""
Send [critical] Alert to PagerDuty
https://developer.pagerduty.com/api-reference/YXBpOjI3NDgyNjU-pager-duty-v2-events-api
"""
try:
verbose_logger.debug(f"Sending alert to PagerDuty: {alert_message}")
async_client: AsyncHTTPHandler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
payload: PagerDutyRequestBody = PagerDutyRequestBody(
payload=PagerDutyPayload(
summary=alert_message,
severity="critical",
source="LiteLLM Alert",
component="LiteLLM",
custom_details=custom_details,
),
routing_key=self.api_key,
event_action="trigger",
)
return await async_client.post(
url="https://events.pagerduty.com/v2/enqueue",
json=dict(payload),
headers={"Content-Type": "application/json"},
)
except Exception as e:
verbose_logger.exception(f"Error sending alert to PagerDuty: {e}")