prasadnu commited on
Commit
5ae9dd3
·
1 Parent(s): dc5a08b

rerank model

Browse files
pages/Multimodal_Conversational_Search.py CHANGED
@@ -13,49 +13,39 @@ import botocore.session
13
  import json
14
  import random
15
  import string
16
- #import rag_DocumentLoader
17
  import rag_DocumentSearcher
 
18
  import pandas as pd
19
  from PIL import Image
20
  import shutil
21
  import base64
22
  import time
23
  import botocore
 
 
 
24
  from requests_aws4auth import AWS4Auth
25
  import colpali
26
  from requests.auth import HTTPBasicAuth
27
- import warnings
28
 
29
 
30
- warnings.filterwarnings("ignore", category=DeprecationWarning)
31
 
32
  st.set_page_config(
33
  #page_title="Semantic Search using OpenSearch",
34
  layout="wide",
35
  page_icon="images/opensearch_mark_default.png"
36
  )
37
-
38
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
39
  USER_ICON = "images/user.png"
40
  AI_ICON = "images/opensearch-twitter-card.png"
41
  REGENERATE_ICON = "images/regenerate.png"
42
  s3_bucket_ = "pdf-repo-uploads"
43
  #"pdf-repo-uploads"
44
-
45
- # @st.cache_resource
46
- # def get_polly_client():
47
- # return boto3.client('polly',
48
- # aws_access_key_id=st.secrets['user_access_key'],
49
- # aws_secret_access_key=st.secrets['user_secret_key'],
50
- # region_name='us-east-1'
51
- # )
52
-
53
- # polly_client = get_polly_client()
54
-
55
 
56
  # Check if the user ID is already stored in the session state
57
-
58
-
59
  if 'user_id' in st.session_state:
60
  user_id = st.session_state['user_id']
61
  #print(f"User ID: {user_id}")
@@ -80,19 +70,23 @@ if "chats" not in st.session_state:
80
 
81
  if "questions_" not in st.session_state:
82
  st.session_state.questions_ = []
83
-
 
84
  if "show_columns" not in st.session_state:
85
  st.session_state.show_columns = False
 
 
 
86
 
87
  if "answers_" not in st.session_state:
88
  st.session_state.answers_ = []
89
 
90
  if "input_index" not in st.session_state:
91
- st.session_state.input_index = "globalwarming"#"hpijan2024hometrack"#"#"hpijan2024hometrack_no_img_no_table"
92
 
93
  if "input_is_rerank" not in st.session_state:
94
  st.session_state.input_is_rerank = True
95
-
96
  if "input_is_colpali" not in st.session_state:
97
  st.session_state.input_is_colpali = False
98
 
@@ -103,12 +97,25 @@ if "input_table_with_sql" not in st.session_state:
103
  st.session_state.input_table_with_sql = False
104
 
105
  if "input_query" not in st.session_state:
106
- st.session_state.input_query="What is the projected energy percentage from renewable sources in future?"#"which city has the highest average housing price in UK ?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
 
 
 
 
 
 
107
 
108
  if "input_rag_searchType" not in st.session_state:
109
- st.session_state.input_rag_searchType = ["Vector Search"]
 
110
 
111
 
 
 
 
 
 
 
