prasadnu commited on
Commit
1bc0fd8
·
1 Parent(s): dac1d4e

RAG changes

Browse files
pages/Multimodal_Conversational_Search.py CHANGED
@@ -15,17 +15,14 @@ import random
15
  import string
16
  import rag_DocumentLoader
17
  import rag_DocumentSearcher
18
- #import colpali
19
  import pandas as pd
20
  from PIL import Image
21
  import shutil
22
  import base64
23
  import time
24
  import botocore
25
- #from langchain.callbacks.base import BaseCallbackHandler
26
- #from IPython.display import clear_output, display, display_markdown, Markdown
27
  from requests_aws4auth import AWS4Auth
28
- import colpali
29
  from requests.auth import HTTPBasicAuth
30
 
31
 
@@ -41,8 +38,8 @@ AI_ICON = "images/opensearch-twitter-card.png"
41
  REGENERATE_ICON = "images/regenerate.png"
42
  s3_bucket_ = "pdf-repo-uploads"
43
  #"pdf-repo-uploads"
44
- polly_client = boto3.Session(
45
- region_name='us-east-1').client('polly')
46
 
47
  # Check if the user ID is already stored in the session state
48
  if 'user_id' in st.session_state:
@@ -69,13 +66,9 @@ if "chats" not in st.session_state:
69
 
70
  if "questions_" not in st.session_state:
71
  st.session_state.questions_ = []
72
-
73
-
74
  if "show_columns" not in st.session_state:
75
  st.session_state.show_columns = False
76
-
77
- if "answer_ready" not in st.session_state:
78
- st.session_state.answer_ready = False
79
 
80
  if "answers_" not in st.session_state:
81
  st.session_state.answers_ = []
@@ -85,7 +78,7 @@ if "input_index" not in st.session_state:
85
 
86
  if "input_is_rerank" not in st.session_state:
87
  st.session_state.input_is_rerank = True
88
-
89
  if "input_is_colpali" not in st.session_state:
90
  st.session_state.input_is_colpali = False
91
 
@@ -98,10 +91,8 @@ if "input_table_with_sql" not in st.session_state:
98
  if "input_query" not in st.session_state:
99
  st.session_state.input_query="which city has the highest average housing price in UK ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
100
 
101
-
102
  if "input_rag_searchType" not in st.session_state:
103
- st.session_state.input_rag_searchType = ["Vector Search"]
104
-
105
 
106
 
107
 
