Spaces:
Running
on
T4
Running
on
T4
import streamlit as st | |
import uuid | |
import os | |
import re | |
import sys | |
import uuid | |
from io import BytesIO | |
sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/semantic_search") | |
sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/RAG") | |
sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/utilities") | |
import boto3 | |
import requests | |
from boto3 import Session | |
import botocore.session | |
import json | |
import random | |
import string | |
# import rag_DocumentLoader | |
# import rag_DocumentSearcher | |
import pandas as pd | |
from PIL import Image | |
import shutil | |
import base64 | |
import time | |
import botocore | |
#from langchain.callbacks.base import BaseCallbackHandler | |
#import streamlit_nested_layout | |
#from IPython.display import clear_output, display, display_markdown, Markdown | |
from requests_aws4auth import AWS4Auth | |
#import copali | |
from requests.auth import HTTPBasicAuth | |
import bedrock_agent | |
import warnings | |
warnings.filterwarnings("ignore", category=DeprecationWarning) | |
st.set_page_config( | |
layout="wide", | |
page_icon="images/opensearch_mark_default.png" | |
) | |
parent_dirname = '/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp' | |
USER_ICON = "images/user.png" | |
AI_ICON = "images/opensearch-twitter-card.png" | |
REGENERATE_ICON = "images/regenerate.png" | |
s3_bucket_ = "pdf-repo-uploads" | |
polly_client = boto3.Session( | |
region_name='us-east-1').client('polly') | |
# Check if the user ID is already stored in the session state | |
if 'user_id' in st.session_state: | |
user_id = st.session_state['user_id'] | |
# If the user ID is not yet stored in the session state, generate a random UUID | |
else: | |
user_id = str(uuid.uuid4()) | |
st.session_state['user_id'] = user_id | |
if 'session_id_' not in st.session_state: | |
st.session_state['session_id_'] = str(uuid.uuid1()) | |
if "chats" not in st.session_state: | |
st.session_state.chats = [ | |
{ | |
'id': 0, | |
'question': '', | |
'answer': '' | |
} | |
] | |
if "questions__" not in st.session_state: | |
st.session_state.questions__ = [] | |
if "answers__" not in st.session_state: | |
st.session_state.answers__ = [] | |
if "input_is_rerank" not in st.session_state: | |
st.session_state.input_is_rerank = True | |
if "input_copali_rerank" not in st.session_state: | |
st.session_state.input_copali_rerank = False | |
if "input_table_with_sql" not in st.session_state: | |
st.session_state.input_table_with_sql = False | |
if "inputs_" not in st.session_state: | |
st.session_state.inputs_ = {} | |
if "input_shopping_query" not in st.session_state: | |
st.session_state.input_shopping_query="get me shoes suitable for trekking" | |
if "input_rag_searchType" not in st.session_state: | |
st.session_state.input_rag_searchType = ["Sparse Search"] | |
region = 'us-east-1' | |
output = [] | |
service = 'es' | |
st.markdown(""" | |
<style> | |
[data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{ | |
gap: 0rem; | |
} | |
[data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{ | |
gap: 0rem; | |
} | |
</style> | |
""",unsafe_allow_html=True) | |
def write_logo(): | |
col1, col2, col3 = st.columns([5, 1, 5]) | |
with col2: | |
st.image(AI_ICON, use_column_width='always') | |
def write_top_bar(): | |
col1, col2 = st.columns([77,23]) | |
with col1: | |
st.page_link("app.py", label=":orange[Home]", icon="🏠") | |
st.header("AI Shopping assistant",divider='rainbow') | |
with col2: | |
st.write("") | |
st.write("") | |
clear = st.button("Clear") | |
st.write("") | |
st.write("") | |
return clear | |
clear = write_top_bar() | |
if clear: | |
st.session_state.questions__ = [] | |
st.session_state.answers__ = [] | |
st.session_state.input_shopping_query="" | |
st.session_state.session_id_ = str(uuid.uuid1()) | |
bedrock_agent.delete_memory() | |
def handle_input(): | |
if(st.session_state.input_shopping_query==''): | |
return "" | |
inputs = {} | |
for key in st.session_state: | |
if key.startswith('input_'): | |
inputs[key.removeprefix('input_')] = st.session_state[key] | |
st.session_state.inputs_ = inputs | |
question_with_id = { | |
'question': inputs["shopping_query"], | |
'id': len(st.session_state.questions__) | |
} | |
st.session_state.questions__.append(question_with_id) | |
print(inputs) | |
out_ = bedrock_agent.query_(inputs) | |
st.session_state.answers__.append({ | |
'answer': out_['text'], | |
'source':out_['source'], | |
'last_tool':out_['last_tool'], | |
'id': len(st.session_state.questions__) | |
}) | |
st.session_state.input_shopping_query="" | |
def write_user_message(md): | |
col1, col2 = st.columns([3,97]) | |
with col1: | |
st.image(USER_ICON, use_column_width='always') | |
with col2: | |
st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True) | |
def render_answer(question,answer,index): | |
col1, col2, col_3 = st.columns([4,74,22]) | |
with col1: | |
st.image(AI_ICON, use_column_width='always') | |
with col2: | |
use_interim_results = False | |
src_dict = {} | |
ans_ = answer['answer'] | |
span_ans = ans_.replace('<question>',"<span style='fontSize:18px;color:#f37709;fontStyle:italic;'>").replace("</question>","</span>") | |
st.markdown("<p>"+span_ans+"</p>",unsafe_allow_html = True) | |
if(answer['last_tool']['name'] in ["generate_images","get_relevant_items_for_image","get_relevant_items_for_text","retrieve_with_hybrid_search","retrieve_with_keyword_search","get_any_general_recommendation"]): | |
use_interim_results = True | |
src_dict =json.loads(answer['last_tool']['response'].replace("'",'"')) | |
if(use_interim_results and answer['last_tool']['name']!= 'generate_images' and answer['last_tool']['name']!= 'get_any_general_recommendation'): | |
key_ = answer['last_tool']['name'] | |
st.write("<br><br>",unsafe_allow_html = True) | |
img_col1, img_col2, img_col3 = st.columns([30,30,40]) | |
for index,item in enumerate(src_dict[key_]): | |
response_ = requests.get(item['image']) | |
img = Image.open(BytesIO(response_.content)) | |
resizedImg = img.resize((230, 180), Image.Resampling.LANCZOS) | |
if(index ==0): | |
with img_col1: | |
st.image(resizedImg,use_column_width = True,caption = item['title']) | |
if(index ==1): | |
with img_col2: | |
st.image(resizedImg,use_column_width = True,caption = item['title']) | |
if(answer['last_tool']['name'] == "generate_images" or answer['last_tool']['name'] == "get_any_general_recommendation"): | |
st.write("<br>",unsafe_allow_html = True) | |
gen_img_col1, gen_img_col2,gen_img_col2 = st.columns([30,30,30]) | |
res = src_dict['generate_images'].replace('s3://','') | |
s3_ = boto3.resource('s3', | |
aws_access_key_id=st.secrets['user_access_key'], | |
aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1') | |
key = res.split('/')[1] | |
s3_stream = s3_.Object("bedrock-video-generation-us-east-1-lbxkrh", key).get()['Body'].read() | |
img_ = Image.open(BytesIO(s3_stream)) | |
resizedImg = img_.resize((230, 180), Image.Resampling.LANCZOS) | |
with gen_img_col1: | |
st.image(resizedImg,caption = "Generated image for "+key.split(".")[0],use_column_width = True) | |
st.write("<br>",unsafe_allow_html = True) | |
colu1,colu2,colu3 = st.columns([4,82,20]) | |
if(answer['source']!={}): | |
with colu2: | |
with st.expander("Agent Traces:"): | |
st.write(answer['source']) | |
#Each answer will have context of the question asked in order to associate the provided feedback with the respective question | |
def write_chat_message(md, q,index): | |
chat = st.container() | |
with chat: | |
render_answer(q,md,index) | |
def render_all(): | |
index = 0 | |
for (q, a) in zip(st.session_state.questions__, st.session_state.answers__): | |
index = index +1 | |
write_user_message(q) | |
write_chat_message(a, q,index) | |
placeholder = st.empty() | |
with placeholder.container(): | |
render_all() | |
st.markdown("") | |
col_2, col_3 = st.columns([75,20]) | |
with col_2: | |
input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_shopping_query") | |
with col_3: | |
play = st.button("Go",on_click=handle_input,key = "play") | |