|
""" |
|
Manages calling Bedrock's `/converse` API + `/invoke` API |
|
""" |
|
|
|
import copy |
|
import json |
|
import time |
|
import types |
|
import urllib.parse |
|
import uuid |
|
from functools import partial |
|
from typing import ( |
|
Any, |
|
AsyncIterator, |
|
Callable, |
|
Iterator, |
|
List, |
|
Optional, |
|
Tuple, |
|
Union, |
|
cast, |
|
get_args, |
|
) |
|
|
|
import httpx |
|
|
|
import litellm |
|
from litellm import verbose_logger |
|
from litellm._logging import print_verbose |
|
from litellm.caching.caching import InMemoryCache |
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason |
|
from litellm.litellm_core_utils.litellm_logging import Logging |
|
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing |
|
from litellm.litellm_core_utils.prompt_templates.factory import ( |
|
cohere_message_pt, |
|
construct_tool_use_system_prompt, |
|
contains_tag, |
|
custom_prompt, |
|
extract_between_tags, |
|
parse_xml_params, |
|
prompt_factory, |
|
) |
|
from litellm.llms.custom_httpx.http_handler import ( |
|
AsyncHTTPHandler, |
|
HTTPHandler, |
|
_get_httpx_client, |
|
get_async_httpx_client, |
|
) |
|
from litellm.types.llms.bedrock import * |
|
from litellm.types.llms.openai import ( |
|
ChatCompletionToolCallChunk, |
|
ChatCompletionToolCallFunctionChunk, |
|
ChatCompletionUsageBlock, |
|
) |
|
from litellm.types.utils import ChatCompletionMessageToolCall, Choices |
|
from litellm.types.utils import GenericStreamingChunk as GChunk |
|
from litellm.types.utils import ModelResponse, Usage |
|
from litellm.utils import CustomStreamWrapper, get_secret |
|
|
|
from ..base_aws_llm import BaseAWSLLM |
|
from ..common_utils import BedrockError, ModelResponseIterator, get_bedrock_tool_name |
|
|
|
_response_stream_shape_cache = None |
|
bedrock_tool_name_mappings: InMemoryCache = InMemoryCache( |
|
max_size_in_memory=50, default_ttl=600 |
|
) |
|
|
|
|
|
class AmazonCohereChatConfig: |
|
""" |
|
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command-r-plus.html |
|
""" |
|
|
|
documents: Optional[List[Document]] = None |
|
search_queries_only: Optional[bool] = None |
|
preamble: Optional[str] = None |
|
max_tokens: Optional[int] = None |
|
temperature: Optional[float] = None |
|
p: Optional[float] = None |
|
k: Optional[float] = None |
|
prompt_truncation: Optional[str] = None |
|
frequency_penalty: Optional[float] = None |
|
presence_penalty: Optional[float] = None |
|
seed: Optional[int] = None |
|
return_prompt: Optional[bool] = None |
|
stop_sequences: Optional[List[str]] = None |
|
raw_prompting: Optional[bool] = None |
|
|
|
def __init__( |
|
self, |
|
documents: Optional[List[Document]] = None, |
|
search_queries_only: Optional[bool] = None, |
|
preamble: Optional[str] = None, |
|
max_tokens: Optional[int] = None, |
|
temperature: Optional[float] = None, |
|
p: Optional[float] = None, |
|
k: Optional[float] = None, |
|
prompt_truncation: Optional[str] = None, |
|
frequency_penalty: Optional[float] = None, |
|
presence_penalty: Optional[float] = None, |
|
seed: Optional[int] = None, |
|
return_prompt: Optional[bool] = None, |
|
stop_sequences: Optional[str] = None, |
|
raw_prompting: Optional[bool] = None, |
|
) -> None: |
|
locals_ = locals() |
|
for key, value in locals_.items(): |
|
if key != "self" and value is not None: |
|
setattr(self.__class__, key, value) |
|
|
|
@classmethod |
|
def get_config(cls): |
|
return { |
|
k: v |
|
for k, v in cls.__dict__.items() |
|
if not k.startswith("__") |
|
and not isinstance( |
|
v, |
|
( |
|
types.FunctionType, |
|
types.BuiltinFunctionType, |
|
classmethod, |
|
staticmethod, |
|
), |
|
) |
|
and v is not None |
|
} |
|
|
|
def get_supported_openai_params(self) -> List[str]: |
|
return [ |
|
"max_tokens", |
|
"max_completion_tokens", |
|
"stream", |
|
"stop", |
|
"temperature", |
|
"top_p", |
|
"frequency_penalty", |
|
"presence_penalty", |
|
"seed", |
|
"stop", |
|
"tools", |
|
"tool_choice", |
|
] |
|
|
|
def map_openai_params( |
|
self, non_default_params: dict, optional_params: dict |
|
) -> dict: |
|
for param, value in non_default_params.items(): |
|
if param == "max_tokens" or param == "max_completion_tokens": |
|
optional_params["max_tokens"] = value |
|
if param == "stream": |
|
optional_params["stream"] = value |
|
if param == "stop": |
|
if isinstance(value, str): |
|
value = [value] |
|
optional_params["stop_sequences"] = value |
|
if param == "temperature": |
|
optional_params["temperature"] = value |
|
if param == "top_p": |
|
optional_params["p"] = value |
|
if param == "frequency_penalty": |
|
optional_params["frequency_penalty"] = value |
|
if param == "presence_penalty": |
|
optional_params["presence_penalty"] = value |
|
if "seed": |
|
optional_params["seed"] = value |
|
return optional_params |
|
|
|
|
|
async def make_call( |
|
client: Optional[AsyncHTTPHandler], |
|
api_base: str, |
|
headers: dict, |
|
data: str, |
|
model: str, |
|
messages: list, |
|
logging_obj: Logging, |
|
fake_stream: bool = False, |
|
json_mode: Optional[bool] = False, |
|
): |
|
try: |
|
if client is None: |
|
client = get_async_httpx_client( |
|
llm_provider=litellm.LlmProviders.BEDROCK |
|
) |
|
|
|
response = await client.post( |
|
api_base, |
|
headers=headers, |
|
data=data, |
|
stream=not fake_stream, |
|
logging_obj=logging_obj, |
|
) |
|
|
|
if response.status_code != 200: |
|
raise BedrockError(status_code=response.status_code, message=response.text) |
|
|
|
if fake_stream: |
|
model_response: ( |
|
ModelResponse |
|
) = litellm.AmazonConverseConfig()._transform_response( |
|
model=model, |
|
response=response, |
|
model_response=litellm.ModelResponse(), |
|
stream=True, |
|
logging_obj=logging_obj, |
|
optional_params={}, |
|
api_key="", |
|
data=data, |
|
messages=messages, |
|
print_verbose=print_verbose, |
|
encoding=litellm.encoding, |
|
) |
|
completion_stream: Any = MockResponseIterator( |
|
model_response=model_response, json_mode=json_mode |
|
) |
|
else: |
|
decoder = AWSEventStreamDecoder(model=model) |
|
completion_stream = decoder.aiter_bytes( |
|
response.aiter_bytes(chunk_size=1024) |
|
) |
|
|
|
|
|
logging_obj.post_call( |
|
input=messages, |
|
api_key="", |
|
original_response="first stream response received", |
|
additional_args={"complete_input_dict": data}, |
|
) |
|
|
|
return completion_stream |
|
except httpx.HTTPStatusError as err: |
|
error_code = err.response.status_code |
|
raise BedrockError(status_code=error_code, message=err.response.text) |
|
except httpx.TimeoutException: |
|
raise BedrockError(status_code=408, message="Timeout error occurred.") |
|
except Exception as e: |
|
raise BedrockError(status_code=500, message=str(e)) |
|
|
|
|
|
class BedrockLLM(BaseAWSLLM): |
|
""" |
|
Example call |
|
|
|
``` |
|
curl --location --request POST 'https://bedrock-runtime.{aws_region_name}.amazonaws.com/model/{bedrock_model_name}/invoke' \ |
|
--header 'Content-Type: application/json' \ |
|
--header 'Accept: application/json' \ |
|
--user "$AWS_ACCESS_KEY_ID":"$AWS_SECRET_ACCESS_KEY" \ |
|
--aws-sigv4 "aws:amz:us-east-1:bedrock" \ |
|
--data-raw '{ |
|
"prompt": "Hi", |
|
"temperature": 0, |
|
"p": 0.9, |
|
"max_tokens": 4096 |
|
}' |
|
``` |
|
""" |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def convert_messages_to_prompt( |
|
self, model, messages, provider, custom_prompt_dict |
|
) -> Tuple[str, Optional[list]]: |
|
|
|
prompt = "" |
|
chat_history: Optional[list] = None |
|
|
|
if model in custom_prompt_dict: |
|
|
|
model_prompt_details = custom_prompt_dict[model] |
|
prompt = custom_prompt( |
|
role_dict=model_prompt_details["roles"], |
|
initial_prompt_value=model_prompt_details.get( |
|
"initial_prompt_value", "" |
|
), |
|
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), |
|
messages=messages, |
|
) |
|
return prompt, None |
|
|
|
if provider == "anthropic" or provider == "amazon": |
|
prompt = prompt_factory( |
|
model=model, messages=messages, custom_llm_provider="bedrock" |
|
) |
|
elif provider == "mistral": |
|
prompt = prompt_factory( |
|
model=model, messages=messages, custom_llm_provider="bedrock" |
|
) |
|
elif provider == "meta" or provider == "llama": |
|
prompt = prompt_factory( |
|
model=model, messages=messages, custom_llm_provider="bedrock" |
|
) |
|
elif provider == "cohere": |
|
prompt, chat_history = cohere_message_pt(messages=messages) |
|
else: |
|
prompt = "" |
|
for message in messages: |
|
if "role" in message: |
|
if message["role"] == "user": |
|
prompt += f"{message['content']}" |
|
else: |
|
prompt += f"{message['content']}" |
|
else: |
|
prompt += f"{message['content']}" |
|
return prompt, chat_history |
|
|
|
def process_response( |
|
self, |
|
model: str, |
|
response: httpx.Response, |
|
model_response: ModelResponse, |
|
stream: Optional[bool], |
|
logging_obj: Logging, |
|
optional_params: dict, |
|
api_key: str, |
|
data: Union[dict, str], |
|
messages: List, |
|
print_verbose, |
|
encoding, |
|
) -> Union[ModelResponse, CustomStreamWrapper]: |
|
provider = self.get_bedrock_invoke_provider(model) |
|
|
|
logging_obj.post_call( |
|
input=messages, |
|
api_key=api_key, |
|
original_response=response.text, |
|
additional_args={"complete_input_dict": data}, |
|
) |
|
print_verbose(f"raw model_response: {response.text}") |
|
|
|
|
|
try: |
|
completion_response = response.json() |
|
except Exception: |
|
raise BedrockError(message=response.text, status_code=422) |
|
|
|
outputText: Optional[str] = None |
|
try: |
|
if provider == "cohere": |
|
if "text" in completion_response: |
|
outputText = completion_response["text"] |
|
elif "generations" in completion_response: |
|
outputText = completion_response["generations"][0]["text"] |
|
model_response.choices[0].finish_reason = map_finish_reason( |
|
completion_response["generations"][0]["finish_reason"] |
|
) |
|
elif provider == "anthropic": |
|
if model.startswith("anthropic.claude-3"): |
|
json_schemas: dict = {} |
|
_is_function_call = False |
|
|
|
if "tools" in optional_params: |
|
_is_function_call = True |
|
for tool in optional_params["tools"]: |
|
json_schemas[tool["function"]["name"]] = tool[ |
|
"function" |
|
].get("parameters", None) |
|
outputText = completion_response.get("content")[0].get("text", None) |
|
if outputText is not None and contains_tag( |
|
"invoke", outputText |
|
): |
|
function_name = extract_between_tags("tool_name", outputText)[0] |
|
function_arguments_str = extract_between_tags( |
|
"invoke", outputText |
|
)[0].strip() |
|
function_arguments_str = ( |
|
f"<invoke>{function_arguments_str}</invoke>" |
|
) |
|
function_arguments = parse_xml_params( |
|
function_arguments_str, |
|
json_schema=json_schemas.get( |
|
function_name, None |
|
), |
|
) |
|
_message = litellm.Message( |
|
tool_calls=[ |
|
{ |
|
"id": f"call_{uuid.uuid4()}", |
|
"type": "function", |
|
"function": { |
|
"name": function_name, |
|
"arguments": json.dumps(function_arguments), |
|
}, |
|
} |
|
], |
|
content=None, |
|
) |
|
model_response.choices[0].message = _message |
|
model_response._hidden_params["original_response"] = ( |
|
outputText |
|
) |
|
if ( |
|
_is_function_call is True |
|
and stream is not None |
|
and stream is True |
|
): |
|
print_verbose( |
|
"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK" |
|
) |
|
|
|
streaming_model_response = ModelResponse(stream=True) |
|
streaming_model_response.choices[0].finish_reason = getattr( |
|
model_response.choices[0], "finish_reason", "stop" |
|
) |
|
|
|
streaming_choice = litellm.utils.StreamingChoices() |
|
streaming_choice.index = model_response.choices[0].index |
|
_tool_calls = [] |
|
print_verbose( |
|
f"type of model_response.choices[0]: {type(model_response.choices[0])}" |
|
) |
|
print_verbose( |
|
f"type of streaming_choice: {type(streaming_choice)}" |
|
) |
|
if isinstance(model_response.choices[0], litellm.Choices): |
|
if getattr( |
|
model_response.choices[0].message, "tool_calls", None |
|
) is not None and isinstance( |
|
model_response.choices[0].message.tool_calls, list |
|
): |
|
for tool_call in model_response.choices[ |
|
0 |
|
].message.tool_calls: |
|
_tool_call = {**tool_call.dict(), "index": 0} |
|
_tool_calls.append(_tool_call) |
|
delta_obj = litellm.utils.Delta( |
|
content=getattr( |
|
model_response.choices[0].message, "content", None |
|
), |
|
role=model_response.choices[0].message.role, |
|
tool_calls=_tool_calls, |
|
) |
|
streaming_choice.delta = delta_obj |
|
streaming_model_response.choices = [streaming_choice] |
|
completion_stream = ModelResponseIterator( |
|
model_response=streaming_model_response |
|
) |
|
print_verbose( |
|
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" |
|
) |
|
return litellm.CustomStreamWrapper( |
|
completion_stream=completion_stream, |
|
model=model, |
|
custom_llm_provider="cached_response", |
|
logging_obj=logging_obj, |
|
) |
|
|
|
model_response.choices[0].finish_reason = map_finish_reason( |
|
completion_response.get("stop_reason", "") |
|
) |
|
_usage = litellm.Usage( |
|
prompt_tokens=completion_response["usage"]["input_tokens"], |
|
completion_tokens=completion_response["usage"]["output_tokens"], |
|
total_tokens=completion_response["usage"]["input_tokens"] |
|
+ completion_response["usage"]["output_tokens"], |
|
) |
|
setattr(model_response, "usage", _usage) |
|
else: |
|
outputText = completion_response["completion"] |
|
|
|
model_response.choices[0].finish_reason = completion_response[ |
|
"stop_reason" |
|
] |
|
elif provider == "ai21": |
|
outputText = ( |
|
completion_response.get("completions")[0].get("data").get("text") |
|
) |
|
elif provider == "meta" or provider == "llama": |
|
outputText = completion_response["generation"] |
|
elif provider == "mistral": |
|
outputText = completion_response["outputs"][0]["text"] |
|
model_response.choices[0].finish_reason = completion_response[ |
|
"outputs" |
|
][0]["stop_reason"] |
|
else: |
|
outputText = completion_response.get("results")[0].get("outputText") |
|
except Exception as e: |
|
raise BedrockError( |
|
message="Error processing={}, Received error={}".format( |
|
response.text, str(e) |
|
), |
|
status_code=422, |
|
) |
|
|
|
try: |
|
if ( |
|
outputText is not None |
|
and len(outputText) > 0 |
|
and hasattr(model_response.choices[0], "message") |
|
and getattr(model_response.choices[0].message, "tool_calls", None) |
|
is None |
|
): |
|
model_response.choices[0].message.content = outputText |
|
elif ( |
|
hasattr(model_response.choices[0], "message") |
|
and getattr(model_response.choices[0].message, "tool_calls", None) |
|
is not None |
|
): |
|
pass |
|
else: |
|
raise Exception() |
|
except Exception as e: |
|
raise BedrockError( |
|
message="Error parsing received text={}.\nError-{}".format( |
|
outputText, str(e) |
|
), |
|
status_code=response.status_code, |
|
) |
|
|
|
if stream and provider == "ai21": |
|
streaming_model_response = ModelResponse(stream=True) |
|
streaming_model_response.choices[0].finish_reason = model_response.choices[ |
|
0 |
|
].finish_reason |
|
|
|
streaming_choice = litellm.utils.StreamingChoices() |
|
streaming_choice.index = model_response.choices[0].index |
|
delta_obj = litellm.utils.Delta( |
|
content=getattr(model_response.choices[0].message, "content", None), |
|
role=model_response.choices[0].message.role, |
|
) |
|
streaming_choice.delta = delta_obj |
|
streaming_model_response.choices = [streaming_choice] |
|
mri = ModelResponseIterator(model_response=streaming_model_response) |
|
return CustomStreamWrapper( |
|
completion_stream=mri, |
|
model=model, |
|
custom_llm_provider="cached_response", |
|
logging_obj=logging_obj, |
|
) |
|
|
|
|
|
bedrock_input_tokens = response.headers.get( |
|
"x-amzn-bedrock-input-token-count", None |
|
) |
|
bedrock_output_tokens = response.headers.get( |
|
"x-amzn-bedrock-output-token-count", None |
|
) |
|
|
|
prompt_tokens = int( |
|
bedrock_input_tokens or litellm.token_counter(messages=messages) |
|
) |
|
|
|
completion_tokens = int( |
|
bedrock_output_tokens |
|
or litellm.token_counter( |
|
text=model_response.choices[0].message.content, |
|
count_response_tokens=True, |
|
) |
|
) |
|
|
|
model_response.created = int(time.time()) |
|
model_response.model = model |
|
usage = Usage( |
|
prompt_tokens=prompt_tokens, |
|
completion_tokens=completion_tokens, |
|
total_tokens=prompt_tokens + completion_tokens, |
|
) |
|
setattr(model_response, "usage", usage) |
|
|
|
return model_response |
|
|
|
def encode_model_id(self, model_id: str) -> str: |
|
""" |
|
Double encode the model ID to ensure it matches the expected double-encoded format. |
|
Args: |
|
model_id (str): The model ID to encode. |
|
Returns: |
|
str: The double-encoded model ID. |
|
""" |
|
return urllib.parse.quote(model_id, safe="") |
|
|
|
def completion( |
|
self, |
|
model: str, |
|
messages: list, |
|
api_base: Optional[str], |
|
custom_prompt_dict: dict, |
|
model_response: ModelResponse, |
|
print_verbose: Callable, |
|
encoding, |
|
logging_obj: Logging, |
|
optional_params: dict, |
|
acompletion: bool, |
|
timeout: Optional[Union[float, httpx.Timeout]], |
|
litellm_params=None, |
|
logger_fn=None, |
|
extra_headers: Optional[dict] = None, |
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, |
|
) -> Union[ModelResponse, CustomStreamWrapper]: |
|
try: |
|
from botocore.auth import SigV4Auth |
|
from botocore.awsrequest import AWSRequest |
|
from botocore.credentials import Credentials |
|
except ImportError: |
|
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") |
|
|
|
|
|
stream = optional_params.pop("stream", None) |
|
|
|
provider = self.get_bedrock_invoke_provider(model) |
|
modelId = self.get_bedrock_model_id( |
|
model=model, |
|
provider=provider, |
|
optional_params=optional_params, |
|
) |
|
|
|
|
|
|
|
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) |
|
aws_access_key_id = optional_params.pop("aws_access_key_id", None) |
|
aws_session_token = optional_params.pop("aws_session_token", None) |
|
aws_region_name = optional_params.pop("aws_region_name", None) |
|
aws_role_name = optional_params.pop("aws_role_name", None) |
|
aws_session_name = optional_params.pop("aws_session_name", None) |
|
aws_profile_name = optional_params.pop("aws_profile_name", None) |
|
aws_bedrock_runtime_endpoint = optional_params.pop( |
|
"aws_bedrock_runtime_endpoint", None |
|
) |
|
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) |
|
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) |
|
|
|
|
|
if aws_region_name is None: |
|
|
|
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) |
|
|
|
if litellm_aws_region_name is not None and isinstance( |
|
litellm_aws_region_name, str |
|
): |
|
aws_region_name = litellm_aws_region_name |
|
|
|
standard_aws_region_name = get_secret("AWS_REGION", None) |
|
if standard_aws_region_name is not None and isinstance( |
|
standard_aws_region_name, str |
|
): |
|
aws_region_name = standard_aws_region_name |
|
|
|
if aws_region_name is None: |
|
aws_region_name = "us-west-2" |
|
|
|
credentials: Credentials = self.get_credentials( |
|
aws_access_key_id=aws_access_key_id, |
|
aws_secret_access_key=aws_secret_access_key, |
|
aws_session_token=aws_session_token, |
|
aws_region_name=aws_region_name, |
|
aws_session_name=aws_session_name, |
|
aws_profile_name=aws_profile_name, |
|
aws_role_name=aws_role_name, |
|
aws_web_identity_token=aws_web_identity_token, |
|
aws_sts_endpoint=aws_sts_endpoint, |
|
) |
|
|
|
|
|
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint( |
|
api_base=api_base, |
|
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint, |
|
aws_region_name=aws_region_name, |
|
) |
|
|
|
if (stream is not None and stream is True) and provider != "ai21": |
|
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream" |
|
proxy_endpoint_url = ( |
|
f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream" |
|
) |
|
else: |
|
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke" |
|
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke" |
|
|
|
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) |
|
|
|
prompt, chat_history = self.convert_messages_to_prompt( |
|
model, messages, provider, custom_prompt_dict |
|
) |
|
inference_params = copy.deepcopy(optional_params) |
|
json_schemas: dict = {} |
|
if provider == "cohere": |
|
if model.startswith("cohere.command-r"): |
|
|
|
config = litellm.AmazonCohereChatConfig().get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
_data = {"message": prompt, **inference_params} |
|
if chat_history is not None: |
|
_data["chat_history"] = chat_history |
|
data = json.dumps(_data) |
|
else: |
|
|
|
config = litellm.AmazonCohereConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
if stream is True: |
|
inference_params["stream"] = ( |
|
True |
|
) |
|
data = json.dumps({"prompt": prompt, **inference_params}) |
|
elif provider == "anthropic": |
|
if model.startswith("anthropic.claude-3"): |
|
|
|
system_prompt_idx: list[int] = [] |
|
system_messages: list[str] = [] |
|
for idx, message in enumerate(messages): |
|
if message["role"] == "system": |
|
system_messages.append(message["content"]) |
|
system_prompt_idx.append(idx) |
|
if len(system_prompt_idx) > 0: |
|
inference_params["system"] = "\n".join(system_messages) |
|
messages = [ |
|
i for j, i in enumerate(messages) if j not in system_prompt_idx |
|
] |
|
|
|
messages = prompt_factory( |
|
model=model, messages=messages, custom_llm_provider="anthropic_xml" |
|
) |
|
|
|
config = litellm.AmazonAnthropicClaude3Config.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
|
|
if "tools" in inference_params: |
|
_is_function_call = True |
|
for tool in inference_params["tools"]: |
|
json_schemas[tool["function"]["name"]] = tool["function"].get( |
|
"parameters", None |
|
) |
|
tool_calling_system_prompt = construct_tool_use_system_prompt( |
|
tools=inference_params["tools"] |
|
) |
|
inference_params["system"] = ( |
|
inference_params.get("system", "\n") |
|
+ tool_calling_system_prompt |
|
) |
|
inference_params.pop("tools") |
|
data = json.dumps({"messages": messages, **inference_params}) |
|
else: |
|
|
|
config = litellm.AmazonAnthropicConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
data = json.dumps({"prompt": prompt, **inference_params}) |
|
elif provider == "ai21": |
|
|
|
config = litellm.AmazonAI21Config.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
|
|
data = json.dumps({"prompt": prompt, **inference_params}) |
|
elif provider == "mistral": |
|
|
|
config = litellm.AmazonMistralConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
|
|
data = json.dumps({"prompt": prompt, **inference_params}) |
|
elif provider == "amazon": |
|
|
|
config = litellm.AmazonTitanConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
|
|
data = json.dumps( |
|
{ |
|
"inputText": prompt, |
|
"textGenerationConfig": inference_params, |
|
} |
|
) |
|
elif provider == "meta" or provider == "llama": |
|
|
|
config = litellm.AmazonLlamaConfig.get_config() |
|
for k, v in config.items(): |
|
if ( |
|
k not in inference_params |
|
): |
|
inference_params[k] = v |
|
data = json.dumps({"prompt": prompt, **inference_params}) |
|
else: |
|
|
|
logging_obj.pre_call( |
|
input=messages, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": inference_params, |
|
}, |
|
) |
|
raise BedrockError( |
|
status_code=404, |
|
message="Bedrock Invoke HTTPX: Unknown provider={}, model={}. Try calling via converse route - `bedrock/converse/<model>`.".format( |
|
provider, model |
|
), |
|
) |
|
|
|
|
|
|
|
headers = {"Content-Type": "application/json"} |
|
if extra_headers is not None: |
|
headers = {"Content-Type": "application/json", **extra_headers} |
|
request = AWSRequest( |
|
method="POST", url=endpoint_url, data=data, headers=headers |
|
) |
|
sigv4.add_auth(request) |
|
if ( |
|
extra_headers is not None and "Authorization" in extra_headers |
|
): |
|
request.headers["Authorization"] = extra_headers["Authorization"] |
|
prepped = request.prepare() |
|
|
|
|
|
logging_obj.pre_call( |
|
input=messages, |
|
api_key="", |
|
additional_args={ |
|
"complete_input_dict": data, |
|
"api_base": proxy_endpoint_url, |
|
"headers": prepped.headers, |
|
}, |
|
) |
|
|
|
|
|
if acompletion: |
|
if isinstance(client, HTTPHandler): |
|
client = None |
|
if stream is True and provider != "ai21": |
|
return self.async_streaming( |
|
model=model, |
|
messages=messages, |
|
data=data, |
|
api_base=proxy_endpoint_url, |
|
model_response=model_response, |
|
print_verbose=print_verbose, |
|
encoding=encoding, |
|
logging_obj=logging_obj, |
|
optional_params=optional_params, |
|
stream=True, |
|
litellm_params=litellm_params, |
|
logger_fn=logger_fn, |
|
headers=prepped.headers, |
|
timeout=timeout, |
|
client=client, |
|
) |
|
|
|
return self.async_completion( |
|
model=model, |
|
messages=messages, |
|
data=data, |
|
api_base=proxy_endpoint_url, |
|
model_response=model_response, |
|
print_verbose=print_verbose, |
|
encoding=encoding, |
|
logging_obj=logging_obj, |
|
optional_params=optional_params, |
|
stream=stream, |
|
litellm_params=litellm_params, |
|
logger_fn=logger_fn, |
|
headers=prepped.headers, |
|
timeout=timeout, |
|
client=client, |
|
) |
|
|
|
if client is None or isinstance(client, AsyncHTTPHandler): |
|
_params = {} |
|
if timeout is not None: |
|
if isinstance(timeout, float) or isinstance(timeout, int): |
|
timeout = httpx.Timeout(timeout) |
|
_params["timeout"] = timeout |
|
self.client = _get_httpx_client(_params) |
|
else: |
|
self.client = client |
|
if (stream is not None and stream is True) and provider != "ai21": |
|
response = self.client.post( |
|
url=proxy_endpoint_url, |
|
headers=prepped.headers, |
|
data=data, |
|
stream=stream, |
|
logging_obj=logging_obj, |
|
) |
|
|
|
if response.status_code != 200: |
|
raise BedrockError( |
|
status_code=response.status_code, message=str(response.read()) |
|
) |
|
|
|
decoder = AWSEventStreamDecoder(model=model) |
|
|
|
completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) |
|
streaming_response = CustomStreamWrapper( |
|
completion_stream=completion_stream, |
|
model=model, |
|
custom_llm_provider="bedrock", |
|
logging_obj=logging_obj, |
|
) |
|
|
|
|
|
logging_obj.post_call( |
|
input=messages, |
|
api_key="", |
|
original_response=streaming_response, |
|
additional_args={"complete_input_dict": data}, |
|
) |
|
return streaming_response |
|
|
|
try: |
|
response = self.client.post( |
|
url=proxy_endpoint_url, |
|
headers=dict(prepped.headers), |
|
data=data, |
|
logging_obj=logging_obj, |
|
) |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as err: |
|
error_code = err.response.status_code |
|
raise BedrockError(status_code=error_code, message=err.response.text) |
|
except httpx.TimeoutException: |
|
raise BedrockError(status_code=408, message="Timeout error occurred.") |
|
|
|
return self.process_response( |
|
model=model, |
|
response=response, |
|
model_response=model_response, |
|
stream=stream, |
|
logging_obj=logging_obj, |
|
optional_params=optional_params, |
|
api_key="", |
|
data=data, |
|
messages=messages, |
|
print_verbose=print_verbose, |
|
encoding=encoding, |
|
) |
|
|
|
async def async_completion( |
|
self, |
|
model: str, |
|
messages: list, |
|
api_base: str, |
|
model_response: ModelResponse, |
|
print_verbose: Callable, |
|
data: str, |
|
timeout: Optional[Union[float, httpx.Timeout]], |
|
encoding, |
|
logging_obj: Logging, |
|
stream, |
|
optional_params: dict, |
|
litellm_params=None, |
|
logger_fn=None, |
|
headers={}, |
|
client: Optional[AsyncHTTPHandler] = None, |
|
) -> Union[ModelResponse, CustomStreamWrapper]: |
|
if client is None: |
|
_params = {} |
|
if timeout is not None: |
|
if isinstance(timeout, float) or isinstance(timeout, int): |
|
timeout = httpx.Timeout(timeout) |
|
_params["timeout"] = timeout |
|
client = get_async_httpx_client(params=_params, llm_provider=litellm.LlmProviders.BEDROCK) |
|
else: |
|
client = client |
|
|
|
try: |
|
response = await client.post( |
|
api_base, |
|
headers=headers, |
|
data=data, |
|
timeout=timeout, |
|
logging_obj=logging_obj, |
|
) |
|
response.raise_for_status() |
|
except httpx.HTTPStatusError as err: |
|
error_code = err.response.status_code |
|
raise BedrockError(status_code=error_code, message=err.response.text) |
|
except httpx.TimeoutException: |
|
raise BedrockError(status_code=408, message="Timeout error occurred.") |
|
|
|
return self.process_response( |
|
model=model, |
|
response=response, |
|
model_response=model_response, |
|
stream=stream if isinstance(stream, bool) else False, |
|
logging_obj=logging_obj, |
|
api_key="", |
|
data=data, |
|
messages=messages, |
|
print_verbose=print_verbose, |
|
optional_params=optional_params, |
|
encoding=encoding, |
|
) |
|
|
|
@track_llm_api_timing() |
|
async def async_streaming( |
|
self, |
|
model: str, |
|
messages: list, |
|
api_base: str, |
|
model_response: ModelResponse, |
|
print_verbose: Callable, |
|
data: str, |
|
timeout: Optional[Union[float, httpx.Timeout]], |
|
encoding, |
|
logging_obj: Logging, |
|
stream, |
|
optional_params: dict, |
|
litellm_params=None, |
|
logger_fn=None, |
|
headers={}, |
|
client: Optional[AsyncHTTPHandler] = None, |
|
) -> CustomStreamWrapper: |
|
|
|
|
|
streaming_response = CustomStreamWrapper( |
|
completion_stream=None, |
|
make_call=partial( |
|
make_call, |
|
client=client, |
|
api_base=api_base, |
|
headers=headers, |
|
data=data, |
|
model=model, |
|
messages=messages, |
|
logging_obj=logging_obj, |
|
fake_stream=True if "ai21" in api_base else False, |
|
), |
|
model=model, |
|
custom_llm_provider="bedrock", |
|
logging_obj=logging_obj, |
|
) |
|
return streaming_response |
|
|
|
@staticmethod |
|
def get_bedrock_invoke_provider( |
|
model: str, |
|
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]: |
|
""" |
|
Helper function to get the bedrock provider from the model |
|
|
|
handles 2 scenarions: |
|
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` |
|
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` |
|
""" |
|
_split_model = model.split(".")[0] |
|
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL): |
|
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model) |
|
|
|
|
|
provider = BedrockLLM._get_provider_from_model_path(model) |
|
if provider is not None: |
|
return provider |
|
return None |
|
|
|
@staticmethod |
|
def _get_provider_from_model_path( |
|
model_path: str, |
|
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]: |
|
""" |
|
Helper function to get the provider from a model path with format: provider/model-name |
|
|
|
Args: |
|
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name') |
|
|
|
Returns: |
|
Optional[str]: The provider name, or None if no valid provider found |
|
""" |
|
parts = model_path.split("/") |
|
if len(parts) >= 1: |
|
provider = parts[0] |
|
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL): |
|
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider) |
|
return None |
|
|
|
def get_bedrock_model_id( |
|
self, |
|
optional_params: dict, |
|
provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL], |
|
model: str, |
|
) -> str: |
|
modelId = optional_params.pop("model_id", None) |
|
if modelId is not None: |
|
modelId = self.encode_model_id(model_id=modelId) |
|
else: |
|
modelId = model |
|
|
|
if provider == "llama" and "llama/" in modelId: |
|
modelId = self._get_model_id_for_llama_like_model(modelId) |
|
|
|
return modelId |
|
|
|
def _get_model_id_for_llama_like_model( |
|
self, |
|
model: str, |
|
) -> str: |
|
""" |
|
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models |
|
""" |
|
model_id = model.replace("llama/", "") |
|
return self.encode_model_id(model_id=model_id) |
|
|
|
|
|
def get_response_stream_shape(): |
|
global _response_stream_shape_cache |
|
if _response_stream_shape_cache is None: |
|
|
|
from botocore.loaders import Loader |
|
from botocore.model import ServiceModel |
|
|
|
loader = Loader() |
|
bedrock_service_dict = loader.load_service_model("bedrock-runtime", "service-2") |
|
bedrock_service_model = ServiceModel(bedrock_service_dict) |
|
_response_stream_shape_cache = bedrock_service_model.shape_for("ResponseStream") |
|
|
|
return _response_stream_shape_cache |
|
|
|
|
|
class AWSEventStreamDecoder: |
|
def __init__(self, model: str) -> None: |
|
from botocore.parsers import EventStreamJSONParser |
|
|
|
self.model = model |
|
self.parser = EventStreamJSONParser() |
|
self.content_blocks: List[ContentBlockDeltaEvent] = [] |
|
|
|
def check_empty_tool_call_args(self) -> bool: |
|
""" |
|
Check if the tool call block so far has been an empty string |
|
""" |
|
args = "" |
|
|
|
if len(self.content_blocks) == 0: |
|
return False |
|
|
|
if "text" in self.content_blocks[0]: |
|
return False |
|
|
|
for block in self.content_blocks: |
|
if "toolUse" in block: |
|
args += block["toolUse"]["input"] |
|
|
|
if len(args) == 0: |
|
return True |
|
return False |
|
|
|
def converse_chunk_parser(self, chunk_data: dict) -> GChunk: |
|
try: |
|
verbose_logger.debug("\n\nRaw Chunk: {}\n\n".format(chunk_data)) |
|
text = "" |
|
tool_use: Optional[ChatCompletionToolCallChunk] = None |
|
is_finished = False |
|
finish_reason = "" |
|
usage: Optional[ChatCompletionUsageBlock] = None |
|
|
|
index = int(chunk_data.get("contentBlockIndex", 0)) |
|
if "start" in chunk_data: |
|
start_obj = ContentBlockStartEvent(**chunk_data["start"]) |
|
self.content_blocks = [] |
|
if ( |
|
start_obj is not None |
|
and "toolUse" in start_obj |
|
and start_obj["toolUse"] is not None |
|
): |
|
|
|
_response_tool_name = start_obj["toolUse"]["name"] |
|
response_tool_name = get_bedrock_tool_name( |
|
response_tool_name=_response_tool_name |
|
) |
|
tool_use = { |
|
"id": start_obj["toolUse"]["toolUseId"], |
|
"type": "function", |
|
"function": { |
|
"name": response_tool_name, |
|
"arguments": "", |
|
}, |
|
"index": index, |
|
} |
|
elif "delta" in chunk_data: |
|
delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) |
|
self.content_blocks.append(delta_obj) |
|
if "text" in delta_obj: |
|
text = delta_obj["text"] |
|
elif "toolUse" in delta_obj: |
|
tool_use = { |
|
"id": None, |
|
"type": "function", |
|
"function": { |
|
"name": None, |
|
"arguments": delta_obj["toolUse"]["input"], |
|
}, |
|
"index": index, |
|
} |
|
elif ( |
|
"contentBlockIndex" in chunk_data |
|
): |
|
is_empty = self.check_empty_tool_call_args() |
|
if is_empty: |
|
tool_use = { |
|
"id": None, |
|
"type": "function", |
|
"function": { |
|
"name": None, |
|
"arguments": "{}", |
|
}, |
|
"index": chunk_data["contentBlockIndex"], |
|
} |
|
elif "stopReason" in chunk_data: |
|
finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) |
|
is_finished = True |
|
elif "usage" in chunk_data: |
|
usage = ChatCompletionUsageBlock( |
|
prompt_tokens=chunk_data.get("inputTokens", 0), |
|
completion_tokens=chunk_data.get("outputTokens", 0), |
|
total_tokens=chunk_data.get("totalTokens", 0), |
|
) |
|
|
|
response = GChunk( |
|
text=text, |
|
tool_use=tool_use, |
|
is_finished=is_finished, |
|
finish_reason=finish_reason, |
|
usage=usage, |
|
index=index, |
|
) |
|
|
|
if "trace" in chunk_data: |
|
trace = chunk_data.get("trace") |
|
response["provider_specific_fields"] = {"trace": trace} |
|
return response |
|
except Exception as e: |
|
raise Exception("Received streaming error - {}".format(str(e))) |
|
|
|
def _chunk_parser(self, chunk_data: dict) -> GChunk: |
|
text = "" |
|
is_finished = False |
|
finish_reason = "" |
|
if "outputText" in chunk_data: |
|
text = chunk_data["outputText"] |
|
|
|
elif "ai21" in self.model: |
|
text = chunk_data.get("completions")[0].get("data").get("text") |
|
is_finished = True |
|
finish_reason = "stop" |
|
|
|
elif ( |
|
"contentBlockIndex" in chunk_data |
|
or "stopReason" in chunk_data |
|
or "metrics" in chunk_data |
|
or "trace" in chunk_data |
|
): |
|
return self.converse_chunk_parser(chunk_data=chunk_data) |
|
|
|
elif "outputs" in chunk_data: |
|
if ( |
|
len(chunk_data["outputs"]) == 1 |
|
and chunk_data["outputs"][0].get("text", None) is not None |
|
): |
|
text = chunk_data["outputs"][0]["text"] |
|
stop_reason = chunk_data.get("stop_reason", None) |
|
if stop_reason is not None: |
|
is_finished = True |
|
finish_reason = stop_reason |
|
|
|
|
|
elif "generation" in chunk_data: |
|
text = chunk_data["generation"] |
|
|
|
elif "text" in chunk_data: |
|
text = chunk_data["text"] |
|
|
|
elif "finish_reason" in chunk_data: |
|
finish_reason = chunk_data["finish_reason"] |
|
is_finished = True |
|
elif chunk_data.get("completionReason", None): |
|
is_finished = True |
|
finish_reason = chunk_data["completionReason"] |
|
return GChunk( |
|
text=text, |
|
is_finished=is_finished, |
|
finish_reason=finish_reason, |
|
usage=None, |
|
index=0, |
|
tool_use=None, |
|
) |
|
|
|
def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]: |
|
"""Given an iterator that yields lines, iterate over it & yield every event encountered""" |
|
from botocore.eventstream import EventStreamBuffer |
|
|
|
event_stream_buffer = EventStreamBuffer() |
|
for chunk in iterator: |
|
event_stream_buffer.add_data(chunk) |
|
for event in event_stream_buffer: |
|
message = self._parse_message_from_event(event) |
|
if message: |
|
|
|
_data = json.loads(message) |
|
yield self._chunk_parser(chunk_data=_data) |
|
|
|
async def aiter_bytes( |
|
self, iterator: AsyncIterator[bytes] |
|
) -> AsyncIterator[GChunk]: |
|
"""Given an async iterator that yields lines, iterate over it & yield every event encountered""" |
|
from botocore.eventstream import EventStreamBuffer |
|
|
|
event_stream_buffer = EventStreamBuffer() |
|
async for chunk in iterator: |
|
event_stream_buffer.add_data(chunk) |
|
for event in event_stream_buffer: |
|
message = self._parse_message_from_event(event) |
|
if message: |
|
_data = json.loads(message) |
|
yield self._chunk_parser(chunk_data=_data) |
|
|
|
def _parse_message_from_event(self, event) -> Optional[str]: |
|
response_dict = event.to_response_dict() |
|
parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) |
|
|
|
if response_dict["status_code"] != 200: |
|
decoded_body = response_dict["body"].decode() |
|
if isinstance(decoded_body, dict): |
|
error_message = decoded_body.get("message") |
|
elif isinstance(decoded_body, str): |
|
error_message = decoded_body |
|
else: |
|
error_message = "" |
|
exception_status = response_dict["headers"].get(":exception-type") |
|
error_message = exception_status + " " + error_message |
|
raise BedrockError( |
|
status_code=response_dict["status_code"], |
|
message=( |
|
json.dumps(error_message) |
|
if isinstance(error_message, dict) |
|
else error_message |
|
), |
|
) |
|
if "chunk" in parsed_response: |
|
chunk = parsed_response.get("chunk") |
|
if not chunk: |
|
return None |
|
return chunk.get("bytes").decode() |
|
else: |
|
chunk = response_dict.get("body") |
|
if not chunk: |
|
return None |
|
|
|
return chunk.decode() |
|
|
|
|
|
class MockResponseIterator: |
|
def __init__(self, model_response, json_mode: Optional[bool] = False): |
|
self.model_response = model_response |
|
self.json_mode = json_mode |
|
self.is_done = False |
|
|
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def _handle_json_mode_chunk( |
|
self, text: str, tool_calls: Optional[List[ChatCompletionToolCallChunk]] |
|
) -> Tuple[str, Optional[ChatCompletionToolCallChunk]]: |
|
""" |
|
If JSON mode is enabled, convert the tool call to a message. |
|
|
|
Bedrock returns the JSON schema as part of the tool call |
|
OpenAI returns the JSON schema as part of the content, this handles placing it in the content |
|
|
|
Args: |
|
text: str |
|
tool_use: Optional[ChatCompletionToolCallChunk] |
|
Returns: |
|
Tuple[str, Optional[ChatCompletionToolCallChunk]] |
|
|
|
text: The text to use in the content |
|
tool_use: The ChatCompletionToolCallChunk to use in the chunk response |
|
""" |
|
tool_use: Optional[ChatCompletionToolCallChunk] = None |
|
if self.json_mode is True and tool_calls is not None: |
|
message = litellm.AnthropicConfig()._convert_tool_response_to_message( |
|
tool_calls=tool_calls |
|
) |
|
if message is not None: |
|
text = message.content or "" |
|
tool_use = None |
|
elif tool_calls is not None and len(tool_calls) > 0: |
|
tool_use = tool_calls[0] |
|
return text, tool_use |
|
|
|
def _chunk_parser(self, chunk_data: ModelResponse) -> GChunk: |
|
try: |
|
chunk_usage: Usage = getattr(chunk_data, "usage") |
|
text = chunk_data.choices[0].message.content or "" |
|
tool_use = None |
|
_model_response_tool_call = cast( |
|
Optional[List[ChatCompletionMessageToolCall]], |
|
cast(Choices, chunk_data.choices[0]).message.tool_calls, |
|
) |
|
if self.json_mode is True: |
|
text, tool_use = self._handle_json_mode_chunk( |
|
text=text, |
|
tool_calls=chunk_data.choices[0].message.tool_calls, |
|
) |
|
elif _model_response_tool_call is not None: |
|
tool_use = ChatCompletionToolCallChunk( |
|
id=_model_response_tool_call[0].id, |
|
type="function", |
|
function=ChatCompletionToolCallFunctionChunk( |
|
name=_model_response_tool_call[0].function.name, |
|
arguments=_model_response_tool_call[0].function.arguments, |
|
), |
|
index=0, |
|
) |
|
processed_chunk = GChunk( |
|
text=text, |
|
tool_use=tool_use, |
|
is_finished=True, |
|
finish_reason=map_finish_reason( |
|
finish_reason=chunk_data.choices[0].finish_reason or "" |
|
), |
|
usage=ChatCompletionUsageBlock( |
|
prompt_tokens=chunk_usage.prompt_tokens, |
|
completion_tokens=chunk_usage.completion_tokens, |
|
total_tokens=chunk_usage.total_tokens, |
|
), |
|
index=0, |
|
) |
|
return processed_chunk |
|
except Exception as e: |
|
raise ValueError(f"Failed to decode chunk: {chunk_data}. Error: {e}") |
|
|
|
def __next__(self): |
|
if self.is_done: |
|
raise StopIteration |
|
self.is_done = True |
|
return self._chunk_parser(self.model_response) |
|
|
|
|
|
def __aiter__(self): |
|
return self |
|
|
|
async def __anext__(self): |
|
if self.is_done: |
|
raise StopAsyncIteration |
|
self.is_done = True |
|
return self._chunk_parser(self.model_response) |
|
|