OpenSearch-AI / pages /AI_Shopping_Assistant.py
prasadnu's picture
rerank model
eb03410
import streamlit as st
import uuid
import os
import re
import sys
import uuid
from io import BytesIO
sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/semantic_search")
sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/RAG")
sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/utilities")
import boto3
import requests
from boto3 import Session
import botocore.session
import json
import random
import string
# import rag_DocumentLoader
# import rag_DocumentSearcher
import pandas as pd
from PIL import Image
import shutil
import base64
import time
import botocore
#from langchain.callbacks.base import BaseCallbackHandler
#import streamlit_nested_layout
#from IPython.display import clear_output, display, display_markdown, Markdown
from requests_aws4auth import AWS4Auth
#import copali
from requests.auth import HTTPBasicAuth
import bedrock_agent
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
st.set_page_config(
layout="wide",
page_icon="images/opensearch_mark_default.png"
)
parent_dirname = '/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp'
USER_ICON = "images/user.png"
AI_ICON = "images/opensearch-twitter-card.png"
REGENERATE_ICON = "images/regenerate.png"
s3_bucket_ = "pdf-repo-uploads"
polly_client = boto3.Session(
region_name='us-east-1').client('polly')
# Check if the user ID is already stored in the session state
if 'user_id' in st.session_state:
user_id = st.session_state['user_id']
# If the user ID is not yet stored in the session state, generate a random UUID
else:
user_id = str(uuid.uuid4())
st.session_state['user_id'] = user_id
if 'session_id_' not in st.session_state:
st.session_state['session_id_'] = str(uuid.uuid1())
if "chats" not in st.session_state:
st.session_state.chats = [
{
'id': 0,
'question': '',
'answer': ''
}
]
if "questions__" not in st.session_state:
st.session_state.questions__ = []
if "answers__" not in st.session_state:
st.session_state.answers__ = []
if "input_is_rerank" not in st.session_state:
st.session_state.input_is_rerank = True
if "input_copali_rerank" not in st.session_state:
st.session_state.input_copali_rerank = False
if "input_table_with_sql" not in st.session_state:
st.session_state.input_table_with_sql = False
if "inputs_" not in st.session_state:
st.session_state.inputs_ = {}
if "input_shopping_query" not in st.session_state:
st.session_state.input_shopping_query="get me shoes suitable for trekking"
if "input_rag_searchType" not in st.session_state:
st.session_state.input_rag_searchType = ["Sparse Search"]
region = 'us-east-1'
output = []
service = 'es'
st.markdown("""
<style>
[data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
gap: 0rem;
}
[data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
gap: 0rem;
}
</style>
""",unsafe_allow_html=True)
def write_logo():
col1, col2, col3 = st.columns([5, 1, 5])
with col2:
st.image(AI_ICON, use_column_width='always')
def write_top_bar():
col1, col2 = st.columns([77,23])
with col1:
st.page_link("app.py", label=":orange[Home]", icon="🏠")
st.header("AI Shopping assistant",divider='rainbow')
with col2:
st.write("")
st.write("")
clear = st.button("Clear")
st.write("")
st.write("")
return clear
clear = write_top_bar()
if clear:
st.session_state.questions__ = []
st.session_state.answers__ = []
st.session_state.input_shopping_query=""
st.session_state.session_id_ = str(uuid.uuid1())
bedrock_agent.delete_memory()
def handle_input():
if(st.session_state.input_shopping_query==''):
return ""
inputs = {}
for key in st.session_state:
if key.startswith('input_'):
inputs[key.removeprefix('input_')] = st.session_state[key]
st.session_state.inputs_ = inputs
question_with_id = {
'question': inputs["shopping_query"],
'id': len(st.session_state.questions__)
}
st.session_state.questions__.append(question_with_id)
print(inputs)
out_ = bedrock_agent.query_(inputs)
st.session_state.answers__.append({
'answer': out_['text'],
'source':out_['source'],
'last_tool':out_['last_tool'],
'id': len(st.session_state.questions__)
})
st.session_state.input_shopping_query=""
def write_user_message(md):
col1, col2 = st.columns([3,97])
with col1:
st.image(USER_ICON, use_column_width='always')
with col2:
st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
def render_answer(question,answer,index):
col1, col2, col_3 = st.columns([4,74,22])
with col1:
st.image(AI_ICON, use_column_width='always')
with col2:
use_interim_results = False
src_dict = {}
ans_ = answer['answer']
span_ans = ans_.replace('<question>',"<span style='fontSize:18px;color:#f37709;fontStyle:italic;'>").replace("</question>","</span>")
st.markdown("<p>"+span_ans+"</p>",unsafe_allow_html = True)
if(answer['last_tool']['name'] in ["generate_images","get_relevant_items_for_image","get_relevant_items_for_text","retrieve_with_hybrid_search","retrieve_with_keyword_search","get_any_general_recommendation"]):
use_interim_results = True
src_dict =json.loads(answer['last_tool']['response'].replace("'",'"'))
if(use_interim_results and answer['last_tool']['name']!= 'generate_images' and answer['last_tool']['name']!= 'get_any_general_recommendation'):
key_ = answer['last_tool']['name']
st.write("<br><br>",unsafe_allow_html = True)
img_col1, img_col2, img_col3 = st.columns([30,30,40])
for index,item in enumerate(src_dict[key_]):
response_ = requests.get(item['image'])
img = Image.open(BytesIO(response_.content))
resizedImg = img.resize((230, 180), Image.Resampling.LANCZOS)
if(index ==0):
with img_col1:
st.image(resizedImg,use_column_width = True,caption = item['title'])
if(index ==1):
with img_col2:
st.image(resizedImg,use_column_width = True,caption = item['title'])
if(answer['last_tool']['name'] == "generate_images" or answer['last_tool']['name'] == "get_any_general_recommendation"):
st.write("<br>",unsafe_allow_html = True)
gen_img_col1, gen_img_col2,gen_img_col2 = st.columns([30,30,30])
res = src_dict['generate_images'].replace('s3://','')
s3_ = boto3.resource('s3',
aws_access_key_id=st.secrets['user_access_key'],
aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
key = res.split('/')[1]
s3_stream = s3_.Object("bedrock-video-generation-us-east-1-lbxkrh", key).get()['Body'].read()
img_ = Image.open(BytesIO(s3_stream))
resizedImg = img_.resize((230, 180), Image.Resampling.LANCZOS)
with gen_img_col1:
st.image(resizedImg,caption = "Generated image for "+key.split(".")[0],use_column_width = True)
st.write("<br>",unsafe_allow_html = True)
colu1,colu2,colu3 = st.columns([4,82,20])
if(answer['source']!={}):
with colu2:
with st.expander("Agent Traces:"):
st.write(answer['source'])
#Each answer will have context of the question asked in order to associate the provided feedback with the respective question
def write_chat_message(md, q,index):
chat = st.container()
with chat:
render_answer(q,md,index)
def render_all():
index = 0
for (q, a) in zip(st.session_state.questions__, st.session_state.answers__):
index = index +1
write_user_message(q)
write_chat_message(a, q,index)
placeholder = st.empty()
with placeholder.container():
render_all()
st.markdown("")
col_2, col_3 = st.columns([75,20])
with col_2:
input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_shopping_query")
with col_3:
play = st.button("Go",on_click=handle_input,key = "play")