DesertWolf's picture
Upload folder using huggingface_hub
447ebeb verified
# +-------------------------------------------------------------+
#
# Use Lasso Security Guardrails for your LLM calls
# https://www.lasso.security/
#
# +-------------------------------------------------------------+
import os
from typing import Any, Dict, List, Literal, Optional, Union
from fastapi import HTTPException
from litellm import DualCache
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
class LassoGuardrailMissingSecrets(Exception):
pass
class LassoGuardrailAPIError(Exception):
"""Exception raised when there's an error calling the Lasso API."""
pass
class LassoGuardrail(CustomGuardrail):
def __init__(
self,
lasso_api_key: Optional[str] = None,
api_base: Optional[str] = None,
user_id: Optional[str] = None,
conversation_id: Optional[str] = None,
**kwargs,
):
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.lasso_api_key = lasso_api_key or os.environ.get("LASSO_API_KEY")
self.user_id = user_id or os.environ.get("LASSO_USER_ID")
self.conversation_id = conversation_id or os.environ.get(
"LASSO_CONVERSATION_ID"
)
if self.lasso_api_key is None:
msg = (
"Couldn't get Lasso api key, either set the `LASSO_API_KEY` in the environment or "
"pass it as a parameter to the guardrail in the config file"
)
raise LassoGuardrailMissingSecrets(msg)
self.api_base = api_base or "https://server.lasso.security/gateway/v2/classify"
super().__init__(**kwargs)
@log_guardrail_information
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside Lasso Pre-Call Hook")
return await self.run_lasso_guardrail(data)
@log_guardrail_information
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
"""
This is used for during_call moderation
"""
verbose_proxy_logger.debug("Inside Lasso Moderation Hook")
return await self.run_lasso_guardrail(data)
async def run_lasso_guardrail(
self,
data: dict,
):
"""
Run the Lasso guardrail
Raises:
LassoGuardrailAPIError: If the Lasso API call fails
"""
messages: List[Dict[str, str]] = data.get("messages", [])
# check if messages are present
if not messages:
return data
try:
headers = self._prepare_headers()
payload = self._prepare_payload(messages)
response = await self._call_lasso_api(
headers=headers,
payload=payload,
)
self._process_lasso_response(response)
return data
except Exception as e:
if isinstance(e, HTTPException):
raise e
verbose_proxy_logger.error(f"Error calling Lasso API: {str(e)}")
# Instead of allowing the request to proceed, raise an exception
raise LassoGuardrailAPIError(
f"Failed to verify request safety with Lasso API: {str(e)}"
)
def _prepare_headers(self) -> dict[str, str]:
"""Prepare headers for the Lasso API request."""
if not self.lasso_api_key:
msg = (
"Couldn't get Lasso api key, either set the `LASSO_API_KEY` in the environment or "
"pass it as a parameter to the guardrail in the config file"
)
raise LassoGuardrailMissingSecrets(msg)
headers: dict[str, str] = {
"lasso-api-key": self.lasso_api_key,
"Content-Type": "application/json",
}
# Add optional headers if provided
if self.user_id:
headers["lasso-user-id"] = self.user_id
if self.conversation_id:
headers["lasso-conversation-id"] = self.conversation_id
return headers
def _prepare_payload(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
"""Prepare the payload for the Lasso API request."""
return {"messages": messages}
async def _call_lasso_api(
self, headers: Dict[str, str], payload: Dict[str, Any]
) -> Dict[str, Any]:
"""Call the Lasso API and return the response."""
verbose_proxy_logger.debug(f"Sending request to Lasso API: {payload}")
response = await self.async_handler.post(
url=self.api_base,
headers=headers,
json=payload,
timeout=10.0,
)
response.raise_for_status()
res = response.json()
verbose_proxy_logger.debug(f"Lasso API response: {res}")
return res
def _process_lasso_response(self, response: Dict[str, Any]) -> None:
"""Process the Lasso API response and raise exceptions if violations are detected."""
if response and response.get("violations_detected") is True:
violated_deputies = self._parse_violated_deputies(response)
verbose_proxy_logger.warning(
f"Lasso guardrail detected violations: {violated_deputies}"
)
raise HTTPException(
status_code=400,
detail={
"error": "Violated Lasso guardrail policy",
"detection_message": f"Guardrail violations detected: {', '.join(violated_deputies)}",
"lasso_response": response,
},
)
def _parse_violated_deputies(self, response: Dict[str, Any]) -> List[str]:
"""Parse the response to extract violated deputies."""
violated_deputies = []
if "deputies" in response:
for deputy, is_violated in response["deputies"].items():
if is_violated:
violated_deputies.append(deputy)
return violated_deputies