Spaces:
Configuration error
Configuration error
File size: 6,931 Bytes
447ebeb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
# +-------------------------------------------------------------+
#
# Use Lasso Security Guardrails for your LLM calls
# https://www.lasso.security/
#
# +-------------------------------------------------------------+
import os
from typing import Any, Dict, List, Literal, Optional, Union
from fastapi import HTTPException
from litellm import DualCache
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._types import UserAPIKeyAuth
class LassoGuardrailMissingSecrets(Exception):
pass
class LassoGuardrailAPIError(Exception):
"""Exception raised when there's an error calling the Lasso API."""
pass
class LassoGuardrail(CustomGuardrail):
def __init__(
self,
lasso_api_key: Optional[str] = None,
api_base: Optional[str] = None,
user_id: Optional[str] = None,
conversation_id: Optional[str] = None,
**kwargs,
):
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
self.lasso_api_key = lasso_api_key or os.environ.get("LASSO_API_KEY")
self.user_id = user_id or os.environ.get("LASSO_USER_ID")
self.conversation_id = conversation_id or os.environ.get(
"LASSO_CONVERSATION_ID"
)
if self.lasso_api_key is None:
msg = (
"Couldn't get Lasso api key, either set the `LASSO_API_KEY` in the environment or "
"pass it as a parameter to the guardrail in the config file"
)
raise LassoGuardrailMissingSecrets(msg)
self.api_base = api_base or "https://server.lasso.security/gateway/v2/classify"
super().__init__(**kwargs)
@log_guardrail_information
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: Literal[
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"pass_through_endpoint",
"rerank",
],
) -> Union[Exception, str, dict, None]:
verbose_proxy_logger.debug("Inside Lasso Pre-Call Hook")
return await self.run_lasso_guardrail(data)
@log_guardrail_information
async def async_moderation_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
"responses",
],
):
"""
This is used for during_call moderation
"""
verbose_proxy_logger.debug("Inside Lasso Moderation Hook")
return await self.run_lasso_guardrail(data)
async def run_lasso_guardrail(
self,
data: dict,
):
"""
Run the Lasso guardrail
Raises:
LassoGuardrailAPIError: If the Lasso API call fails
"""
messages: List[Dict[str, str]] = data.get("messages", [])
# check if messages are present
if not messages:
return data
try:
headers = self._prepare_headers()
payload = self._prepare_payload(messages)
response = await self._call_lasso_api(
headers=headers,
payload=payload,
)
self._process_lasso_response(response)
return data
except Exception as e:
if isinstance(e, HTTPException):
raise e
verbose_proxy_logger.error(f"Error calling Lasso API: {str(e)}")
# Instead of allowing the request to proceed, raise an exception
raise LassoGuardrailAPIError(
f"Failed to verify request safety with Lasso API: {str(e)}"
)
def _prepare_headers(self) -> dict[str, str]:
"""Prepare headers for the Lasso API request."""
if not self.lasso_api_key:
msg = (
"Couldn't get Lasso api key, either set the `LASSO_API_KEY` in the environment or "
"pass it as a parameter to the guardrail in the config file"
)
raise LassoGuardrailMissingSecrets(msg)
headers: dict[str, str] = {
"lasso-api-key": self.lasso_api_key,
"Content-Type": "application/json",
}
# Add optional headers if provided
if self.user_id:
headers["lasso-user-id"] = self.user_id
if self.conversation_id:
headers["lasso-conversation-id"] = self.conversation_id
return headers
def _prepare_payload(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
"""Prepare the payload for the Lasso API request."""
return {"messages": messages}
async def _call_lasso_api(
self, headers: Dict[str, str], payload: Dict[str, Any]
) -> Dict[str, Any]:
"""Call the Lasso API and return the response."""
verbose_proxy_logger.debug(f"Sending request to Lasso API: {payload}")
response = await self.async_handler.post(
url=self.api_base,
headers=headers,
json=payload,
timeout=10.0,
)
response.raise_for_status()
res = response.json()
verbose_proxy_logger.debug(f"Lasso API response: {res}")
return res
def _process_lasso_response(self, response: Dict[str, Any]) -> None:
"""Process the Lasso API response and raise exceptions if violations are detected."""
if response and response.get("violations_detected") is True:
violated_deputies = self._parse_violated_deputies(response)
verbose_proxy_logger.warning(
f"Lasso guardrail detected violations: {violated_deputies}"
)
raise HTTPException(
status_code=400,
detail={
"error": "Violated Lasso guardrail policy",
"detection_message": f"Guardrail violations detected: {', '.join(violated_deputies)}",
"lasso_response": response,
},
)
def _parse_violated_deputies(self, response: Dict[str, Any]) -> List[str]:
"""Parse the response to extract violated deputies."""
violated_deputies = []
if "deputies" in response:
for deputy, is_violated in response["deputies"].items():
if is_violated:
violated_deputies.append(deputy)
return violated_deputies
|