import streamlit as st import uuid import os import re import sys import uuid from io import BytesIO sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/semantic_search") sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/RAG") sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/utilities") import boto3 import requests from boto3 import Session import botocore.session import json import random import string # import rag_DocumentLoader # import rag_DocumentSearcher import pandas as pd from PIL import Image import shutil import base64 import time import botocore #from langchain.callbacks.base import BaseCallbackHandler #import streamlit_nested_layout #from IPython.display import clear_output, display, display_markdown, Markdown from requests_aws4auth import AWS4Auth #import copali from requests.auth import HTTPBasicAuth import bedrock_agent import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) st.set_page_config( layout="wide", page_icon="images/opensearch_mark_default.png" ) parent_dirname = '/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp' USER_ICON = "images/user.png" AI_ICON = "images/opensearch-twitter-card.png" REGENERATE_ICON = "images/regenerate.png" s3_bucket_ = "pdf-repo-uploads" polly_client = boto3.Session( region_name='us-east-1').client('polly') # Check if the user ID is already stored in the session state if 'user_id' in st.session_state: user_id = st.session_state['user_id'] # If the user ID is not yet stored in the session state, generate a random UUID else: user_id = str(uuid.uuid4()) st.session_state['user_id'] = user_id if 'session_id_' not in st.session_state: st.session_state['session_id_'] = str(uuid.uuid1()) if "chats" not in st.session_state: st.session_state.chats = [ { 'id': 0, 'question': '', 'answer': '' } ] if "questions__" not in st.session_state: st.session_state.questions__ = [] if "answers__" not in st.session_state: st.session_state.answers__ = [] if "input_is_rerank" not in st.session_state: st.session_state.input_is_rerank = True if "input_copali_rerank" not in st.session_state: st.session_state.input_copali_rerank = False if "input_table_with_sql" not in st.session_state: st.session_state.input_table_with_sql = False if "inputs_" not in st.session_state: st.session_state.inputs_ = {} if "input_shopping_query" not in st.session_state: st.session_state.input_shopping_query="get me shoes suitable for trekking" if "input_rag_searchType" not in st.session_state: st.session_state.input_rag_searchType = ["Sparse Search"] region = 'us-east-1' output = [] service = 'es' st.markdown(""" """,unsafe_allow_html=True) def write_logo(): col1, col2, col3 = st.columns([5, 1, 5]) with col2: st.image(AI_ICON, use_column_width='always') def write_top_bar(): col1, col2 = st.columns([77,23]) with col1: st.page_link("app.py", label=":orange[Home]", icon="🏠") st.header("AI Shopping assistant",divider='rainbow') with col2: st.write("") st.write("") clear = st.button("Clear") st.write("") st.write("") return clear clear = write_top_bar() if clear: st.session_state.questions__ = [] st.session_state.answers__ = [] st.session_state.input_shopping_query="" st.session_state.session_id_ = str(uuid.uuid1()) bedrock_agent.delete_memory() def handle_input(): if(st.session_state.input_shopping_query==''): return "" inputs = {} for key in st.session_state: if key.startswith('input_'): inputs[key.removeprefix('input_')] = st.session_state[key] st.session_state.inputs_ = inputs question_with_id = { 'question': inputs["shopping_query"], 'id': len(st.session_state.questions__) } st.session_state.questions__.append(question_with_id) print(inputs) out_ = bedrock_agent.query_(inputs) st.session_state.answers__.append({ 'answer': out_['text'], 'source':out_['source'], 'last_tool':out_['last_tool'], 'id': len(st.session_state.questions__) }) st.session_state.input_shopping_query="" def write_user_message(md): col1, col2 = st.columns([3,97]) with col1: st.image(USER_ICON, use_column_width='always') with col2: st.markdown("
"+md['question']+"
", unsafe_allow_html = True) def render_answer(question,answer,index): col1, col2, col_3 = st.columns([4,74,22]) with col1: st.image(AI_ICON, use_column_width='always') with col2: use_interim_results = False src_dict = {} ans_ = answer['answer'] span_ans = ans_.replace('',"").replace("","") st.markdown("

"+span_ans+"

",unsafe_allow_html = True) if(answer['last_tool']['name'] in ["generate_images","get_relevant_items_for_image","get_relevant_items_for_text","retrieve_with_hybrid_search","retrieve_with_keyword_search","get_any_general_recommendation"]): use_interim_results = True src_dict =json.loads(answer['last_tool']['response'].replace("'",'"')) if(use_interim_results and answer['last_tool']['name']!= 'generate_images' and answer['last_tool']['name']!= 'get_any_general_recommendation'): key_ = answer['last_tool']['name'] st.write("

",unsafe_allow_html = True) img_col1, img_col2, img_col3 = st.columns([30,30,40]) for index,item in enumerate(src_dict[key_]): response_ = requests.get(item['image']) img = Image.open(BytesIO(response_.content)) resizedImg = img.resize((230, 180), Image.Resampling.LANCZOS) if(index ==0): with img_col1: st.image(resizedImg,use_column_width = True,caption = item['title']) if(index ==1): with img_col2: st.image(resizedImg,use_column_width = True,caption = item['title']) if(answer['last_tool']['name'] == "generate_images" or answer['last_tool']['name'] == "get_any_general_recommendation"): st.write("
",unsafe_allow_html = True) gen_img_col1, gen_img_col2,gen_img_col2 = st.columns([30,30,30]) res = src_dict['generate_images'].replace('s3://','') s3_ = boto3.resource('s3', aws_access_key_id=st.secrets['user_access_key'], aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1') key = res.split('/')[1] s3_stream = s3_.Object("bedrock-video-generation-us-east-1-lbxkrh", key).get()['Body'].read() img_ = Image.open(BytesIO(s3_stream)) resizedImg = img_.resize((230, 180), Image.Resampling.LANCZOS) with gen_img_col1: st.image(resizedImg,caption = "Generated image for "+key.split(".")[0],use_column_width = True) st.write("
",unsafe_allow_html = True) colu1,colu2,colu3 = st.columns([4,82,20]) if(answer['source']!={}): with colu2: with st.expander("Agent Traces:"): st.write(answer['source']) #Each answer will have context of the question asked in order to associate the provided feedback with the respective question def write_chat_message(md, q,index): chat = st.container() with chat: render_answer(q,md,index) def render_all(): index = 0 for (q, a) in zip(st.session_state.questions__, st.session_state.answers__): index = index +1 write_user_message(q) write_chat_message(a, q,index) placeholder = st.empty() with placeholder.container(): render_all() st.markdown("") col_2, col_3 = st.columns([75,20]) with col_2: input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_shopping_query") with col_3: play = st.button("Go",on_click=handle_input,key = "play")