prasadnu commited on
Commit
eb03410
·
1 Parent(s): 7862398

rerank model

Browse files
RAG/bedrock_agent.py CHANGED
@@ -23,8 +23,6 @@ if "inputs_" not in st.session_state:
23
 
24
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
25
  region = 'us-east-1'
26
- print(region)
27
- account_id = '445083327804'
28
  # setting logger
29
  logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
30
  logger = logging.getLogger(__name__)
@@ -46,9 +44,6 @@ def delete_memory():
46
  )
47
 
48
  def query_(inputs):
49
- ## create a random id for session initiator id
50
-
51
-
52
  # invoke the agent API
53
  agentResponse = bedrock_agent_runtime_client.invoke_agent(
54
  inputText=inputs['shopping_query'],
@@ -71,13 +66,6 @@ def query_(inputs):
71
  for event in event_stream:
72
  print("***event*********")
73
  print(event)
74
- # if 'chunk' in event:
75
- # data = event['chunk']['bytes']
76
- # print("***chunk*********")
77
- # print(data)
78
- # logger.info(f"Final answer ->\n{data.decode('utf8')}")
79
- # agent_answer_ = data.decode('utf8')
80
- # print(agent_answer_)
81
  if 'trace' in event:
82
  print("trace*****total*********")
83
  print(event['trace'])
@@ -109,38 +97,7 @@ def query_(inputs):
109
  print(total_context)
110
  except botocore.exceptions.EventStreamError as error:
111
  raise error
112
- # t.sleep(2)
113
- # query_(st.session_state.inputs_)
114
-
115
- # if 'chunk' in event:
116
- # data = event['chunk']['bytes']
117
- # final_ans = data.decode('utf8')
118
- # print(f"Final answer ->\n{final_ans}")
119
- # logger.info(f"Final answer ->\n{final_ans}")
120
- # agent_answer = final_ans
121
- # end_event_received = True
122
- # # End event indicates that the request finished successfully
123
- # elif 'trace' in event:
124
- # logger.info(json.dumps(event['trace'], indent=2))
125
- # else:
126
- # raise Exception("unexpected event.", event)
127
- # except Exception as e:
128
- # raise Exception("unexpected event.", e)
129
  return {'text':agent_answer,'source':total_context,'last_tool':{'name':last_tool_name,'response':last_tool}}
130
 
131
- ####### Re-Rank ########
132
-
133
- #print("re-rank")
134
-
135
- # if(st.session_state.input_is_rerank == True and len(total_context)):
136
- # ques = [{"question":question}]
137
- # ans = [{"answer":total_context}]
138
-
139
- # total_context = re_ranker.re_rank('rag','Cross Encoder',"",ques, ans)
140
 
141
- # llm_prompt = prompt_template.format(context=total_context[0],question=question)
142
- # output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
143
- # #print(output)
144
- # if(len(images_2)==0):
145
- # images_2 = images
146
- # return {'text':output,'source':total_context,'image':images_2,'table':df}
 
23
 
24
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
25
  region = 'us-east-1'
 
 
26
  # setting logger
27
  logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
28
  logger = logging.getLogger(__name__)
 
44
  )
45
 
46
  def query_(inputs):
 
 
 
47
  # invoke the agent API
48
  agentResponse = bedrock_agent_runtime_client.invoke_agent(
49
  inputText=inputs['shopping_query'],
 
66
  for event in event_stream:
67
  print("***event*********")
68
  print(event)
 
 
 
 
 
 
 
69
  if 'trace' in event:
70
  print("trace*****total*********")
71
  print(event['trace'])
 
97
  print(total_context)
98
  except botocore.exceptions.EventStreamError as error:
99
  raise error
100
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  return {'text':agent_answer,'source':total_context,'last_tool':{'name':last_tool_name,'response':last_tool}}
102
 
 
 
 
 
 
 
 
 
 
103
 
 
 
 
 
 
 
RAG/rag_DocumentSearcher.py CHANGED
@@ -49,7 +49,6 @@ def query_(awsauth,inputs, session_id,search_types):
49
  images = []
50
 
51
  for hit in hits:
52
- #context.append(hit['_source']['caption'])
53
  images.append({'file':hit['_source']['image'],'caption':hit['_source']['processed_element']})
54
 
55
  ####### SEARCH ########
@@ -102,10 +101,6 @@ def query_(awsauth,inputs, session_id,search_types):
102
  }
103
  ]
104
 
105
-
106
-
107
-
108
-
109
  SIZE = 5
110
 