@@ -139,7 +130,7 @@ st.markdown("""
139
 
140
 
141
  credentials = boto3.Session().get_credentials()
142
- awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, 'us-west-2', service, session_token=credentials.token)
143
  service = 'es'
144
 
145
 
@@ -166,7 +157,7 @@ service = 'es'
166
  def write_logo():
167
  col1, col2, col3 = st.columns([5, 1, 5])
168
  with col2:
169
- st.image(AI_ICON, use_column_width='always')
170
 
171
  def write_top_bar():
172
  col1, col2 = st.columns([77,23])
@@ -174,7 +165,7 @@ def write_top_bar():
174
  st.write("")
175
  st.header("Chat with your data",divider='rainbow')
176
 
177
- #st.image(AI_ICON, use_column_width='always')
178
 
179
  with col2:
180
  st.write("")
@@ -198,8 +189,6 @@ if clear:
198
 
199
 
200
  def handle_input():
201
- # st.session_state.answer_ready = True
202
- # st.session_state.show_columns = False # reset column display
203
  print("Question: "+st.session_state.input_query)
204
  print("-----------")
205
  print("\n\n")
@@ -264,7 +253,7 @@ def write_user_message(md):
264
  col1, col2 = st.columns([3,97])
265
 
266
  with col1:
267
- st.image(USER_ICON, use_column_width='always')
268
  with col2:
269
  #st.warning(md['question'])
270
 
@@ -277,7 +266,7 @@ def render_answer(question,answer,index,res_img):
277
 
278
  col1, col2, col_3 = st.columns([4,74,22])
279
  with col1:
280
- st.image(AI_ICON, use_column_width='always')
281
  with col2:
282
  ans_ = answer['answer']
283
  st.write(ans_)
@@ -387,7 +376,6 @@ def render_answer(question,answer,index,res_img):
387
  st.write(answer["source"])
388
 
389
 
390
-
391
  with col_3:
392
 
393
  #st.markdown("<div style='color:#e28743;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 5px;'><b>"+",".join(st.session_state.input_rag_searchType)+"</b></div>", unsafe_allow_html = True)
@@ -398,26 +386,28 @@ def render_answer(question,answer,index,res_img):
398
 
399
  rdn_key = ''.join([random.choice(string.ascii_letters)
400
  for _ in range(10)])
401
- rdn_key_1 = ''.join([random.choice(string.ascii_letters)
402
- for _ in range(10)])
403
  currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
404
  oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
405
  #print("changing values-----------------")
406
  def on_button_click():
 
 
 
407
  if(currentValue!=oldValue or 1==1):
 
408
  st.session_state.input_query = st.session_state.questions_[-1]["question"]
409
  st.session_state.answers_.pop()
410
  st.session_state.questions_.pop()
411
 
412
-
 
 
413
  def show_maxsim():
414
  st.session_state.show_columns = True
415
  st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
416
  handle_input()
417
  with placeholder.container():
418
  render_all()
419
-
420
-
421
  if("currentValue" in st.session_state):
422
  del st.session_state["currentValue"]
423
 
@@ -426,18 +416,17 @@ def render_answer(question,answer,index,res_img):
426
  except:
427
  pass
428
 
 
 
 
429
  placeholder__ = st.empty()
 
430
  placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
431
  placeholder__.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
432
 
433
-
434
-
435
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
436
  def write_chat_message(md, q,index):
437
- if(st.session_state.show_columns):
438
- res_img = st.session_state.maxSimImages
439
- else:
440
- res_img = md['image']
441
  #st.session_state['session_id'] = res['session_id'] to be added in memory
442
  chat = st.container()
443
  with chat:
@@ -468,7 +457,7 @@ with col_2:
468
  input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
469
  with col_3:
470
  #hidden = st.button("RUN",disabled=True,key = "hidden")
471
- play = st.button("Go",on_click=handle_input,key = "play")
472
  with st.sidebar:
473
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
474
  st.subheader(":blue[Sample Data]")
@@ -627,6 +616,4 @@ with st.sidebar:
627
  # st.session_state.input_copali_rerank = False
628
 
629
 
630
-
631
-
632
-
 
15
  import string
16
  import rag_DocumentLoader
17
  import rag_DocumentSearcher
 
18
  import pandas as pd
19
  from PIL import Image
20
  import shutil
21
  import base64
22
  import time
23
  import botocore
 
 
24
  from requests_aws4auth import AWS4Auth
25
+ import copali
26
  from requests.auth import HTTPBasicAuth
27
 
28
 
 
38
  REGENERATE_ICON = "images/regenerate.png"
39
  s3_bucket_ = "pdf-repo-uploads"
40
  #"pdf-repo-uploads"
41
+ polly_client = boto3.client('polly',aws_access_key_id=st.secrets['user_access_key'],
42
+ aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
43
 
44
  # Check if the user ID is already stored in the session state
45
  if 'user_id' in st.session_state:
 
66
 
67
  if "questions_" not in st.session_state:
68
  st.session_state.questions_ = []
69
+
 
70
  if "show_columns" not in st.session_state:
71
  st.session_state.show_columns = False
 
 
 
72
 
73
  if "answers_" not in st.session_state:
74
  st.session_state.answers_ = []
 
78
 
79
  if "input_is_rerank" not in st.session_state:
80
  st.session_state.input_is_rerank = True
81
+
82
  if "input_is_colpali" not in st.session_state:
83
  st.session_state.input_is_colpali = False
84
 
 
91
  if "input_query" not in st.session_state:
92
  st.session_state.input_query="which city has the highest average housing price in UK ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
93
 
 
94
  if "input_rag_searchType" not in st.session_state:
95
+ st.session_state.input_rag_searchType = ["Vector Search"]
 
96
 
97
 
98
 
 
130
 
131
 
132
  credentials = boto3.Session().get_credentials()
133
+ awsauth = HTTPBasicAuth('prasadnu',st.secrets['rag_shopping_assistant_os_api_access'])
134
  service = 'es'
135
 
136
 
 
157
  def write_logo():
158
  col1, col2, col3 = st.columns([5, 1, 5])
159
  with col2:
160
+ st.image(AI_ICON, use_container_width='always')
161
 
162
  def write_top_bar():
163
  col1, col2 = st.columns([77,23])
 
165
  st.write("")
166
  st.header("Chat with your data",divider='rainbow')
167
 
168
+ #st.image(AI_ICON, use_container_width='always')
169
 
170
  with col2:
171
  st.write("")
 
189
 
190
 
191
  def handle_input():
 
 
192
  print("Question: "+st.session_state.input_query)
193
  print("-----------")
194
  print("\n\n")
 
253
  col1, col2 = st.columns([3,97])
254
 
255
  with col1:
256
+ st.image(USER_ICON, use_container_width='always')
257
  with col2:
258
  #st.warning(md['question'])
259
 
 
266
 
267
  col1, col2, col_3 = st.columns([4,74,22])
268
  with col1:
269
+ st.image(AI_ICON, use_container_width='always')
270
  with col2:
271
  ans_ = answer['answer']
272
  st.write(ans_)
 
376
  st.write(answer["source"])
377
 
378
 
 
379
  with col_3:
380
 
381
  #st.markdown("<div style='color:#e28743;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 5px;'><b>"+",".join(st.session_state.input_rag_searchType)+"</b></div>", unsafe_allow_html = True)
 
386
 
387
  rdn_key = ''.join([random.choice(string.ascii_letters)
388
  for _ in range(10)])
 
 
389
  currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
390
  oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
391
  #print("changing values-----------------")
392
  def on_button_click():
393
+ # print("button clicked---------------")
394
+ # print(currentValue)
395
+ # print(oldValue)
396
  if(currentValue!=oldValue or 1==1):
397
+ #print("----------regenerate----------------")
398
  st.session_state.input_query = st.session_state.questions_[-1]["question"]
399
  st.session_state.answers_.pop()
400
  st.session_state.questions_.pop()
401
 
402
+ handle_input()
403
+ with placeholder.container():
404
+ render_all()
405
  def show_maxsim():
406
  st.session_state.show_columns = True
407
  st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
408
  handle_input()
409
  with placeholder.container():
410
  render_all()
 
 
411
  if("currentValue" in st.session_state):
412
  del st.session_state["currentValue"]
413
 
 
416
  except:
417
  pass
418
 
419
+ #print("------------------------")
420
+ #print(st.session_state)
421
+
422
  placeholder__ = st.empty()
423
+
424
  placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
425
  placeholder__.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
426
 
 
 
427
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
428
  def write_chat_message(md, q,index):
429
+ res_img = md['image']
 
 
 
430
  #st.session_state['session_id'] = res['session_id'] to be added in memory
431
  chat = st.container()
432
  with chat:
 
457
  input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
458
  with col_3:
459
  #hidden = st.button("RUN",disabled=True,key = "hidden")
460
+ play = st.button("GO",on_click=handle_input,key = "play")
461
  with st.sidebar:
462
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
463
  st.subheader(":blue[Sample Data]")
 
616
  # st.session_state.input_copali_rerank = False
617
 
618
 
619
+