""" Transformation for Bedrock Invoke Agent https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent-runtime_InvokeAgent.html """ import base64 import json import uuid from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import httpx from litellm._logging import verbose_logger from litellm.litellm_core_utils.prompt_templates.common_utils import ( convert_content_list_to_str, ) from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM from litellm.llms.bedrock.common_utils import BedrockError from litellm.types.llms.bedrock_invoke_agents import ( InvokeAgentChunkPayload, InvokeAgentEvent, InvokeAgentEventHeaders, InvokeAgentEventList, InvokeAgentTrace, InvokeAgentTracePayload, InvokeAgentUsage, ) from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import Choices, Message, ModelResponse if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj LiteLLMLoggingObj = _LiteLLMLoggingObj else: LiteLLMLoggingObj = Any class AmazonInvokeAgentConfig(BaseConfig, BaseAWSLLM): def __init__(self, **kwargs): BaseConfig.__init__(self, **kwargs) BaseAWSLLM.__init__(self, **kwargs) def get_supported_openai_params(self, model: str) -> List[str]: """ This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class. Bedrock Invoke Agents has 0 OpenAI compatible params As of May 29th, 2025 - they don't support streaming. """ return [] def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, ) -> dict: """ This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class. """ return optional_params def get_complete_url( self, api_base: Optional[str], api_key: Optional[str], model: str, optional_params: dict, litellm_params: dict, stream: Optional[bool] = None, ) -> str: """ Get the complete url for the request """ ### SET RUNTIME ENDPOINT ### aws_bedrock_runtime_endpoint = optional_params.get( "aws_bedrock_runtime_endpoint", None ) # https://bedrock-runtime.{region_name}.amazonaws.com endpoint_url, _ = self.get_runtime_endpoint( api_base=api_base, aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, aws_region_name=self._get_aws_region_name( optional_params=optional_params, model=model ), endpoint_type="agent", ) agent_id, agent_alias_id = self._get_agent_id_and_alias_id(model) session_id = self._get_session_id(optional_params) endpoint_url = f"{endpoint_url}/agents/{agent_id}/agentAliases/{agent_alias_id}/sessions/{session_id}/text" return endpoint_url def sign_request( self, headers: dict, optional_params: dict, request_data: dict, api_base: str, model: Optional[str] = None, stream: Optional[bool] = None, fake_stream: Optional[bool] = None, ) -> Tuple[dict, Optional[bytes]]: return self._sign_request( service_name="bedrock", headers=headers, optional_params=optional_params, request_data=request_data, api_base=api_base, model=model, stream=stream, fake_stream=fake_stream, ) def _get_agent_id_and_alias_id(self, model: str) -> tuple[str, str]: """ model = "agent/L1RT58GYRW/MFPSBCXYTW" agent_id = "L1RT58GYRW" agent_alias_id = "MFPSBCXYTW" """ # Split the model string by '/' and extract components parts = model.split("/") if len(parts) != 3 or parts[0] != "agent": raise ValueError( "Invalid model format. Expected format: 'model=agent/AGENT_ID/ALIAS_ID'" ) return parts[1], parts[2] # Return (agent_id, agent_alias_id) def _get_session_id(self, optional_params: dict) -> str: """ """ return optional_params.get("sessionID", None) or str(uuid.uuid4()) def transform_request( self, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, headers: dict, ) -> dict: # use the last message content as the query query: str = convert_content_list_to_str(messages[-1]) return { "inputText": query, "enableTrace": True, **optional_params, } def _parse_aws_event_stream(self, raw_content: bytes) -> InvokeAgentEventList: """ Parse AWS event stream format using boto3/botocore's built-in parser. This is the same approach used in the existing AWSEventStreamDecoder. """ try: from botocore.eventstream import EventStreamBuffer from botocore.parsers import EventStreamJSONParser except ImportError: raise ImportError("boto3/botocore is required for AWS event stream parsing") events: InvokeAgentEventList = [] parser = EventStreamJSONParser() event_stream_buffer = EventStreamBuffer() # Add the entire response to the buffer event_stream_buffer.add_data(raw_content) # Process all events in the buffer for event in event_stream_buffer: try: headers = self._extract_headers_from_event(event) event_type = headers.get("event_type", "") if event_type == "chunk": # Handle chunk events specially - they contain decoded content, not JSON message = self._parse_message_from_event(event, parser) parsed_event: InvokeAgentEvent = InvokeAgentEvent() if message: # For chunk events, create a payload with the decoded content parsed_event = { "headers": headers, "payload": { "bytes": base64.b64encode( message.encode("utf-8") ).decode("utf-8") }, # Re-encode for consistency } events.append(parsed_event) elif event_type == "trace": # Handle trace events normally - they contain JSON message = self._parse_message_from_event(event, parser) if message: try: event_data = json.loads(message) parsed_event = { "headers": headers, "payload": event_data, } events.append(parsed_event) except json.JSONDecodeError as e: verbose_logger.warning( f"Failed to parse trace event JSON: {e}" ) else: verbose_logger.debug(f"Unknown event type: {event_type}") except Exception as e: verbose_logger.error(f"Error processing event: {e}") continue return events def _parse_message_from_event(self, event, parser) -> Optional[str]: """Extract message content from an AWS event, adapted from AWSEventStreamDecoder.""" try: response_dict = event.to_response_dict() verbose_logger.debug(f"Response dict: {response_dict}") # Use the same response shape parsing as the existing decoder parsed_response = parser.parse( response_dict, self._get_response_stream_shape() ) verbose_logger.debug(f"Parsed response: {parsed_response}") if response_dict["status_code"] != 200: decoded_body = response_dict["body"].decode() if isinstance(decoded_body, dict): error_message = decoded_body.get("message") elif isinstance(decoded_body, str): error_message = decoded_body else: error_message = "" exception_status = response_dict["headers"].get(":exception-type") error_message = exception_status + " " + error_message raise BedrockError( status_code=response_dict["status_code"], message=( json.dumps(error_message) if isinstance(error_message, dict) else error_message ), ) if "chunk" in parsed_response: chunk = parsed_response.get("chunk") if not chunk: return None return chunk.get("bytes").decode() else: chunk = response_dict.get("body") if not chunk: return None return chunk.decode() except Exception as e: verbose_logger.debug(f"Error parsing message from event: {e}") return None def _extract_headers_from_event(self, event) -> InvokeAgentEventHeaders: """Extract headers from an AWS event for categorization.""" try: response_dict = event.to_response_dict() headers = response_dict.get("headers", {}) # Extract the event-type and content-type headers that we care about return InvokeAgentEventHeaders( event_type=headers.get(":event-type", ""), content_type=headers.get(":content-type", ""), message_type=headers.get(":message-type", ""), ) except Exception as e: verbose_logger.debug(f"Error extracting headers: {e}") return InvokeAgentEventHeaders( event_type="", content_type="", message_type="" ) def _get_response_stream_shape(self): """Get the response stream shape for parsing, reusing existing logic.""" try: # Try to reuse the cached shape from the existing decoder from litellm.llms.bedrock.chat.invoke_handler import ( get_response_stream_shape, ) return get_response_stream_shape() except ImportError: # Fallback: create our own shape try: from botocore.loaders import Loader from botocore.model import ServiceModel loader = Loader() bedrock_service_dict = loader.load_service_model( "bedrock-runtime", "service-2" ) bedrock_service_model = ServiceModel(bedrock_service_dict) return bedrock_service_model.shape_for("ResponseStream") except Exception as e: verbose_logger.warning(f"Could not load response stream shape: {e}") return None def _extract_response_content(self, events: InvokeAgentEventList) -> str: """Extract the final response content from parsed events.""" response_parts = [] for event in events: headers = event.get("headers", {}) payload = event.get("payload") event_type = headers.get( "event_type" ) # Note: using event_type not event-type if event_type == "chunk" and payload: # Extract base64 encoded content from chunk events chunk_payload: InvokeAgentChunkPayload = payload # type: ignore encoded_bytes = chunk_payload.get("bytes", "") if encoded_bytes: try: decoded_content = base64.b64decode(encoded_bytes).decode( "utf-8" ) response_parts.append(decoded_content) except Exception as e: verbose_logger.warning(f"Failed to decode chunk content: {e}") return "".join(response_parts) def _extract_usage_info(self, events: InvokeAgentEventList) -> InvokeAgentUsage: """Extract token usage information from trace events.""" usage_info = InvokeAgentUsage( inputTokens=0, outputTokens=0, model=None, ) response_model: Optional[str] = None for event in events: if not self._is_trace_event(event): continue trace_data = self._get_trace_data(event) if not trace_data: continue verbose_logger.debug(f"Trace event: {trace_data}") # Extract usage from pre-processing trace self._extract_and_update_preprocessing_usage( trace_data=trace_data, usage_info=usage_info, ) # Extract model from orchestration trace if response_model is None: response_model = self._extract_orchestration_model(trace_data) usage_info["model"] = response_model return usage_info def _is_trace_event(self, event: InvokeAgentEvent) -> bool: """Check if the event is a trace event.""" headers = event.get("headers", {}) event_type = headers.get("event_type") payload = event.get("payload") return event_type == "trace" and payload is not None def _get_trace_data(self, event: InvokeAgentEvent) -> Optional[InvokeAgentTrace]: """Extract trace data from a trace event.""" payload = event.get("payload") if not payload: return None trace_payload: InvokeAgentTracePayload = payload # type: ignore return trace_payload.get("trace", {}) def _extract_and_update_preprocessing_usage( self, trace_data: InvokeAgentTrace, usage_info: InvokeAgentUsage ) -> None: """Extract usage information from preprocessing trace.""" pre_processing = trace_data.get("preProcessingTrace", {}) if not pre_processing: return model_output = pre_processing.get("modelInvocationOutput", {}) if not model_output: return metadata = model_output.get("metadata", {}) if not metadata: return usage: Optional[Union[InvokeAgentUsage, Dict]] = metadata.get("usage", {}) if not usage: return usage_info["inputTokens"] += usage.get("inputTokens", 0) usage_info["outputTokens"] += usage.get("outputTokens", 0) def _extract_orchestration_model( self, trace_data: InvokeAgentTrace ) -> Optional[str]: """Extract model information from orchestration trace.""" orchestration_trace = trace_data.get("orchestrationTrace", {}) if not orchestration_trace: return None model_invocation = orchestration_trace.get("modelInvocationInput", {}) if not model_invocation: return None return model_invocation.get("foundationModel") def _build_model_response( self, content: str, model: str, usage_info: InvokeAgentUsage, model_response: ModelResponse, ) -> ModelResponse: """Build the final ModelResponse object.""" # Create the message content message = Message(content=content, role="assistant") # Create choices choice = Choices(finish_reason="stop", index=0, message=message) # Update model response model_response.choices = [choice] model_response.model = usage_info.get("model", model) # Add usage information if available if usage_info: from litellm.types.utils import Usage usage = Usage( prompt_tokens=usage_info.get("inputTokens", 0), completion_tokens=usage_info.get("outputTokens", 0), total_tokens=usage_info.get("inputTokens", 0) + usage_info.get("outputTokens", 0), ) setattr(model_response, "usage", usage) return model_response def transform_response( self, model: str, raw_response: httpx.Response, model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, request_data: dict, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, encoding: Any, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: try: # Get the raw binary content raw_content = raw_response.content verbose_logger.debug( f"Processing {len(raw_content)} bytes of AWS event stream data" ) # Parse the AWS event stream format events = self._parse_aws_event_stream(raw_content) verbose_logger.debug(f"Parsed {len(events)} events from stream") # Extract response content from chunk events content = self._extract_response_content(events) # Extract usage information from trace events usage_info = self._extract_usage_info(events) # Build and return the model response return self._build_model_response( content=content, model=model, usage_info=usage_info, model_response=model_response, ) except Exception as e: verbose_logger.error( f"Error processing Bedrock Invoke Agent response: {str(e)}" ) raise BedrockError( message=f"Error processing response: {str(e)}", status_code=raw_response.status_code, ) def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: return headers def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: return BedrockError(status_code=status_code, message=error_message) def should_fake_stream( self, model: Optional[str], stream: Optional[bool], custom_llm_provider: Optional[str] = None, ) -> bool: return True