# +-------------------------------------------------------------+ # # Use Bedrock Guardrails for your LLM calls # # +-------------------------------------------------------------+ # Thank you users! We ❤️ you! - Krrish & Ishaan import os import sys sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path import json import sys from typing import Any, AsyncGenerator, List, Literal, Optional, Tuple, Union from fastapi import HTTPException import litellm from litellm._logging import verbose_proxy_logger from litellm.caching import DualCache from litellm.integrations.custom_guardrail import ( CustomGuardrail, log_guardrail_information, ) from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, httpxSpecialProvider, ) from litellm.proxy._types import UserAPIKeyAuth from litellm.types.guardrails import GuardrailEventHooks from litellm.types.llms.openai import AllMessageValues from litellm.types.proxy.guardrails.guardrail_hooks.bedrock_guardrails import ( BedrockContentItem, BedrockGuardrailOutput, BedrockGuardrailResponse, BedrockRequest, BedrockTextContent, ) from litellm.types.utils import ModelResponse, ModelResponseStream GUARDRAIL_NAME = "bedrock" class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): def __init__( self, guardrailIdentifier: Optional[str] = None, guardrailVersion: Optional[str] = None, **kwargs, ): self.async_handler = get_async_httpx_client( llm_provider=httpxSpecialProvider.GuardrailCallback ) self.guardrailIdentifier = guardrailIdentifier self.guardrailVersion = guardrailVersion # store kwargs as optional_params self.optional_params = kwargs super().__init__(**kwargs) BaseAWSLLM.__init__(self) verbose_proxy_logger.debug( "Bedrock Guardrail initialized with guardrailIdentifier: %s, guardrailVersion: %s", self.guardrailIdentifier, self.guardrailVersion, ) def convert_to_bedrock_format( self, messages: Optional[List[AllMessageValues]] = None, response: Optional[Union[Any, ModelResponse]] = None, ) -> BedrockRequest: bedrock_request: BedrockRequest = BedrockRequest(source="INPUT") bedrock_request_content: List[BedrockContentItem] = [] if messages: for message in messages: message_text_content: Optional[ List[str] ] = self.get_content_for_message(message=message) if message_text_content is None: continue for text_content in message_text_content: bedrock_content_item = BedrockContentItem( text=BedrockTextContent(text=text_content) ) bedrock_request_content.append(bedrock_content_item) bedrock_request["content"] = bedrock_request_content if response: bedrock_request["source"] = "OUTPUT" if isinstance(response, litellm.ModelResponse): for choice in response.choices: if isinstance(choice, litellm.Choices): if choice.message.content and isinstance( choice.message.content, str ): bedrock_content_item = BedrockContentItem( text=BedrockTextContent(text=choice.message.content) ) bedrock_request_content.append(bedrock_content_item) bedrock_request["content"] = bedrock_request_content return bedrock_request #### CALL HOOKS - proxy only #### def _load_credentials( self, ): try: from botocore.credentials import Credentials except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") ## CREDENTIALS ## aws_secret_access_key = self.optional_params.get("aws_secret_access_key", None) aws_access_key_id = self.optional_params.get("aws_access_key_id", None) aws_session_token = self.optional_params.get("aws_session_token", None) aws_region_name = self.optional_params.get("aws_region_name", None) aws_role_name = self.optional_params.get("aws_role_name", None) aws_session_name = self.optional_params.get("aws_session_name", None) aws_profile_name = self.optional_params.get("aws_profile_name", None) aws_web_identity_token = self.optional_params.get( "aws_web_identity_token", None ) aws_sts_endpoint = self.optional_params.get("aws_sts_endpoint", None) ### SET REGION NAME ### aws_region_name = self.get_aws_region_name_for_non_llm_api_calls( aws_region_name=aws_region_name, ) credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, aws_session_token=aws_session_token, aws_region_name=aws_region_name, aws_session_name=aws_session_name, aws_profile_name=aws_profile_name, aws_role_name=aws_role_name, aws_web_identity_token=aws_web_identity_token, aws_sts_endpoint=aws_sts_endpoint, ) return credentials, aws_region_name def _prepare_request( self, credentials, data: dict, optional_params: dict, aws_region_name: str, extra_headers: Optional[dict] = None, ): try: from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) api_base = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply" encoded_data = json.dumps(data).encode("utf-8") headers = {"Content-Type": "application/json"} if extra_headers is not None: headers = {"Content-Type": "application/json", **extra_headers} request = AWSRequest( method="POST", url=api_base, data=encoded_data, headers=headers ) sigv4.add_auth(request) if ( extra_headers is not None and "Authorization" in extra_headers ): # prevent sigv4 from overwriting the auth header request.headers["Authorization"] = extra_headers["Authorization"] prepped_request = request.prepare() return prepped_request async def make_bedrock_api_request( self, kwargs: dict, response: Optional[Union[Any, litellm.ModelResponse]] = None ) -> BedrockGuardrailResponse: credentials, aws_region_name = self._load_credentials() bedrock_request_data: dict = dict( self.convert_to_bedrock_format( messages=kwargs.get("messages"), response=response ) ) bedrock_guardrail_response: BedrockGuardrailResponse = ( BedrockGuardrailResponse() ) bedrock_request_data.update( self.get_guardrail_dynamic_request_body_params(request_data=kwargs) ) prepared_request = self._prepare_request( credentials=credentials, data=bedrock_request_data, optional_params=self.optional_params, aws_region_name=aws_region_name, ) verbose_proxy_logger.debug( "Bedrock AI request body: %s, url %s, headers: %s", bedrock_request_data, prepared_request.url, prepared_request.headers, ) response = await self.async_handler.post( url=prepared_request.url, data=prepared_request.body, # type: ignore headers=prepared_request.headers, # type: ignore ) verbose_proxy_logger.debug("Bedrock AI response: %s", response.text) if response.status_code == 200: # check if the response was flagged _json_response = response.json() bedrock_guardrail_response = BedrockGuardrailResponse(**_json_response) if self._should_raise_guardrail_blocked_exception( bedrock_guardrail_response ): raise HTTPException( status_code=400, detail={ "error": "Violated guardrail policy", "bedrock_guardrail_response": _json_response, }, ) else: verbose_proxy_logger.error( "Bedrock AI: error in response. Status code: %s, response: %s", response.status_code, response.text, ) return bedrock_guardrail_response def _should_raise_guardrail_blocked_exception( self, response: BedrockGuardrailResponse ) -> bool: """ By default always raise an exception when a guardrail intervention is detected. If `self.mask_request_content` or `self.mask_response_content` is set to `True`, then use the output from the guardrail to mask the request or response content. """ # if user opted into masking, return False. since we'll use the masked output from the guardrail if self.mask_request_content or self.mask_response_content: return False # if intervention, return True if response.get("action") == "GUARDRAIL_INTERVENED": return True # if no intervention, return False return False @log_guardrail_information async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal[ "completion", "text_completion", "embeddings", "image_generation", "moderation", "audio_transcription", "pass_through_endpoint", "rerank", ], ) -> Union[Exception, str, dict, None]: verbose_proxy_logger.debug("Inside AIM Pre-Call Hook") from litellm.proxy.common_utils.callback_utils import ( add_guardrail_to_applied_guardrails_header, ) event_type: GuardrailEventHooks = GuardrailEventHooks.pre_call if self.should_run_guardrail(data=data, event_type=event_type) is not True: return data new_messages: Optional[List[AllMessageValues]] = data.get("messages") if new_messages is None: verbose_proxy_logger.warning( "Bedrock AI: not running guardrail. No messages in data" ) return data ######################################################### ########## 1. Make the Bedrock API request ########## ######################################################### bedrock_guardrail_response = await self.make_bedrock_api_request(kwargs=data) ######################################################### ######################################################### ########## 2. Update the messages with the guardrail response ########## ######################################################### data[ "messages" ] = self._update_messages_with_updated_bedrock_guardrail_response( messages=new_messages, bedrock_guardrail_response=bedrock_guardrail_response, ) ######################################################### ########## 3. Add the guardrail to the applied guardrails header ########## ######################################################### add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) return data @log_guardrail_information async def async_moderation_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, call_type: Literal[ "completion", "embeddings", "image_generation", "moderation", "audio_transcription", "responses", ], ): from litellm.proxy.common_utils.callback_utils import ( add_guardrail_to_applied_guardrails_header, ) event_type: GuardrailEventHooks = GuardrailEventHooks.during_call if self.should_run_guardrail(data=data, event_type=event_type) is not True: return new_messages: Optional[List[AllMessageValues]] = data.get("messages") if new_messages is None: verbose_proxy_logger.warning( "Bedrock AI: not running guardrail. No messages in data" ) return ######################################################### ########## 1. Make the Bedrock API request ########## ######################################################### bedrock_guardrail_response = await self.make_bedrock_api_request(kwargs=data) ######################################################### ######################################################### ########## 2. Update the messages with the guardrail response ########## ######################################################### data[ "messages" ] = self._update_messages_with_updated_bedrock_guardrail_response( messages=new_messages, bedrock_guardrail_response=bedrock_guardrail_response, ) ######################################################### ########## 3. Add the guardrail to the applied guardrails header ########## ######################################################### add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) return data @log_guardrail_information async def async_post_call_success_hook( self, data: dict, user_api_key_dict: UserAPIKeyAuth, response, ): from litellm.proxy.common_utils.callback_utils import ( add_guardrail_to_applied_guardrails_header, ) from litellm.types.guardrails import GuardrailEventHooks if ( self.should_run_guardrail( data=data, event_type=GuardrailEventHooks.post_call ) is not True ): return new_messages: Optional[List[AllMessageValues]] = data.get("messages") if new_messages is None: verbose_proxy_logger.warning( "Bedrock AI: not running guardrail. No messages in data" ) return ######################################################### ########## 1. Make the Bedrock API request ########## ######################################################### bedrock_guardrail_response = await self.make_bedrock_api_request( kwargs=data, response=response ) ######################################################### ######################################################### ########## 2. Update the messages with the guardrail response ########## ######################################################### data[ "messages" ] = self._update_messages_with_updated_bedrock_guardrail_response( messages=new_messages, bedrock_guardrail_response=bedrock_guardrail_response, ) ######################################################### ########## 3. Add the guardrail to the applied guardrails header ########## ######################################################### add_guardrail_to_applied_guardrails_header( request_data=data, guardrail_name=self.guardrail_name ) ########### HELPER FUNCTIONS for bedrock guardrails ############################ ############################################################################## ############################################################################## def _update_messages_with_updated_bedrock_guardrail_response( self, messages: List[AllMessageValues], bedrock_guardrail_response: BedrockGuardrailResponse, ) -> List[AllMessageValues]: """ Use the output from the bedrock guardrail to mask sensitive content in messages. Args: messages: Original list of messages bedrock_guardrail_response: Response from Bedrock guardrail containing masked content Returns: List of messages with content masked according to guardrail response """ # Skip processing if masking is not enabled if not (self.mask_request_content or self.mask_response_content): return messages # Get masked texts from guardrail response masked_texts = self._extract_masked_texts_from_response( bedrock_guardrail_response ) if not masked_texts: return messages # Apply masking to messages using index tracking return self._apply_masking_to_messages( messages=messages, masked_texts=masked_texts ) async def async_post_call_streaming_iterator_hook( self, user_api_key_dict: UserAPIKeyAuth, response: Any, request_data: dict, ) -> AsyncGenerator[ModelResponseStream, None]: """ Process streaming response chunks. Collect content from the stream and make a bedrock api request to get the guardrail response. """ # Import here to avoid circular imports from litellm.llms.base_llm.base_model_iterator import MockResponseIterator from litellm.main import stream_chunk_builder from litellm.types.utils import TextCompletionResponse # Collect all chunks to process them together all_chunks: List[ModelResponseStream] = [] async for chunk in response: all_chunks.append(chunk) assembled_model_response: Optional[ Union[ModelResponse, TextCompletionResponse] ] = stream_chunk_builder( chunks=all_chunks, ) if isinstance(assembled_model_response, ModelResponse): #################################################################### ########## 1. Make the Bedrock Apply Guardrail API request ########## # Bedrock will raise an exception if this violates the guardrail policy ################################################################### await self.make_bedrock_api_request( kwargs=request_data, response=assembled_model_response ) ######################################################################### ########## If guardrail passed, then return the collected chunks ########## ######################################################################### mock_response = MockResponseIterator( model_response=assembled_model_response ) # Return the reconstructed stream async for chunk in mock_response: yield chunk else: for chunk in all_chunks: yield chunk def _extract_masked_texts_from_response( self, bedrock_guardrail_response: BedrockGuardrailResponse ) -> List[str]: """ Extract all masked text outputs from the guardrail response. Args: bedrock_guardrail_response: Response from Bedrock guardrail Returns: List of masked text strings """ masked_output_text: List[str] = [] masked_outputs: Optional[List[BedrockGuardrailOutput]] = ( bedrock_guardrail_response.get("outputs", []) or [] ) if not masked_outputs: verbose_proxy_logger.debug("No masked outputs found in guardrail response") return [] for output in masked_outputs: text_content: Optional[str] = output.get("text") if text_content is not None: masked_output_text.append(text_content) return masked_output_text def _apply_masking_to_messages( self, messages: List[AllMessageValues], masked_texts: List[str] ) -> List[AllMessageValues]: """ Apply masked texts to message content using index tracking. Args: messages: Original messages masked_texts: List of masked text strings from guardrail Returns: Updated messages with masked content """ updated_messages = [] masking_index = 0 for message in messages: new_message = message.copy() content = new_message.get("content") # Skip messages with no content if content is None: updated_messages.append(new_message) continue # Handle string content if isinstance(content, str): if masking_index < len(masked_texts): new_message["content"] = masked_texts[masking_index] masking_index += 1 # Handle list content elif isinstance(content, list): new_message["content"], masking_index = self._mask_content_list( content_list=content, masked_texts=masked_texts, masking_index=masking_index, ) updated_messages.append(new_message) return updated_messages def _mask_content_list( self, content_list: List[Any], masked_texts: List[str], masking_index: int ) -> Tuple[List[Any], int]: """ Apply masking to a list of content items. Args: content_list: List of content items masked_texts: List of masked text strings starting_index: Starting index in the masked_texts list Returns: Updated content list with masked items """ new_content: List[Union[dict, str]] = [] for item in content_list: if isinstance(item, dict) and "text" in item: new_item = item.copy() if masking_index < len(masked_texts): new_item["text"] = masked_texts[masking_index] masking_index += 1 new_content.append(new_item) elif isinstance(item, str): if masking_index < len(masked_texts): item = masked_texts[masking_index] masking_index += 1 if item is not None: new_content.append(item) return new_content, masking_index def get_content_for_message(self, message: AllMessageValues) -> Optional[List[str]]: """ Get the content for a message. For bedrock guardrails we create a list of all the text content in the message. If a message has a list of content items, we flatten the list and return a list of text content. """ message_text_content = [] content = message.get("content") if content is None: return None if isinstance(content, str): message_text_content.append(content) elif isinstance(content, list): for item in content: if isinstance(item, dict) and "text" in item: message_text_content.append(item["text"]) elif isinstance(item, str): message_text_content.append(item) return message_text_content