File size: 4,364 Bytes
2e2dda5
 
 
 
 
 
 
 
 
 
 
 
 
 
5a7796a
2e2dda5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eb03410
2e2dda5
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import boto3
import json
import time
import zipfile
from io import BytesIO
import uuid
import pprint
import logging
print(boto3.__version__)
from PIL import Image 
import os
import base64
import re
import requests 
#import utilities.re_ranker as re_ranker
import utilities.invoke_models as invoke_models
import streamlit as st
import time as t
import botocore.exceptions

if "inputs_" not in st.session_state:
    st.session_state.inputs_ = {}

parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
region = 'us-east-1'
# setting logger
logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
# getting boto3 clients for required AWS services

#bedrock_agent_client = boto3.client('bedrock-agent',region_name=region)
bedrock_agent_runtime_client = boto3.client(
    'bedrock-agent-runtime',
    aws_access_key_id=st.secrets['user_access_key'],
    aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1'
)
enable_trace:bool = True
end_session:bool = False

def delete_memory():
    response = bedrock_agent_runtime_client.delete_agent_memory(
    agentAliasId='TSTALIASID',
    agentId='B4Z7BTURC4'
    )
    
def query_(inputs):
    # invoke the agent API
    agentResponse = bedrock_agent_runtime_client.invoke_agent(
        inputText=inputs['shopping_query'],
        agentId='B4Z7BTURC4',
        agentAliasId='TSTALIASID', 
        sessionId=st.session_state.session_id_,
        enableTrace=enable_trace, 
        endSession= end_session
    )

    logger.info(pprint.pprint(agentResponse))
    print("***agent*****response*********")
    print(agentResponse)
    event_stream = agentResponse['completion']
    total_context = []
    last_tool = ""
    last_tool_name = ""
    agent_answer = ""
    try:
        for event in event_stream:
            print("***event*********")
            print(event)
            if 'trace' in event: 
                print("trace*****total*********")
                print(event['trace'])
                if('orchestrationTrace' not in event['trace']['trace']):
                    continue
                orchestration_trace = event['trace']['trace']['orchestrationTrace']
                total_context_item = {}
                if('modelInvocationOutput' in orchestration_trace and '<tool_name>' in orchestration_trace['modelInvocationOutput']['rawResponse']['content']):
                    total_context_item['tool'] = orchestration_trace['modelInvocationOutput']['rawResponse']
                if('rationale' in orchestration_trace):
                    total_context_item['rationale'] = orchestration_trace['rationale']['text']
                if('invocationInput' in orchestration_trace):
                    total_context_item['invocationInput'] = orchestration_trace['invocationInput']['actionGroupInvocationInput']
                    last_tool_name = total_context_item['invocationInput']['function']
                if('observation'  in orchestration_trace):
                    print("trace****observation******")
                    total_context_item['observation'] = event['trace']['trace']['orchestrationTrace']['observation']
                    tool_output_last_obs = event['trace']['trace']['orchestrationTrace']['observation']
                    print(tool_output_last_obs)
                    if(tool_output_last_obs['type'] == 'ACTION_GROUP'):
                        last_tool = tool_output_last_obs['actionGroupInvocationOutput']['text']
                    if(tool_output_last_obs['type'] == 'FINISH'):   
                        agent_answer = tool_output_last_obs['finalResponse']['text']
                if('modelInvocationOutput' in orchestration_trace and '<thinking>' in orchestration_trace['modelInvocationOutput']['rawResponse']['content']):
                    total_context_item['thinking'] = orchestration_trace['modelInvocationOutput']['rawResponse']
                if(total_context_item!={}):
                    total_context.append(total_context_item)
        print("total_context------")
        print(total_context)    
    except botocore.exceptions.EventStreamError as error:
        raise error
        
    return {'text':agent_answer,'source':total_context,'last_tool':{'name':last_tool_name,'response':last_tool}}