import streamlit as st import math import io import uuid import os import sys import boto3 import requests from requests_aws4auth import AWS4Auth sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/semantic_search") sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/RAG") sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/utilities") from boto3 import Session from pathlib import Path import botocore.session import subprocess #import os_index_df_sql import json import random import string from PIL import Image import urllib.request import base64 import shutil import re from requests.auth import HTTPBasicAuth import nltk try: nltk.data.find("tokenizers/punkt") except LookupError: nltk.download("punkt") from nltk.stem import PorterStemmer from nltk.tokenize import word_tokenize import query_rewrite import amazon_rekognition from streamlit.components.v1 import html #from st_click_detector import click_detector import llm_eval import all_search_execute import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) st.set_page_config( page_icon="images/opensearch_mark_default.png" ) parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1]) st.markdown(""" """, unsafe_allow_html=True) # st.markdown(""" # # """, unsafe_allow_html=True) ps = PorterStemmer() st.session_state.REGION = 'us-east-1' USER_ICON = "images/user.png" AI_ICON = "images/opensearch-twitter-card.png" REGENERATE_ICON = "images/regenerate.png" IMAGE_ICON = "images/Image_Icon.png" TEXT_ICON = "images/text.png" s3_bucket_ = "pdf-repo-uploads" #"pdf-repo-uploads" # Check if the user ID is already stored in the session state if 'user_id' in st.session_state: user_id = st.session_state['user_id'] print(f"User ID: {user_id}") # If the user ID is not yet stored in the session state, generate a random UUID # else: # user_id = str(uuid.uuid4()) # st.session_state['user_id'] = user_id # dynamodb = boto3.resource('dynamodb') # table = dynamodb.Table('ml-search') if 'session_id' not in st.session_state: st.session_state['session_id'] = "" if 'input_reranker' not in st.session_state: st.session_state['input_reranker'] = "None"#"Cross Encoder" if "chats" not in st.session_state: st.session_state.chats = [ { 'id': 0, 'question': '', 'answer': '' } ] if "questions" not in st.session_state: st.session_state.questions = [] if "input_mvector_rerank" not in st.session_state: st.session_state.input_colBert_rerank = False if "clear_" not in st.session_state: st.session_state.clear_ = False if "input_clear_filter" not in st.session_state: st.session_state.input_clear_filter = False if "radio_disabled" not in st.session_state: st.session_state.radio_disabled = True if "input_rad_1" not in st.session_state: st.session_state.input_rad_1 = "" if "input_manual_filter" not in st.session_state: st.session_state.input_manual_filter = "" if "input_category" not in st.session_state: st.session_state.input_category = None if "input_gender" not in st.session_state: st.session_state.input_gender = None # if "input_price" not in st.session_state: # st.session_state.input_price = (0,0) if "input_sql_query" not in st.session_state: st.session_state.input_sql_query = "" if "input_rewritten_query" not in st.session_state: st.session_state.input_rewritten_query = "" if "input_hybridType" not in st.session_state: st.session_state.input_hybridType = "OpenSearch Hybrid Query" if "ndcg_increase" not in st.session_state: st.session_state.ndcg_increase = " ~ " if "inputs_" not in st.session_state: st.session_state.inputs_ = {} if "img_container" not in st.session_state: st.session_state.img_container = "" if "input_rekog_directoutput" not in st.session_state: st.session_state.input_rekog_directoutput = {} if "input_weightage" not in st.session_state: st.session_state.input_weightage = {} if "img_gen" not in st.session_state: st.session_state.img_gen = [] if "answers" not in st.session_state: st.session_state.answers = [] if "answers_none_rank" not in st.session_state: st.session_state.answers_none_rank = [] if "input_text" not in st.session_state: st.session_state.input_text="black jacket for men"#"black jacket for men under 120 dollars" if "input_ndcg" not in st.session_state: st.session_state.input_ndcg=0.0 if "gen_image_str" not in st.session_state: st.session_state.gen_image_str="" if "input_NormType" not in st.session_state: st.session_state.input_NormType = "min_max" if "input_CombineType" not in st.session_state: st.session_state.input_CombineType = "arithmetic_mean" if "input_sparse" not in st.session_state: st.session_state.input_sparse = "disabled" if "input_evaluate" not in st.session_state: st.session_state.input_evaluate = "disabled" if "input_is_rewrite_query" not in st.session_state: st.session_state.input_is_rewrite_query = "disabled" if "input_rekog_label" not in st.session_state: st.session_state.input_rekog_label = "" if "input_sparse_filter" not in st.session_state: st.session_state.input_sparse_filter = 0.5 if "input_modelType" not in st.session_state: st.session_state.input_modelType = "Titan-Embed-Text-v1" if "input_weight" not in st.session_state: st.session_state.input_weight = 0.5 if "image_prompt2" not in st.session_state: st.session_state.image_prompt2 = "" if "image_prompt" not in st.session_state: st.session_state.image_prompt = "" if "bytes_for_rekog" not in st.session_state: st.session_state.bytes_for_rekog = "" if "OpenSearchDomainEndpoint" not in st.session_state: st.session_state.OpenSearchDomainEndpoint = "search-opensearchservi-shjckef2t7wo-iyv6rajdgxg6jas25aupuxev6i.us-west-2.es.amazonaws.com" if "max_selections" not in st.session_state: st.session_state.max_selections = "None" if "re_ranker" not in st.session_state: st.session_state.re_ranker = "true" host = 'https://'+st.session_state.OpenSearchDomainEndpoint+'/' service = 'es' #credentials = boto3.Session().get_credentials() awsauth = awsauth = HTTPBasicAuth('master',st.secrets['ml_search_demo_api_access']) headers = {"Content-Type": "application/json"} if "REGION" not in st.session_state: st.session_state.REGION = "" if "BEDROCK_MULTIMODAL_MODEL_ID" not in st.session_state: st.session_state.BEDROCK_MULTIMODAL_MODEL_ID = "p_Qk-ZMBcuw9xT4ly3_B" if "search_types" not in st.session_state: st.session_state.search_types = 'Keyword Search,Vector Search,Multimodal Search,NeuralSparse Search', if "KendraResourcePlanID" not in st.session_state: st.session_state.KendraResourcePlanID= "" if "SAGEMAKER_CrossEncoder_MODEL_ID" not in st.session_state: st.session_state.SAGEMAKER_CrossEncoder_MODEL_ID = "deBS3pYB5VHEj-qVuPHT" if "SAGEMAKER_SPARSE_MODEL_ID" not in st.session_state: st.session_state.SAGEMAKER_SPARSE_MODEL_ID = "fkol-ZMBTp0efWqBcO2P" if "BEDROCK_TEXT_MODEL_ID" not in st.session_state: st.session_state.BEDROCK_TEXT_MODEL_ID = "usQk-ZMBkiQuoz1QFmXN" #bytes_for_rekog = "" bedrock_ = boto3.client('bedrock-runtime', aws_access_key_id=st.secrets['user_access_key'], aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1') search_all_type = True if(search_all_type==True): search_types = ['Keyword Search', 'Vector Search', 'Multimodal Search', 'NeuralSparse Search', ] def generate_images(tab,inp_): #write_top_bar() seed = random.randint(1, 10) request = json.dumps( { "taskType": "TEXT_IMAGE", "textToImageParams": {"text": st.session_state.image_prompt}, "imageGenerationConfig": { "numberOfImages": 3, "quality": "standard", "cfgScale": 8.0, "height": 512, "width": 512, "seed": seed, }, } ) if(inp_!=st.session_state.image_prompt): print("call bedrocck") response = bedrock_.invoke_model( modelId="amazon.titan-image-generator-v1", body=request ) response_body = json.loads(response["body"].read()) st.session_state.img_gen = response_body["images"] gen_images_dir = os.path.join(parent_dirname, "gen_images") if os.path.exists(gen_images_dir): shutil.rmtree(gen_images_dir) os.mkdir(gen_images_dir) width_ = 200 height_ = 200 index_ = 0 #if(inp_!=st.session_state.image_prompt): if(len(st.session_state.img_gen)==0 and st.session_state.clear_ == True): #write_top_bar() placeholder1 = st.empty() with tab: with placeholder1.container(): st.empty() images_dis = [] for image_ in st.session_state.img_gen: st.session_state.radio_disabled = False if(index_==0): # with tab: # rad1, rad2,rad3 = st.columns([98,1,1]) # if(st.session_state.input_rad_1 is None): # rand_ = "" # else: # rand_ = st.session_state.input_rad_1 # if(inp_!=st.session_state.image_prompt+rand_): # with rad1: # sel_rad_1 = st.radio("Choose one image", ["1","2","3"],index=None, horizontal = True,key = 'input_rad_1') with tab: #sel_image = st.radio("", ["1","2","3"],index=None, horizontal = True) if(st.session_state.img_container!=""): st.session_state.img_container.empty() place_ = st.empty() img1, img2,img3 = place_.columns([30,30,30]) st.session_state.img_container = place_ img_arr = [img1, img2,img3] base64_image_data = image_ #st.session_state.gen_image_str = base64_image_data print("perform multimodal search") Image.MAX_IMAGE_PIXELS = 100000000 filename = st.session_state.image_prompt+"_gen_"+str(index_) photo = parent_dirname+"/gen_images/"+filename+'.jpg' # I assume you have a way of picking unique filenames imgdata = base64.b64decode(base64_image_data) with open(photo, 'wb') as f: f.write(imgdata) with Image.open(photo) as image: file_type = 'jpg' path = image.filename.rsplit(".", 1)[0] image.thumbnail((width_, height_)) image.save(parent_dirname+"/gen_images/"+filename+"-resized_display."+file_type) with img_arr[index_]: placeholder_ = st.empty() placeholder_.image(parent_dirname+"/gen_images/"+filename+"-resized_display."+file_type) index_ = index_ + 1 def handle_input(): if("text" in st.session_state.inputs_): if(st.session_state.inputs_["text"] != st.session_state.input_text): st.session_state.input_ndcg=0.0 st.session_state.bytes_for_rekog = "" print("***") if(st.session_state.img_doc is not None or (st.session_state.input_rad_1 is not None and st.session_state.input_rad_1!="") ):#and st.session_state.input_searchType == 'Multi-modal Search'): print("perform multimodal search") st.session_state.input_imageUpload = 'yes' if(st.session_state.input_rad_1 is not None and st.session_state.input_rad_1!=""): num_str = str(int(st.session_state.input_rad_1.strip())-1) with open(parent_dirname+"/gen_images/"+st.session_state.image_prompt+"_gen_"+num_str+"-resized_display.jpg", "rb") as image_file: input_image = base64.b64encode(image_file.read()).decode("utf8") st.session_state.input_image = input_image if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType): st.session_state.bytes_for_rekog = Path(parent_dirname+"/gen_images/"+st.session_state.image_prompt+"_gen_"+num_str+".jpg").read_bytes() else: Image.MAX_IMAGE_PIXELS = 100000000 width = 2048 height = 2048 uploaded_images = os.path.join(parent_dirname, "uploaded_images") if not os.path.exists(uploaded_images): os.mkdir(uploaded_images) with open(os.path.join(parent_dirname+"/uploaded_images",st.session_state.img_doc.name),"wb") as f: f.write(st.session_state.img_doc.getbuffer()) photo = parent_dirname+"/uploaded_images/"+st.session_state.img_doc.name with Image.open(photo) as image: image.verify() with Image.open(photo) as image: width_ = 200 height_ = 200 if image.format.upper() in ["JPEG", "PNG","JPG"]: path = image.filename.rsplit(".", 1)[0] org_file_type = st.session_state.img_doc.name.split(".")[1] image.thumbnail((width, height)) if(org_file_type.upper()=="PNG"): file_type = "jpg" image.convert('RGB').save(f"{path}-resized.{file_type}") else: file_type = org_file_type image.save(f"{path}-resized.{file_type}") image.thumbnail((width_, height_)) image.save(f"{path}-resized_display.{org_file_type}") with open(photo.split(".")[0]+"-resized."+file_type, "rb") as image_file: input_image = base64.b64encode(image_file.read()).decode("utf8") st.session_state.input_image = input_image if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType): st.session_state.bytes_for_rekog = Path(parent_dirname+"/uploaded_images/"+st.session_state.img_doc.name).read_bytes() else: print("no image uploaded") st.session_state.input_imageUpload = 'no' st.session_state.input_image = '' inputs = {} if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType): old_rekog_label = st.session_state.input_rekog_label st.session_state.input_rekog_label = amazon_rekognition.extract_image_metadata(st.session_state.bytes_for_rekog) if(st.session_state.input_text == ""): st.session_state.input_text = st.session_state.input_rekog_label weightage = {} st.session_state.weights_ = [] total_weight = 0.0 counter = 0 num_search = len(st.session_state.input_searchType) any_weight_zero = False for type in st.session_state.input_searchType: key_weight = "input_"+type.split(" ")[0]+"-weight" total_weight = total_weight + st.session_state[key_weight] if(st.session_state[key_weight]==0): any_weight_zero = True print(total_weight) for key in st.session_state: if(key.startswith('input_')): original_key = key.removeprefix('input_') if('weight' not in key): inputs[original_key] = st.session_state[key] else: if(original_key.split("-")[0] + " Search" in st.session_state.input_searchType): counter = counter +1 if(total_weight!=100 or any_weight_zero == True): extra_weight = 100%num_search if(counter == num_search): cal_weight = math.trunc(100/num_search)+extra_weight else: cal_weight = math.trunc(100/num_search) st.session_state[key] = cal_weight weightage[original_key] = cal_weight st.session_state.weights_.append(cal_weight) else: weightage[original_key] = st.session_state[key] st.session_state.weights_.append(st.session_state[key]) else: weightage[original_key] = 0.0 st.session_state[key] = 0.0 inputs['weightage']=weightage st.session_state.input_weightage = weightage st.session_state.inputs_ = inputs question_with_id = { 'question': inputs["text"], 'id': len(st.session_state.questions) } st.session_state.questions = [] st.session_state.questions.append(question_with_id) st.session_state.answers = [] if(st.session_state.input_is_sql_query == 'enabled'): os_index_df_sql.sql_process(st.session_state.input_text) print(st.session_state.input_sql_query) else: st.session_state.input_sql_query = "" if(st.session_state.input_is_rewrite_query == 'enabled' or (st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType)): query_rewrite.get_new_query_res(st.session_state.input_text) else: st.session_state.input_rewritten_query = "" ans__ = all_search_execute.handler(inputs, st.session_state['session_id']) st.session_state.answers.append({ 'answer': ans__, 'search_type':inputs['searchType'], 'id': len(st.session_state.questions) }) st.session_state.answers_none_rank = st.session_state.answers if(st.session_state.input_evaluate == "enabled"): llm_eval.eval(st.session_state.questions, st.session_state.answers) def write_top_bar(): col1, col2,col3,col4 = st.columns([2.5,35,8,7]) with col1: st.image(TEXT_ICON, use_column_width='always') with col2: #st.markdown("") input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_text",placeholder = "Type your query") with col3: play = st.button("Search",on_click=handle_input,key = "play") with col4: clear = st.button("Clear") col5, col6 = st.columns([4.5,95]) with col5: st.image(IMAGE_ICON, use_column_width='always') with col6: with st.expander(':green[Search by using an image]'): tab2, tab1 = st.tabs(["Upload Image","Generate Image by AI"]) with tab1: c1,c2 = st.columns([80,20]) with c1: gen_images=st.text_area("Text2Image:",placeholder = "Enter the text prompt to generate images",height = 68, key = "image_prompt") with c2: st.markdown("
",unsafe_allow_html=True) st.button("Generate",disabled=False,key = "generate",on_click = generate_images, args=(tab1,"default_img")) image_select = st.radio("Choose one image", ["Image 1","Image 2","Image 3"],index=None, horizontal = True,key = 'image_select',disabled = st.session_state.radio_disabled) st.markdown(""" """,unsafe_allow_html=True) if(st.session_state.image_select is not None and st.session_state.image_select !="" and len(st.session_state.img_gen)!=0): st.session_state.input_rad_1 = st.session_state.image_select.split(" ")[1] else: st.session_state.input_rad_1 = "" generate_images(tab1,gen_images) with tab2: st.session_state.img_doc = st.file_uploader( "Upload image", accept_multiple_files=False,type = ['png', 'jpg']) return clear,tab1 clear,tab_ = write_top_bar() if clear: st.session_state.questions = [] st.session_state.answers = [] st.session_state.clear_ = True st.session_state.image_prompt2 = "" st.session_state.input_rekog_label = "" st.session_state.radio_disabled = True if(len(st.session_state.img_gen)!=0): st.session_state.img_container.empty() st.session_state.img_gen = [] st.session_state.input_rad_1 = "" col1, col3, col4 = st.columns([70,18,12]) with col1: if(st.session_state.max_selections == "" or st.session_state.max_selections == "1"): st.session_state.max_selections = 1 if(st.session_state.max_selections == "None"): st.session_state.max_selections = None search_type = st.multiselect('Select the Search type(s)', search_types,['Keyword Search'], max_selections = st.session_state.max_selections, key = 'input_searchType', help = "Select the type of Search, adding more than one search type will activate hybrid search"#\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers)" ) with col3: st.number_input("No. of docs", min_value=1, max_value=50, value=5, step=5, key='input_K', help=None) with col4: st.markdown("'+json.dumps(st.session_state.input_rekog_directoutput)+'
',unsafe_allow_html=True) else: st.markdown("
No results found, please try again with different query
", unsafe_allow_html = True) else: for ans in answer: if('b5/b5319e00' in ans['image_url'] ): filter_out+=1 continue format_ = ans['image_url'].split(".")[-1] Image.MAX_IMAGE_PIXELS = 100000000 width = 500 height = 500 with col_1: inner_col_1,inner_col_2 = st.columns([8,92]) with inner_col_2: st.image(ans['image_url'].replace("/home/ec2-user/SageMaker/","/home/user/app/")) if('max_score_dict_list_sorted' in ans and 'Vector Search' in st.session_state.input_searchType): desc___ = ans['desc'].split(" ") res___ = [] for o in ans['max_score_dict_list_sorted']: res___.append(o['doc_token']) final_desc_ = "" for word_ in desc___: str_=re.sub('[^A-Za-z0-9]+', '', word_).lower() stemmed_word = next(iter(set(stem_(str_)))) if(stemmed_word in res___ or str_ in res___): if(stemmed_word in res___): mod_word = stemmed_word else: mod_word = str_ if(res___.index(mod_word)==0): final_desc_ += ""+word_+" " elif(res___.index(mod_word)==1): final_desc_ += ""+word_+" " else: final_desc_ += ""+word_+" " else: final_desc_ += word_ + " " final_desc_ += "
" for word in desc__: if(re.sub('[^A-Za-z0-9]+', '', word) in res__): final_desc += ""+word+" " else: final_desc += word + " " final_desc += "
" st.markdown(final_desc,unsafe_allow_html = True) else: st.write(ans['desc']) if("sparse" in ans): with st.expander("Expanded document:"): sparse_ = dict(sorted(ans['sparse'].items(), key=lambda item: item[1],reverse=True)) filtered_sparse = dict() for key in sparse_: if(sparse_[key]>=1.0): filtered_sparse[key] = round(sparse_[key], 2) st.write(filtered_sparse) with st.expander("Document Metadata:",expanded = False): st.write(":green[default:]") st.json({"category:":ans['category'],"price":str(ans['price']),"gender_affinity":ans['gender_affinity'],"style":ans['style']},expanded = True) if("rekog" in ans): st.write(":green[enriched:]") st.json(ans['rekog'],expanded = True) with inner_col_1: if(st.session_state.input_evaluate == "enabled"): with st.container(border = False): if("relevant" in ans.keys()): if(ans['relevant']==True): st.write(":white_check_mark:") else: st.write(":x:") i = i+1 with col_3: if(index == len(st.session_state.questions)): rdn_key = ''.join([random.choice(string.ascii_letters) for _ in range(10)]) currentValue = "".join(st.session_state.input_searchType)+st.session_state.input_imageUpload+json.dumps(st.session_state.input_weightage)+st.session_state.input_NormType+st.session_state.input_CombineType+str(st.session_state.input_K)+st.session_state.input_sparse+st.session_state.input_reranker+st.session_state.input_is_rewrite_query+st.session_state.input_evaluate+st.session_state.input_image+st.session_state.input_rad_1+st.session_state.input_reranker+st.session_state.input_hybridType+st.session_state.input_manual_filter oldValue = "".join(st.session_state.inputs_["searchType"])+st.session_state.inputs_["imageUpload"]+str(st.session_state.inputs_["weightage"])+st.session_state.inputs_["NormType"]+st.session_state.inputs_["CombineType"]+str(st.session_state.inputs_["K"])+st.session_state.inputs_["sparse"]+st.session_state.inputs_["reranker"]+st.session_state.inputs_["is_rewrite_query"]+st.session_state.inputs_["evaluate"]+st.session_state.inputs_["image"]+st.session_state.inputs_["rad_1"]+st.session_state.inputs_["reranker"]+st.session_state.inputs_["hybridType"]+st.session_state.inputs_["manual_filter"] def on_button_click(): if(currentValue!=oldValue): st.session_state.input_text = st.session_state.questions[-1]["question"] st.session_state.answers.pop() st.session_state.questions.pop() handle_input() with placeholder.container(): render_all() if("currentValue" in st.session_state): del st.session_state["currentValue"] try: del regenerate except: pass placeholder__ = st.empty() placeholder__.button("🔄",key=rdn_key,on_click=on_button_click, help = "This will regenerate the responses with new settings that you entered, Note: To see difference in responses, you should change any of the applicable settings")#,type="primary",use_column_width=True) if(filter_out > 0): placeholder_no_results.text(str(filter_out)+" result(s) removed due to missing or in-appropriate content") #Each answer will have context of the question asked in order to associate the provided feedback with the respective question def write_chat_message(md, q,index): if('body' in md['answer']): res = json.loads(md['answer']['body']) else: res = md['answer'] st.session_state['session_id'] = "1234" chat = st.container() with chat: render_answer(res,index) def render_all(): index = 0 for (q, a) in zip(st.session_state.questions, st.session_state.answers): index = index +1 ans_ = st.session_state.answers[0] write_user_message(q,ans_) write_chat_message(a, q,index) placeholder = st.empty() with placeholder.container(): render_all() st.markdown("")