Spaces:
Running
on
T4
Running
on
T4
import boto3 | |
import json | |
import time | |
import zipfile | |
from io import BytesIO | |
import uuid | |
import pprint | |
import logging | |
from PIL import Image | |
import os | |
import base64 | |
import re | |
import requests | |
#import utilities.re_ranker as re_ranker | |
import utilities.invoke_models as invoke_models | |
import streamlit as st | |
import time as t | |
import botocore.exceptions | |
if "inputs_" not in st.session_state: | |
st.session_state.inputs_ = {} | |
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1]) | |
region = 'us-east-1' | |
# setting logger | |
logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# getting boto3 clients for required AWS services | |
#bedrock_agent_client = boto3.client('bedrock-agent',region_name=region) | |
bedrock_agent_runtime_client = boto3.client( | |
'bedrock-agent-runtime', | |
aws_access_key_id=st.secrets['user_access_key'], | |
aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1' | |
) | |
enable_trace:bool = True | |
end_session:bool = False | |
def delete_memory(): | |
response = bedrock_agent_runtime_client.delete_agent_memory( | |
agentAliasId='TSTALIASID', | |
agentId='B4Z7BTURC4' | |
) | |
def query_(inputs): | |
# invoke the agent API | |
agentResponse = bedrock_agent_runtime_client.invoke_agent( | |
inputText=inputs['shopping_query'], | |
agentId='B4Z7BTURC4', | |
agentAliasId='TSTALIASID', | |
sessionId=st.session_state.session_id_, | |
enableTrace=enable_trace, | |
endSession= end_session | |
) | |
logger.info(pprint.pprint(agentResponse)) | |
print("***agent*****response*********") | |
print(agentResponse) | |
event_stream = agentResponse['completion'] | |
total_context = [] | |
last_tool = "" | |
last_tool_name = "" | |
agent_answer = "" | |
try: | |
for event in event_stream: | |
print("***event*********") | |
print(event) | |
if 'trace' in event: | |
print("trace*****total*********") | |
print(event['trace']) | |
if('orchestrationTrace' not in event['trace']['trace']): | |
continue | |
orchestration_trace = event['trace']['trace']['orchestrationTrace'] | |
total_context_item = {} | |
if('modelInvocationOutput' in orchestration_trace and '<tool_name>' in orchestration_trace['modelInvocationOutput']['rawResponse']['content']): | |
total_context_item['tool'] = orchestration_trace['modelInvocationOutput']['rawResponse'] | |
if('rationale' in orchestration_trace): | |
total_context_item['rationale'] = orchestration_trace['rationale']['text'] | |
if('invocationInput' in orchestration_trace): | |
total_context_item['invocationInput'] = orchestration_trace['invocationInput']['actionGroupInvocationInput'] | |
last_tool_name = total_context_item['invocationInput']['function'] | |
if('observation' in orchestration_trace): | |
print("trace****observation******") | |
total_context_item['observation'] = event['trace']['trace']['orchestrationTrace']['observation'] | |
tool_output_last_obs = event['trace']['trace']['orchestrationTrace']['observation'] | |
print(tool_output_last_obs) | |
if(tool_output_last_obs['type'] == 'ACTION_GROUP'): | |
last_tool = tool_output_last_obs['actionGroupInvocationOutput']['text'] | |
if(tool_output_last_obs['type'] == 'FINISH'): | |
agent_answer = tool_output_last_obs['finalResponse']['text'] | |
if('modelInvocationOutput' in orchestration_trace and '<thinking>' in orchestration_trace['modelInvocationOutput']['rawResponse']['content']): | |
total_context_item['thinking'] = orchestration_trace['modelInvocationOutput']['rawResponse'] | |
if(total_context_item!={}): | |
total_context.append(total_context_item) | |
print("total_context------") | |
print(total_context) | |
except botocore.exceptions.EventStreamError as error: | |
raise error | |
return {'text':agent_answer,'source':total_context,'last_tool':{'name':last_tool_name,'response':last_tool}} | |