prasadnu commited on
Commit
911c8fd
·
1 Parent(s): 10cea74

colpali fix

Browse files
pages/Multimodal_Conversational_Search.py CHANGED
@@ -15,16 +15,12 @@ import random
15
  import string
16
  import rag_DocumentLoader
17
  import rag_DocumentSearcher
18
- #import colpali
19
  import pandas as pd
20
  from PIL import Image
21
  import shutil
22
  import base64
23
  import time
24
  import botocore
25
- #from langchain.callbacks.base import BaseCallbackHandler
26
- #import streamlit_nested_layout
27
- #from IPython.display import clear_output, display, display_markdown, Markdown
28
  from requests_aws4auth import AWS4Auth
29
  import colpali
30
  from requests.auth import HTTPBasicAuth
@@ -41,14 +37,14 @@ USER_ICON = "images/user.png"
41
  AI_ICON = "images/opensearch-twitter-card.png"
42
  REGENERATE_ICON = "images/regenerate.png"
43
  s3_bucket_ = "pdf-repo-uploads"
44
- #"pdf-repo-uploads"
45
  # polly_client = boto3.Session(
46
  # region_name='us-east-1').client('polly')
47
 
48
  # Check if the user ID is already stored in the session state
49
  if 'user_id' in st.session_state:
50
  user_id = st.session_state['user_id']
51
- #print(f"User ID: {user_id}")
52
 
53
  # If the user ID is not yet stored in the session state, generate a random UUID
54
  else:
@@ -105,10 +101,6 @@ if "input_query" not in st.session_state:
105
  st.session_state.input_query="How many aged above 85 years died due to covid ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
106
 
107
 
