Spaces:
Sleeping
Sleeping
# +-------------------------------------------------------------+ | |
# | |
# Add Bedrock Knowledge Base Context to your LLM calls | |
# | |
# +-------------------------------------------------------------+ | |
# Thank you users! We ❤️ you! - Krrish & Ishaan | |
import json | |
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple | |
from litellm._logging import verbose_logger, verbose_proxy_logger | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.integrations.custom_prompt_management import CustomPromptManagement | |
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM | |
from litellm.llms.custom_httpx.http_handler import ( | |
get_async_httpx_client, | |
httpxSpecialProvider, | |
) | |
from litellm.types.integrations.rag.bedrock_knowledgebase import ( | |
BedrockKBContent, | |
BedrockKBGuardrailConfiguration, | |
BedrockKBRequest, | |
BedrockKBResponse, | |
BedrockKBRetrievalConfiguration, | |
BedrockKBRetrievalQuery, | |
BedrockKBRetrievalResult, | |
) | |
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage | |
if TYPE_CHECKING: | |
from litellm.litellm_core_utils.litellm_logging import StandardCallbackDynamicParams | |
else: | |
StandardCallbackDynamicParams = Any | |
class BedrockKnowledgeBaseHook(CustomPromptManagement, BaseAWSLLM): | |
CONTENT_PREFIX_STRING = "Context: \n\n" | |
def __init__( | |
self, | |
**kwargs, | |
): | |
self.async_handler = get_async_httpx_client( | |
llm_provider=httpxSpecialProvider.LoggingCallback | |
) | |
# store kwargs as optional_params | |
self.optional_params = kwargs | |
super().__init__(**kwargs) | |
BaseAWSLLM.__init__(self) | |
async def async_get_chat_completion_prompt( | |
self, | |
model: str, | |
messages: List[AllMessageValues], | |
non_default_params: dict, | |
prompt_id: Optional[str], | |
prompt_variables: Optional[dict], | |
dynamic_callback_params: StandardCallbackDynamicParams, | |
) -> Tuple[str, List[AllMessageValues], dict]: | |
""" | |
Retrieves the context from the Bedrock Knowledge Base and appends it to the messages. | |
""" | |
vector_store_ids = non_default_params.pop("vector_store_ids", None) | |
if vector_store_ids: | |
for vector_store_id in vector_store_ids: | |
response = await self.make_bedrock_kb_retrieve_request( | |
knowledge_base_id=vector_store_id, | |
query=self._get_kb_query_from_messages(messages), | |
) | |
verbose_logger.debug(f"Bedrock Knowledge Base Response: {response}") | |
context_message = ( | |
self.get_chat_completion_message_from_bedrock_kb_response(response) | |
) | |
if context_message is not None: | |
messages.append(context_message) | |
return model, messages, non_default_params | |
def _get_kb_query_from_messages(self, messages: List[AllMessageValues]) -> str: | |
""" | |
Uses the text `content` field of the last message in the list of messages | |
""" | |
if len(messages) == 0: | |
return "" | |
last_message = messages[-1] | |
last_message_content = last_message.get("content", None) | |
if last_message_content is None: | |
return "" | |
if isinstance(last_message_content, str): | |
return last_message_content | |
elif isinstance(last_message_content, list): | |
return "\n".join([item.get("text", "") for item in last_message_content]) | |
return "" | |
def _prepare_request( | |
self, | |
credentials: Any, | |
data: BedrockKBRequest, | |
optional_params: dict, | |
aws_region_name: str, | |
api_base: str, | |
extra_headers: Optional[dict] = None, | |
) -> Any: | |
""" | |
Prepare a signed AWS request. | |
Args: | |
credentials: AWS credentials | |
data: Request data | |
optional_params: Additional parameters | |
aws_region_name: AWS region name | |
api_base: Base API URL | |
extra_headers: Additional headers | |
Returns: | |
AWSRequest: A signed AWS request | |
""" | |
try: | |
from botocore.auth import SigV4Auth | |
from botocore.awsrequest import AWSRequest | |
except ImportError: | |
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") | |
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) | |
encoded_data = json.dumps(data).encode("utf-8") | |
headers = {"Content-Type": "application/json"} | |
if extra_headers is not None: | |
headers = {"Content-Type": "application/json", **extra_headers} | |
request = AWSRequest( | |
method="POST", url=api_base, data=encoded_data, headers=headers | |
) | |
sigv4.add_auth(request) | |
if extra_headers is not None and "Authorization" in extra_headers: | |
# prevent sigv4 from overwriting the auth header | |
request.headers["Authorization"] = extra_headers["Authorization"] | |
return request.prepare() | |
async def make_bedrock_kb_retrieve_request( | |
self, | |
knowledge_base_id: str, | |
query: str, | |
guardrail_id: Optional[str] = None, | |
guardrail_version: Optional[str] = None, | |
next_token: Optional[str] = None, | |
retrieval_configuration: Optional[BedrockKBRetrievalConfiguration] = None, | |
) -> BedrockKBResponse: | |
""" | |
Make a Bedrock Knowledge Base retrieve request. | |
Args: | |
knowledge_base_id (str): The unique identifier of the knowledge base to query | |
query (str): The query text to search for | |
guardrail_id (Optional[str]): The guardrail ID to apply | |
guardrail_version (Optional[str]): The version of the guardrail to apply | |
next_token (Optional[str]): Token for pagination | |
retrieval_configuration (Optional[BedrockKBRetrievalConfiguration]): Configuration for the retrieval process | |
Returns: | |
BedrockKBRetrievalResponse: A typed response object containing the retrieval results | |
""" | |
from fastapi import HTTPException | |
credentials = self.get_credentials() | |
aws_region_name = self._get_aws_region_name( | |
optional_params=self.optional_params | |
) | |
# Prepare request data | |
request_data: BedrockKBRequest = BedrockKBRequest( | |
retrievalQuery=BedrockKBRetrievalQuery(text=query), | |
) | |
if next_token: | |
request_data["nextToken"] = next_token | |
if retrieval_configuration: | |
request_data["retrievalConfiguration"] = retrieval_configuration | |
if guardrail_id and guardrail_version: | |
request_data["guardrailConfiguration"] = BedrockKBGuardrailConfiguration( | |
guardrailId=guardrail_id, guardrailVersion=guardrail_version | |
) | |
verbose_logger.debug( | |
f"Request Data: {json.dumps(request_data, indent=4, default=str)}" | |
) | |
# Prepare the request | |
api_base = f"https://bedrock-agent-runtime.{aws_region_name}.amazonaws.com/knowledgebases/{knowledge_base_id}/retrieve" | |
prepared_request = self._prepare_request( | |
credentials=credentials, | |
data=request_data, | |
optional_params=self.optional_params, | |
aws_region_name=aws_region_name, | |
api_base=api_base, | |
) | |
verbose_proxy_logger.debug( | |
"Bedrock Knowledge Base request body: %s, url %s, headers: %s", | |
request_data, | |
prepared_request.url, | |
prepared_request.headers, | |
) | |
response = await self.async_handler.post( | |
url=prepared_request.url, | |
data=prepared_request.body, # type: ignore | |
headers=prepared_request.headers, # type: ignore | |
) | |
verbose_proxy_logger.debug("Bedrock Knowledge Base response: %s", response.text) | |
if response.status_code == 200: | |
response_data = response.json() | |
return BedrockKBResponse(**response_data) | |
else: | |
verbose_proxy_logger.error( | |
"Bedrock Knowledge Base: error in response. Status code: %s, response: %s", | |
response.status_code, | |
response.text, | |
) | |
raise HTTPException( | |
status_code=response.status_code, | |
detail={ | |
"error": "Error calling Bedrock Knowledge Base", | |
"response": response.text, | |
}, | |
) | |
def should_use_prompt_management_hook(non_default_params: Dict) -> bool: | |
if non_default_params.get("vector_store_ids", None): | |
return True | |
return False | |
def get_initialized_custom_logger( | |
non_default_params: Dict, | |
) -> Optional[CustomLogger]: | |
from litellm.litellm_core_utils.litellm_logging import ( | |
_init_custom_logger_compatible_class, | |
) | |
if BedrockKnowledgeBaseHook.should_use_prompt_management_hook( | |
non_default_params | |
): | |
return _init_custom_logger_compatible_class( | |
logging_integration="bedrock_knowledgebase_hook", | |
internal_usage_cache=None, | |
llm_router=None, | |
) | |
return None | |
def get_chat_completion_message_from_bedrock_kb_response( | |
response: BedrockKBResponse, | |
) -> Optional[ChatCompletionUserMessage]: | |
""" | |
Retrieves the context from the Bedrock Knowledge Base response and returns a ChatCompletionUserMessage object. | |
""" | |
retrieval_results: Optional[List[BedrockKBRetrievalResult]] = response.get( | |
"retrievalResults", None | |
) | |
if retrieval_results is None: | |
return None | |
# string to combine the context from the knowledge base | |
context_string: str = BedrockKnowledgeBaseHook.CONTENT_PREFIX_STRING | |
for retrieval_result in retrieval_results: | |
retrieval_result_content: Optional[BedrockKBContent] = ( | |
retrieval_result.get("content", None) or {} | |
) | |
if retrieval_result_content is None: | |
continue | |
retrieval_result_text: Optional[str] = retrieval_result_content.get( | |
"text", None | |
) | |
if retrieval_result_text is None: | |
continue | |
context_string += retrieval_result_text | |
message = ChatCompletionUserMessage( | |
role="user", | |
content=context_string, | |
) | |
return message | |