112
  st.markdown("""
113
  <style>
114
  [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
@@ -120,10 +127,48 @@ st.markdown("""
120
  </style>
121
  """,unsafe_allow_html=True)
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  credentials = boto3.Session().get_credentials()
124
- awsauth = HTTPBasicAuth('master',st.secrets['ml_search_demo_api_access'])
125
  service = 'es'
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  def write_logo():
128
  col1, col2, col3 = st.columns([5, 1, 5])
129
  with col2:
@@ -151,9 +196,16 @@ if clear:
151
  st.session_state.questions_ = []
152
  st.session_state.answers_ = []
153
  st.session_state.input_query=""
 
 
 
 
 
154
 
155
 
156
  def handle_input():
 
 
157
  print("Question: "+st.session_state.input_query)
158
  print("-----------")
159
  print("\n\n")
@@ -164,6 +216,11 @@ def handle_input():
164
  if key.startswith('input_'):
165
  inputs[key.removeprefix('input_')] = st.session_state[key]
166
  st.session_state.inputs_ = inputs
 
 
 
 
 
167
  question_with_id = {
168
  'question': inputs["query"],
169
  'id': len(st.session_state.questions_)
@@ -171,7 +228,7 @@ def handle_input():
171
  st.session_state.questions_.append(question_with_id)
172
  if(st.session_state.input_is_colpali):
173
  out_ = colpali.colpali_search_rerank(st.session_state.input_query)
174
-
175
  else:
176
  out_ = rag_DocumentSearcher.query_(awsauth, inputs, st.session_state['session_id'],st.session_state.input_rag_searchType)
177
  st.session_state.answers_.append({
@@ -182,67 +239,122 @@ def handle_input():
182
  'table':out_['table']
183
  })
184
  st.session_state.input_query=""
 
 
 
 
 
 
 
 
 
 
 
 
185
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- def write_user_message(placeholder,md):
188
- placeholder.empty()
189
- with placeholder.container():
190
- col1, col2 = st.columns([3,97])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- with col1:
193
- st.image(USER_ICON, use_column_width='always')
194
- with col2:
195
- #st.warning(md['question'])
196
-
197
- st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
198
- return placeholder
199
-
200
-
201
- def render_answer(placeholder,question,answer,index,res_img):
202
- placeholder.empty()
203
- with placeholder.container():
204
- col1, col2, col_3 = st.columns([4,74,22])
205
- with col1:
206
- st.image(AI_ICON, use_column_width='always')
207
- with col2:
208
- ans_ = answer['answer']
209
- st.write(ans_)
210
-
211
- # polly_response = polly_client.synthesize_speech(VoiceId='Joanna',
212
- # OutputFormat='ogg_vorbis',
213
- # Text = ans_,
214
- # Engine = 'neural')
215
-
216
- # audio_col1, audio_col2 = st.columns([50,50])
217
- # with audio_col1:
218
- # st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
219
- rdn_key_1 = ''.join([random.choice(string.ascii_letters)
220
- for _ in range(10)])
221
- def show_maxsim(placeholder,dummy):
222
- st.session_state.show_columns = True
223
- st.session_state.input_query = ""
224
- st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
225
- #handle_input()
226
- # placeholder = st.empty()
227
- # with placeholder.container():
228
- placeholder.empty()
229
- render_all(placeholder)
230
- if(st.session_state.input_is_colpali):
231
- st.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim,args=(placeholder,"default_img"))
232
 
233
- colu1,colu2,colu3 = st.columns([4,82,20])
234
- with colu2:
235
- @st.cache_data
236
- def load_table_from_file(filepath):
237
- df = pd.read_csv(filepath, skipinitialspace=True, on_bad_lines='skip', delimiter='`')
238
- df.fillna(method='pad', inplace=True)
239
- return df
240
-
241
- with st.expander("Relevant Sources:"):
242
- with st.container():
243
- if(len(res_img)>0):
244
- #with st.expander("Images:"):
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  idx = 0
247
  print(res_img)
248
  for i in range(0,len(res_img)):
@@ -263,8 +375,13 @@ def render_answer(placeholder,question,answer,index,res_img):
263
  col3_,col4_,col5_ = st.columns([33,33,33])
264
  with col3_:
265
  st.image(res_img[i]['file'])
 
 
 
 
 
266
  else:
267
- if(res_img[i]['file'].lower()!='none' and idx < 1):
268
  col3,col4,col5 = st.columns([33,33,33])
269
  cols = [col3,col4]
270
  img = res_img[i]['file'].split(".")[0]
@@ -273,86 +390,120 @@ def render_answer(placeholder,question,answer,index,res_img):
273
  with cols[idx]:
274
 
275
  st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
 
276
  idx = idx+1
277
  if(st.session_state.show_columns == True):
278
  st.session_state.show_columns = False
279
- if(len(answer["table"] )>0):
280
- #with st.expander("Table:"):
281
- df = load_table_from_file(answer["table"][0]['name'])
 
 
282
  st.table(df)
283
- #with st.expander("Raw sources:"):
284
  st.write(answer["source"])
285
-
 
 
 
 
 
286
 
287
- # with col_3:
288
- # if(index == len(st.session_state.questions_)):
289
-
290
- # rdn_key = ''.join([random.choice(string.ascii_letters)
291
- # for _ in range(10)])
292
- # currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
293
- # oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
294
- # def on_button_click():
295
- # if(currentValue!=oldValue or 1==1):
296
- # st.session_state.input_query = st.session_state.questions_[-1]["question"]
297
- # st.session_state.answers_.pop()
298
- # st.session_state.questions_.pop()
 
 
 
 
299
 
300
- # handle_input()
301
- # with placeholder.container():
302
- # render_all()
303
- # if("currentValue" in st.session_state):
304
- # del st.session_state["currentValue"]
305
-
306
- # try:
307
- # del regenerate
308
- # except:
309
- # pass
310
- # placeholder__ = st.empty()
311
- # placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
312
-
 
 
 
 
 
 
 
313
 
 
 
314
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
315
- def write_chat_message(placeholder, md, q,index):
316
  if(st.session_state.show_columns):
317
  res_img = st.session_state.maxSimImages
318
  else:
319
  res_img = md['image']
320
- render_answer(placeholder,q,md,index,res_img)
 
 
 
 
 
321
 
322
- def render_all(placeholder):
323
  index = 0
324
  for (q, a) in zip(st.session_state.questions_, st.session_state.answers_):
325
  index = index +1
326
 
327
- placeholder = write_user_message(placeholder,q)
328
- write_chat_message(placeholder, a, q,index)
329
 
330
  placeholder = st.empty()
331
- render_all(placeholder)
332
-
333
 
334
-
335
-
336
  st.markdown("")
337
- col_2, col_3 = st.columns([75, 20])
 
 
 
 
338
  with col_2:
339
- input = st.text_input("Ask here", key="input_query", label_visibility="collapsed")
 
340
  with col_3:
341
- play = st.button("Go",on_click=handle_input, key="play")
342
-
343
-
344
- ##### Sidebar #####
345
  with st.sidebar:
346
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
347
  st.subheader(":blue[Sample Data]")
348
  coln_1,coln_2 = st.columns([70,30])
 
 
 
 
 
 
349
  with coln_1:
350
- index_select = st.radio("Choose one index",["Global Warming stats","UK Housing","Covid19 impacts on Ireland"],key="input_rad_index")
351
  with coln_2:
352
  st.markdown("<p style='font-size:15px'>Preview file</p>",unsafe_allow_html=True)
353
- st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)")
354
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)")
 
355
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)")
 
356
  st.markdown("""
357
  <style>
358
  [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
@@ -364,23 +515,92 @@ with st.sidebar:
364
  </style>
365
  """,unsafe_allow_html=True)
366
  with st.expander("Sample questions:"):
367
- st.markdown("<span style = 'color:#FF9900;'>Global Warming stats</span> - What is the projected energy percentage from renewable sources in future?",unsafe_allow_html=True)
368
  st.markdown("<span style = 'color:#FF9900;'>UK Housing</span> - which city has the highest average housing price in UK ?",unsafe_allow_html=True)
 
369
  st.markdown("<span style = 'color:#FF9900;'>Covid19 impacts</span> - How many aged above 85 years died due to covid ?",unsafe_allow_html=True)
370
 
371
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
 
373
  ############## haystach demo temporary addition ############
374
  #if(pdf_doc_ is None or pdf_doc_ == ""):
375
  if(index_select == "Global Warming stats"):
376
- st.session_state.input_index = "globalwarming"
377
  if(index_select == "Covid19 impacts on Ireland"):
378
  st.session_state.input_index = "covid19ie"#"choosetheknnalgorithmforyourbillionscaleusecasewithopensearchawsbigdatablog"
379
  if(index_select == "BEIR"):
380
  st.session_state.input_index = "2104"
381
  if(index_select == "UK Housing"):
382
  st.session_state.input_index = "hpijan2024hometrack"
383
-
 
 
 
 
 
384
 
385
  st.subheader(":blue[Retriever]")
386
  search_type = st.multiselect('Select the Retriever(s)',
@@ -403,6 +623,7 @@ with st.sidebar:
403
 
404
  st.subheader(":blue[Multi-vector retrieval]")
405
 
 
406
  colpali_search_rerank = st.checkbox('Try Colpali multi-vector retrieval on the [sample dataset](https://huggingface.co/datasets/vespa-engine/gpfg-QA)', key = 'input_colpali', disabled = False, value = False, help = "Checking this box will use colpali as the embedding model and retrieval is performed using multi-vectors followed by re-ranking using MaxSim")
407
 
408
  if(colpali_search_rerank):
@@ -413,6 +634,16 @@ with st.sidebar:
413
 
414
  with st.expander("Sample questions for Colpali retriever:"):
415
  st.write("1. Proportion of female new hires 2021-2023? \n\n 2. First-half 2021 return on unlisted real estate investments? \n\n 3. Trend of the fund's expected absolute volatility between January 2014 and January 2016? \n\n 4. Fund return percentage in 2017? \n\n 5. Annualized gross return of the fund from 1997 to 2008?")
416
-
417
 
418
-
 
 
 
 
 
 
 
 
 
 
 
 
13
  import json
14
  import random
15
  import string
16
+ import rag_DocumentLoader
17
  import rag_DocumentSearcher
18
+ #import colpali
19
  import pandas as pd
20
  from PIL import Image
21
  import shutil
22
  import base64
23
  import time
24
  import botocore
25
+ #from langchain.callbacks.base import BaseCallbackHandler
26
+ import streamlit_nested_layout
27
+ #from IPython.display import clear_output, display, display_markdown, Markdown
28
  from requests_aws4auth import AWS4Auth
29
  import colpali
30
  from requests.auth import HTTPBasicAuth
 
31
 
32
 
 
33
 
34
  st.set_page_config(
35
  #page_title="Semantic Search using OpenSearch",
36
  layout="wide",
37
  page_icon="images/opensearch_mark_default.png"
38
  )
 
39
  parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
40
  USER_ICON = "images/user.png"
41
  AI_ICON = "images/opensearch-twitter-card.png"
42
  REGENERATE_ICON = "images/regenerate.png"
43
  s3_bucket_ = "pdf-repo-uploads"
44
  #"pdf-repo-uploads"
45
+ polly_client = boto3.Session(
46
+ region_name='us-east-1').client('polly')
 
 
 
 
 
 
 
 
 
47
 
48
  # Check if the user ID is already stored in the session state
 
 
49
  if 'user_id' in st.session_state:
50
  user_id = st.session_state['user_id']
51
  #print(f"User ID: {user_id}")
 
70
 
71
  if "questions_" not in st.session_state:
72
  st.session_state.questions_ = []
73
+
74
+
75
  if "show_columns" not in st.session_state:
76
  st.session_state.show_columns = False
77
+
78
+ if "answer_ready" not in st.session_state:
79
+ st.session_state.answer_ready = False
80
 
81
  if "answers_" not in st.session_state:
82
  st.session_state.answers_ = []
83
 
84
  if "input_index" not in st.session_state:
85
+ st.session_state.input_index = "hpijan2024hometrack"#"globalwarmingnew"#"hpijan2024hometrack_no_img_no_table"
86
 
87
  if "input_is_rerank" not in st.session_state:
88
  st.session_state.input_is_rerank = True
89
+
90
  if "input_is_colpali" not in st.session_state:
91
  st.session_state.input_is_colpali = False
92
 
 
97
  st.session_state.input_table_with_sql = False
98
 
99
  if "input_query" not in st.session_state:
100
+ if(st.session_state.input_index == "globalwarming"):
101
+ st.session_state.input_query="What is the projected energy percentage from renewable sources in future ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
102
+ if(st.session_state.input_index == "hpijan2024hometrack"):
103
+ st.session_state.input_query="which city has the highest average housing price in UK ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
104
+ if(st.session_state.input_index == "covid19ie"):
105
+ st.session_state.input_query="How many aged above 85 years died due to covid ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
106
+
107
 
108
  if "input_rag_searchType" not in st.session_state:
109
+ st.session_state.input_rag_searchType = ["Vector Search"]
110
+
111
 
112
 
113
+
114
+ region = 'us-east-1'
115
+ bedrock_runtime_client = boto3.client('bedrock-runtime',region_name=region)
116
+ output = []
117
+ service = 'es'
118
+
119
  st.markdown("""
120
  <style>
121
  [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
 
127
  </style>
128
  """,unsafe_allow_html=True)
129
 
130
+ ################ OpenSearch Py client #####################
131
+
132
+ # credentials = boto3.Session().get_credentials()
133
+ # awsauth = AWSV4SignerAuth(credentials, region, service)
134
+
135
+ # ospy_client = OpenSearch(
136
+ # hosts = [{'host': 'search-opensearchservi-75ucark0bqob-bzk6r6h2t33dlnpgx2pdeg22gi.us-east-1.es.amazonaws.com', 'port': 443}],
137
+ # http_auth = awsauth,
138
+ # use_ssl = True,
139
+ # verify_certs = True,
140
+ # connection_class = RequestsHttpConnection,
141
+ # pool_maxsize = 20
142
+ # )
143
+
144
+ ################# using boto3 credentials ###################
145
+
146
+
147
  credentials = boto3.Session().get_credentials()
148
+ awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, 'us-west-2', service, session_token=credentials.token)
149
  service = 'es'
150
 
151
+
152
+ ################# using boto3 credentials ####################
153
+
154
+
155
+
156
+ # if "input_searchType" not in st.session_state:
157
+ # st.session_state.input_searchType = "Conversational Search (RAG)"
158
+
159
+ # if "input_temperature" not in st.session_state:
160
+ # st.session_state.input_temperature = "0.001"
161
+
162
+ # if "input_topK" not in st.session_state:
163
+ # st.session_state.input_topK = 200
164
+
165
+ # if "input_topP" not in st.session_state:
166
+ # st.session_state.input_topP = 0.95
167
+
168
+ # if "input_maxTokens" not in st.session_state:
169
+ # st.session_state.input_maxTokens = 1024
170
+
171
+
172
  def write_logo():
173
  col1, col2, col3 = st.columns([5, 1, 5])
174
  with col2:
 
196
  st.session_state.questions_ = []
197
  st.session_state.answers_ = []
198
  st.session_state.input_query=""
199
+ # st.session_state.input_searchType="Conversational Search (RAG)"
200
+ # st.session_state.input_temperature = "0.001"
201
+ # st.session_state.input_topK = 200
202
+ # st.session_state.input_topP = 0.95
203
+ # st.session_state.input_maxTokens = 1024
204
 
205
 
206
  def handle_input():
207
+ # st.session_state.answer_ready = True
208
+ # st.session_state.show_columns = False # reset column display
209
  print("Question: "+st.session_state.input_query)
210
  print("-----------")
211
  print("\n\n")
 
216
  if key.startswith('input_'):
217
  inputs[key.removeprefix('input_')] = st.session_state[key]
218
  st.session_state.inputs_ = inputs
219
+
220
+ #######
221
+
222
+
223
+ #st.write(inputs)
224
  question_with_id = {
225
  'question': inputs["query"],
226
  'id': len(st.session_state.questions_)
 
228
  st.session_state.questions_.append(question_with_id)
229
  if(st.session_state.input_is_colpali):
230
  out_ = colpali.colpali_search_rerank(st.session_state.input_query)
231
+ #print(out_)
232
  else:
233
  out_ = rag_DocumentSearcher.query_(awsauth, inputs, st.session_state['session_id'],st.session_state.input_rag_searchType)
234
  st.session_state.answers_.append({
 
239
  'table':out_['table']
240
  })
241
  st.session_state.input_query=""
242
+
243
+
244
+
245
+ # search_type = st.selectbox('Select the Search type',
246
+ # ('Conversational Search (RAG)',
247
+ # 'OpenSearch vector search',
248
+ # 'LLM Text Generation'
249
+ # ),
250
+
251
+ # key = 'input_searchType',
252
+ # help = "Select the type of retriever\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers_)"
253
+ # )
254
 
255
+ # col1, col2, col3, col4 = st.columns(4)
256
+
257
+ # with col1:
258
+ # st.text_input('Temperature', value = "0.001", placeholder='LLM Temperature', key = 'input_temperature',help = "Set the temperature of the Large Language model. \n Note: 1. Set this to values lower to 1 in the order of 0.001, 0.0001, such low values reduces hallucination and creativity in the LLM response; 2. This applies only when LLM is a part of the retriever pipeline")
259
+ # with col2:
260
+ # st.number_input('Top K', value = 200, placeholder='Top K', key = 'input_topK', step = 50, help = "This limits the LLM's predictions to the top k most probable tokens at each step of generation, this applies only when LLM is a prt of the retriever pipeline")
261
+ # with col3:
262
+ # st.number_input('Top P', value = 0.95, placeholder='Top P', key = 'input_topP', step = 0.05, help = "This sets a threshold probability and selects the top tokens whose cumulative probability exceeds the threshold while the tokens are generated by the LLM")
263
+ # with col4:
264
+ # st.number_input('Max Output Tokens', value = 500, placeholder='Max Output Tokens', key = 'input_maxTokens', step = 100, help = "This decides the total number of tokens generated as the final response. Note: Values greater than 1000 takes longer response time")
265
 
266
+ # st.markdown('---')
267
+
268
+
269
+ def write_user_message(md):
270
+ col1, col2 = st.columns([3,97])
271
+
272
+ with col1:
273
+ st.image(USER_ICON, use_column_width='always')
274
+ with col2:
275
+ #st.warning(md['question'])
276
+
277
+ st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
278
+
279
+
280
+
281
+ def render_answer(question,answer,index,res_img):
282
+
283
+
284
+ col1, col2, col_3 = st.columns([4,74,22])
285
+ with col1:
286
+ st.image(AI_ICON, use_column_width='always')
287
+ with col2:
288
+ ans_ = answer['answer']
289
+ st.write(ans_)
290
+
291
+
292
 
293
+ # def stream_():
294
+ # #use for streaming response on the client side
295
+ # for word in ans_.split(" "):
296
+ # yield word + " "
297
+ # time.sleep(0.04)
298
+ # #use for streaming response from Llm directly
299
+ # if(isinstance(ans_,botocore.eventstream.EventStream)):
300
+ # for event in ans_:
301
+ # chunk = event.get('chunk')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
+ # if chunk:
304
+
305
+ # chunk_obj = json.loads(chunk.get('bytes').decode())
306
+
307
+ # if('content_block' in chunk_obj or ('delta' in chunk_obj and 'text' in chunk_obj['delta'])):
308
+ # key_ = list(chunk_obj.keys())[2]
309
+ # text = chunk_obj[key_]['text']
 
 
 
 
 
310
 
311
+ # clear_output(wait=True)
312
+ # output.append(text)
313
+ # yield text
314
+ # time.sleep(0.04)
315
+
316
+
317
+
318
+ # if(index == len(st.session_state.questions_)):
319
+ # st.write_stream(stream_)
320
+ # if(isinstance(st.session_state.answers_[index-1]['answer'],botocore.eventstream.EventStream)):
321
+ # st.session_state.answers_[index-1]['answer'] = "".join(output)
322
+ # else:
323
+ # st.write(ans_)
324
+
325
+
326
+ polly_response = polly_client.synthesize_speech(VoiceId='Joanna',
327
+ OutputFormat='ogg_vorbis',
328
+ Text = ans_,
329
+ Engine = 'neural')
330
+
331
+ audio_col1, audio_col2 = st.columns([50,50])
332
+ with audio_col1:
333
+ st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
334
+
335
+ rdn_key_1 = ''.join([random.choice(string.ascii_letters)
336
+ for _ in range(10)])
337
+ def show_maxsim():
338
+ st.session_state.show_columns = True
339
+ st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
340
+ handle_input()
341
+ with placeholder.container():
342
+ render_all()
343
+ if(st.session_state.input_is_colpali):
344
+ st.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
345
+
346
+
347
+
348
+ #st.markdown("<div style='font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;border-radius: 10px;'>"+ans_+"</div>", unsafe_allow_html = True)
349
+ #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Relevant images from the document :</b></div>", unsafe_allow_html = True)
350
+ #st.write("")
351
+ colu1,colu2,colu3 = st.columns([4,82,20])
352
+ with colu2:
353
+ with st.expander("Relevant Sources:"):
354
+ with st.container():
355
+ if(len(res_img)>0):
356
+ with st.expander("Images:"):
357
+
358
  idx = 0
359
  print(res_img)
360
  for i in range(0,len(res_img)):
 
375
  col3_,col4_,col5_ = st.columns([33,33,33])
376
  with col3_:
377
  st.image(res_img[i]['file'])
378
+
379
+
380
+
381
+
382
+
383
  else:
384
+ if(res_img[i]['file'].lower()!='none' and idx < 2):
385
  col3,col4,col5 = st.columns([33,33,33])
386
  cols = [col3,col4]
387
  img = res_img[i]['file'].split(".")[0]
 
390
  with cols[idx]:
391
 
392
  st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
393
+ #st.write(caption)
394
  idx = idx+1
395
  if(st.session_state.show_columns == True):
396
  st.session_state.show_columns = False
397
+ #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Sources from the document:</b></div>", unsafe_allow_html = True)
398
+ if(len(answer["table"] )>0):
399
+ with st.expander("Table:"):
400
+ df = pd.read_csv(answer["table"][0]['name'],skipinitialspace = True, on_bad_lines='skip',delimiter='`')
401
+ df.fillna(method='pad', inplace=True)
402
  st.table(df)
403
+ with st.expander("Raw sources:"):
404
  st.write(answer["source"])
405
+
406
+
407
+
408
+ with col_3:
409
+
410
+ #st.markdown("<div style='color:#e28743;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 5px;'><b>"+",".join(st.session_state.input_rag_searchType)+"</b></div>", unsafe_allow_html = True)
411
 
412
+
413
+
414
+ if(index == len(st.session_state.questions_)):
415
+
416
+ rdn_key = ''.join([random.choice(string.ascii_letters)
417
+ for _ in range(10)])
418
+ # rdn_key_1 = ''.join([random.choice(string.ascii_letters)
419
+ # for _ in range(10)])
420
+ currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
421
+ oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
422
+ #print("changing values-----------------")
423
+ def on_button_click():
424
+ if(currentValue!=oldValue or 1==1):
425
+ st.session_state.input_query = st.session_state.questions_[-1]["question"]
426
+ st.session_state.answers_.pop()
427
+ st.session_state.questions_.pop()
428
 
429
+
430
+ # def show_maxsim():
431
+ # st.session_state.show_columns = True
432
+ # st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
433
+ # handle_input()
434
+ # with placeholder.container():
435
+ # render_all()
436
+
437
+
438
+ if("currentValue" in st.session_state):
439
+ del st.session_state["currentValue"]
440
+
441
+ try:
442
+ del regenerate
443
+ except:
444
+ pass
445
+
446
+ placeholder__ = st.empty()
447
+ placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
448
+ #placeholder__.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
449
 
450
+
451
+
452
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
453
+ def write_chat_message(md, q,index):
454
  if(st.session_state.show_columns):
455
  res_img = st.session_state.maxSimImages
456
  else:
457
  res_img = md['image']
458
+ #st.session_state['session_id'] = res['session_id'] to be added in memory
459
+ chat = st.container()
460
+ with chat:
461
+ #print("st.session_state.input_index------------------")
462
+ #print(st.session_state.input_index)
463
+ render_answer(q,md,index,res_img)
464
 
465
+ def render_all():
466
  index = 0
467
  for (q, a) in zip(st.session_state.questions_, st.session_state.answers_):
468
  index = index +1
469
 
470
+ write_user_message(q)
471
+ write_chat_message(a, q,index)
472
 
473
  placeholder = st.empty()
474
+ with placeholder.container():
475
+ render_all()
476
 
 
 
477
  st.markdown("")
478
+ col_2, col_3 = st.columns([75,20])
479
+ #col_1, col_2, col_3 = st.columns([7.5,71.5,22])
480
+ # with col_1:
481
+ # st.markdown("<p style='padding:0px 0px 0px 0px; color:#FF9900;font-size:120%'><b>Ask:</b></p>",unsafe_allow_html=True, help = 'Enter the questions and click on "GO"')
482
+
483
  with col_2:
484
+ #st.markdown("")
485
+ input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
486
  with col_3:
487
+ #hidden = st.button("RUN",disabled=True,key = "hidden")
488
+ play = st.button("Go",on_click=handle_input,key = "play")
 
 
489
  with st.sidebar:
490
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
491
  st.subheader(":blue[Sample Data]")
492
  coln_1,coln_2 = st.columns([70,30])
493
+ # index_select = st.radio("Choose one index",["UK Housing","Covid19 impacts on Ireland","Environmental Global Warming","BEIR Research"],
494
+ # captions = ['[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)',
495
+ # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)',
496
+ # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)',
497
+ # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/BEIR.pdf)'],
498
+ # key="input_rad_index")
499
  with coln_1:
500
+ index_select = st.radio("Choose one index",["UK Housing","Global Warming stats","Covid19 impacts on Ireland"],key="input_rad_index")
501
  with coln_2:
502
  st.markdown("<p style='font-size:15px'>Preview file</p>",unsafe_allow_html=True)
 
503
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)")
504
+ st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)")
505
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)")
506
+ #st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/BEIR.pdf)")
507
  st.markdown("""
508
  <style>
509
  [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
 
515
  </style>
516
  """,unsafe_allow_html=True)
517
  with st.expander("Sample questions:"):
 
518
  st.markdown("<span style = 'color:#FF9900;'>UK Housing</span> - which city has the highest average housing price in UK ?",unsafe_allow_html=True)
519
+ st.markdown("<span style = 'color:#FF9900;'>Global Warming stats</span> - What is the projected energy percentage from renewable sources in future?",unsafe_allow_html=True)
520
  st.markdown("<span style = 'color:#FF9900;'>Covid19 impacts</span> - How many aged above 85 years died due to covid ?",unsafe_allow_html=True)
521
 
522
+ # Initialize boto3 to use the S3 client.
523
+ s3_client = boto3.resource('s3')
524
+ bucket=s3_client.Bucket(s3_bucket_)
525
+
526
+ objects = bucket.objects.filter(Prefix="sample_pdfs/")
527
+ urls = []
528
+
529
+ client = boto3.client('s3')
530
+
531
+ for obj in objects:
532
+ if obj.key.endswith('.pdf'):
533
+
534
+ # Generate the S3 presigned URL
535
+ s3_presigned_url = client.generate_presigned_url(
536
+ ClientMethod='get_object',
537
+ Params={
538
+ 'Bucket': s3_bucket_,
539
+ 'Key': obj.key
540
+ },
541
+ ExpiresIn=3600
542
+ )
543
+
544
+ # Print the created S3 presigned URL
545
+ print(s3_presigned_url)
546
+ urls.append(s3_presigned_url)
547
+ #st.write("["+obj.key.split('/')[1]+"]("+s3_presigned_url+")")
548
+ st.link_button(obj.key.split('/')[1], s3_presigned_url)
549
+
550
+
551
+ # st.subheader(":blue[Your multi-modal documents]")
552
+ # pdf_doc_ = st.file_uploader(
553
+ # "Upload your PDFs here and click on 'Process'", accept_multiple_files=False)
554
+
555
+
556
+ # pdf_docs = [pdf_doc_]
557
+ # if st.button("Process"):
558
+ # with st.spinner("Processing"):
559
+ # if os.path.isdir(parent_dirname+"/pdfs") == False:
560
+ # os.mkdir(parent_dirname+"/pdfs")
561
+
562
+ # for pdf_doc in pdf_docs:
563
+ # print(type(pdf_doc))
564
+ # pdf_doc_name = (pdf_doc.name).replace(" ","_")
565
+ # with open(os.path.join(parent_dirname+"/pdfs",pdf_doc_name),"wb") as f:
566
+ # f.write(pdf_doc.getbuffer())
567
+
568
+ # request_ = { "bucket": s3_bucket_,"key": pdf_doc_name}
569
+ # # if(st.session_state.input_copali_rerank):
570
+ # # copali.process_doc(request_)
571
+ # # else:
572
+ # rag_DocumentLoader.load_docs(request_)
573
+ # print('lambda done')
574
+ # st.success('you can start searching on your PDF')
575
+
576
+ ############## haystach demo temporary addition ############
577
+ # st.subheader(":blue[Multimodality]")
578
+ # colu1,colu2 = st.columns([50,50])
579
+ # with colu1:
580
+ # in_images = st.toggle('Images', key = 'in_images', disabled = False)
581
+ # with colu2:
582
+ # in_tables = st.toggle('Tables', key = 'in_tables', disabled = False)
583
+ # if(in_tables):
584
+ # st.session_state.input_table_with_sql = True
585
+ # else:
586
+ # st.session_state.input_table_with_sql = False
587
 
588
  ############## haystach demo temporary addition ############
589
  #if(pdf_doc_ is None or pdf_doc_ == ""):
590
  if(index_select == "Global Warming stats"):
591
+ st.session_state.input_index = "globalwarmingnew"
592
  if(index_select == "Covid19 impacts on Ireland"):
593
  st.session_state.input_index = "covid19ie"#"choosetheknnalgorithmforyourbillionscaleusecasewithopensearchawsbigdatablog"
594
  if(index_select == "BEIR"):
595
  st.session_state.input_index = "2104"
596
  if(index_select == "UK Housing"):
597
  st.session_state.input_index = "hpijan2024hometrack"
598
+
599
+ # custom_index = st.text_input("If uploaded the file already, enter the original file name", value = "")
600
+ # if(custom_index!=""):
601
+ # st.session_state.input_index = re.sub('[^A-Za-z0-9]+', '', (custom_index.lower().replace(".pdf","").split("/")[-1].split(".")[0]).lower())
602
+
603
+
604
 
605
  st.subheader(":blue[Retriever]")
606
  search_type = st.multiselect('Select the Retriever(s)',
 
623
 
624
  st.subheader(":blue[Multi-vector retrieval]")
625
 
626
+ #st.write("Dataset indexed: https://huggingface.co/datasets/vespa-engine/gpfg-QA")
627
  colpali_search_rerank = st.checkbox('Try Colpali multi-vector retrieval on the [sample dataset](https://huggingface.co/datasets/vespa-engine/gpfg-QA)', key = 'input_colpali', disabled = False, value = False, help = "Checking this box will use colpali as the embedding model and retrieval is performed using multi-vectors followed by re-ranking using MaxSim")
628
 
629
  if(colpali_search_rerank):
 
634
 
635
  with st.expander("Sample questions for Colpali retriever:"):
636
  st.write("1. Proportion of female new hires 2021-2023? \n\n 2. First-half 2021 return on unlisted real estate investments? \n\n 3. Trend of the fund's expected absolute volatility between January 2014 and January 2016? \n\n 4. Fund return percentage in 2017? \n\n 5. Annualized gross return of the fund from 1997 to 2008?")
 
637
 
638
+
639
+ # copali_rerank = st.checkbox("Search and Re-rank with Token level vectors",key = 'copali_rerank',help = "Enabling this option uses 'Copali' model's page level image embeddings to retrieve documents and MaxSim to re-rank the pages.\n\n Hugging Face Model: https://huggingface.co/vidore/colpali")
640
+
641
+ # if(copali_rerank):
642
+ # st.session_state.input_copali_rerank = True
643
+ # else:
644
+ # st.session_state.input_copali_rerank = False
645
+
646
+
647
+
648
+
649
+