108
-
109
-
110
-
111
-
112
  st.markdown("""
113
  <style>
114
  [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
@@ -120,43 +112,11 @@ st.markdown("""
120
  </style>
121
  """,unsafe_allow_html=True)
122
 
123
- ################ OpenSearch Py client #####################
124
-
125
- # credentials = boto3.Session().get_credentials()
126
- # awsauth = AWSV4SignerAuth(credentials, region, service)
127
-
128
- # ospy_client = OpenSearch(
129
- # hosts = [{'host': 'search-opensearchservi-75ucark0bqob-bzk6r6h2t33dlnpgx2pdeg22gi.us-east-1.es.amazonaws.com', 'port': 443}],
130
- # http_auth = awsauth,
131
- # use_ssl = True,
132
- # verify_certs = True,
133
- # connection_class = RequestsHttpConnection,
134
- # pool_maxsize = 20
135
- # )
136
-
137
  ################# using boto3 credentials ###################
138
 
139
  awsauth = HTTPBasicAuth('master',st.secrets['ml_search_demo_api_access'])
140
 
141
 
142
- ################# using boto3 credentials ####################
143
-
144
-
145
-
146
- # if "input_searchType" not in st.session_state:
147
- # st.session_state.input_searchType = "Conversational Search (RAG)"
148
-
149
- # if "input_temperature" not in st.session_state:
150
- # st.session_state.input_temperature = "0.001"
151
-
152
- # if "input_topK" not in st.session_state:
153
- # st.session_state.input_topK = 200
154
-
155
- # if "input_topP" not in st.session_state:
156
- # st.session_state.input_topP = 0.95
157
-
158
- # if "input_maxTokens" not in st.session_state:
159
- # st.session_state.input_maxTokens = 1024
160
 
161
 
162
  def write_logo():
@@ -186,20 +146,14 @@ if clear:
186
  st.session_state.questions_ = []
187
  st.session_state.answers_ = []
188
  st.session_state.input_query=""
189
- # st.session_state.input_searchType="Conversational Search (RAG)"
190
- # st.session_state.input_temperature = "0.001"
191
- # st.session_state.input_topK = 200
192
- # st.session_state.input_topP = 0.95
193
- # st.session_state.input_maxTokens = 1024
194
 
195
 
196
  def handle_input(state,dummy):
197
  if(state == 'colpali_show_similarity_map'):
198
  st.session_state.show_columns = True
199
- # st.session_state.answer_ready = True
200
- # st.session_state.show_columns = False # reset column display
201
  print("Question: "+st.session_state.input_query)
202
- print("-----------")
203
  print("\n\n")
204
  # if(st.session_state.input_query==''):
205
  # return ""
@@ -234,8 +188,6 @@ def write_user_message(md):
234
  with col1:
235
  st.image(USER_ICON, use_column_width='always')
236
  with col2:
237
- #st.warning(md['question'])
238
-
239
  st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
240
 
241
 
@@ -251,25 +203,12 @@ def render_answer(question,answer,index,res_img):
251
  st.write(ans_)
252
  rdn_key_1 = ''.join([random.choice(string.ascii_letters)
253
  for _ in range(10)])
254
- # def show_maxsim():
255
- # st.session_state.show_columns = True
256
- # # st.session_state.input_query = st.session_state.questions_[-1]["question"]
257
- # # st.session_state.answers_.pop()
258
- # # st.session_state.questions_.pop()
259
- # handle_input()
260
- # print("*"*20)
261
- # print(st.session_state.input_query)
262
- # print(st.session_state.answers_)
263
- # print(st.session_state.questions_)
264
- # print("*"*20)
265
- # with placeholder.container():
266
- # render_all()
267
 
268
  if(st.session_state.input_is_colpali):
269
  placeholder__ = st.empty()
270
  placeholder__.button("Show similarity map",key=rdn_key_1,on_click=handle_input,args=('colpali_show_similarity_map',True))
271
- # with placeholder.container():
272
- # render_all()
273
 
274
  colu1,colu2,colu3 = st.columns([4,82,20])
275
  with colu2:
@@ -320,19 +259,13 @@ def render_answer(question,answer,index,res_img):
320
  st.table(df)
321
  #with st.expander("Raw sources:"):
322
  st.write(answer["source"])
323
-
324
-
325
-
326
  with col_3:
327
  if(index == len(st.session_state.questions_)):
328
 
329
  rdn_key = ''.join([random.choice(string.ascii_letters)
330
  for _ in range(10)])
331
- # rdn_key_1 = ''.join([random.choice(string.ascii_letters)
332
- # for _ in range(10)])
333
  currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
334
  oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
335
- #print("changing values-----------------")
336
  def on_button_click():
337
  if(currentValue!=oldValue or 1==1):
338
  st.session_state.input_query = st.session_state.questions_[-1]["question"]
@@ -341,14 +274,7 @@ def render_answer(question,answer,index,res_img):
341
  handle_input("regenerate_",None)
342
  with placeholder.container():
343
  render_all()
344
-
345
-
346
- # def show_maxsim():
347
- # st.session_state.show_columns = True
348
- # st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
349
- # handle_input()
350
- # with placeholder.container():
351
- # render_all()
352
 
353
 
354
  if("currentValue" in st.session_state):
@@ -361,8 +287,7 @@ def render_answer(question,answer,index,res_img):
361
 
362
  placeholder__ = st.empty()
363
  placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
364
- #placeholder__.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
365
-
366
 
367
 
368
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
@@ -371,8 +296,6 @@ def write_chat_message(md, q,index):
371
  #st.session_state['session_id'] = res['session_id'] to be added in memory
372
  chat = st.container()
373
  with chat:
374
- #print("st.session_state.input_index------------------")
375
- #print(st.session_state.input_index)
376
  render_answer(q,md,index,res_img)
377
 
378
  def render_all():
@@ -389,12 +312,8 @@ with placeholder.container():
389
 
390
  st.markdown("")
391
  col_2, col_3 = st.columns([75,20])
392
- #col_1, col_2, col_3 = st.columns([7.5,71.5,22])
393
- # with col_1:
394
- # st.markdown("<p style='padding:0px 0px 0px 0px; color:#FF9900;font-size:120%'><b>Ask:</b></p>",unsafe_allow_html=True, help = 'Enter the questions and click on "GO"')
395
 
396
  with col_2:
397
- #st.markdown("")
398
  input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
399
  with col_3:
400
  #hidden = st.button("RUN",disabled=True,key = "hidden")
@@ -403,12 +322,6 @@ with st.sidebar:
403
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
404
  st.subheader(":blue[Sample Data]")
405
  coln_1,coln_2 = st.columns([70,30])
406
- # index_select = st.radio("Choose one index",["UK Housing","Covid19 impacts on Ireland","Environmental Global Warming","BEIR Research"],
407
- # captions = ['[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)',
408
- # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)',
409
- # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)',
410
- # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/BEIR.pdf)'],
411
- # key="input_rad_index")
412
  with coln_1:
413
  index_select = st.radio("Choose one index",["UK Housing","Global Warming stats","Covid19 impacts on Ireland"],key="input_rad_index")
414
  with coln_2:
@@ -416,7 +329,6 @@ with st.sidebar:
416
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)")
417
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)")
418
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)")
419
- #st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/BEIR.pdf)")
420
  st.markdown("""
421
  <style>
422
  [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
@@ -435,41 +347,29 @@ with st.sidebar:
435
 
436
 
437
  # st.subheader(":blue[Your multi-modal documents]")
438
- pdf_doc_ = st.file_uploader(
439
- "Upload your PDFs here and click on 'Process'", accept_multiple_files=False)
440
 
441
 
442
- pdf_docs = [pdf_doc_]
443
- if st.button("Process"):
444
- with st.spinner("Processing"):
445
- if os.path.isdir(parent_dirname+"/pdfs") == False:
446
- os.mkdir(parent_dirname+"/pdfs")
447
 
448
- for pdf_doc in pdf_docs:
449
- print(type(pdf_doc))
450
- pdf_doc_name = (pdf_doc.name).replace(" ","_")
451
- with open(os.path.join(parent_dirname+"/pdfs",pdf_doc_name),"wb") as f:
452
- f.write(pdf_doc.getbuffer())
453
 
454
- request_ = { "bucket": s3_bucket_,"key": pdf_doc_name}
455
- # if(st.session_state.input_copali_rerank):
456
- # copali.process_doc(request_)
457
- # else:
458
- rag_DocumentLoader.load_docs(request_)
459
- print('lambda done')
460
- st.success('you can start searching on your PDF')
461
-
462
- ############## haystach demo temporary addition ############
463
- # st.subheader(":blue[Multimodality]")
464
- # colu1,colu2 = st.columns([50,50])
465
- # with colu1:
466
- # in_images = st.toggle('Images', key = 'in_images', disabled = False)
467
- # with colu2:
468
- # in_tables = st.toggle('Tables', key = 'in_tables', disabled = False)
469
- # if(in_tables):
470
- # st.session_state.input_table_with_sql = True
471
- # else:
472
- # st.session_state.input_table_with_sql = False
473
 
474
  ############## haystach demo temporary addition ############
475
  #if(pdf_doc_ is None or pdf_doc_ == ""):
@@ -509,7 +409,6 @@ with st.sidebar:
509
 
510
  st.subheader(":blue[Multi-vector retrieval]")
511
 
512
- #st.write("Dataset indexed: https://huggingface.co/datasets/vespa-engine/gpfg-QA")
513
  colpali_search_rerank = st.checkbox('Try Colpali multi-vector retrieval on the [sample dataset](https://huggingface.co/datasets/vespa-engine/gpfg-QA)', key = 'input_colpali', disabled = False, value = False, help = "Checking this box will use colpali as the embedding model and retrieval is performed using multi-vectors followed by re-ranking using MaxSim")
514
 
515
  if(colpali_search_rerank):
@@ -524,12 +423,6 @@ with st.sidebar:
524
  st.write("1. Proportion of female new hires 2021-2023? \n\n 2. First-half 2021 return on unlisted real estate investments? \n\n 3. Trend of the fund's expected absolute volatility between January 2014 and January 2016? \n\n 4. Fund return percentage in 2017? \n\n 5. Annualized gross return of the fund from 1997 to 2008?")
525
 
526
 
527
- # copali_rerank = st.checkbox("Search and Re-rank with Token level vectors",key = 'copali_rerank',help = "Enabling this option uses 'Copali' model's page level image embeddings to retrieve documents and MaxSim to re-rank the pages.\n\n Hugging Face Model: https://huggingface.co/vidore/colpali")
528
-
529
- # if(copali_rerank):
530
- # st.session_state.input_copali_rerank = True
531
- # else:
532
- # st.session_state.input_copali_rerank = False
533
 
534
 
535
 
 
15
  import string
16
  import rag_DocumentLoader
17
  import rag_DocumentSearcher
 
18
  import pandas as pd
19
  from PIL import Image
20
  import shutil
21
  import base64
22
  import time
23
  import botocore
 
 
 
24
  from requests_aws4auth import AWS4Auth
25
  import colpali
26
  from requests.auth import HTTPBasicAuth
 
37
  AI_ICON = "images/opensearch-twitter-card.png"
38
  REGENERATE_ICON = "images/regenerate.png"
39
  s3_bucket_ = "pdf-repo-uploads"
40
+
41
  # polly_client = boto3.Session(
42
  # region_name='us-east-1').client('polly')
43
 
44
  # Check if the user ID is already stored in the session state
45
  if 'user_id' in st.session_state:
46
  user_id = st.session_state['user_id']
47
+
48
 
49
  # If the user ID is not yet stored in the session state, generate a random UUID
50
  else:
 
101
  st.session_state.input_query="How many aged above 85 years died due to covid ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
102
 
103
 
 
 
 
 
104
  st.markdown("""
105
  <style>
106
  [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
 
112
  </style>
113
  """,unsafe_allow_html=True)
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  ################# using boto3 credentials ###################
116
 
117
  awsauth = HTTPBasicAuth('master',st.secrets['ml_search_demo_api_access'])
118
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
 
122
  def write_logo():
 
146
  st.session_state.questions_ = []
147
  st.session_state.answers_ = []
148
  st.session_state.input_query=""
149
+
 
 
 
 
150
 
151
 
152
  def handle_input(state,dummy):
153
  if(state == 'colpali_show_similarity_map'):
154
  st.session_state.show_columns = True
 
 
155
  print("Question: "+st.session_state.input_query)
156
+ print("-"*20)
157
  print("\n\n")
158
  # if(st.session_state.input_query==''):
159
  # return ""
 
188
  with col1:
189
  st.image(USER_ICON, use_column_width='always')
190
  with col2:
 
 
191
  st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
192
 
193
 
 
203
  st.write(ans_)
204
  rdn_key_1 = ''.join([random.choice(string.ascii_letters)
205
  for _ in range(10)])
206
+
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
  if(st.session_state.input_is_colpali):
209
  placeholder__ = st.empty()
210
  placeholder__.button("Show similarity map",key=rdn_key_1,on_click=handle_input,args=('colpali_show_similarity_map',True))
211
+
 
212
 
213
  colu1,colu2,colu3 = st.columns([4,82,20])
214
  with colu2:
 
259
  st.table(df)
260
  #with st.expander("Raw sources:"):
261
  st.write(answer["source"])
 
 
 
262
  with col_3:
263
  if(index == len(st.session_state.questions_)):
264
 
265
  rdn_key = ''.join([random.choice(string.ascii_letters)
266
  for _ in range(10)])
 
 
267
  currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
268
  oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
 
269
  def on_button_click():
270
  if(currentValue!=oldValue or 1==1):
271
  st.session_state.input_query = st.session_state.questions_[-1]["question"]
 
274
  handle_input("regenerate_",None)
275
  with placeholder.container():
276
  render_all()
277
+
 
 
 
 
 
 
 
278
 
279
 
280
  if("currentValue" in st.session_state):
 
287
 
288
  placeholder__ = st.empty()
289
  placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
290
+
 
291
 
292
 
293
  #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
 
296
  #st.session_state['session_id'] = res['session_id'] to be added in memory
297
  chat = st.container()
298
  with chat:
 
 
299
  render_answer(q,md,index,res_img)
300
 
301
  def render_all():
 
312
 
313
  st.markdown("")
314
  col_2, col_3 = st.columns([75,20])
 
 
 
315
 
316
  with col_2:
 
317
  input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
318
  with col_3:
319
  #hidden = st.button("RUN",disabled=True,key = "hidden")
 
322
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
323
  st.subheader(":blue[Sample Data]")
324
  coln_1,coln_2 = st.columns([70,30])
 
 
 
 
 
 
325
  with coln_1:
326
  index_select = st.radio("Choose one index",["UK Housing","Global Warming stats","Covid19 impacts on Ireland"],key="input_rad_index")
327
  with coln_2:
 
329
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)")
330
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)")
331
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)")
 
