import streamlit as st import math import io import uuid import os import sys import boto3 import requests from requests_aws4auth import AWS4Auth sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/semantic_search") sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/RAG") sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/utilities") from boto3 import Session from pathlib import Path import botocore.session import subprocess #import os_index_df_sql import json import random import string from PIL import Image import urllib.request import base64 import shutil import re from requests.auth import HTTPBasicAuth from nltk.stem import PorterStemmer from nltk.tokenize import word_tokenize import query_rewrite import amazon_rekognition from streamlit.components.v1 import html #from st_click_detector import click_detector import llm_eval import all_search_execute import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) st.set_page_config( page_icon="images/opensearch_mark_default.png" ) parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1]) st.markdown(""" """, unsafe_allow_html=True) # st.markdown(""" # # """, unsafe_allow_html=True) ps = PorterStemmer() st.session_state.REGION = 'us-east-1' USER_ICON = "images/user.png" AI_ICON = "images/opensearch-twitter-card.png" REGENERATE_ICON = "images/regenerate.png" IMAGE_ICON = "images/Image_Icon.png" TEXT_ICON = "images/text.png" s3_bucket_ = "pdf-repo-uploads" #"pdf-repo-uploads" # Check if the user ID is already stored in the session state if 'user_id' in st.session_state: user_id = st.session_state['user_id'] print(f"User ID: {user_id}") # If the user ID is not yet stored in the session state, generate a random UUID # else: # user_id = str(uuid.uuid4()) # st.session_state['user_id'] = user_id # dynamodb = boto3.resource('dynamodb') # table = dynamodb.Table('ml-search') if 'session_id' not in st.session_state: st.session_state['session_id'] = "" if 'input_reranker' not in st.session_state: st.session_state['input_reranker'] = "None"#"Cross Encoder" if "chats" not in st.session_state: st.session_state.chats = [ { 'id': 0, 'question': '', 'answer': '' } ] if "questions" not in st.session_state: st.session_state.questions = [] if "input_mvector_rerank" not in st.session_state: st.session_state.input_colBert_rerank = False if "clear_" not in st.session_state: st.session_state.clear_ = False if "input_clear_filter" not in st.session_state: st.session_state.input_clear_filter = False if "radio_disabled" not in st.session_state: st.session_state.radio_disabled = True if "input_rad_1" not in st.session_state: st.session_state.input_rad_1 = "" if "input_manual_filter" not in st.session_state: st.session_state.input_manual_filter = "" if "input_category" not in st.session_state: st.session_state.input_category = None if "input_gender" not in st.session_state: st.session_state.input_gender = None # if "input_price" not in st.session_state: # st.session_state.input_price = (0,0) if "input_sql_query" not in st.session_state: st.session_state.input_sql_query = "" if "input_rewritten_query" not in st.session_state: st.session_state.input_rewritten_query = "" if "input_hybridType" not in st.session_state: st.session_state.input_hybridType = "OpenSearch Hybrid Query" if "ndcg_increase" not in st.session_state: st.session_state.ndcg_increase = " ~ " if "inputs_" not in st.session_state: st.session_state.inputs_ = {} if "img_container" not in st.session_state: st.session_state.img_container = "" if "input_rekog_directoutput" not in st.session_state: st.session_state.input_rekog_directoutput = {} if "input_weightage" not in st.session_state: st.session_state.input_weightage = {} if "img_gen" not in st.session_state: st.session_state.img_gen = [] if "answers" not in st.session_state: st.session_state.answers = [] if "answers_none_rank" not in st.session_state: st.session_state.answers_none_rank = [] if "input_text" not in st.session_state: st.session_state.input_text="black jacket for men"#"black jacket for men under 120 dollars" if "input_ndcg" not in st.session_state: st.session_state.input_ndcg=0.0 if "gen_image_str" not in st.session_state: st.session_state.gen_image_str="" if "input_NormType" not in st.session_state: st.session_state.input_NormType = "min_max" if "input_CombineType" not in st.session_state: st.session_state.input_CombineType = "arithmetic_mean" if "input_sparse" not in st.session_state: st.session_state.input_sparse = "disabled" if "input_evaluate" not in st.session_state: st.session_state.input_evaluate = "disabled" if "input_is_rewrite_query" not in st.session_state: st.session_state.input_is_rewrite_query = "disabled" if "input_rekog_label" not in st.session_state: st.session_state.input_rekog_label = "" if "input_sparse_filter" not in st.session_state: st.session_state.input_sparse_filter = 0.5 if "input_modelType" not in st.session_state: st.session_state.input_modelType = "Titan-Embed-Text-v1" if "input_weight" not in st.session_state: st.session_state.input_weight = 0.5 if "image_prompt2" not in st.session_state: st.session_state.image_prompt2 = "" if "image_prompt" not in st.session_state: st.session_state.image_prompt = "" if "bytes_for_rekog" not in st.session_state: st.session_state.bytes_for_rekog = "" if "OpenSearchDomainEndpoint" not in st.session_state: st.session_state.OpenSearchDomainEndpoint = "search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com" if "max_selections" not in st.session_state: st.session_state.max_selections = "None" if "re_ranker" not in st.session_state: st.session_state.re_ranker = "true" host = 'https://'+st.session_state.OpenSearchDomainEndpoint+'/' service = 'es' #credentials = boto3.Session().get_credentials() awsauth = awsauth = HTTPBasicAuth('master',st.secrets['ml_search_demo_api_access']) headers = {"Content-Type": "application/json"} if "REGION" not in st.session_state: st.session_state.REGION = "" if "BEDROCK_MULTIMODAL_MODEL_ID" not in st.session_state: st.session_state.BEDROCK_MULTIMODAL_MODEL_ID = "p_Qk-ZMBcuw9xT4ly3_B" if "search_types" not in st.session_state: st.session_state.search_types = 'Keyword Search,Vector Search,Multimodal Search,NeuralSparse Search', if "KendraResourcePlanID" not in st.session_state: st.session_state.KendraResourcePlanID= "" if "SAGEMAKER_CrossEncoder_MODEL_ID" not in st.session_state: st.session_state.SAGEMAKER_CrossEncoder_MODEL_ID = "deBS3pYB5VHEj-qVuPHT" if "SAGEMAKER_SPARSE_MODEL_ID" not in st.session_state: st.session_state.SAGEMAKER_SPARSE_MODEL_ID = "fkol-ZMBTp0efWqBcO2P" if "BEDROCK_TEXT_MODEL_ID" not in st.session_state: st.session_state.BEDROCK_TEXT_MODEL_ID = "usQk-ZMBkiQuoz1QFmXN" #bytes_for_rekog = "" bedrock_ = boto3.client('bedrock-runtime', aws_access_key_id=st.secrets['user_access_key'], aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1') search_all_type = True if(search_all_type==True): search_types = ['Keyword Search', 'Vector Search', 'Multimodal Search', 'NeuralSparse Search', ] def generate_images(tab,inp_): #write_top_bar() seed = random.randint(1, 10) request = json.dumps( { "taskType": "TEXT_IMAGE", "textToImageParams": {"text": st.session_state.image_prompt}, "imageGenerationConfig": { "numberOfImages": 3, "quality": "standard", "cfgScale": 8.0, "height": 512, "width": 512, "seed": seed, }, } ) if(inp_!=st.session_state.image_prompt): print("call bedrocck") response = bedrock_.invoke_model( modelId="amazon.titan-image-generator-v1", body=request ) response_body = json.loads(response["body"].read()) st.session_state.img_gen = response_body["images"] gen_images_dir = os.path.join(parent_dirname, "gen_images") if os.path.exists(gen_images_dir): shutil.rmtree(gen_images_dir) os.mkdir(gen_images_dir) width_ = 200 height_ = 200 index_ = 0 #if(inp_!=st.session_state.image_prompt): if(len(st.session_state.img_gen)==0 and st.session_state.clear_ == True): #write_top_bar() placeholder1 = st.empty() with tab: with placeholder1.container(): st.empty() images_dis = [] for image_ in st.session_state.img_gen: st.session_state.radio_disabled = False if(index_==0): # with tab: # rad1, rad2,rad3 = st.columns([98,1,1]) # if(st.session_state.input_rad_1 is None): # rand_ = "" # else: # rand_ = st.session_state.input_rad_1 # if(inp_!=st.session_state.image_prompt+rand_): # with rad1: # sel_rad_1 = st.radio("Choose one image", ["1","2","3"],index=None, horizontal = True,key = 'input_rad_1') with tab: #sel_image = st.radio("", ["1","2","3"],index=None, horizontal = True) if(st.session_state.img_container!=""): st.session_state.img_container.empty() place_ = st.empty() img1, img2,img3 = place_.columns([30,30,30]) st.session_state.img_container = place_ img_arr = [img1, img2,img3] base64_image_data = image_ #st.session_state.gen_image_str = base64_image_data print("perform multimodal search") Image.MAX_IMAGE_PIXELS = 100000000 filename = st.session_state.image_prompt+"_gen_"+str(index_) photo = parent_dirname+"/gen_images/"+filename+'.jpg' # I assume you have a way of picking unique filenames imgdata = base64.b64decode(base64_image_data) with open(photo, 'wb') as f: f.write(imgdata) with Image.open(photo) as image: file_type = 'jpg' path = image.filename.rsplit(".", 1)[0] image.thumbnail((width_, height_)) image.save(parent_dirname+"/gen_images/"+filename+"-resized_display."+file_type) with img_arr[index_]: placeholder_ = st.empty() placeholder_.image(parent_dirname+"/gen_images/"+filename+"-resized_display."+file_type) index_ = index_ + 1 def handle_input(): if("text" in st.session_state.inputs_): if(st.session_state.inputs_["text"] != st.session_state.input_text): st.session_state.input_ndcg=0.0 st.session_state.bytes_for_rekog = "" print("***") if(st.session_state.img_doc is not None or (st.session_state.input_rad_1 is not None and st.session_state.input_rad_1!="") ):#and st.session_state.input_searchType == 'Multi-modal Search'): print("perform multimodal search") st.session_state.input_imageUpload = 'yes' if(st.session_state.input_rad_1 is not None and st.session_state.input_rad_1!=""): num_str = str(int(st.session_state.input_rad_1.strip())-1) with open(parent_dirname+"/gen_images/"+st.session_state.image_prompt+"_gen_"+num_str+"-resized_display.jpg", "rb") as image_file: input_image = base64.b64encode(image_file.read()).decode("utf8") st.session_state.input_image = input_image if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType): st.session_state.bytes_for_rekog = Path(parent_dirname+"/gen_images/"+st.session_state.image_prompt+"_gen_"+num_str+".jpg").read_bytes() else: Image.MAX_IMAGE_PIXELS = 100000000 width = 2048 height = 2048 uploaded_images = os.path.join(parent_dirname, "uploaded_images") if not os.path.exists(uploaded_images): os.mkdir(uploaded_images) with open(os.path.join(parent_dirname+"/uploaded_images",st.session_state.img_doc.name),"wb") as f: f.write(st.session_state.img_doc.getbuffer()) photo = parent_dirname+"/uploaded_images/"+st.session_state.img_doc.name with Image.open(photo) as image: image.verify() with Image.open(photo) as image: width_ = 200 height_ = 200 if image.format.upper() in ["JPEG", "PNG","JPG"]: path = image.filename.rsplit(".", 1)[0] org_file_type = st.session_state.img_doc.name.split(".")[1] image.thumbnail((width, height)) if(org_file_type.upper()=="PNG"): file_type = "jpg" image.convert('RGB').save(f"{path}-resized.{file_type}") else: file_type = org_file_type image.save(f"{path}-resized.{file_type}") image.thumbnail((width_, height_)) image.save(f"{path}-resized_display.{org_file_type}") with open(photo.split(".")[0]+"-resized."+file_type, "rb") as image_file: input_image = base64.b64encode(image_file.read()).decode("utf8") st.session_state.input_image = input_image if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType): st.session_state.bytes_for_rekog = Path(parent_dirname+"/uploaded_images/"+st.session_state.img_doc.name).read_bytes() else: print("no image uploaded") st.session_state.input_imageUpload = 'no' st.session_state.input_image = '' inputs = {} if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType): old_rekog_label = st.session_state.input_rekog_label st.session_state.input_rekog_label = amazon_rekognition.extract_image_metadata(st.session_state.bytes_for_rekog) if(st.session_state.input_text == ""): st.session_state.input_text = st.session_state.input_rekog_label weightage = {} st.session_state.weights_ = [] total_weight = 0.0 counter = 0 num_search = len(st.session_state.input_searchType) any_weight_zero = False for type in st.session_state.input_searchType: key_weight = "input_"+type.split(" ")[0]+"-weight" total_weight = total_weight + st.session_state[key_weight] if(st.session_state[key_weight]==0): any_weight_zero = True print(total_weight) for key in st.session_state: if(key.startswith('input_')): original_key = key.removeprefix('input_') if('weight' not in key): inputs[original_key] = st.session_state[key] else: if(original_key.split("-")[0] + " Search" in st.session_state.input_searchType): counter = counter +1 if(total_weight!=100 or any_weight_zero == True): extra_weight = 100%num_search if(counter == num_search): cal_weight = math.trunc(100/num_search)+extra_weight else: cal_weight = math.trunc(100/num_search) st.session_state[key] = cal_weight weightage[original_key] = cal_weight st.session_state.weights_.append(cal_weight) else: weightage[original_key] = st.session_state[key] st.session_state.weights_.append(st.session_state[key]) else: weightage[original_key] = 0.0 st.session_state[key] = 0.0 inputs['weightage']=weightage st.session_state.input_weightage = weightage st.session_state.inputs_ = inputs question_with_id = { 'question': inputs["text"], 'id': len(st.session_state.questions) } st.session_state.questions = [] st.session_state.questions.append(question_with_id) st.session_state.answers = [] if(st.session_state.input_is_sql_query == 'enabled'): os_index_df_sql.sql_process(st.session_state.input_text) print(st.session_state.input_sql_query) else: st.session_state.input_sql_query = "" if(st.session_state.input_is_rewrite_query == 'enabled' or (st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType)): query_rewrite.get_new_query_res(st.session_state.input_text) else: st.session_state.input_rewritten_query = "" ans__ = all_search_execute.handler(inputs, st.session_state['session_id']) st.session_state.answers.append({ 'answer': ans__, 'search_type':inputs['searchType'], 'id': len(st.session_state.questions) }) st.session_state.answers_none_rank = st.session_state.answers if(st.session_state.input_evaluate == "enabled"): llm_eval.eval(st.session_state.questions, st.session_state.answers) def write_top_bar(): col1, col2,col3,col4 = st.columns([2.5,35,8,7]) with col1: st.image(TEXT_ICON, use_column_width='always') with col2: #st.markdown("") input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_text",placeholder = "Type your query") with col3: play = st.button("Search",on_click=handle_input,key = "play") with col4: clear = st.button("Clear") col5, col6 = st.columns([4.5,95]) with col5: st.image(IMAGE_ICON, use_column_width='always') with col6: with st.expander(':green[Search by using an image]'): tab2, tab1 = st.tabs(["Upload Image","Generate Image by AI"]) with tab1: c1,c2 = st.columns([80,20]) with c1: gen_images=st.text_area("Text2Image:",placeholder = "Enter the text prompt to generate images",height = 68, key = "image_prompt") with c2: st.markdown("
",unsafe_allow_html=True) st.button("Generate",disabled=False,key = "generate",on_click = generate_images, args=(tab1,"default_img")) image_select = st.radio("Choose one image", ["Image 1","Image 2","Image 3"],index=None, horizontal = True,key = 'image_select',disabled = st.session_state.radio_disabled) st.markdown(""" """,unsafe_allow_html=True) if(st.session_state.image_select is not None and st.session_state.image_select !="" and len(st.session_state.img_gen)!=0): st.session_state.input_rad_1 = st.session_state.image_select.split(" ")[1] else: st.session_state.input_rad_1 = "" generate_images(tab1,gen_images) with tab2: st.session_state.img_doc = st.file_uploader( "Upload image", accept_multiple_files=False,type = ['png', 'jpg']) return clear,tab1 clear,tab_ = write_top_bar() if clear: st.session_state.questions = [] st.session_state.answers = [] st.session_state.clear_ = True st.session_state.image_prompt2 = "" st.session_state.input_rekog_label = "" st.session_state.radio_disabled = True if(len(st.session_state.img_gen)!=0): st.session_state.img_container.empty() st.session_state.img_gen = [] st.session_state.input_rad_1 = "" col1, col3, col4 = st.columns([70,18,12]) with col1: if(st.session_state.max_selections == "" or st.session_state.max_selections == "1"): st.session_state.max_selections = 1 if(st.session_state.max_selections == "None"): st.session_state.max_selections = None search_type = st.multiselect('Select the Search type(s)', search_types,['Keyword Search'], max_selections = st.session_state.max_selections, key = 'input_searchType', help = "Select the type of Search, adding more than one search type will activate hybrid search"#\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers)" ) with col3: st.number_input("No. of docs", min_value=1, max_value=50, value=5, step=5, key='input_K', help=None) with col4: st.markdown("'+json.dumps(st.session_state.input_rekog_directoutput)+'
',unsafe_allow_html=True) else: st.markdown("
No results found, please try again with different query
", unsafe_allow_html = True) else: for ans in answer: if('b5/b5319e00' in ans['image_url'] ): filter_out+=1 continue format_ = ans['image_url'].split(".")[-1] Image.MAX_IMAGE_PIXELS = 100000000 width = 500 height = 500 with col_1: inner_col_1,inner_col_2 = st.columns([8,92]) with inner_col_2: st.image(ans['image_url'].replace("/home/ec2-user/SageMaker/","/home/user/")) if('max_score_dict_list_sorted' in ans and 'Vector Search' in st.session_state.input_searchType): desc___ = ans['desc'].split(" ") res___ = [] for o in ans['max_score_dict_list_sorted']: res___.append(o['doc_token']) final_desc_ = "" for word_ in desc___: str_=re.sub('[^A-Za-z0-9]+', '', word_).lower() stemmed_word = next(iter(set(stem_(str_)))) if(stemmed_word in res___ or str_ in res___): if(stemmed_word in res___): mod_word = stemmed_word else: mod_word = str_ if(res___.index(mod_word)==0): final_desc_ += ""+word_+" " elif(res___.index(mod_word)==1): final_desc_ += ""+word_+" " else: final_desc_ += ""+word_+" " else: final_desc_ += word_ + " " final_desc_ += "
" for word in desc__: if(re.sub('[^A-Za-z0-9]+', '', word) in res__): final_desc += ""+word+" " else: final_desc += word + " " final_desc += "
" st.markdown(final_desc,unsafe_allow_html = True) else: st.write(ans['desc']) if("sparse" in ans): with st.expander("Expanded document:"): sparse_ = dict(sorted(ans['sparse'].items(), key=lambda item: item[1],reverse=True)) filtered_sparse = dict() for key in sparse_: if(sparse_[key]>=1.0): filtered_sparse[key] = round(sparse_[key], 2) st.write(filtered_sparse) with st.expander("Document Metadata:",expanded = False): st.write(":green[default:]") st.json({"category:":ans['category'],"price":str(ans['price']),"gender_affinity":ans['gender_affinity'],"style":ans['style']},expanded = True) if("rekog" in ans): st.write(":green[enriched:]") st.json(ans['rekog'],expanded = True) with inner_col_1: if(st.session_state.input_evaluate == "enabled"): with st.container(border = False): if("relevant" in ans.keys()): if(ans['relevant']==True): st.write(":white_check_mark:") else: st.write(":x:") i = i+1 with col_3: if(index == len(st.session_state.questions)): rdn_key = ''.join([random.choice(string.ascii_letters) for _ in range(10)]) currentValue = "".join(st.session_state.input_searchType)+st.session_state.input_imageUpload+json.dumps(st.session_state.input_weightage)+st.session_state.input_NormType+st.session_state.input_CombineType+str(st.session_state.input_K)+st.session_state.input_sparse+st.session_state.input_reranker+st.session_state.input_is_rewrite_query+st.session_state.input_evaluate+st.session_state.input_image+st.session_state.input_rad_1+st.session_state.input_reranker+st.session_state.input_hybridType+st.session_state.input_manual_filter oldValue = "".join(st.session_state.inputs_["searchType"])+st.session_state.inputs_["imageUpload"]+str(st.session_state.inputs_["weightage"])+st.session_state.inputs_["NormType"]+st.session_state.inputs_["CombineType"]+str(st.session_state.inputs_["K"])+st.session_state.inputs_["sparse"]+st.session_state.inputs_["reranker"]+st.session_state.inputs_["is_rewrite_query"]+st.session_state.inputs_["evaluate"]+st.session_state.inputs_["image"]+st.session_state.inputs_["rad_1"]+st.session_state.inputs_["reranker"]+st.session_state.inputs_["hybridType"]+st.session_state.inputs_["manual_filter"] def on_button_click(): if(currentValue!=oldValue): st.session_state.input_text = st.session_state.questions[-1]["question"] st.session_state.answers.pop() st.session_state.questions.pop() handle_input() with placeholder.container(): render_all() if("currentValue" in st.session_state): del st.session_state["currentValue"] try: del regenerate except: pass placeholder__ = st.empty() placeholder__.button("🔄",key=rdn_key,on_click=on_button_click, help = "This will regenerate the responses with new settings that you entered, Note: To see difference in responses, you should change any of the applicable settings")#,type="primary",use_column_width=True) if(filter_out > 0): placeholder_no_results.text(str(filter_out)+" result(s) removed due to missing or in-appropriate content") #Each answer will have context of the question asked in order to associate the provided feedback with the respective question def write_chat_message(md, q,index): if('body' in md['answer']): res = json.loads(md['answer']['body']) else: res = md['answer'] st.session_state['session_id'] = "1234" chat = st.container() with chat: render_answer(res,index) def render_all(): index = 0 for (q, a) in zip(st.session_state.questions, st.session_state.answers): index = index +1 ans_ = st.session_state.answers[0] write_user_message(q,ans_) write_chat_message(a, q,index) placeholder = st.empty() with placeholder.container(): render_all() st.markdown("")