111
  hybrid_payload = {
@@ -159,7 +154,6 @@ def query_(awsauth,inputs, session_id,search_types):
159
 
160
  if('Sparse Search' in search_types):
161
 
162
- #print("text expansion is enabled")
163
  sparse_payload = { "neural_sparse": {
164
  "processed_element_embedding_sparse": {
165
  "query_text": question,
@@ -301,7 +295,6 @@ def query_(awsauth,inputs, session_id,search_types):
301
  images_2.append({'file':hit["_source"]["image"],'caption':hit["_source"]["processed_element"]})
302
 
303
  idx = idx +1
304
- #images.append(hit['_source']['image'])
305
 
306
  # if(is_table_in_result == False):
307
  # df = lazy_get_table()
@@ -315,19 +308,9 @@ def query_(awsauth,inputs, session_id,search_types):
315
 
316
  total_context = context_tables + context
317
 
318
- ####### Re-Rank ########
319
-
320
- #print("re-rank")
321
-
322
- # if(st.session_state.input_is_rerank == True and len(total_context)):
323
- # ques = [{"question":question}]
324
- # ans = [{"answer":total_context}]
325
-
326
- # total_context = re_ranker.re_rank('rag','Cross Encoder',"",ques, ans)
327
 
328
  llm_prompt = prompt_template.format(context=total_context[0],question=question)
329
  output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
330
- #print(output)
331
  if(len(images_2)==0):
332
  images_2 = images
333
  return {'text':output,'source':total_context,'image':images_2,'table':df}
 
49
  images = []
50
 
51
  for hit in hits:
 
52
  images.append({'file':hit['_source']['image'],'caption':hit['_source']['processed_element']})
53
 
54
  ####### SEARCH ########
 
101
  }
102
  ]
103
 
 
 
 
 
104
  SIZE = 5
105
 
106
  hybrid_payload = {
 
154
 
155
  if('Sparse Search' in search_types):
156
 
 
157
  sparse_payload = { "neural_sparse": {
158
  "processed_element_embedding_sparse": {
159
  "query_text": question,
 
295
  images_2.append({'file':hit["_source"]["image"],'caption':hit["_source"]["processed_element"]})
296
 
297
  idx = idx +1
 
298
 
299
  # if(is_table_in_result == False):
300
  # df = lazy_get_table()
 
308
 
309
  total_context = context_tables + context
310
 
 
 
 
 
 
 
 
 
 
311
 
312
  llm_prompt = prompt_template.format(context=total_context[0],question=question)
313
  output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
 
314
  if(len(images_2)==0):
315
  images_2 = images
316
  return {'text':output,'source':total_context,'image':images_2,'table':df}
app.py CHANGED
@@ -152,28 +152,6 @@ spacer_col = st.columns(1)[0]
152
  with spacer_col:
153
  st.markdown("<div style='height: 120px;'></div>", unsafe_allow_html=True)
154
 
155
- #st.image("/home/ubuntu/images/OS_AI_1.png", use_column_width=True)
156
- # with col_title:
157
- # st.write("")
158
- # st.markdown('<div class="title">OpenSearch AI demos</div>', unsafe_allow_html=True)
159
-
160
- # def demo_link_block(icon, title, target_page):
161
- # st.markdown(f"""
162
- # <a href="/{target_page}" target="_self" style="text-decoration: none;">
163
- # <div class="demo-card">
164
- # <div class="demo-text">
165
- # <span>{icon} {title}</span>
166
- # <span class="demo-arrow">→</span>
167
- # </div>
168
- # </div>
169
- # </a>
170
- # """, unsafe_allow_html=True)
171
-
172
-
173
- # st.write("")
174
- # demo_link_block("🔍", "AI Search", "Semantic_Search")
175
- # demo_link_block("💬","Multimodal Conversational Search", "Multimodal_Conversational_Search")
176
- # demo_link_block("🛍️","Agentic Shopping Assistant", "AI_Shopping_Assistant")
177
 
178
 
179
  col1, col2, col3 = st.columns(3)
@@ -225,5 +203,3 @@ st.markdown("""
225
  </style>
226
  """, unsafe_allow_html=True)
227
 
228
- # <div class="card-arrow"></div>
229
-
 
152
  with spacer_col:
153
  st.markdown("<div style='height: 120px;'></div>", unsafe_allow_html=True)
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
 
157
  col1, col2, col3 = st.columns(3)
 
203
  </style>
204
  """, unsafe_allow_html=True)
205
 
 
 
pages/AI_Shopping_Assistant.py CHANGED
@@ -33,12 +33,7 @@ import bedrock_agent
33
  import warnings
34
 
35
  warnings.filterwarnings("ignore", category=DeprecationWarning)
36
-
37
-
38
-
39
-
40
  st.set_page_config(
41
- #page_title="Semantic Search using OpenSearch",
42
  layout="wide",
43
  page_icon="images/opensearch_mark_default.png"
44
  )
@@ -47,15 +42,14 @@ USER_ICON = "images/user.png"
47
  AI_ICON = "images/opensearch-twitter-card.png"
48
  REGENERATE_ICON = "images/regenerate.png"
49
  s3_bucket_ = "pdf-repo-uploads"
50
- #"pdf-repo-uploads"
51
  polly_client = boto3.Session(
52
  region_name='us-east-1').client('polly')
53
 
54
  # Check if the user ID is already stored in the session state
55
  if 'user_id' in st.session_state:
56
  user_id = st.session_state['user_id']
57
- #print(f"User ID: {user_id}")
58
-
59
  # If the user ID is not yet stored in the session state, generate a random UUID
60
  else:
61
  user_id = str(uuid.uuid4())
@@ -79,9 +73,6 @@ if "questions__" not in st.session_state:
79
 
80
  if "answers__" not in st.session_state:
81
  st.session_state.answers__ = []
82
-
83
- if "input_index" not in st.session_state:
84
- st.session_state.input_index = "hpijan2024hometrack"#"globalwarmingnew"#"hpijan2024hometrack_no_img_no_table"
85
 
86
  if "input_is_rerank" not in st.session_state:
87
  st.session_state.input_is_rerank = True
@@ -92,22 +83,17 @@ if "input_copali_rerank" not in st.session_state:
92
  if "input_table_with_sql" not in st.session_state:
93
  st.session_state.input_table_with_sql = False
94
 
95
-
96
  if "inputs_" not in st.session_state:
97
  st.session_state.inputs_ = {}
98
 
99
  if "input_shopping_query" not in st.session_state:
100
- st.session_state.input_shopping_query="get me shoes suitable for trekking"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
101
 
102
 
103
  if "input_rag_searchType" not in st.session_state:
104
  st.session_state.input_rag_searchType = ["Sparse Search"]
105
 
106
-
107
-
108
-
109
  region = 'us-east-1'
110
- #bedrock_runtime_client = boto3.client('bedrock-runtime',region_name=region)
111
  output = []
112
  service = 'es'
113
 
@@ -122,48 +108,6 @@ st.markdown("""
122
  </style>
123
  """,unsafe_allow_html=True)
124
 
125
- ################ OpenSearch Py client #####################
126
-
127
- # credentials = boto3.Session().get_credentials()
128
- # awsauth = AWSV4SignerAuth(credentials, region, service)
129
-
130
- # ospy_client = OpenSearch(
131
- # hosts = [{'host': 'search-opensearchservi-75ucark0bqob-bzk6r6h2t33dlnpgx2pdeg22gi.us-east-1.es.amazonaws.com', 'port': 443}],
132
- # http_auth = awsauth,
133
- # use_ssl = True,
134
- # verify_certs = True,
135
- # connection_class = RequestsHttpConnection,
136
- # pool_maxsize = 20
137
- # )
138
-
139
- ################# using boto3 credentials ###################
140
-
141
-
142
- # credentials = boto3.Session().get_credentials()
143
- # awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, service, session_token=credentials.token)
144
- # service = 'es'
145
-
146
-
147
- ################# using boto3 credentials ####################
148
-
149
-
150
-
151
- # if "input_searchType" not in st.session_state:
152
- # st.session_state.input_searchType = "Conversational Search (RAG)"
153
-
154
- # if "input_temperature" not in st.session_state:
155
- # st.session_state.input_temperature = "0.001"
156
-
157
- # if "input_topK" not in st.session_state:
158
- # st.session_state.input_topK = 200
159
-
160
- # if "input_topP" not in st.session_state:
161
- # st.session_state.input_topP = 0.95
162
-
163
- # if "input_maxTokens" not in st.session_state:
164
- # st.session_state.input_maxTokens = 1024
165
-
166
-
167
  def write_logo():
168
  col1, col2, col3 = st.columns([5, 1, 5])
169
  with col2:
@@ -175,8 +119,6 @@ def write_top_bar():
175
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
176
  st.header("AI Shopping assistant",divider='rainbow')
177
 
178
- #st.image(AI_ICON, use_column_width='always')
179
-
180
  with col2:
181
  st.write("")
182
  st.write("")
@@ -193,17 +135,10 @@ if clear:
193
  st.session_state.input_shopping_query=""
194
  st.session_state.session_id_ = str(uuid.uuid1())
195
  bedrock_agent.delete_memory()
196
- # st.session_state.input_searchType="Conversational Search (RAG)"
197
- # st.session_state.input_temperature = "0.001"
198
- # st.session_state.input_topK = 200
199
- # st.session_state.input_topP = 0.95
200
- # st.session_state.input_maxTokens = 1024
201
 
202
 
203
  def handle_input():
204
- print("Question: "+st.session_state.input_shopping_query)
205
- print("-----------")
206
- print("\n\n")
207
  if(st.session_state.input_shopping_query==''):
208
  return ""
209
  inputs = {}
@@ -212,10 +147,6 @@ def handle_input():
212
  inputs[key.removeprefix('input_')] = st.session_state[key]
213
  st.session_state.inputs_ = inputs
214
 
215
- #######
216
-
217
-
218
- #st.write(inputs)
219
  question_with_id = {
220
  'question': inputs["shopping_query"],
221
  'id': len(st.session_state.questions__)
@@ -234,30 +165,6 @@ def handle_input():
234
  st.session_state.input_shopping_query=""
235
 
236
 
237
-
238
- # search_type = st.selectbox('Select the Search type',
239
- # ('Conversational Search (RAG)',
240
- # 'OpenSearch vector search',
241
- # 'LLM Text Generation'
242
- # ),
243
-
244
- # key = 'input_searchType',
245
- # help = "Select the type of retriever\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers_)"
246
- # )
247
-
248
- # col1, col2, col3, col4 = st.columns(4)
249
-
250
- # with col1:
251
- # st.text_input('Temperature', value = "0.001", placeholder='LLM Temperature', key = 'input_temperature',help = "Set the temperature of the Large Language model. \n Note: 1. Set this to values lower to 1 in the order of 0.001, 0.0001, such low values reduces hallucination and creativity in the LLM response; 2. This applies only when LLM is a part of the retriever pipeline")
252
- # with col2:
253
- # st.number_input('Top K', value = 200, placeholder='Top K', key = 'input_topK', step = 50, help = "This limits the LLM's predictions to the top k most probable tokens at each step of generation, this applies only when LLM is a prt of the retriever pipeline")
254
- # with col3:
255
- # st.number_input('Top P', value = 0.95, placeholder='Top P', key = 'input_topP', step = 0.05, help = "This sets a threshold probability and selects the top tokens whose cumulative probability exceeds the threshold while the tokens are generated by the LLM")
256
- # with col4:
257
- # st.number_input('Max Output Tokens', value = 500, placeholder='Max Output Tokens', key = 'input_maxTokens', step = 100, help = "This decides the total number of tokens generated as the final response. Note: Values greater than 1000 takes longer response time")
258
-
259
- # st.markdown('---')
260
-
261
 
262
  def write_user_message(md):
263
  col1, col2 = st.columns([3,97])
@@ -265,8 +172,6 @@ def write_user_message(md):
265
  with col1:
266
  st.image(USER_ICON, use_column_width='always')
267
  with col2:
268
- #st.warning(md['question'])
269
-
270
  st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
271
 
272
 
@@ -283,18 +188,9 @@ def render_answer(question,answer,index):
283
  ans_ = answer['answer']
284
  span_ans = ans_.replace('<question>',"<span style='fontSize:18px;color:#f37709;fontStyle:italic;'>").replace("</question>","</span>")
285
  st.markdown("<p>"+span_ans+"</p>",unsafe_allow_html = True)
286
- print("answer['source']")
287
- print("-------------")
288
- print(answer['source'])
289
- print("-------------")
290
- print(answer['last_tool'])
291
  if(answer['last_tool']['name'] in ["generate_images","get_relevant_items_for_image","get_relevant_items_for_text","retrieve_with_hybrid_search","retrieve_with_keyword_search","get_any_general_recommendation"]):
292
  use_interim_results = True
293
  src_dict =json.loads(answer['last_tool']['response'].replace("'",'"'))
294
- print("src_dict")
295
- print("-------------")
296
- print(src_dict)
297
- #if("get_relevant_items_for_text" in src_dict):
298
  if(use_interim_results and answer['last_tool']['name']!= 'generate_images' and answer['last_tool']['name']!= 'get_any_general_recommendation'):
299
  key_ = answer['last_tool']['name']
300
 
@@ -310,9 +206,7 @@ def render_answer(question,answer,index):
310
  if(index ==1):
311
  with img_col2:
312
  st.image(resizedImg,use_column_width = True,caption = item['title'])
313
- #st.image(parent_dirname+"/retrieved_esci_images/"+item['id']+"_resized.jpg",caption = item['title'],use_column_width = True)
314
-
315
-
316
  if(answer['last_tool']['name'] == "generate_images" or answer['last_tool']['name'] == "get_any_general_recommendation"):
317
  st.write("<br>",unsafe_allow_html = True)
318
  gen_img_col1, gen_img_col2,gen_img_col2 = st.columns([30,30,30])
@@ -328,143 +222,17 @@ def render_answer(question,answer,index):
328
  with gen_img_col1:
329
  st.image(resizedImg,caption = "Generated image for "+key.split(".")[0],use_column_width = True)
330
  st.write("<br>",unsafe_allow_html = True)
331
-
332
-
333
-
334
-
335
-
336
-
337
-
338
- # def stream_():
339
- # #use for streaming response on the client side
340
- # for word in ans_.split(" "):
341
- # yield word + " "
342
- # time.sleep(0.04)
343
- # #use for streaming response from Llm directly
344
- # if(isinstance(ans_,botocore.eventstream.EventStream)):
345
- # for event in ans_:
346
- # chunk = event.get('chunk')
347
-
348
- # if chunk:
349
-
350
- # chunk_obj = json.loads(chunk.get('bytes').decode())
351
-
352
- # if('content_block' in chunk_obj or ('delta' in chunk_obj and 'text' in chunk_obj['delta'])):
353
- # key_ = list(chunk_obj.keys())[2]
354
- # text = chunk_obj[key_]['text']
355
-
356
- # clear_output(wait=True)
357
- # output.append(text)
358
- # yield text
359
- # time.sleep(0.04)
360
-
361
-
362
-
363
- # if(index == len(st.session_state.questions_)):
364
- # st.write_stream(stream_)
365
- # if(isinstance(st.session_state.answers_[index-1]['answer'],botocore.eventstream.EventStream)):
366
- # st.session_state.answers_[index-1]['answer'] = "".join(output)
367
- # else:
368
- # st.write(ans_)
369
-
370
-
371
- # polly_response = polly_client.synthesize_speech(VoiceId='Joanna',
372
- # OutputFormat='ogg_vorbis',
373
- # Text = ans_,
374
- # Engine = 'neural')
375
-
376
- # audio_col1, audio_col2 = st.columns([50,50])
377
- # with audio_col1:
378
- # st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
379
-
380
-
381
-
382
- #st.markdown("<div style='font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;border-radius: 10px;'>"+ans_+"</div>", unsafe_allow_html = True)
383
- #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Relevant images from the document :</b></div>", unsafe_allow_html = True)
384
- #st.write("")
385
  colu1,colu2,colu3 = st.columns([4,82,20])
386
  if(answer['source']!={}):
387
  with colu2:
388
  with st.expander("Agent Traces:"):
389
  st.write(answer['source'])
390
- # with st.container():
391
- # if(len(res_img)>0):
392
- # with st.expander("Images:"):
393
- # col3,col4,col5 = st.columns([33,33,33])
394
- # cols = [col3,col4]
395
- # idx = 0
396
- # #print(res_img)
397
- # for img_ in res_img:
398
- # if(img_['file'].lower()!='none' and idx < 2):
399
- # img = img_['file'].split(".")[0]
400
- # caption = img_['caption']
401
-
402
- # with cols[idx]:
403
-
404
- # st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
405
- # #st.write(caption)
406
- # idx = idx+1
407
- # #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Sources from the document:</b></div>", unsafe_allow_html = True)
408
- # if(len(answer["table"] )>0):
409
- # with st.expander("Table:"):
410
- # df = pd.read_csv(answer["table"][0]['name'],skipinitialspace = True, on_bad_lines='skip',delimiter='`')
411
- # df.fillna(method='pad', inplace=True)
412
- # st.table(df)
413
- # with st.expander("Raw sources:"):
414
- # st.write(answer["source"])
415
-
416
-
417
-
418
- # with col_3:
419
-
420
- # #st.markdown("<div style='color:#e28743;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 5px;'><b>"+",".join(st.session_state.input_rag_searchType)+"</b></div>", unsafe_allow_html = True)
421
-
422
-
423
-
424
- # if(index == len(st.session_state.questions_)):
425
-
426
- # rdn_key = ''.join([random.choice(string.ascii_letters)
427
- # for _ in range(10)])
428
- # currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
429
- # oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
430
- # #print("changing values-----------------")
431
- # def on_button_click():
432
- # # print("button clicked---------------")
433
- # # print(currentValue)
434
- # # print(oldValue)
435
- # if(currentValue!=oldValue or 1==1):
436
- # #print("----------regenerate----------------")
437
- # st.session_state.input_query = st.session_state.questions_[-1]["question"]
438
- # st.session_state.answers_.pop()
439
- # st.session_state.questions_.pop()
440
-
441
- # handle_input()
442
- # with placeholder.container():
443
- # render_all()
444
-
445
- # if("currentValue" in st.session_state):
446
- # del st.session_state["currentValue"]
447
-
448
- # try:
449
- # del regenerate
450
- # except:
451
- # pass
452
-
453
- # #print("------------------------")
454
- # #print(st.session_state)
455
-
456
- # placeholder__ = st.empty()
457
-
458
- # placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
459
 
460
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
461
  def write_chat_message(md, q,index):
462
- #res_img = md['image']
463
- #st.session_state['session_id'] = res['session_id'] to be added in memory
464
  chat = st.container()
465
  with chat:
466
- #print("st.session_state.input_index------------------")
467
- #print(st.session_state.input_index)
468
  render_answer(q,md,index)
469
 
470
  def render_all():
@@ -480,173 +248,8 @@ with placeholder.container():
480
 
481
  st.markdown("")
482
  col_2, col_3 = st.columns([75,20])
483
- #col_1, col_2, col_3 = st.columns([7.5,71.5,22])
484
- # with col_1:
485
- # st.markdown("<p style='padding:0px 0px 0px 0px; color:#FF9900;font-size:120%'><b>Ask:</b></p>",unsafe_allow_html=True, help = 'Enter the questions and click on "GO"')
486
-
487
  with col_2:
488
- #st.markdown("")
489
  input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_shopping_query")
490
  with col_3:
491
- #hidden = st.button("RUN",disabled=True,key = "hidden")
492
- # audio_value = st.audio_input("Record a voice message")
493
- # print(audio_value)
494
  play = st.button("Go",on_click=handle_input,key = "play")
495
- #with st.sidebar:
496
- # st.page_link("/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp/app.py", label=":orange[Home]", icon="🏠")
497
- # st.subheader(":blue[Sample Data]")
498
- # coln_1,coln_2 = st.columns([70,30])
499
- # # index_select = st.radio("Choose one index",["UK Housing","Covid19 impacts on Ireland","Environmental Global Warming","BEIR Research"],
500
- # # captions = ['[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)',
501
- # # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)',
502
- # # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)',
503
- # # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/BEIR.pdf)'],
504
- # # key="input_rad_index")
505
- # with coln_1:
506
- # index_select = st.radio("Choose one index",["UK Housing","Global Warming stats","Covid19 impacts on Ireland"],key="input_rad_index")
507
- # with coln_2:
508
- # st.markdown("<p style='font-size:15px'>Preview file</p>",unsafe_allow_html=True)
509
- # st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)")
510
- # st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)")
511
- # st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)")
512
- # #st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/BEIR.pdf)")
513
- # st.markdown("""
514
- # <style>
515
- # [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
516
- # gap: 0rem;
517
- # }
518
- # [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
519
- # gap: 0rem;
520
- # }
521
- # </style>
522
- # """,unsafe_allow_html=True)
523
- # # Initialize boto3 to use the S3 client.
524
- # s3_client = boto3.resource('s3')
525
- # bucket=s3_client.Bucket(s3_bucket_)
526
-
527
- # objects = bucket.objects.filter(Prefix="sample_pdfs/")
528
- # urls = []
529
-
530
- # client = boto3.client('s3')
531
-
532
- # for obj in objects:
533
- # if obj.key.endswith('.pdf'):
534
-
535
- # # Generate the S3 presigned URL
536
- # s3_presigned_url = client.generate_presigned_url(
537
- # ClientMethod='get_object',
538
- # Params={
539
- # 'Bucket': s3_bucket_,
540
- # 'Key': obj.key
541
- # },
542
- # ExpiresIn=3600
543
- # )
544
-
545
- # # Print the created S3 presigned URL
546
- # print(s3_presigned_url)
547
- # urls.append(s3_presigned_url)
548
- # #st.write("["+obj.key.split('/')[1]+"]("+s3_presigned_url+")")
549
- # st.link_button(obj.key.split('/')[1], s3_presigned_url)
550
-
551
-
552
- # st.subheader(":blue[Your multi-modal documents]")
553
- # pdf_doc_ = st.file_uploader(
554
- # "Upload your PDFs here and click on 'Process'", accept_multiple_files=False)
555
-
556
-
557
- # pdf_docs = [pdf_doc_]
558
- # if st.button("Process"):
559
- # with st.spinner("Processing"):
560
- # if os.path.isdir(parent_dirname+"/pdfs") == False:
561
- # os.mkdir(parent_dirname+"/pdfs")
562
-
563
- # for pdf_doc in pdf_docs:
564
- # print(type(pdf_doc))
565
- # pdf_doc_name = (pdf_doc.name).replace(" ","_")
566
- # with open(os.path.join(parent_dirname+"/pdfs",pdf_doc_name),"wb") as f:
567
- # f.write(pdf_doc.getbuffer())
568
-
569
- # request_ = { "bucket": s3_bucket_,"key": pdf_doc_name}
570
- # # if(st.session_state.input_copali_rerank):
571
- # # copali.process_doc(request_)
572
- # # else:
573
- # rag_DocumentLoader.load_docs(request_)
574
- # print('lambda done')
575
- # st.success('you can start searching on your PDF')
576
-
577
- # ############## haystach demo temporary addition ############
578
- # # st.subheader(":blue[Multimodality]")
579
- # # colu1,colu2 = st.columns([50,50])
580
- # # with colu1:
581
- # # in_images = st.toggle('Images', key = 'in_images', disabled = False)
582
- # # with colu2:
583
- # # in_tables = st.toggle('Tables', key = 'in_tables', disabled = False)
584
- # # if(in_tables):
585
- # # st.session_state.input_table_with_sql = True
586
- # # else:
587
- # # st.session_state.input_table_with_sql = False
588
-
589
- # ############## haystach demo temporary addition ############
590
- # if(pdf_doc_ is None or pdf_doc_ == ""):
591
- # if(index_select == "Global Warming stats"):
592
- # st.session_state.input_index = "globalwarmingnew"
593
- # if(index_select == "Covid19 impacts on Ireland"):
594
- # st.session_state.input_index = "covid19ie"#"choosetheknnalgorithmforyourbillionscaleusecasewithopensearchawsbigdatablog"
595
- # if(index_select == "BEIR"):
596
- # st.session_state.input_index = "2104"
597
- # if(index_select == "UK Housing"):
598
- # st.session_state.input_index = "hpijan2024hometrack"
599
- # # if(in_images == True and in_tables == True):
600
- # # st.session_state.input_index = "hpijan2024hometrack"
601
- # # else:
602
- # # if(in_images == True and in_tables == False):
603
- # # st.session_state.input_index = "hpijan2024hometrackno_table"
604
- # # else:
605
- # # if(in_images == False and in_tables == True):
606
- # # st.session_state.input_index = "hpijan2024hometrackno_images"
607
- # # else:
608
- # # st.session_state.input_index = "hpijan2024hometrack_no_img_no_table"
609
-
610
-
611
- # # if(in_images):
612
- # # st.session_state.input_include_images = True
613
- # # else:
614
- # # st.session_state.input_include_images = False
615
- # # if(in_tables):
616
- # # st.session_state.input_include_tables = True
617
- # # else:
618
- # # st.session_state.input_include_tables = False
619
-
620
- # custom_index = st.text_input("If uploaded the file already, enter the original file name", value = "")
621
- # if(custom_index!=""):
622
- # st.session_state.input_index = re.sub('[^A-Za-z0-9]+', '', (custom_index.lower().replace(".pdf","").split("/")[-1].split(".")[0]).lower())
623
-
624
-
625
-
626
- # st.subheader(":blue[Retriever]")
627
- # search_type = st.multiselect('Select the Retriever(s)',
628
- # ['Keyword Search',
629
- # 'Vector Search',
630
- # 'Sparse Search',
631
- # ],
632
- # ['Sparse Search'],
633
-
634
- # key = 'input_rag_searchType',
635
- # help = "Select the type of Search, adding more than one search type will activate hybrid search"#\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers)"
636
- # )
637
-
638
- # re_rank = st.checkbox('Re-rank results', key = 'input_re_rank', disabled = False, value = True, help = "Checking this box will re-rank the results using a cross-encoder model")
639
-
640
- # if(re_rank):
641
- # st.session_state.input_is_rerank = True
642
- # else:
643
- # st.session_state.input_is_rerank = False
644
-
645
- # # copali_rerank = st.checkbox("Search and Re-rank with Token level vectors",key = 'copali_rerank',help = "Enabling this option uses 'Copali' model's page level image embeddings to retrieve documents and MaxSim to re-rank the pages.\n\n Hugging Face Model: https://huggingface.co/vidore/colpali")
646
-
647
- # # if(copali_rerank):
648
- # # st.session_state.input_copali_rerank = True
649
- # # else:
650
- # # st.session_state.input_copali_rerank = False
651
-
652
-
 
33
  import warnings
34
 
35
  warnings.filterwarnings("ignore", category=DeprecationWarning)
 
 
 
 
36
  st.set_page_config(
 
37
  layout="wide",
38
  page_icon="images/opensearch_mark_default.png"
39
  )
 
42
  AI_ICON = "images/opensearch-twitter-card.png"
43
  REGENERATE_ICON = "images/regenerate.png"
44
  s3_bucket_ = "pdf-repo-uploads"
45
+
46
  polly_client = boto3.Session(
47
  region_name='us-east-1').client('polly')
48
 
49
  # Check if the user ID is already stored in the session state
50
  if 'user_id' in st.session_state:
51
  user_id = st.session_state['user_id']
52
+
 
53
  # If the user ID is not yet stored in the session state, generate a random UUID
54
  else:
55
  user_id = str(uuid.uuid4())
 
73
 
74
  if "answers__" not in st.session_state:
75
  st.session_state.answers__ = []
 
 
 
76
 
77
  if "input_is_rerank" not in st.session_state:
78
  st.session_state.input_is_rerank = True
 
83
  if "input_table_with_sql" not in st.session_state:
84
  st.session_state.input_table_with_sql = False
85
 
 
86
  if "inputs_" not in st.session_state:
87
  st.session_state.inputs_ = {}
88
 
89
  if "input_shopping_query" not in st.session_state:
90
+ st.session_state.input_shopping_query="get me shoes suitable for trekking"
91
 
92
 
93
  if "input_rag_searchType" not in st.session_state:
94
  st.session_state.input_rag_searchType = ["Sparse Search"]
95
 
 
 
 
96
  region = 'us-east-1'
 
97
  output = []
98
  service = 'es'
99
 
 
108
  </style>
109
  """,unsafe_allow_html=True)
110
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def write_logo():
112
  col1, col2, col3 = st.columns([5, 1, 5])
113
  with col2:
 
119
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
120
  st.header("AI Shopping assistant",divider='rainbow')
121
 
 
 
122
  with col2:
123
  st.write("")
124
  st.write("")
 
135
  st.session_state.input_shopping_query=""
136
  st.session_state.session_id_ = str(uuid.uuid1())
137
  bedrock_agent.delete_memory()
138
+
 
 
 
 
139
 
140
 
141
  def handle_input():
 
 
 
142
  if(st.session_state.input_shopping_query==''):
143
  return ""
144
  inputs = {}
 
147
  inputs[key.removeprefix('input_')] = st.session_state[key]
148
  st.session_state.inputs_ = inputs
149
 
 
 
 
 
150
  question_with_id = {
151
  'question': inputs["shopping_query"],
152
  'id': len(st.session_state.questions__)
 
165
  st.session_state.input_shopping_query=""
166
 
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  def write_user_message(md):
170
  col1, col2 = st.columns([3,97])
 
172
  with col1:
173
  st.image(USER_ICON, use_column_width='always')
174
  with col2:
 
 
175
  st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
176
 
177
 
 
188
  ans_ = answer['answer']
189
  span_ans = ans_.replace('<question>',"<span style='fontSize:18px;color:#f37709;fontStyle:italic;'>").replace("</question>","</span>")
190
  st.markdown("<p>"+span_ans+"</p>",unsafe_allow_html = True)
 
 
 
 
 
191
  if(answer['last_tool']['name'] in ["generate_images","get_relevant_items_for_image","get_relevant_items_for_text","retrieve_with_hybrid_search","retrieve_with_keyword_search","get_any_general_recommendation"]):
192
  use_interim_results = True
193
  src_dict =json.loads(answer['last_tool']['response'].replace("'",'"'))
 
 
 
 
194
  if(use_interim_results and answer['last_tool']['name']!= 'generate_images' and answer['last_tool']['name']!= 'get_any_general_recommendation'):
195
  key_ = answer['last_tool']['name']
196
 
 
206
  if(index ==1):
207
  with img_col2:
208
  st.image(resizedImg,use_column_width = True,caption = item['title'])
209
+
 
 
210
  if(answer['last_tool']['name'] == "generate_images" or answer['last_tool']['name'] == "get_any_general_recommendation"):
211
  st.write("<br>",unsafe_allow_html = True)
212
  gen_img_col1, gen_img_col2,gen_img_col2 = st.columns([30,30,30])
 
222
  with gen_img_col1:
223
  st.image(resizedImg,caption = "Generated image for "+key.split(".")[0],use_column_width = True)
224
  st.write("<br>",unsafe_allow_html = True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  colu1,colu2,colu3 = st.columns([4,82,20])
226
  if(answer['source']!={}):
227
  with colu2:
228
  with st.expander("Agent Traces:"):
229
  st.write(answer['source'])
230
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
233
  def write_chat_message(md, q,index):
 
 
234
  chat = st.container()
235
  with chat:
 
 
236
  render_answer(q,md,index)
237
 
238
  def render_all():
 
248
 
249
  st.markdown("")
250
  col_2, col_3 = st.columns([75,20])
251
+
 
 
 
252
  with col_2:
 
253
  input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_shopping_query")
254
  with col_3:
 
 
 
255
  play = st.button("Go",on_click=handle_input,key = "play")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/Semantic_Search.py CHANGED
@@ -24,24 +24,18 @@ import base64
24
  import shutil
25
  import re
26
  from requests.auth import HTTPBasicAuth
27
- #import utilities.re_ranker as re_ranker
28
  # from nltk.stem import PorterStemmer
29
  # from nltk.tokenize import word_tokenize
30
  import query_rewrite
31
  import amazon_rekognition
 
32
  #from st_click_detector import click_detector
33
  import llm_eval
34
  import all_search_execute
35
  import warnings
36
 
37
  warnings.filterwarnings("ignore", category=DeprecationWarning)
38
-
39
-
40
-
41
-
42
  st.set_page_config(
43
- #page_title="Semantic Search using OpenSearch",
44
- #layout="wide",
45
  page_icon="images/opensearch_mark_default.png"
46
  )
47
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
@@ -58,11 +52,6 @@ st.markdown("""
58
  #ps = PorterStemmer()
59
 
60
  st.session_state.REGION = 'us-east-1'
61
-
62
-
63
- #from langchain.callbacks.base import BaseCallbackHandler
64
-
65
-
66
  USER_ICON = "images/user.png"
67
  AI_ICON = "images/opensearch-twitter-card.png"
68
  REGENERATE_ICON = "images/regenerate.png"
@@ -170,12 +159,6 @@ if "input_ndcg" not in st.session_state:
170
  if "gen_image_str" not in st.session_state:
171
  st.session_state.gen_image_str=""
172
 
173
- # if "input_searchType" not in st.session_state:
174
- # st.session_state.input_searchType = ['Keyword Search']
175
-
176
- # if "input_must" not in st.session_state:
177
- # st.session_state.input_must = ["Category","Price","Gender","Style"]
178
-
179
  if "input_NormType" not in st.session_state:
180
  st.session_state.input_NormType = "min_max"
181
 
@@ -261,25 +244,8 @@ if(search_all_type==True):
261
  'Multimodal Search',
262
  'NeuralSparse Search',
263
  ]
264
- from streamlit.components.v1 import html
265
- # with st.container():
266
- # html("""
267
- # <script>
268
- # // Locate elements
269
- # var decoration = window.parent.document.querySelectorAll('[data-testid="stDecoration"]')[0];
270
- # decoration.style.height = "3.0rem";
271
- # decoration.style.right = "45px";
272
- # // Adjust text decorations
273
- # decoration.innerText = "Semantic Search with OpenSearch!"; // Replace with your desired text
274
- # decoration.style.fontWeight = "bold";
275
- # decoration.style.display = "flex";
276
- # decoration.style.justifyContent = "center";
277
- # decoration.style.alignItems = "center";
278
- # decoration.style.fontWeight = "bold";
279
- # decoration.style.backgroundImage = url('/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp/images/service_logo.png'); // Remove background image
280
- # decoration.style.backgroundSize = "unset"; // Remove background size
281
- # </script>
282
- # """, width=0, height=0)
283
 
284
 
285
 
@@ -448,31 +414,12 @@ def handle_input():
448
 
449
 
450
  inputs = {}
451
- # if(st.session_state.input_imageUpload == 'yes'):
452
- # st.session_state.input_searchType = 'Multi-modal Search'
453
- # if(st.session_state.input_sparse == 'enabled' or st.session_state.input_is_rewrite_query == 'enabled'):
454
- # st.session_state.input_searchType = 'Keyword Search'
455
  if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType):
456
  old_rekog_label = st.session_state.input_rekog_label
457
  st.session_state.input_rekog_label = amazon_rekognition.extract_image_metadata(st.session_state.bytes_for_rekog)
458
  if(st.session_state.input_text == ""):
459
  st.session_state.input_text = st.session_state.input_rekog_label
460
 
461
- # if(st.session_state.input_imageUpload == 'yes'):
462
- # if(st.session_state.input_searchType!='Multi-modal Search'):
463
- # if(st.session_state.input_searchType=='Keyword Search'):
464
- # if(st.session_state.input_rekognition != 'enabled'):
465
- # st.error('For Keyword Search using images, enable "Enrich metadata for Images" in the left panel',icon = "🚨")
466
- # #st.session_state.input_rekognition = 'enabled'
467
- # st.switch_page('pages/1_Semantic_Search.py')
468
- # #st.stop()
469
-
470
- # else:
471
- # st.error('Please set the search type as "Keyword Search (enabling Enrich metadata for Images) or Multi-modal Search"',icon = "🚨")
472
- # #st.session_state.input_searchType='Multi-modal Search'
473
- # st.switch_page('pages/1_Semantic_Search.py')
474
- # #st.stop()
475
-
476
 
477
  weightage = {}
478
  st.session_state.weights_ = []
@@ -511,44 +458,13 @@ def handle_input():
511
  else:
512
  weightage[original_key] = 0.0
513
  st.session_state[key] = 0.0
514
-
515
-
516
-
517
-
518
-
519
-
520
-
521
-
522
-
523
-
524
-
525
-
526
-
527
  inputs['weightage']=weightage
528
  st.session_state.input_weightage = weightage
529
 
530
- print("====================")
531
- print(st.session_state.weights_)
532
- print(st.session_state.input_weightage )
533
- print("====================")
534
- #print("***************************")
535
- #print(sum(weights_))
536
- # if(sum(st.session_state.weights_)!=100):
537
- # st.warning('The total weight of selected search type(s) should be equal to 100',icon = "🚨")
538
- # refresh = st.button("Re-Enter")
539
- # if(refresh):
540
- # st.switch_page('pages/1_Semantic_Search.py')
541
- # st.stop()
542
-
543
-
544
- # #st.session_state.input_rekognition = 'enabled'
545
- # st.rerun()
546
-
547
-
548
 
549
  st.session_state.inputs_ = inputs
550
 
551
- #st.write(inputs)
552
  question_with_id = {
553
  'question': inputs["text"],
554
  'id': len(st.session_state.questions)
@@ -567,19 +483,15 @@ def handle_input():
567
 
568
  if(st.session_state.input_is_rewrite_query == 'enabled' or (st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType)):
569
  query_rewrite.get_new_query_res(st.session_state.input_text)
570
- print("-------------------")
571
- print(st.session_state.input_rewritten_query)
572
- print("-------------------")
573
  else:
574
  st.session_state.input_rewritten_query = ""
575
 
576
- # elif(st.session_state.input_rekog_label!="" and st.session_state.input_rekognition == 'enabled'):
577
- # ans__ = amazon_rekognition.call(st.session_state.input_text,st.session_state.input_rekog_label)
578
- # else:
579
  ans__ = all_search_execute.handler(inputs, st.session_state['session_id'])
580
 
581
  st.session_state.answers.append({
582
- 'answer': ans__,#all_search_api.call(json.dumps(inputs), st.session_state['session_id']),
583
  'search_type':inputs['searchType'],
584
  'id': len(st.session_state.questions)
585
  })
@@ -587,21 +499,8 @@ def handle_input():
587
  st.session_state.answers_none_rank = st.session_state.answers
588
  if(st.session_state.input_evaluate == "enabled"):
589
  llm_eval.eval(st.session_state.questions, st.session_state.answers)
590
- #st.session_state.input_text=""
591
- #st.session_state.input_searchType=st.session_state.input_searchType
592
-
593
  def write_top_bar():
594
- # st.markdown("""
595
- # <style>
596
- # [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
597
- # gap: 0rem;
598
- # }
599
- # </style>
600
- # """,unsafe_allow_html=True)
601
- #print("top bar")
602
- # st.title(':mag: AI powered OpenSearch')
603
- # st.write("")
604
- # st.write("")
605
  col1, col2,col3,col4 = st.columns([2.5,35,8,7])
606
  with col1:
607
  st.image(TEXT_ICON, use_column_width='always')
@@ -630,9 +529,6 @@ def write_top_bar():
630
  st.markdown("<div style = 'height:43px'></div>",unsafe_allow_html=True)
631
  st.button("Generate",disabled=False,key = "generate",on_click = generate_images, args=(tab1,"default_img"))
632
 
633
- # image_select = st.select_slider(
634
- # "Select a image",
635
- # options=["Image 1","Image 2","Image 3"], value = None, disabled = st.session_state.radio_disabled,key = "image_select")
636
  image_select = st.radio("Choose one image", ["Image 1","Image 2","Image 3"],index=None, horizontal = True,key = 'image_select',disabled = st.session_state.radio_disabled)
637
  st.markdown("""
638
  <style>
@@ -642,25 +538,10 @@ def write_top_bar():
642
  </style>
643
  """,unsafe_allow_html=True)
644
  if(st.session_state.image_select is not None and st.session_state.image_select !="" and len(st.session_state.img_gen)!=0):
645
- print("image_select")
646
- print("------------")
647
- print(st.session_state.image_select)
648
  st.session_state.input_rad_1 = st.session_state.image_select.split(" ")[1]
649
  else:
650
  st.session_state.input_rad_1 = ""
651
- # rad1, rad2,rad3 = st.columns([33,33,33])
652
- # with rad1:
653
- # btn1 = st.button("choose image 1", disabled = st.session_state.radio_disabled)
654
- # with rad2:
655
- # btn2 = st.button("choose image 2", disabled = st.session_state.radio_disabled)
656
- # with rad3:
657
- # btn3 = st.button("choose image 3", disabled = st.session_state.radio_disabled)
658
- # if(btn1):
659
- # st.session_state.input_rad_1 = "1"
660
- # if(btn2):
661
- # st.session_state.input_rad_1 = "2"
662
- # if(btn3):
663
- # st.session_state.input_rad_1 = "3"
664
 
665
 
666
  generate_images(tab1,gen_images)
@@ -669,19 +550,11 @@ def write_top_bar():
669
  with tab2:
670
  st.session_state.img_doc = st.file_uploader(
671
  "Upload image", accept_multiple_files=False,type = ['png', 'jpg'])
672
-
673
-
674
-
675
-
676
-
677
  return clear,tab1
678
 
679
  clear,tab_ = write_top_bar()
680
 
681
  if clear:
682
-
683
-
684
- print("clear1")
685
  st.session_state.questions = []
686
  st.session_state.answers = []
687
 
@@ -697,18 +570,7 @@ if clear:
697
  st.session_state.input_rad_1 = ""
698
 
699
 
700
- # placeholder1 = st.empty()
701
- # with placeholder1.container():
702
- # generate_images(tab_,st.session_state.image_prompt)
703
-
704
-
705
- #st.session_state.input_text=""
706
- # st.session_state.input_searchType="Conversational Search (RAG)"
707
- # st.session_state.input_temperature = "0.001"
708
- # st.session_state.input_topK = 200
709
- # st.session_state.input_topP = 0.95
710
- # st.session_state.input_maxTokens = 1024
711
-
712
  col1, col3, col4 = st.columns([70,18,12])
713
 
714
  with col1:
@@ -732,7 +594,7 @@ with col4:
732
  evaluate = st.toggle(' ', key = 'evaluate', disabled = False) #help = "Checking this box will use LLM to evaluate results as relevant and irrelevant. \n\n This option increases the latency")
733
  if(evaluate):
734
  st.session_state.input_evaluate = "enabled"
735
- #llm_eval.eval(st.session_state.questions, st.session_state.answers)
736
  else:
737
  st.session_state.input_evaluate = "disabled"
738
 
@@ -740,11 +602,7 @@ with col4:
740
  if(search_all_type == True or 1==1):
741
  with st.sidebar:
742
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
743
- #st.image('/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp/images/service_logo.png', width = 300)
744
- #st.warning('Note: After changing any of the below settings, click "SEARCH" button or 🔄 to apply the changes', icon="⚠️")
745
- #st.header(' :gear: :orange[Fine-tune Search]')
746
- #st.write("Note: After changing any of the below settings, click 'SEARCH' button or '🔄' to apply the changes")
747
- #st.subheader(':blue[Keyword Search]')
748
 
749
  ########################## enable for query_rewrite ########################
750
  rewrite_query = st.checkbox('Auto-apply filters', key = 'query_rewrite', disabled = False, help = "Checking this box will use LLM to rewrite your query. \n\n Here your natural language query is transformed into OpenSearch query with added filters and attributes")
@@ -754,6 +612,8 @@ if(search_all_type == True or 1==1):
754
  key = 'input_must',
755
  )
756
  ########################## enable for query_rewrite ########################
 
 
757
  ####### Filters #########
758
 
759
  st.subheader(':blue[Filters]')
@@ -776,25 +636,6 @@ if(search_all_type == True or 1==1):
776
 
777
 
778
  clear_filter = st.button("Clear Filters",on_click=clear_filter)
779
-
780
-
781
- # filter_place_holder = st.container()
782
- # with filter_place_holder:
783
- # st.selectbox("Select one Category", ("accessories", "books","floral","furniture","hot_dispensed","jewelry","tools","apparel","cold_dispensed","food_service","groceries","housewares","outdoors","salty_snacks","videos","beauty","electronics","footwear","homedecor","instruments","seasonal"),index = None,key = "input_category")
784
- # st.selectbox("Select one Gender", ("male","female"),index = None,key = "input_gender")
785
- # st.slider("Select a range of price", 0, 2000, (0, 0),50, key = "input_price")
786
-
787
- # st.session_state.input_category=None
788
- # st.session_state.input_gender=None
789
- # st.session_state.input_price=(0,0)
790
-
791
- print("--------------------filters---------------")
792
- print(st.session_state.input_gender)
793
- print(st.session_state.input_manual_filter)
794
- print("--------------------filters---------------")
795
-
796
-
797
-
798
  ####### Filters #########
799
 
800
  if('NeuralSparse Search' in st.session_state.search_types):
@@ -802,111 +643,21 @@ if(search_all_type == True or 1==1):
802
  sparse_filter = st.slider('Keep only sparse tokens with weight >=', 0.0, 1.0, 0.5,0.1,key = 'input_sparse_filter', help = 'Use this slider to set the minimum weight that the sparse vector token weights should meet, rest are filtered out')
803
 
804
 
805
- #sql_query = st.checkbox('Re-write as SQL query', key = 'sql_rewrite', disabled = True, help = "In Progress")
806
  st.session_state.input_is_rewrite_query = 'disabled'
807
  st.session_state.input_is_sql_query = 'disabled'
808
 
809
  ########################## enable for query_rewrite ########################
810
  if rewrite_query:
811
- #st.write(st.session_state.inputs_)
812
  st.session_state.input_is_rewrite_query = 'enabled'
813
- # if sql_query:
814
- # #st.write(st.session_state.inputs_)
815
- # st.session_state.input_is_sql_query = 'enabled'
816
- ########################## enable for sql conversion ########################
817
-
818
-
819
- #st.markdown('---')
820
- #st.header('Fine-tune keyword Search', divider='rainbow')
821
- #st.subheader('Note: The below selection applies only when the Search type is set to Keyword Search')
822
-
823
-
824
- # st.markdown("<u>Enrich metadata for :</u>",unsafe_allow_html=True)
825
-
826
-
827
-
828
- # c3,c4 = st.columns([10,90])
829
- # with c4:
830
- # rekognition = st.checkbox('Images', key = 'rekognition', help = "Checking this box will use AI to extract metadata for images that are present in query and documents")
831
- # if rekognition:
832
- # #st.write(st.session_state.inputs_)
833
- # st.session_state.input_rekognition = 'enabled'
834
- # else:
835
- # st.session_state.input_rekognition = "disabled"
836
-
837
- #st.markdown('---')
838
- #st.header('Fine-tune Hybrid Search', divider='rainbow')
839
- #st.subheader('Note: The below parameters apply only when the Search type is set to Hybrid Search')
840
-
841
-
842
-
843
-
844
-
845
-
846
-
847
- #st.write("---")
848
- #if(st.session_state.max_selections == "None"):
849
  st.subheader(':blue[Hybrid Search]')
850
- # st.selectbox('Select the Hybrid Search type',
851
- # ("OpenSearch Hybrid Query","Reciprocal Rank Fusion"),key = 'input_hybridType')
852
- # equal_weight = st.button("Give equal weights to selected searches")
853
-
854
-
855
-
856
-
857
-
858
-
859
- #st.warning('Weight of each of the selected search type should be greater than 0 and the total weight of all the selected search type(s) should be equal to 100',icon = "⚠️")
860
-
861
-
862
- #st.markdown("<p style = 'font-size:14.5px;font-style:italic;'>Set Weights</p>",unsafe_allow_html=True)
863
-
864
  with st.expander("Set query Weightage:"):
865
  st.number_input("Keyword %", min_value=0, max_value=100, value=100, step=5, key='input_Keyword-weight', help=None)
866
  st.number_input("Vector %", min_value=0, max_value=100, value=0, step=5, key='input_Vector-weight', help=None)
867
  st.number_input("Multimodal %", min_value=0, max_value=100, value=0, step=5, key='input_Multimodal-weight', help=None)
868
  st.number_input("NeuralSparse %", min_value=0, max_value=100, value=0, step=5, key='input_NeuralSparse-weight', help=None)
869
 
870
- # if(equal_weight):
871
- # counter = 0
872
- # num_search = len(st.session_state.input_searchType)
873
- # weight_type = ["input_Keyword-weight","input_Vector-weight","input_Multimodal-weight","input_NeuralSparse-weight"]
874
- # for type in weight_type:
875
- # if(type.split("-")[0].replace("input_","")+ " Search" in st.session_state.input_searchType):
876
- # print("ssssssssssss")
877
- # counter = counter +1
878
- # extra_weight = 100%num_search
879
- # if(counter == num_search):
880
- # cal_weight = math.trunc(100/num_search)+extra_weight
881
- # else:
882
- # cal_weight = math.trunc(100/num_search)
883
- # st.session_state[weight_type] = cal_weight
884
- # else:
885
- # st.session_state[weight_type] = 0
886
- #weight = st.slider('Weight for Vector Search', 0.0, 1.0, 0.5,0.1,key = 'input_weight', help = 'Use this slider to set the weightage for keyword and vector search, higher values of the slider indicate the increased weightage for semantic search.\n\n This applies only when the search type is set to Hybrid Search')
887
- # st.selectbox('Select the Normalisation type',
888
- # ('min_max',
889
- # 'l2'
890
- # ),
891
- #st.write("---")
892
- # key = 'input_NormType',
893
- # disabled = True,
894
- # help = "Select the type of Normalisation to be applied on the two sets of scores"
895
- # )
896
-
897
- # st.selectbox('Select the Score Combination type',
898
- # ('arithmetic_mean','geometric_mean','harmonic_mean'
899
- # ),
900
-
901
- # key = 'input_CombineType',
902
- # disabled = True,
903
- # help = "Select the Combination strategy to be used while combining the two scores of the two search queries for every document"
904
- # )
905
-
906
- #st.markdown('---')
907
-
908
- #st.header('Select the ML Model for text embedding', divider='rainbow')
909
- #st.subheader('Note: The below selection applies only when the Search type is set to Vector or Hybrid Search')
910
  if(st.session_state.re_ranker == "true"):
911
  st.subheader(':blue[Re-ranking]')
912
  reranker = st.selectbox('Choose a Re-Ranker',
@@ -916,41 +667,19 @@ if(search_all_type == True or 1==1):
916
 
917
  key = 'input_reranker',
918
  help = 'Select the Re-Ranker type, select "None" to apply no re-ranking of the results',
919
- #on_change = re_ranker.re_rank,
920
  args=(st.session_state.questions, st.session_state.answers)
921
 
922
  )
923
- # st.write("---")
924
- # st.subheader('Text Embeddings Model')
925
- # model_type = st.selectbox('Select the Text Embeddings Model',
926
- # ('Titan-Embed-Text-v1','GPT-J-6B'
927
-
928
- # ),
929
-
930
- # key = 'input_modelType',
931
- # help = "Select the Text embedding model, this applies only for the vector and hybrid search"
932
- # )
933
-
934
- #st.markdown('---')
935
-
936
-
937
-
938
-
939
-
940
-
941
-
942
- #st.markdown('---')
943
 
944
 
945
  def write_user_message(md,ans):
946
- #print(ans)
947
  ans = ans["answer"][0]
948
  col1, col2, col3 = st.columns([3,40,20])
949
 
950
  with col1:
951
  st.image(USER_ICON, use_column_width='always')
952
  with col2:
953
- #st.warning(md['question'])
954
  st.markdown("<div style='fontSize:15px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'>Input Text: </div><div style='fontSize:25px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;color:#e28743'>"+md['question']+"</div>", unsafe_allow_html = True)
955
  if('query_sparse' in ans):
956
  with st.expander("Expanded Query:"):
@@ -1011,10 +740,7 @@ def render_answer(answer,index):
1011
  span_color = "red"
1012
  st.markdown("<span style='fontSize:20px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 20px;font-family:Courier New;color:#e28743'>Relevance:" +str('%.3f'%(st.session_state.input_ndcg)) + "</span><span style='font-size:30px;font-weight:bold;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-weight:bold;font-family:Courier New;color:"+span_color+"'> "+st.session_state.ndcg_increase.split("~")[1]+"</span>", unsafe_allow_html = True)
1013
 
1014
-
1015
- #st.markdown("<span style='font-size:30px;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-family:Courier New;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[1]+"</span>",unsafe_allow_html = True)
1016
-
1017
-
1018
 
1019
  placeholder_no_results = st.empty()
1020
 
@@ -1030,12 +756,7 @@ def render_answer(answer,index):
1030
  continue
1031
 
1032
 
1033
- # imgdata = base64.b64decode(ans['image_binary'])
1034
  format_ = ans['image_url'].split(".")[-1]
1035
-
1036
- #urllib.request.urlretrieve(ans['image_url'], "/home/ubuntu/res_images/"+str(i)+"_."+format_)
1037
-
1038
-
1039
  Image.MAX_IMAGE_PIXELS = 100000000
1040
 
1041
  width = 500
@@ -1066,23 +787,6 @@ def render_answer(answer,index):
1066
  desc__ = ans['desc'].split(" ")
1067
 
1068
  final_desc = "<p>"
1069
-
1070
- ###### stemming and highlighting
1071
-
1072
- # ans_text = ans['desc']
1073
- # query_text = st.session_state.input_text
1074
-
1075
- # ans_text_stemmed = set(stem_(ans_text))
1076
- # query_text_stemmed = set(stem_(query_text))
1077
-
1078
- # common = ans_text_stemmed.intersection( query_text_stemmed)
1079
- # #unique = set(document_1_words).symmetric_difference( )
1080
-
1081
- # desc__stemmed = stem_(desc__)
1082
-
1083
- # for word_ in desc__stemmed:
1084
- # if(word_ in common):
1085
-
1086
 
1087
  for word in desc__:
1088
  if(re.sub('[^A-Za-z0-9]+', '', word) in res__):
@@ -1104,16 +808,8 @@ def render_answer(answer,index):
1104
  filtered_sparse[key] = round(sparse_[key], 2)
1105
  st.write(filtered_sparse)
1106
  with st.expander("Document Metadata:",expanded = False):
1107
- # if("rekog" in ans):
1108
- # div_size = [50,50]
1109
- # else:
1110
- # div_size = [99,1]
1111
- # div1,div2 = st.columns(div_size)
1112
- # with div1:
1113
-
1114
  st.write(":green[default:]")
1115
  st.json({"category:":ans['category'],"price":str(ans['price']),"gender_affinity":ans['gender_affinity'],"style":ans['style']},expanded = True)
1116
- #with div2:
1117
  if("rekog" in ans):
1118
  st.write(":green[enriched:]")
1119
  st.json(ans['rekog'],expanded = True)
@@ -1128,18 +824,7 @@ def render_answer(answer,index):
1128
  st.write(":x:")
1129
 
1130
  i = i+1
1131
- # with col_2:
1132
- # if(st.session_state.input_evaluate == "enabled"):
1133
- # st.markdown("<div style='fontSize:12px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;font-weight:bold;height: fit-content;border-radius: 20px;font-family:Courier New;color:#e28743'>DCG: " +str('%.3f'%(st.session_state.input_ndcg)) + "</div>", unsafe_allow_html = True)
1134
- # with col_2_b:
1135
- # span_color = "white"
1136
- # if("&uarr;" in st.session_state.ndcg_increase):
1137
- # span_color = "green"
1138
- # if("&darr;" in st.session_state.ndcg_increase):
1139
- # span_color = "red"
1140
- # st.markdown("<span style='font-size:30px;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-family:Courier New;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[1]+"</span>",unsafe_allow_html = True)
1141
-
1142
-
1143
  with col_3:
1144
  if(index == len(st.session_state.questions)):
1145
 
@@ -1155,7 +840,6 @@ def render_answer(answer,index):
1155
  st.session_state.questions.pop()
1156
 
1157
  handle_input()
1158
- #re_ranker.re_rank(st.session_state.questions, st.session_state.answers)
1159
  with placeholder.container():
1160
  render_all()
1161
 
@@ -1169,9 +853,6 @@ def render_answer(answer,index):
1169
  except:
1170
  pass
1171
 
1172
- print("------------------------")
1173
- #print(st.session_state)
1174
-
1175
  placeholder__ = st.empty()
1176
 
1177
  placeholder__.button("🔄",key=rdn_key,on_click=on_button_click, help = "This will regenerate the responses with new settings that you entered, Note: To see difference in responses, you should change any of the applicable settings")#,type="primary",use_column_width=True)
@@ -1196,8 +877,6 @@ def render_all():
1196
  index = 0
1197
  for (q, a) in zip(st.session_state.questions, st.session_state.answers):
1198
  index = index +1
1199
- #print("answers----")
1200
- #print(a)
1201
  ans_ = st.session_state.answers[0]
1202
  write_user_message(q,ans_)
1203
  write_chat_message(a, q,index)
@@ -1206,6 +885,4 @@ placeholder = st.empty()
1206
  with placeholder.container():
1207
  render_all()
1208
 
1209
- #generate_images("",st.session_state.image_prompt)
1210
-
1211
  st.markdown("")
 
24
  import shutil
25
  import re
26
  from requests.auth import HTTPBasicAuth
 
27
  # from nltk.stem import PorterStemmer
28
  # from nltk.tokenize import word_tokenize
29
  import query_rewrite
30
  import amazon_rekognition
31
+ from streamlit.components.v1 import html
32
  #from st_click_detector import click_detector
33
  import llm_eval
34
  import all_search_execute
35
  import warnings
36
 
37
  warnings.filterwarnings("ignore", category=DeprecationWarning)
 
 
 
 
38
  st.set_page_config(
 
 
39
  page_icon="images/opensearch_mark_default.png"
40
  )
41
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
 
52
  #ps = PorterStemmer()
53
 
54
  st.session_state.REGION = 'us-east-1'
 
 
 
 
 
55
  USER_ICON = "images/user.png"
56
  AI_ICON = "images/opensearch-twitter-card.png"
57
  REGENERATE_ICON = "images/regenerate.png"
 
159
  if "gen_image_str" not in st.session_state:
160
  st.session_state.gen_image_str=""
161
 
 
 
 
 
 
 
162
  if "input_NormType" not in st.session_state:
163
  st.session_state.input_NormType = "min_max"
164
 
 
244
  'Multimodal Search',
245
  'NeuralSparse Search',
246
  ]
247
+
248
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
 
251
 
 
414
 
415
 
416
  inputs = {}
 
 
 
 
417
  if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType):
418
  old_rekog_label = st.session_state.input_rekog_label
419
  st.session_state.input_rekog_label = amazon_rekognition.extract_image_metadata(st.session_state.bytes_for_rekog)
420
  if(st.session_state.input_text == ""):
421
  st.session_state.input_text = st.session_state.input_rekog_label
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
 
424
  weightage = {}
425
  st.session_state.weights_ = []
 
458
  else:
459
  weightage[original_key] = 0.0
460
  st.session_state[key] = 0.0
461
+
 
 
 
 
 
 
 
 
 
 
 
 
462
  inputs['weightage']=weightage
463
  st.session_state.input_weightage = weightage
464
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
  st.session_state.inputs_ = inputs
467
 
 
468
  question_with_id = {
469
  'question': inputs["text"],
470
  'id': len(st.session_state.questions)
 
483
 
484
  if(st.session_state.input_is_rewrite_query == 'enabled' or (st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType)):
485
  query_rewrite.get_new_query_res(st.session_state.input_text)
486
+
 
 
487
  else:
488
  st.session_state.input_rewritten_query = ""
489
 
490
+
 
 
491
  ans__ = all_search_execute.handler(inputs, st.session_state['session_id'])
492
 
493
  st.session_state.answers.append({
494
+ 'answer': ans__,
495
  'search_type':inputs['searchType'],
496
  'id': len(st.session_state.questions)
497
  })
 
499
  st.session_state.answers_none_rank = st.session_state.answers
500
  if(st.session_state.input_evaluate == "enabled"):
501
  llm_eval.eval(st.session_state.questions, st.session_state.answers)
502
+
 
 
503
  def write_top_bar():
 
 
 
 
 
 
 
 
 
 
 
504
  col1, col2,col3,col4 = st.columns([2.5,35,8,7])
505
  with col1:
506
  st.image(TEXT_ICON, use_column_width='always')
 
529
  st.markdown("<div style = 'height:43px'></div>",unsafe_allow_html=True)
530
  st.button("Generate",disabled=False,key = "generate",on_click = generate_images, args=(tab1,"default_img"))
531
 
 
 
 
532
  image_select = st.radio("Choose one image", ["Image 1","Image 2","Image 3"],index=None, horizontal = True,key = 'image_select',disabled = st.session_state.radio_disabled)
533
  st.markdown("""
534
  <style>
 
538
  </style>
539
  """,unsafe_allow_html=True)
540
  if(st.session_state.image_select is not None and st.session_state.image_select !="" and len(st.session_state.img_gen)!=0):
 
 
 
541
  st.session_state.input_rad_1 = st.session_state.image_select.split(" ")[1]
542
  else:
543
  st.session_state.input_rad_1 = ""
544
+
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
 
547
  generate_images(tab1,gen_images)
 
550
  with tab2:
551
  st.session_state.img_doc = st.file_uploader(
552
  "Upload image", accept_multiple_files=False,type = ['png', 'jpg'])
 
 
 
 
 
553
  return clear,tab1
554
 
555
  clear,tab_ = write_top_bar()
556
 
557
  if clear:
 
 
 
558
  st.session_state.questions = []
559
  st.session_state.answers = []
560
 
 
570
  st.session_state.input_rad_1 = ""
571
 
572
 
573
+
 
 
 
 
 
 
 
 
 
 
 
574
  col1, col3, col4 = st.columns([70,18,12])
575
 
576
  with col1:
 
594
  evaluate = st.toggle(' ', key = 'evaluate', disabled = False) #help = "Checking this box will use LLM to evaluate results as relevant and irrelevant. \n\n This option increases the latency")
595
  if(evaluate):
596
  st.session_state.input_evaluate = "enabled"
597
+
598
  else:
599
  st.session_state.input_evaluate = "disabled"
600
 
 
602
  if(search_all_type == True or 1==1):
603
  with st.sidebar:
604
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
605
+
 
 
 
 
606
 
607
  ########################## enable for query_rewrite ########################
608
  rewrite_query = st.checkbox('Auto-apply filters', key = 'query_rewrite', disabled = False, help = "Checking this box will use LLM to rewrite your query. \n\n Here your natural language query is transformed into OpenSearch query with added filters and attributes")
 
612
  key = 'input_must',
613
  )
614
  ########################## enable for query_rewrite ########################
615
+
616
+
617
  ####### Filters #########
618
 
619
  st.subheader(':blue[Filters]')
 
636
 
637
 
638
  clear_filter = st.button("Clear Filters",on_click=clear_filter)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
639
  ####### Filters #########
640
 
641
  if('NeuralSparse Search' in st.session_state.search_types):
 
643
  sparse_filter = st.slider('Keep only sparse tokens with weight >=', 0.0, 1.0, 0.5,0.1,key = 'input_sparse_filter', help = 'Use this slider to set the minimum weight that the sparse vector token weights should meet, rest are filtered out')
644
 
645
 
 
646
  st.session_state.input_is_rewrite_query = 'disabled'
647
  st.session_state.input_is_sql_query = 'disabled'
648
 
649
  ########################## enable for query_rewrite ########################
650
  if rewrite_query:
 
651
  st.session_state.input_is_rewrite_query = 'enabled'
652
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  st.subheader(':blue[Hybrid Search]')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  with st.expander("Set query Weightage:"):
655
  st.number_input("Keyword %", min_value=0, max_value=100, value=100, step=5, key='input_Keyword-weight', help=None)
656
  st.number_input("Vector %", min_value=0, max_value=100, value=0, step=5, key='input_Vector-weight', help=None)
657
  st.number_input("Multimodal %", min_value=0, max_value=100, value=0, step=5, key='input_Multimodal-weight', help=None)
658
  st.number_input("NeuralSparse %", min_value=0, max_value=100, value=0, step=5, key='input_NeuralSparse-weight', help=None)
659
 
660
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  if(st.session_state.re_ranker == "true"):
662
  st.subheader(':blue[Re-ranking]')
663
  reranker = st.selectbox('Choose a Re-Ranker',
 
667
 
668
  key = 'input_reranker',
669
  help = 'Select the Re-Ranker type, select "None" to apply no re-ranking of the results',
 
670
  args=(st.session_state.questions, st.session_state.answers)
671
 
672
  )
673
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
674
 
675
 
676
  def write_user_message(md,ans):
 
677
  ans = ans["answer"][0]
678
  col1, col2, col3 = st.columns([3,40,20])
679
 
680
  with col1:
681
  st.image(USER_ICON, use_column_width='always')
682
  with col2:
 
683
  st.markdown("<div style='fontSize:15px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'>Input Text: </div><div style='fontSize:25px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;color:#e28743'>"+md['question']+"</div>", unsafe_allow_html = True)
684
  if('query_sparse' in ans):
685
  with st.expander("Expanded Query:"):
 
740
  span_color = "red"
741
  st.markdown("<span style='fontSize:20px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 20px;font-family:Courier New;color:#e28743'>Relevance:" +str('%.3f'%(st.session_state.input_ndcg)) + "</span><span style='font-size:30px;font-weight:bold;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-weight:bold;font-family:Courier New;color:"+span_color+"'> "+st.session_state.ndcg_increase.split("~")[1]+"</span>", unsafe_allow_html = True)
742
 
743
+
 
 
 
744
 
745
  placeholder_no_results = st.empty()
746
 
 
756
  continue
757
 
758
 
 
759
  format_ = ans['image_url'].split(".")[-1]
 
 
 
 
760
  Image.MAX_IMAGE_PIXELS = 100000000
761
 
762
  width = 500
 
787
  desc__ = ans['desc'].split(" ")
788
 
789
  final_desc = "<p>"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
790
 
791
  for word in desc__:
792
  if(re.sub('[^A-Za-z0-9]+', '', word) in res__):
 
808
  filtered_sparse[key] = round(sparse_[key], 2)
809
  st.write(filtered_sparse)
810
  with st.expander("Document Metadata:",expanded = False):
 
 
 
 
 
 
 
811
  st.write(":green[default:]")
812
  st.json({"category:":ans['category'],"price":str(ans['price']),"gender_affinity":ans['gender_affinity'],"style":ans['style']},expanded = True)
 
813
  if("rekog" in ans):
814
  st.write(":green[enriched:]")
815
  st.json(ans['rekog'],expanded = True)
 
824
  st.write(":x:")
825
 
826
  i = i+1
827
+
 
 
 
 
 
 
 
 
 
 
 
828
  with col_3:
829
  if(index == len(st.session_state.questions)):
830
 
 
840
  st.session_state.questions.pop()
841
 
842
  handle_input()
 
843
  with placeholder.container():
844
  render_all()
845
 
 
853
  except:
854
  pass
855
 
 
 
 
856
  placeholder__ = st.empty()
857
 
858
  placeholder__.button("🔄",key=rdn_key,on_click=on_button_click, help = "This will regenerate the responses with new settings that you entered, Note: To see difference in responses, you should change any of the applicable settings")#,type="primary",use_column_width=True)
 
877
  index = 0
878
  for (q, a) in zip(st.session_state.questions, st.session_state.answers):
879
  index = index +1
 
 
880
  ans_ = st.session_state.answers[0]
881
  write_user_message(q,ans_)
882
  write_chat_message(a, q,index)
 
885
  with placeholder.container():
886
  render_all()
887
 
 
 
888
  st.markdown("")
semantic_search/amazon_rekognition.py CHANGED
@@ -24,12 +24,7 @@ def extract_image_metadata(img):
24
  MaxLabels = 10,
25
  MinConfidence = 80.0,
26
  Settings = {
27
- # "GeneralLabels": {
28
- # "LabelCategoryExclusionFilters": [ "string" ],
29
- # "LabelCategoryInclusionFilters": [ "string" ],
30
- # "LabelExclusionFilters": [ "string" ],
31
- # "LabelInclusionFilters": [ "string" ]
32
- # },
33
  "ImageProperties": {
34
  "MaxDominantColors": 5
35
  }
@@ -76,20 +71,12 @@ def extract_image_metadata(img):
76
  objects = " ".join(set(objects))
77
  categories = " ".join(set(categories))
78
  colors = " ".join(set(colors))
79
-
80
- print("^^^^^^^^^^^^^^^^^^")
81
- print(colors+ " " + objects + " " + categories)
82
-
83
  return colors+ " " + objects + " " + categories
84
 
85
  def call(a,b):
86
- print("'''''''''''''''''''''''")
87
- print(b)
88
-
89
  if(st.session_state.input_is_rewrite_query == 'enabled' and st.session_state.input_rewritten_query!=""):
90
 
91
 
92
- #st.session_state.input_rewritten_query['query']['bool']['should'].pop()
93
  st.session_state.input_rewritten_query['query']['bool']['should'].append( {
94
  "simple_query_string": {
95
 
@@ -112,36 +99,4 @@ def call(a,b):
112
  }
113
  st.session_state.input_rewritten_query = rekog_query
114
 
115
- # response = aos_client.search(
116
- # body = rekog_query,
117
- # index = 'demo-retail-rekognition'
118
- # #pipeline = 'RAG-Search-Pipeline'
119
- # )
120
-
121
-
122
- # hits = response['hits']['hits']
123
- # print("rewrite-------------------------")
124
- # arr = []
125
- # for doc in hits:
126
- # # if('b5/b5319e00' in doc['_source']['image_s3_url'] ):
127
- # # filter_out +=1
128
- # # continue
129
-
130
- # res_ = {"desc":doc['_source']['text'].replace(doc['_source']['metadata']['rekog_all']," ^^^ " +doc['_source']['metadata']['rekog_all']),
131
- # "image_url":doc['_source']['metadata']['image_s3_url']}
132
- # if('highlight' in doc):
133
- # res_['highlight'] = doc['highlight']['text']
134
- # # if('caption_embedding' in doc['_source']):
135
- # # res_['sparse'] = doc['_source']['caption_embedding']
136
- # # if('query_sparse' in response_ and len(arr) ==0 ):
137
- # # res_['query_sparse'] = response_["query_sparse"]
138
- # res_['id'] = doc['_id']
139
- # res_['score'] = doc['_score']
140
- # res_['title'] = doc['_source']['text']
141
- # res_['rekog'] = {'color':doc['_source']['metadata']['rekog_color'],'category': doc['_source']['metadata']['rekog_categories'],'objects':doc['_source']['metadata']['rekog_objects']}
142
-
143
- # arr.append(res_)
144
-
145
-
146
-
147
- # return arr
 
24
  MaxLabels = 10,
25
  MinConfidence = 80.0,
26
  Settings = {
27
+
 
 
 
 
 
28
  "ImageProperties": {
29
  "MaxDominantColors": 5
30
  }
 
71
  objects = " ".join(set(objects))
72
  categories = " ".join(set(categories))
73
  colors = " ".join(set(colors))
 
 
 
 
74
  return colors+ " " + objects + " " + categories
75
 
76
  def call(a,b):
 
 
 
77
  if(st.session_state.input_is_rewrite_query == 'enabled' and st.session_state.input_rewritten_query!=""):
78
 
79
 
 
80
  st.session_state.input_rewritten_query['query']['bool']['should'].append( {
81
  "simple_query_string": {
82
 
 
99
  }
100
  st.session_state.input_rewritten_query = rekog_query
101
 
102
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utilities/invoke_models.py CHANGED
@@ -24,17 +24,6 @@ bedrock_runtime_client = get_bedrock_client()
24
 
25
 
26
 
27
- # def generate_image_captions_ml():
28
- # model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
29
- # feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
30
- # tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
31
-
32
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
- # model.to(device)
34
- # max_length = 16
35
- # num_beams = 4
36
- # gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
37
-
38
  def invoke_model(input):
39
  response = bedrock_runtime_client.invoke_model(
40
  body=json.dumps({
@@ -100,56 +89,7 @@ def invoke_llm_model(input,is_stream):
100
 
101
  return (json.loads(res))['content'][0]['text']
102
 
103
- # response = bedrock_runtime_client.invoke_model_with_response_stream(
104
- # body=json.dumps({
105
- # "prompt": input,
106
- # "max_tokens_to_sample": 300,
107
- # "temperature": 0.5,
108
- # "top_k": 250,
109
- # "top_p": 1,
110
- # "stop_sequences": [
111
- # "\n\nHuman:"
112
- # ],
113
- # # "anthropic_version": "bedrock-2023-05-31"
114
- # }),
115
- # modelId="anthropic.claude-v2:1",
116
- # accept="application/json",
117
- # contentType="application/json",
118
- # )
119
- # stream = response.get('body')
120
-
121
- # return stream
122
-
123
- # else:
124
- # response = bedrock_runtime_client.invoke_model_with_response_stream(
125
- # modelId= "anthropic.claude-3-sonnet-20240229-v1:0",
126
- # contentType = "application/json",
127
- # accept = "application/json",
128
-
129
- # body = json.dumps({
130
- # "anthropic_version": "bedrock-2023-05-31",
131
- # "max_tokens": 1024,
132
- # "temperature": 0.0001,
133
- # "top_k": 150,
134
- # "top_p": 0.7,
135
- # "stop_sequences": [
136
- # "\n\nHuman:"
137
- # ],
138
- # "messages": [
139
- # {
140
- # "role": "user",
141
- # "content":input
142
- # }
143
- # ]
144
- # }
145
-
146
- # )
147
- # )
148
-
149
- # stream = response.get('body')
150
-
151
- # return stream
152
-
153
  def read_from_table(file,question):
154
  print("started table analysis:")
155
  print("-----------------------")
@@ -175,7 +115,6 @@ def read_from_table(file,question):
175
  df = pd.read_csv(file,skipinitialspace = True, on_bad_lines='skip',delimiter = "`")
176
  else:
177
  df = file
178
- #df.fillna(method='pad', inplace=True)
179
  agent = create_pandas_dataframe_agent(
180
  model,
181
  df,
@@ -188,24 +127,7 @@ def read_from_table(file,question):
188
 
189
  def generate_image_captions_llm(base64_string,question):
190
 
191
- # ant_client = Anthropic()
192
- # MODEL_NAME = "claude-3-opus-20240229"
193
-
194
- # message_list = [
195
- # {
196
- # "role": 'user',
197
- # "content": [
198
- # {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": base64_string}},
199
- # {"type": "text", "text": "What is in the image ?"}
200
- # ]
201
- # }
202
- # ]
203
-
204
- # response = ant_client.messages.create(
205
- # model=MODEL_NAME,
206
- # max_tokens=2048,
207
- # messages=message_list
208
- # )
209
  response = bedrock_runtime_client.invoke_model(
210
  modelId= "anthropic.claude-3-haiku-20240307-v1:0",
211
  contentType = "application/json",
@@ -234,9 +156,5 @@ def generate_image_captions_llm(base64_string,question):
234
  }
235
  ]
236
  }))
237
- #print(response)
238
  response_body = json.loads(response.get("body").read())['content'][0]['text']
239
-
240
- #print(response_body)
241
-
242
  return response_body
 
24
 
25
 
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  def invoke_model(input):
28
  response = bedrock_runtime_client.invoke_model(
29
  body=json.dumps({
 
89
 
90
  return (json.loads(res))['content'][0]['text']
91
 
92
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def read_from_table(file,question):
94
  print("started table analysis:")
95
  print("-----------------------")
 
115
  df = pd.read_csv(file,skipinitialspace = True, on_bad_lines='skip',delimiter = "`")
116
  else:
117
  df = file
 
118
  agent = create_pandas_dataframe_agent(
119
  model,
120
  df,
 
127
 
128
  def generate_image_captions_llm(base64_string,question):
129
 
130
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  response = bedrock_runtime_client.invoke_model(
132
  modelId= "anthropic.claude-3-haiku-20240307-v1:0",
133
  contentType = "application/json",
 
156
  }
157
  ]
158
  }))
 
159
  response_body = json.loads(response.get("body").read())['content'][0]['text']
 
 
 
160
  return response_body
utilities/re_ranker.py DELETED
@@ -1,127 +0,0 @@
1
- import boto3
2
- from botocore.exceptions import ClientError
3
- import pprint
4
- import time
5
- import streamlit as st
6
- from sentence_transformers import CrossEncoder
7
-
8
- #model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=512)
9
- ####### Add this Kendra Rescore ranking
10
- #kendra_ranking = boto3.client("kendra-ranking",region_name = 'us-east-1')
11
- #print("Create a rescore execution plan.")
12
-
13
- # Provide a name for the rescore execution plan
14
- #name = "MyRescoreExecutionPlan"
15
- # Set your required additional capacity units
16
- # Don't set capacity units if you don't require more than 1 unit given by default
17
- #capacity_units = 2
18
-
19
- # try:
20
- # rescore_execution_plan_response = kendra_ranking.create_rescore_execution_plan(
21
- # Name = name,
22
- # CapacityUnits = {"RescoreCapacityUnits":capacity_units}
23
- # )
24
-
25
- # pprint.pprint(rescore_execution_plan_response)
26
-
27
- # rescore_execution_plan_id = rescore_execution_plan_response["Id"]
28
-
29
- # print("Wait for Amazon Kendra to create the rescore execution plan.")
30
-
31
- # while True:
32
- # # Get the details of the rescore execution plan, such as the status
33
- # rescore_execution_plan_description = kendra_ranking.describe_rescore_execution_plan(
34
- # Id = rescore_execution_plan_id
35
- # )
36
- # # When status is not CREATING quit.
37
- # status = rescore_execution_plan_description["Status"]
38
- # print(" Creating rescore execution plan. Status: "+status)
39
- # time.sleep(60)
40
- # if status != "CREATING":
41
- # break
42
-
43
- # except ClientError as e:
44
- # print("%s" % e)
45
-
46
- # print("Program ends.")
47
- #########################
48
-
49
- @st.cache_resource
50
- def re_rank(self_, rerank_type, search_type, question, answers):
51
-
52
- ans = []
53
- ids = []
54
- ques_ans = []
55
- query = question[0]['question']
56
- for i in answers[0]['answer']:
57
- if(self_ == "search"):
58
-
59
- ans.append({
60
- "Id": i['id'],
61
- "Body": i["desc"],
62
- "OriginalScore": i['score'],
63
- "Title":i["desc"]
64
- })
65
- ids.append(i['id'])
66
- ques_ans.append((query,i["desc"]))
67
-
68
- else:
69
- ans.append({'text':i})
70
-
71
- ques_ans.append((query,i))
72
-
73
-
74
-
75
- re_ranked = [{}]
76
- ####### Add this Kendra Rescore ranking
77
- # if(rerank_type == 'Kendra Rescore'):
78
- # rescore_response = kendra_ranking.rescore(
79
- # RescoreExecutionPlanId = 'b2a4d4f3-98ff-4e17-8b69-4c61ed7d91eb',
80
- # SearchQuery = query,
81
- # Documents = ans
82
- # )
83
- # re_ranked[0]['answer']=[]
84
- # for result in rescore_response["ResultItems"]:
85
-
86
- # pos_ = ids.index(result['DocumentId'])
87
-
88
- # re_ranked[0]['answer'].append(answers[0]['answer'][pos_])
89
- # re_ranked[0]['search_type']=search_type,
90
- # re_ranked[0]['id'] = len(question)
91
- # return re_ranked
92
-
93
- # if(rerank_type == 'Cross Encoder'):
94
-
95
- # scores = model.predict(
96
- # ques_ans
97
- # )
98
-
99
- # index__ = 0
100
- # for i in ans:
101
- # i['new_score'] = scores[index__]
102
- # index__ = index__+1
103
-
104
- # ans_sorted = sorted(ans, key=lambda d: d['new_score'],reverse=True)
105
-
106
-
107
- # def retreive_only_text(item):
108
- # return item['text']
109
-
110
- # if(self_ == 'rag'):
111
- # return list(map(retreive_only_text, ans_sorted))
112
-
113
-
114
- # re_ranked[0]['answer']=[]
115
- # for j in ans_sorted:
116
- # pos_ = ids.index(j['Id'])
117
- # re_ranked[0]['answer'].append(answers[0]['answer'][pos_])
118
- # re_ranked[0]['search_type']= search_type,
119
- # re_ranked[0]['id'] = len(question)
120
- # return re_ranked
121
-
122
-
123
-
124
-
125
-
126
-
127
-