import copy import os from datetime import datetime from typing import Dict, List, Literal, Optional, Tuple, Union import litellm from litellm._logging import verbose_proxy_logger from litellm.exceptions import GuardrailRaisedException from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, httpxSpecialProvider, ) from litellm.proxy._types import UserAPIKeyAuth from litellm.secret_managers.main import get_secret_str from litellm.types.guardrails import GuardrailEventHooks from litellm.types.llms.openai import AllMessageValues from litellm.types.proxy.guardrails.guardrail_hooks.lakera_ai_v2 import ( LakeraAIRequest, LakeraAIResponse, ) class LakeraAIGuardrail(CustomGuardrail): def __init__( self, api_key: Optional[str] = None, api_base: Optional[str] = None, project_id: Optional[str] = None, payload: Optional[bool] = True, breakdown: Optional[bool] = True, metadata: Optional[Dict] = None, dev_info: Optional[bool] = True, **kwargs, ): """ Initialize the LakeraAIGuardrail class. This calls: https://api.lakera.ai/v2/guard Args: api_key: Optional[str] = None, api_base: Optional[str] = None, project_id: Optional[str] = None, payload: Optional[bool] = True, breakdown: Optional[bool] = True, metadata: Optional[Dict] = None, dev_info: Optional[bool] = True, """ self.async_handler = get_async_httpx_client( llm_provider=httpxSpecialProvider.GuardrailCallback ) self.lakera_api_key = api_key or os.environ["LAKERA_API_KEY"] self.project_id = project_id self.api_base = ( api_base or get_secret_str("LAKERA_API_BASE") or "https://api.lakera.ai" ) self.payload: Optional[bool] = payload self.breakdown: Optional[bool] = breakdown self.metadata: Optional[Dict] = metadata self.dev_info: Optional[bool] = dev_info super().__init__(**kwargs) async def call_v2_guard( self, messages: List[AllMessageValues], request_data: Dict, ) -> Tuple[LakeraAIResponse, Dict]: """ Call the Lakera AI v2 guard API. """ status: Literal["success", "failure"] = "success" exception_str: str = "" start_time: datetime = datetime.now() lakera_response: Optional[LakeraAIResponse] = None request: Dict = {} masked_entity_count: Dict = {} try: request = dict( LakeraAIRequest( messages=messages, project_id=self.project_id, payload=self.payload, breakdown=self.breakdown, metadata=self.metadata, dev_info=self.dev_info, ) ) verbose_proxy_logger.debug("Lakera AI v2 guard request: %s", request) response = await self.async_handler.post( url=f"{self.api_base}/v2/guard", headers={"Authorization": f"Bearer {self.lakera_api_key}"}, json=request, ) verbose_proxy_logger.debug( "Lakera AI v2 guard response: %s", response.json() ) lakera_response = LakeraAIResponse(**response.json()) return lakera_response, masked_entity_count except Exception as e: status = "failure" exception_str = str(e) raise e finally: #################################################### # Create Guardrail Trace for logging on Langfuse, Datadog, etc. #################################################### guardrail_json_response: Union[Exception, str, dict, List[dict]] = {} if status == "success": copy_lakera_response_dict = ( dict(copy.deepcopy(lakera_response)) if lakera_response else {} ) # payload contains PII, we don't want to log it copy_lakera_response_dict.pop("payload") guardrail_json_response = copy_lakera_response_dict else: guardrail_json_response = exception_str self.add_standard_logging_guardrail_information_to_request_data( guardrail_json_response=guardrail_json_response, guardrail_status=status, request_data=request_data, start_time=start_time.timestamp(), end_time=datetime.now().timestamp(), duration=(datetime.now() - start_time).total_seconds(), masked_entity_count=masked_entity_count, ) def _mask_pii_in_messages( self, messages: List[AllMessageValues], lakera_response: Optional[LakeraAIResponse], masked_entity_count: Dict, ) -> List[AllMessageValues]: """ Return a copy of messages with any detected PII replaced by “[MASKED ]” tokens. """ payload = lakera_response.get("payload") if lakera_response else None if not payload: return messages # For each message, find its detections on the fly for idx, msg in enumerate(messages): content = msg.get("content", "") if not content: continue # For v1, we only support masking content strings if not isinstance(content, str): continue # Filter only detections for this message detected_modifications = [d for d in payload if d.get("message_id") == idx] if not detected_modifications: continue for modification in detected_modifications: start, end = modification.get("start", 0), modification.get("end", 0) # Extract the type (e.g. 'credit_card' → 'CREDIT_CARD') detector_type = modification.get("detector_type", "") if not detector_type: continue typ = detector_type.split("/")[-1].upper() or "PII" mask = f"[MASKED {typ}]" if start is not None and end is not None: content = self.mask_content_in_string( content_string=content, mask_string=mask, start_index=start, end_index=end, ) masked_entity_count[typ] = masked_entity_count.get(typ, 0) + 1 msg["content"] = content return messages async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: litellm.DualCache, data: Dict, call_type: Literal[ "completion", "text_completion", "embeddings", "image_generation", "moderation", "audio_transcription", "pass_through_endpoint", "rerank", ], ) -> Optional[Union[Exception, str, Dict]]: from litellm.proxy.common_utils.callback_utils import ( add_guardrail_to_applied_guardrails_header, ) verbose_proxy_logger.debug("Lakera AI: pre_call_hook") event_type: GuardrailEventHooks = GuardrailEventHooks.pre_call if self.should_run_guardrail(data=data, event_type=event_type) is not True: verbose_proxy_logger.debug( "Lakera AI: not running guardrail. Guardrail is disabled." ) return data new_messages: Optional[List[AllMessageValues]] = data.get("messages") if new_messages is None: verbose_proxy_logger.warning( "Lakera AI: not running guardrail. No messages in data" ) return data ######################################################### ########## 1. Make the Lakera AI v2 guard API request ########## ######################################################### lakera_guardrail_response, masked_entity_count = await self.call_v2_guard( messages=new_messages, request_data=data, ) ######################################################### ########## 2. Handle flagged content ########## ######################################################### if lakera_guardrail_response.get("flagged") is True: # If only PII violations exist, mask the PII if self._is_only_pii_violation(lakera_guardrail_response): data["messages"] = self._mask_pii_in_messages( messages=new_messages, lakera_response=lakera_guardrail_response, masked_entity_count=masked_entity_count, ) verbose_proxy_logger.info( "Lakera AI: Masked PII in messages instead of blocking request" ) else: # If there are other violations or not set to mask PII, raise exception raise GuardrailRaisedException( guardrail_name=self.guardrail_name, message="Lakera AI flagged this request. Please review the request and try again.", ) ######################################################### ########## 3. Add the guardrail to the applied guardrails header ########## ######################################################### add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) return data async def async_moderation_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, call_type: Literal[ "completion", "embeddings", "image_generation", "moderation", "audio_transcription", "responses", ], ): from litellm.proxy.common_utils.callback_utils import ( add_guardrail_to_applied_guardrails_header, ) event_type: GuardrailEventHooks = GuardrailEventHooks.during_call if self.should_run_guardrail(data=data, event_type=event_type) is not True: return new_messages: Optional[List[AllMessageValues]] = data.get("messages") if new_messages is None: verbose_proxy_logger.warning( "Lakera AI: not running guardrail. No messages in data" ) return ######################################################### ########## 1. Make the Lakera AI v2 guard API request ########## ######################################################### lakera_guardrail_response, masked_entity_count = await self.call_v2_guard( messages=new_messages, request_data=data, ) ######################################################### ########## 2. Handle flagged content ########## ######################################################### if lakera_guardrail_response.get("flagged") is True: # If only PII violations exist, mask the PII if self._is_only_pii_violation(lakera_guardrail_response): data["messages"] = self._mask_pii_in_messages( messages=new_messages, lakera_response=lakera_guardrail_response, masked_entity_count=masked_entity_count, ) verbose_proxy_logger.info( "Lakera AI: Masked PII in messages instead of blocking request" ) else: # If there are other violations or not set to mask PII, raise exception raise GuardrailRaisedException( guardrail_name=self.guardrail_name, message="Lakera AI flagged this request. Please review the request and try again.", ) ######################################################### ########## 3. Add the guardrail to the applied guardrails header ########## ######################################################### add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) return data def _is_only_pii_violation( self, lakera_response: Optional[LakeraAIResponse] ) -> bool: """ Returns True if there are only PII violations in the response. """ if not lakera_response: return False for item in lakera_response.get("payload", []) or []: detector_type = item.get("detector_type", "") or "" if not detector_type.startswith("pii/"): return False return True