Spaces:
Configuration error
Configuration error
# litellm/proxy/guardrails/guardrail_hooks/pangea.py | |
import os | |
from typing import Any, Optional, Protocol | |
from fastapi import HTTPException | |
from litellm._logging import verbose_proxy_logger | |
from litellm.caching.dual_cache import DualCache | |
from litellm.integrations.custom_guardrail import ( | |
CustomGuardrail, | |
log_guardrail_information, | |
) | |
from litellm.llms.custom_httpx.http_handler import ( | |
get_async_httpx_client, | |
httpxSpecialProvider, | |
) | |
from litellm.proxy._types import UserAPIKeyAuth | |
from litellm.proxy.common_utils.callback_utils import ( | |
add_guardrail_to_applied_guardrails_header, | |
) | |
from litellm.types.guardrails import GuardrailEventHooks | |
from litellm.types.utils import LLMResponseTypes, ModelResponse, TextCompletionResponse | |
class PangeaGuardrailMissingSecrets(Exception): | |
"""Custom exception for missing Pangea secrets.""" | |
pass | |
class _Transformer(Protocol): | |
def get_messages(self) -> list[dict]: | |
... | |
def update_original_body(self, prompt_messages: list[dict]) -> Any: | |
... | |
class _TextCompletionRequest: | |
def __init__(self, body): | |
self.body = body | |
def get_messages(self) -> list[dict]: | |
return [{"role": "user", "content": self.body["prompt"]}] | |
# This mutates the original dict, but we'll still return it anyways | |
def update_original_body(self, prompt_messages: list[dict]) -> Any: | |
assert len(prompt_messages) == 1 | |
self.body["prompt"] = prompt_messages[0]["content"] | |
return self.body | |
class _TextCompletionResponse: | |
def __init__(self, body): | |
self.body = body | |
def get_messages(self) -> list[dict]: | |
messages = [] | |
for choice in self.body["choices"]: | |
messages.append({"role": "assistant", "content": choice["text"]}) | |
return messages | |
def update_original_body(self, prompt_messages: list[dict]) -> Any: | |
assert len(prompt_messages) == len(self.body["choices"]) | |
for choice, prompt_message in zip(self.body["choices"], prompt_messages): | |
choice["text"] = prompt_message["content"] | |
return self.body | |
class _ChatCompletionRequest: | |
def __init__(self, body): | |
self.body = body | |
def get_messages(self) -> list[dict]: | |
messages = [] | |
for message in self.body["messages"]: | |
role = message["role"] | |
content = message["content"] | |
if isinstance(content, str): | |
messages.append({"role": role, "content": content}) | |
if isinstance(content, list): | |
for content_part in content: | |
if content_part["type"] == "text": | |
messages.append({"role": role, "content": content_part["text"]}) | |
return messages | |
def update_original_body(self, prompt_messages: list[dict]) -> Any: | |
count = 0 | |
for message in self.body["messages"]: | |
content = message["content"] | |
if isinstance(content, str): | |
message["content"] = prompt_messages[count]["content"] | |
count += 1 | |
if isinstance(content, list): | |
for content_part in content: | |
if content_part["type"] == "text": | |
content_part["text"] = prompt_messages[count]["content"] | |
count += 1 | |
assert len(prompt_messages) == count | |
return self.body | |
class _ChatCompletionResponse: | |
def __init__(self, body): | |
self.body = body | |
def get_messages(self) -> list[dict]: | |
messages = [] | |
for choice in self.body["choices"]: | |
messages.append( | |
{ | |
"role": choice["message"]["role"], | |
"content": choice["message"]["content"], | |
} | |
) | |
return messages | |
def update_original_body(self, prompt_messages: list[dict]) -> Any: | |
assert len(prompt_messages) == len(self.body["choices"]) | |
for choice, prompt_message in zip(self.body["choices"], prompt_messages): | |
choice["message"]["content"] = prompt_message["content"] | |
return self.body | |
def _get_transformer_for_request(body, call_type) -> Optional[_Transformer]: | |
match call_type: | |
case "text_completion" | "atext_completion": | |
return _TextCompletionRequest(body) | |
case "completion" | "acompletion": | |
return _ChatCompletionRequest(body) | |
return None | |
def _get_transformer_for_response(body) -> Optional[_Transformer]: | |
match body: | |
case TextCompletionResponse(): | |
return _TextCompletionResponse(body) | |
case ModelResponse(): | |
return _ChatCompletionResponse(body) | |
return None | |
class PangeaHandler(CustomGuardrail): | |
""" | |
Pangea AI Guardrail handler to interact with the Pangea AI Guard service. | |
This class implements the necessary hooks to call the Pangea AI Guard API | |
for input and output scanning based on the configured recipe. | |
""" | |
def __init__( | |
self, | |
guardrail_name: str, | |
pangea_input_recipe: Optional[str] = None, | |
pangea_output_recipe: Optional[str] = None, | |
api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
**kwargs, | |
): | |
""" | |
Initializes the PangeaHandler. | |
Args: | |
guardrail_name (str): The name of the guardrail instance. | |
pangea_recipe (str): The Pangea recipe key to use for scanning. | |
api_key (Optional[str]): The Pangea API key. Reads from PANGEA_API_KEY env var if None. | |
api_base (Optional[str]): The Pangea API base URL. Reads from PANGEA_API_BASE env var or uses default if None. | |
**kwargs: Additional arguments passed to the CustomGuardrail base class. | |
""" | |
self.async_handler = get_async_httpx_client( | |
llm_provider=httpxSpecialProvider.GuardrailCallback | |
) | |
self.api_key = api_key or os.environ.get("PANGEA_API_KEY") | |
if not self.api_key: | |
raise PangeaGuardrailMissingSecrets( | |
"Pangea API Key not found. Set PANGEA_API_KEY environment variable or pass it in litellm_params." | |
) | |
# Default Pangea base URL if not provided | |
self.api_base = ( | |
api_base | |
or os.environ.get("PANGEA_API_BASE") | |
or "https://ai-guard.aws.us.pangea.cloud" | |
) | |
self.pangea_input_recipe = pangea_input_recipe | |
self.pangea_output_recipe = pangea_output_recipe | |
self.guardrail_endpoint = f"{self.api_base}/v1/text/guard" | |
# Pass relevant kwargs to the parent class | |
super().__init__(guardrail_name=guardrail_name, **kwargs) | |
verbose_proxy_logger.info( | |
f"Initialized Pangea Guardrail: name={guardrail_name}, recipe={pangea_input_recipe}, api_base={self.api_base}" | |
) | |
async def _call_pangea_guard(self, payload: dict, hook_name: str) -> dict: | |
""" | |
Makes the API call to the Pangea AI Guard endpoint. | |
The function itself will raise an error in the case that a response | |
should be blocked, but will return a list of redacted messages that the caller | |
should act on. | |
Args: | |
payload (dict): The request payload. | |
request_data (dict): Original request data (used for logging/headers). | |
hook_name (str): Name of the hook calling this function (for logging). | |
Raises: | |
HTTPException: If the Pangea API returns a 'blocked: true' response. | |
Exception: For other API call failures. | |
Returns: | |
list[dict]: The original response body | |
""" | |
headers = { | |
"Authorization": f"Bearer {self.api_key}", | |
"Content-Type": "application/json", | |
} | |
try: | |
verbose_proxy_logger.debug( | |
f"Pangea Guardrail ({hook_name}): Calling endpoint {self.guardrail_endpoint} with payload: {payload}" | |
) | |
response = await self.async_handler.post( | |
url=self.guardrail_endpoint, json=payload, headers=headers | |
) | |
response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx) | |
result = response.json() | |
verbose_proxy_logger.debug( | |
f"Pangea Guardrail ({hook_name}): Received response: {result}" | |
) | |
# Check if the request was blocked | |
if result.get("result", {}).get("blocked") is True: | |
verbose_proxy_logger.warning( | |
f"Pangea Guardrail ({hook_name}): Request blocked. Response: {result}" | |
) | |
raise HTTPException( | |
status_code=400, # Bad Request, indicating violation | |
detail={ | |
"error": "Violated Pangea guardrail policy", | |
"guardrail_name": self.guardrail_name, | |
"pangea_response": result.get("result"), | |
}, | |
) | |
else: | |
verbose_proxy_logger.info( | |
f"Pangea Guardrail ({hook_name}): Request passed. Response: {result.get('result', {}).get('detectors')}" | |
) | |
return result | |
except HTTPException as e: | |
# Re-raise HTTPException if it's the one we raised for blocking | |
raise e | |
except Exception as e: | |
verbose_proxy_logger.error( | |
f"Pangea Guardrail ({hook_name}): Error calling API: {e}. Response text: {getattr(e, 'response', None) and getattr(e.response, 'text', None)}" # type: ignore | |
) | |
# Decide if you want to block by default on error, or allow through | |
# Raising an exception here will block the request. | |
# To allow through on error, you might just log and return. | |
raise HTTPException( | |
status_code=500, | |
detail={ | |
"error": "Error communicating with Pangea Guardrail", | |
"guardrail_name": self.guardrail_name, | |
"exception": str(e), | |
}, | |
) from e | |
async def async_pre_call_hook( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
cache: DualCache, | |
data: dict, | |
call_type: str, | |
): | |
event_type = GuardrailEventHooks.pre_call | |
if self.should_run_guardrail(data=data, event_type=event_type) is not True: | |
verbose_proxy_logger.debug( | |
f"Pangea Guardrail (async_pre_call_hook): Guardrail is disabled {self.guardrail_name}." | |
) | |
return data | |
transformer = _get_transformer_for_request(data, call_type) | |
if not transformer: | |
verbose_proxy_logger.warning( | |
f"Pangea Guardrail (async_pre_call_hook): Skipping guardrail {self.guardrail_name}" | |
f" because we cannot determine type of request: call_type '{call_type}'" | |
) | |
return | |
messages = transformer.get_messages() | |
if not messages: | |
verbose_proxy_logger.warning( | |
f"Pangea Guardrail (async_pre_call_hook): Skipping guardrail {self.guardrail_name}" | |
" because messages is empty." | |
) | |
return | |
ai_guard_payload = { | |
"debug": False, # Or make this configurable if needed | |
"messages": messages, | |
} | |
if self.pangea_input_recipe: | |
ai_guard_payload["recipe"] = self.pangea_input_recipe | |
ai_guard_response = await self._call_pangea_guard( | |
ai_guard_payload, "async_pre_call_hook" | |
) | |
# Add guardrail name to header if passed | |
add_guardrail_to_applied_guardrails_header( | |
request_data=data, guardrail_name=self.guardrail_name | |
) | |
prompt_messages = ai_guard_response.get("result", {}).get("prompt_messages", []) | |
try: | |
return transformer.update_original_body(prompt_messages) | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail={ | |
"error": "Failed to update original request body", | |
"guardrail_name": self.guardrail_name, | |
"exceptions": str(e), | |
}, | |
) from e | |
async def async_post_call_success_hook( | |
self, | |
data: dict, | |
user_api_key_dict: UserAPIKeyAuth, | |
# This union isn't actually correct -- it can get other response types depending on the API called | |
response: LLMResponseTypes, | |
): | |
""" | |
Guardrail hook run after a successful LLM call (scans output). | |
Args: | |
data (dict): The original request data. | |
user_api_key_dict (UserAPIKeyAuth): User API key details. | |
response (LLMResponseTypes): The response object from the LLM call. | |
""" | |
event_type = GuardrailEventHooks.post_call | |
if self.should_run_guardrail(data=data, event_type=event_type) is not True: | |
verbose_proxy_logger.debug( | |
f"Pangea Guardrail (async_pre_call_hook): Guardrail is disabled {self.guardrail_name}." | |
) | |
return data | |
transformer = _get_transformer_for_response(response) | |
if not transformer: | |
verbose_proxy_logger.warning( | |
f"Pangea Guardrail (async_post_call_success_hook): Skipping guardrail {self.guardrail_name}" | |
" because we cannot determine type of request" | |
) | |
return | |
messages = transformer.get_messages() | |
verbose_proxy_logger.warning(f"GOT MESSAGES: {messages}") | |
ai_guard_payload = { | |
"debug": False, # Or make this configurable if needed | |
"messages": messages, | |
} | |
if self.pangea_output_recipe: | |
ai_guard_payload["recipe"] = self.pangea_output_recipe | |
ai_guard_response = await self._call_pangea_guard( | |
ai_guard_payload, "post_call_success_hook" | |
) | |
prompt_messages = ai_guard_response.get("result", {}).get("prompt_messages", []) | |
try: | |
return transformer.update_original_body(prompt_messages) | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail={ | |
"error": "Failed to update original response body", | |
"guardrail_name": self.guardrail_name, | |
"exceptions": str(e), | |
}, | |
) from e | |