Spaces:
Configuration error
Configuration error
# +-------------------------------------------------------------+ | |
# | |
# Use Lasso Security Guardrails for your LLM calls | |
# https://www.lasso.security/ | |
# | |
# +-------------------------------------------------------------+ | |
import os | |
from typing import Any, Dict, List, Literal, Optional, Union | |
from fastapi import HTTPException | |
from litellm import DualCache | |
from litellm._logging import verbose_proxy_logger | |
from litellm.integrations.custom_guardrail import ( | |
CustomGuardrail, | |
log_guardrail_information, | |
) | |
from litellm.llms.custom_httpx.http_handler import ( | |
get_async_httpx_client, | |
httpxSpecialProvider, | |
) | |
from litellm.proxy._types import UserAPIKeyAuth | |
class LassoGuardrailMissingSecrets(Exception): | |
pass | |
class LassoGuardrailAPIError(Exception): | |
"""Exception raised when there's an error calling the Lasso API.""" | |
pass | |
class LassoGuardrail(CustomGuardrail): | |
def __init__( | |
self, | |
lasso_api_key: Optional[str] = None, | |
api_base: Optional[str] = None, | |
user_id: Optional[str] = None, | |
conversation_id: Optional[str] = None, | |
**kwargs, | |
): | |
self.async_handler = get_async_httpx_client( | |
llm_provider=httpxSpecialProvider.GuardrailCallback | |
) | |
self.lasso_api_key = lasso_api_key or os.environ.get("LASSO_API_KEY") | |
self.user_id = user_id or os.environ.get("LASSO_USER_ID") | |
self.conversation_id = conversation_id or os.environ.get( | |
"LASSO_CONVERSATION_ID" | |
) | |
if self.lasso_api_key is None: | |
msg = ( | |
"Couldn't get Lasso api key, either set the `LASSO_API_KEY` in the environment or " | |
"pass it as a parameter to the guardrail in the config file" | |
) | |
raise LassoGuardrailMissingSecrets(msg) | |
self.api_base = api_base or "https://server.lasso.security/gateway/v2/classify" | |
super().__init__(**kwargs) | |
async def async_pre_call_hook( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
cache: DualCache, | |
data: dict, | |
call_type: Literal[ | |
"completion", | |
"text_completion", | |
"embeddings", | |
"image_generation", | |
"moderation", | |
"audio_transcription", | |
"pass_through_endpoint", | |
"rerank", | |
], | |
) -> Union[Exception, str, dict, None]: | |
verbose_proxy_logger.debug("Inside Lasso Pre-Call Hook") | |
return await self.run_lasso_guardrail(data) | |
async def async_moderation_hook( | |
self, | |
data: dict, | |
user_api_key_dict: UserAPIKeyAuth, | |
call_type: Literal[ | |
"completion", | |
"embeddings", | |
"image_generation", | |
"moderation", | |
"audio_transcription", | |
"responses", | |
], | |
): | |
""" | |
This is used for during_call moderation | |
""" | |
verbose_proxy_logger.debug("Inside Lasso Moderation Hook") | |
return await self.run_lasso_guardrail(data) | |
async def run_lasso_guardrail( | |
self, | |
data: dict, | |
): | |
""" | |
Run the Lasso guardrail | |
Raises: | |
LassoGuardrailAPIError: If the Lasso API call fails | |
""" | |
messages: List[Dict[str, str]] = data.get("messages", []) | |
# check if messages are present | |
if not messages: | |
return data | |
try: | |
headers = self._prepare_headers() | |
payload = self._prepare_payload(messages) | |
response = await self._call_lasso_api( | |
headers=headers, | |
payload=payload, | |
) | |
self._process_lasso_response(response) | |
return data | |
except Exception as e: | |
if isinstance(e, HTTPException): | |
raise e | |
verbose_proxy_logger.error(f"Error calling Lasso API: {str(e)}") | |
# Instead of allowing the request to proceed, raise an exception | |
raise LassoGuardrailAPIError( | |
f"Failed to verify request safety with Lasso API: {str(e)}" | |
) | |
def _prepare_headers(self) -> dict[str, str]: | |
"""Prepare headers for the Lasso API request.""" | |
if not self.lasso_api_key: | |
msg = ( | |
"Couldn't get Lasso api key, either set the `LASSO_API_KEY` in the environment or " | |
"pass it as a parameter to the guardrail in the config file" | |
) | |
raise LassoGuardrailMissingSecrets(msg) | |
headers: dict[str, str] = { | |
"lasso-api-key": self.lasso_api_key, | |
"Content-Type": "application/json", | |
} | |
# Add optional headers if provided | |
if self.user_id: | |
headers["lasso-user-id"] = self.user_id | |
if self.conversation_id: | |
headers["lasso-conversation-id"] = self.conversation_id | |
return headers | |
def _prepare_payload(self, messages: List[Dict[str, str]]) -> Dict[str, Any]: | |
"""Prepare the payload for the Lasso API request.""" | |
return {"messages": messages} | |
async def _call_lasso_api( | |
self, headers: Dict[str, str], payload: Dict[str, Any] | |
) -> Dict[str, Any]: | |
"""Call the Lasso API and return the response.""" | |
verbose_proxy_logger.debug(f"Sending request to Lasso API: {payload}") | |
response = await self.async_handler.post( | |
url=self.api_base, | |
headers=headers, | |
json=payload, | |
timeout=10.0, | |
) | |
response.raise_for_status() | |
res = response.json() | |
verbose_proxy_logger.debug(f"Lasso API response: {res}") | |
return res | |
def _process_lasso_response(self, response: Dict[str, Any]) -> None: | |
"""Process the Lasso API response and raise exceptions if violations are detected.""" | |
if response and response.get("violations_detected") is True: | |
violated_deputies = self._parse_violated_deputies(response) | |
verbose_proxy_logger.warning( | |
f"Lasso guardrail detected violations: {violated_deputies}" | |
) | |
raise HTTPException( | |
status_code=400, | |
detail={ | |
"error": "Violated Lasso guardrail policy", | |
"detection_message": f"Guardrail violations detected: {', '.join(violated_deputies)}", | |
"lasso_response": response, | |
}, | |
) | |
def _parse_violated_deputies(self, response: Dict[str, Any]) -> List[str]: | |
"""Parse the response to extract violated deputies.""" | |
violated_deputies = [] | |
if "deputies" in response: | |
for deputy, is_violated in response["deputies"].items(): | |
if is_violated: | |
violated_deputies.append(deputy) | |
return violated_deputies | |