332
  st.markdown("""
333
  <style>
334
  [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
 
347
 
348
 
349
  # st.subheader(":blue[Your multi-modal documents]")
350
+ # pdf_doc_ = st.file_uploader(
351
+ # "Upload your PDFs here and click on 'Process'", accept_multiple_files=False)
352
 
353
 
354
+ # pdf_docs = [pdf_doc_]
355
+ # if st.button("Process"):
356
+ # with st.spinner("Processing"):
357
+ # if os.path.isdir(parent_dirname+"/pdfs") == False:
358
+ # os.mkdir(parent_dirname+"/pdfs")
359
 
360
+ # for pdf_doc in pdf_docs:
361
+ # print(type(pdf_doc))
362
+ # pdf_doc_name = (pdf_doc.name).replace(" ","_")
363
+ # with open(os.path.join(parent_dirname+"/pdfs",pdf_doc_name),"wb") as f:
364
+ # f.write(pdf_doc.getbuffer())
365
 
366
+ # request_ = { "bucket": s3_bucket_,"key": pdf_doc_name}
367
+ # # if(st.session_state.input_copali_rerank):
368
+ # # copali.process_doc(request_)
369
+ # # else:
370
+ # rag_DocumentLoader.load_docs(request_)
371
+ # print('lambda done')
372
+ # st.success('you can start searching on your PDF')
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  ############## haystach demo temporary addition ############
375
  #if(pdf_doc_ is None or pdf_doc_ == ""):
 
409
 
410
  st.subheader(":blue[Multi-vector retrieval]")
411
 
 
412
  colpali_search_rerank = st.checkbox('Try Colpali multi-vector retrieval on the [sample dataset](https://huggingface.co/datasets/vespa-engine/gpfg-QA)', key = 'input_colpali', disabled = False, value = False, help = "Checking this box will use colpali as the embedding model and retrieval is performed using multi-vectors followed by re-ranking using MaxSim")
413
 
414
  if(colpali_search_rerank):
 
423
  st.write("1. Proportion of female new hires 2021-2023? \n\n 2. First-half 2021 return on unlisted real estate investments? \n\n 3. Trend of the fund's expected absolute volatility between January 2014 and January 2016? \n\n 4. Fund return percentage in 2017? \n\n 5. Annualized gross return of the fund from 1997 to 2008?")
424
 
425
 
 
 
 
 
 
 
426
 
427
 
428