prasadnu commited on
Commit
b23a15f
·
1 Parent(s): 55b3b62

RAG upfated

Browse files
pages/Multimodal_Conversational_Search.py CHANGED
@@ -1,232 +1,566 @@
1
- # Streamlit app: Chat with PDFs using OpenSearch, RAG, and ColPali
2
-
3
  import streamlit as st
4
  import uuid
5
  import os
 
6
  import sys
7
- import warnings
 
 
8
  import boto3
 
 
 
9
  import json
10
  import random
11
  import string
 
 
12
  import pandas as pd
13
- from PIL import Image
 
 
 
 
 
 
14
  from requests.auth import HTTPBasicAuth
 
15
 
16
- # Suppress Streamlit deprecation warnings
17
  warnings.filterwarnings("ignore", category=DeprecationWarning)
18
 
19
- # Add necessary module paths
20
- base_path = "/".join(os.path.realpath(__file__).split("/")[:-2])
21
- sys.path.insert(1, f"{base_path}/semantic_search")
22
- sys.path.insert(1, f"{base_path}/RAG")
23
- sys.path.insert(1, f"{base_path}/utilities")
24
 
25
- # Local modules
26
- import rag_DocumentLoader
27
- import rag_DocumentSearcher
28
- import colpali
29
 
30
- # AWS & OpenSearch setup
31
- region = 'us-east-1'
32
- s3_bucket_ = "pdf-repo-uploads"
33
- bedrock_runtime_client = boto3.client('bedrock-runtime', region_name=region)
34
- polly_client = boto3.client(
35
- 'polly',
36
- aws_access_key_id=st.secrets['user_access_key'],
37
- aws_secret_access_key=st.secrets['user_secret_key'],
38
- region_name=region
39
- )
40
- credentials = boto3.Session().get_credentials()
41
- awsauth = HTTPBasicAuth('master', st.secrets['ml_search_demo_api_access'])
42
 
43
- # App configuration
44
- st.set_page_config(layout="wide", page_icon="images/opensearch_mark_default.png")
45
- parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[:-1])
 
 
 
46
  USER_ICON = "images/user.png"
47
  AI_ICON = "images/opensearch-twitter-card.png"
48
  REGENERATE_ICON = "images/regenerate.png"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Session state setup
51
- if 'user_id' not in st.session_state:
52
- st.session_state['user_id'] = str(uuid.uuid4())
53
-
54
- st.session_state.setdefault('session_id', "")
55
- st.session_state.setdefault('chats', [{'id': 0, 'question': '', 'answer': ''}])
56
- st.session_state.setdefault('questions_', [])
57
- st.session_state.setdefault('answers_', [])
58
- st.session_state.setdefault('show_columns', False)
59
- st.session_state.setdefault('input_index', "hpijan2024hometrack")
60
- st.session_state.setdefault('input_is_rerank', True)
61
- st.session_state.setdefault('input_is_colpali', False)
62
- st.session_state.setdefault('input_copali_rerank', False)
63
- st.session_state.setdefault('input_table_with_sql', False)
64
- st.session_state.setdefault('input_query', "which city has the highest average housing price in UK ?")
65
- st.session_state.setdefault('input_rag_searchType', ["Vector Search"])
66
-
67
- # Custom styling
68
  st.markdown("""
69
  <style>
70
- [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock],
71
- [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock] {
 
 
72
  gap: 0rem;
73
  }
74
  </style>
75
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # Top bar with app logo and clear button
78
  def write_top_bar():
79
- col1, col2 = st.columns([77, 23])
80
  with col1:
81
- st.header("Chat with your data", divider='rainbow')
 
 
 
 
82
  with col2:
 
 
83
  clear = st.button("Clear")
84
- st.write("") # spacing
 
85
  return clear
86
 
87
- # Reset inputs when Clear is clicked
88
- if write_top_bar():
 
89
  st.session_state.questions_ = []
90
  st.session_state.answers_ = []
91
- st.session_state.input_query = ""
 
 
 
 
 
92
 
93
- # Handle user query submission
94
- def handle_input():
95
- if st.session_state.input_query == '':
96
- return
97
 
98
- inputs = {key.removeprefix('input_'): st.session_state[key] for key in st.session_state if key.startswith('input_')}
 
 
 
 
 
 
 
 
 
99
  st.session_state.inputs_ = inputs
100
-
101
- st.session_state.questions_.append({
 
 
 
 
102
  'question': inputs["query"],
103
  'id': len(st.session_state.questions_)
104
- })
105
-
106
- if st.session_state.input_is_colpali:
107
  out_ = colpali.colpali_search_rerank(st.session_state.input_query)
 
108
  else:
109
- out_ = rag_DocumentSearcher.query_(
110
- awsauth,
111
- inputs,
112
- st.session_state['session_id'],
113
- st.session_state.input_rag_searchType
114
- )
115
-
116
  st.session_state.answers_.append({
117
  'answer': out_['text'],
118
- 'source': out_['source'],
119
  'id': len(st.session_state.questions_),
120
  'image': out_['image'],
121
- 'table': out_['table']
122
  })
123
- st.session_state.input_query = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Display user message block
126
- def write_user_message(msg):
127
- col1, col2 = st.columns([3, 97])
128
  with col1:
129
- st.image(USER_ICON, use_container_width=True)
130
  with col2:
131
- st.markdown(
132
- f"<div style='color:#e28743;font-size:18px;padding:3px 7px;border-radius:10px;font-style:italic;'>{msg['question']}</div>",
133
- unsafe_allow_html=True
134
- )
135
 
136
- # Render assistant answer block with optional images and tables
137
- def write_chat_message(response, question, index):
138
- col1, col2, col3 = st.columns([4, 74, 22])
139
 
140
- with col1:
141
- st.image(AI_ICON, use_container_width=True)
142
 
 
 
 
 
 
 
143
  with col2:
144
- answer_text = response['answer']
145
- st.write(answer_text)
146
-
147
- polly_response = polly_client.synthesize_speech(
148
- VoiceId='Joanna', OutputFormat='ogg_vorbis', Text=answer_text, Engine='neural')
149
- st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
150
-
151
- if st.session_state.input_is_colpali:
152
- if st.button("Show similarity map", key=f"simmap_{index}"):
153
- st.session_state.show_columns = True
154
- st.session_state.maxSimImages = colpali.img_highlight(
155
- st.session_state.top_img,
156
- st.session_state.query_token_vectors,
157
- st.session_state.query_tokens
158
- )
159
- handle_input()
160
- with placeholder.container():
161
- render_all()
162
-
163
- with st.expander("Relevant Sources"):
164
- for img in response.get('image', []):
165
- if isinstance(img, dict) and 'file' in img:
166
- st.image(img['file'])
167
-
168
- for tbl in response.get('table', []):
169
- try:
170
- df = pd.read_csv(tbl['name'], skipinitialspace=True, on_bad_lines='skip', delimiter='`')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  df.fillna(method='pad', inplace=True)
172
  st.table(df)
173
- except Exception as e:
174
- st.warning(f"Failed to load table: {e}")
175
-
176
- st.write(response.get("source", ""))
177
-
178
- # Render all Q&A pairs
179
- def render_all():
180
- for index, (q, a) in enumerate(zip(st.session_state.questions_, st.session_state.answers_), start=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  write_user_message(q)
182
- write_chat_message(a, q, index)
183
 
184
- # Placeholder for dynamic rendering
185
  placeholder = st.empty()
186
  with placeholder.container():
187
- render_all()
188
 
189
- # Input field for user question
190
- col_2, col_3 = st.columns([75, 20])
191
  with col_2:
192
- st.text_input("Ask here", label_visibility="collapsed", key="input_query")
 
193
  with col_3:
194
- st.button("GO", on_click=handle_input, key="play")
195
-
196
- # Sidebar configuration
197
  with st.sidebar:
198
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
199
  st.subheader(":blue[Sample Data]")
200
- coln_1, coln_2 = st.columns([70, 30])
201
  with coln_1:
202
- st.radio("Choose one index", ["UK Housing", "Global Warming stats", "Covid19 impacts on Ireland"], key="input_rad_index")
203
  with coln_2:
204
- st.markdown("<p style='font-size:15px'>Preview file</p>", unsafe_allow_html=True)
205
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)")
206
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)")
207
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  st.subheader(":blue[Retriever]")
209
- st.multiselect("Select the Retriever(s)", ["Keyword Search", "Vector Search", "Sparse Search"], default=["Vector Search"], key="input_rag_searchType")
210
- st.checkbox("Re-rank results", key="input_is_rerank", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  st.subheader(":blue[Multi-vector retrieval]")
212
-
213
- colpali_search_rerank = st.checkbox('Try Colpali multi-vector retrieval on the [sample dataset](https://huggingface.co/datasets/vespa-engine/gpfg-QA)',
214
- key='input_colpali',
215
- disabled=False,
216
- value=False,
217
- help="Checking this box will use colpali as the embedding model and retrieval is performed using multi-vectors followed by re-ranking using MaxSim")
218
-
219
- if colpali_search_rerank:
220
  st.session_state.input_is_colpali = True
 
221
  else:
222
  st.session_state.input_is_colpali = False
223
-
224
  with st.expander("Sample questions for Colpali retriever:"):
225
- st.write("""
226
- 1. Proportion of female new hires 2021-2023?
227
- 2. First-half 2021 return on unlisted real estate investments?
228
- 3. Trend of the fund's expected absolute volatility between January 2014 and January 2016?
229
- 4. Fund return percentage in 2017?
230
- 5. Annualized gross return of the fund from 1997 to 2008?
231
- """)
 
 
 
232
 
 
 
 
 
1
  import streamlit as st
2
  import uuid
3
  import os
4
+ import re
5
  import sys
6
+ sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/semantic_search")
7
+ sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/RAG")
8
+ sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/utilities")
9
  import boto3
10
+ import requests
11
+ from boto3 import Session
12
+ import botocore.session
13
  import json
14
  import random
15
  import string
16
+ import rag_DocumentLoader
17
+ import rag_DocumentSearcher
18
  import pandas as pd
19
+ from PIL import Image
20
+ import shutil
21
+ import base64
22
+ import time
23
+ import botocore
24
+ from requests_aws4auth import AWS4Auth
25
+ import colpali
26
  from requests.auth import HTTPBasicAuth
27
+ import warnings
28
 
 
29
  warnings.filterwarnings("ignore", category=DeprecationWarning)
30
 
 
 
 
 
 
31
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ st.set_page_config(
35
+ #page_title="Semantic Search using OpenSearch",
36
+ layout="wide",
37
+ page_icon="images/opensearch_mark_default.png"
38
+ )
39
+ parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
40
  USER_ICON = "images/user.png"
41
  AI_ICON = "images/opensearch-twitter-card.png"
42
  REGENERATE_ICON = "images/regenerate.png"
43
+ s3_bucket_ = "pdf-repo-uploads"
44
+ #"pdf-repo-uploads"
45
+ polly_client = boto3.client('polly',aws_access_key_id=st.secrets['user_access_key'],
46
+ aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
47
+
48
+
49
+ # Check if the user ID is already stored in the session state
50
+ if 'user_id' in st.session_state:
51
+ user_id = st.session_state['user_id']
52
+ #print(f"User ID: {user_id}")
53
+
54
+ # If the user ID is not yet stored in the session state, generate a random UUID
55
+ else:
56
+ user_id = str(uuid.uuid4())
57
+ st.session_state['user_id'] = user_id
58
+
59
+
60
+ if 'session_id' not in st.session_state:
61
+ st.session_state['session_id'] = ""
62
+
63
+ if "chats" not in st.session_state:
64
+ st.session_state.chats = [
65
+ {
66
+ 'id': 0,
67
+ 'question': '',
68
+ 'answer': ''
69
+ }
70
+ ]
71
+
72
+ if "questions_" not in st.session_state:
73
+ st.session_state.questions_ = []
74
+
75
+ if "show_columns" not in st.session_state:
76
+ st.session_state.show_columns = False
77
+
78
+ if "answers_" not in st.session_state:
79
+ st.session_state.answers_ = []
80
+
81
+ if "input_index" not in st.session_state:
82
+ st.session_state.input_index = "hpijan2024hometrack"#"globalwarmingnew"#"hpijan2024hometrack_no_img_no_table"
83
+
84
+ if "input_is_rerank" not in st.session_state:
85
+ st.session_state.input_is_rerank = True
86
+
87
+ if "input_is_colpali" not in st.session_state:
88
+ st.session_state.input_is_colpali = False
89
+
90
+ if "input_copali_rerank" not in st.session_state:
91
+ st.session_state.input_copali_rerank = False
92
+
93
+ if "input_table_with_sql" not in st.session_state:
94
+ st.session_state.input_table_with_sql = False
95
+
96
+ if "input_query" not in st.session_state:
97
+ st.session_state.input_query="which city has the highest average housing price in UK ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
98
+
99
+ if "input_rag_searchType" not in st.session_state:
100
+ st.session_state.input_rag_searchType = ["Vector Search"]
101
+
102
+
103
+
104
+ region = 'us-east-1'
105
+ bedrock_runtime_client = boto3.client('bedrock-runtime',region_name=region)
106
+ output = []
107
+ service = 'es'
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  st.markdown("""
110
  <style>
111
+ [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
112
+ gap: 0rem;
113
+ }
114
+ [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
115
  gap: 0rem;
116
  }
117
  </style>
118
+ """,unsafe_allow_html=True)
119
+
120
+
121
+
122
+
123
+ credentials = boto3.Session().get_credentials()
124
+ awsauth = HTTPBasicAuth('master',st.secrets['ml_search_demo_api_access'])
125
+ service = 'es'
126
+
127
+
128
+
129
+
130
+ # if "input_searchType" not in st.session_state:
131
+ # st.session_state.input_searchType = "Conversational Search (RAG)"
132
+
133
+ # if "input_temperature" not in st.session_state:
134
+ # st.session_state.input_temperature = "0.001"
135
+
136
+ # if "input_topK" not in st.session_state:
137
+ # st.session_state.input_topK = 200
138
+
139
+ # if "input_topP" not in st.session_state:
140
+ # st.session_state.input_topP = 0.95
141
+
142
+ # if "input_maxTokens" not in st.session_state:
143
+ # st.session_state.input_maxTokens = 1024
144
+
145
+
146
+ def write_logo():
147
+ col1, col2, col3 = st.columns([5, 1, 5])
148
+ with col2:
149
+ st.image(AI_ICON, use_container_width='always')
150
 
 
151
  def write_top_bar():
152
+ col1, col2 = st.columns([77,23])
153
  with col1:
154
+ st.write("")
155
+ st.header("Chat with your data",divider='rainbow')
156
+
157
+ #st.image(AI_ICON, use_container_width='always')
158
+
159
  with col2:
160
+ st.write("")
161
+ st.write("")
162
  clear = st.button("Clear")
163
+ st.write("")
164
+ st.write("")
165
  return clear
166
 
167
+ clear = write_top_bar()
168
+
169
+ if clear:
170
  st.session_state.questions_ = []
171
  st.session_state.answers_ = []
172
+ st.session_state.input_query=""
173
+ # st.session_state.input_searchType="Conversational Search (RAG)"
174
+ # st.session_state.input_temperature = "0.001"
175
+ # st.session_state.input_topK = 200
176
+ # st.session_state.input_topP = 0.95
177
+ # st.session_state.input_maxTokens = 1024
178
 
 
 
 
 
179
 
180
+ def handle_input():
181
+ print("Question: "+st.session_state.input_query)
182
+ print("-----------")
183
+ print("\n\n")
184
+ if(st.session_state.input_query==''):
185
+ return ""
186
+ inputs = {}
187
+ for key in st.session_state:
188
+ if key.startswith('input_'):
189
+ inputs[key.removeprefix('input_')] = st.session_state[key]
190
  st.session_state.inputs_ = inputs
191
+
192
+ #######
193
+
194
+
195
+ #st.write(inputs)
196
+ question_with_id = {
197
  'question': inputs["query"],
198
  'id': len(st.session_state.questions_)
199
+ }
200
+ st.session_state.questions_.append(question_with_id)
201
+ if(st.session_state.input_is_colpali):
202
  out_ = colpali.colpali_search_rerank(st.session_state.input_query)
203
+ #print(out_)
204
  else:
205
+ out_ = rag_DocumentSearcher.query_(awsauth, inputs, st.session_state['session_id'],st.session_state.input_rag_searchType)
 
 
 
 
 
 
206
  st.session_state.answers_.append({
207
  'answer': out_['text'],
208
+ 'source':out_['source'],
209
  'id': len(st.session_state.questions_),
210
  'image': out_['image'],
211
+ 'table':out_['table']
212
  })
213
+ st.session_state.input_query=""
214
+
215
+
216
+
217
+ # search_type = st.selectbox('Select the Search type',
218
+ # ('Conversational Search (RAG)',
219
+ # 'OpenSearch vector search',
220
+ # 'LLM Text Generation'
221
+ # ),
222
+
223
+ # key = 'input_searchType',
224
+ # help = "Select the type of retriever\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers_)"
225
+ # )
226
+
227
+ # col1, col2, col3, col4 = st.columns(4)
228
+
229
+ # with col1:
230
+ # st.text_input('Temperature', value = "0.001", placeholder='LLM Temperature', key = 'input_temperature',help = "Set the temperature of the Large Language model. \n Note: 1. Set this to values lower to 1 in the order of 0.001, 0.0001, such low values reduces hallucination and creativity in the LLM response; 2. This applies only when LLM is a part of the retriever pipeline")
231
+ # with col2:
232
+ # st.number_input('Top K', value = 200, placeholder='Top K', key = 'input_topK', step = 50, help = "This limits the LLM's predictions to the top k most probable tokens at each step of generation, this applies only when LLM is a prt of the retriever pipeline")
233
+ # with col3:
234
+ # st.number_input('Top P', value = 0.95, placeholder='Top P', key = 'input_topP', step = 0.05, help = "This sets a threshold probability and selects the top tokens whose cumulative probability exceeds the threshold while the tokens are generated by the LLM")
235
+ # with col4:
236
+ # st.number_input('Max Output Tokens', value = 500, placeholder='Max Output Tokens', key = 'input_maxTokens', step = 100, help = "This decides the total number of tokens generated as the final response. Note: Values greater than 1000 takes longer response time")
237
+
238
+ # st.markdown('---')
239
+
240
 
241
+ def write_user_message(md):
242
+ col1, col2 = st.columns([3,97])
243
+
244
  with col1:
245
+ st.image(USER_ICON, use_container_width='always')
246
  with col2:
247
+ #st.warning(md['question'])
 
 
 
248
 
249
+ st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
250
+
 
251
 
 
 
252
 
253
+ def render_answer(question,answer,index,res_img):
254
+
255
+
256
+ col1, col2, col_3 = st.columns([4,74,22])
257
+ with col1:
258
+ st.image(AI_ICON, use_container_width='always')
259
  with col2:
260
+ ans_ = answer['answer']
261
+ st.write(ans_)
262
+
263
+
264
+
265
+ # def stream_():
266
+ # #use for streaming response on the client side
267
+ # for word in ans_.split(" "):
268
+ # yield word + " "
269
+ # time.sleep(0.04)
270
+ # #use for streaming response from Llm directly
271
+ # if(isinstance(ans_,botocore.eventstream.EventStream)):
272
+ # for event in ans_:
273
+ # chunk = event.get('chunk')
274
+
275
+ # if chunk:
276
+
277
+ # chunk_obj = json.loads(chunk.get('bytes').decode())
278
+
279
+ # if('content_block' in chunk_obj or ('delta' in chunk_obj and 'text' in chunk_obj['delta'])):
280
+ # key_ = list(chunk_obj.keys())[2]
281
+ # text = chunk_obj[key_]['text']
282
+
283
+ # clear_output(wait=True)
284
+ # output.append(text)
285
+ # yield text
286
+ # time.sleep(0.04)
287
+
288
+
289
+
290
+ # if(index == len(st.session_state.questions_)):
291
+ # st.write_stream(stream_)
292
+ # if(isinstance(st.session_state.answers_[index-1]['answer'],botocore.eventstream.EventStream)):
293
+ # st.session_state.answers_[index-1]['answer'] = "".join(output)
294
+ # else:
295
+ # st.write(ans_)
296
+
297
+
298
+ polly_response = polly_client.synthesize_speech(VoiceId='Joanna',
299
+ OutputFormat='ogg_vorbis',
300
+ Text = ans_,
301
+ Engine = 'neural')
302
+
303
+ audio_col1, audio_col2 = st.columns([50,50])
304
+ with audio_col1:
305
+ st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
306
+ rdn_key_1 = ''.join([random.choice(string.ascii_letters)
307
+ for _ in range(10)])
308
+ def show_maxsim():
309
+ st.session_state.show_columns = True
310
+ st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
311
+ handle_input()
312
+ with placeholder.container():
313
+ render_all()
314
+ if(st.session_state.input_is_colpali):
315
+ st.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
316
+
317
+
318
+
319
+ #st.markdown("<div style='font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;border-radius: 10px;'>"+ans_+"</div>", unsafe_allow_html = True)
320
+ #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Relevant images from the document :</b></div>", unsafe_allow_html = True)
321
+ #st.write("")
322
+ colu1,colu2,colu3 = st.columns([4,82,20])
323
+ with colu2:
324
+ with st.expander("Relevant Sources:"):
325
+ with st.container():
326
+ if(len(res_img)>0):
327
+ #with st.expander("Images:"):
328
+
329
+ idx = 0
330
+ print(res_img)
331
+ for i in range(0,len(res_img)):
332
+
333
+ if(st.session_state.input_is_colpali):
334
+ if(st.session_state.show_columns == True):
335
+ cols_per_row = 3
336
+ st.session_state.image_placeholder=st.empty()
337
+ with st.session_state.image_placeholder.container():
338
+ row = st.columns(cols_per_row)
339
+ for j, item in enumerate(res_img[i:i+cols_per_row]):
340
+ with row[j]:
341
+ st.image(item['file'])
342
+
343
+ else:
344
+ st.session_state.image_placeholder = st.empty()
345
+ with st.session_state.image_placeholder.container():
346
+ col3_,col4_,col5_ = st.columns([33,33,33])
347
+ with col3_:
348
+ st.image(res_img[i]['file'])
349
+
350
+
351
+
352
+
353
+
354
+ else:
355
+ if(res_img[i]['file'].lower()!='none' and idx < 1):
356
+ col3,col4,col5 = st.columns([33,33,33])
357
+ cols = [col3,col4]
358
+ img = res_img[i]['file'].split(".")[0]
359
+ caption = res_img[i]['caption']
360
+
361
+ with cols[idx]:
362
+
363
+ st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
364
+ #st.write(caption)
365
+ idx = idx+1
366
+ if(st.session_state.show_columns == True):
367
+ st.session_state.show_columns = False
368
+ #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Sources from the document:</b></div>", unsafe_allow_html = True)
369
+ if(len(answer["table"] )>0):
370
+ #with st.expander("Table:"):
371
+ df = pd.read_csv(answer["table"][0]['name'],skipinitialspace = True, on_bad_lines='skip',delimiter='`')
372
  df.fillna(method='pad', inplace=True)
373
  st.table(df)
374
+ #with st.expander("Raw sources:"):
375
+ st.write(answer["source"])
376
+
377
+
378
+ with col_3:
379
+ if(index == len(st.session_state.questions_)):
380
+
381
+ rdn_key = ''.join([random.choice(string.ascii_letters)
382
+ for _ in range(10)])
383
+ # rdn_key_1 = ''.join([random.choice(string.ascii_letters)
384
+ # for _ in range(10)])
385
+ currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
386
+ oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
387
+ def on_button_click():
388
+ if(currentValue!=oldValue or 1==1):
389
+ st.session_state.input_query = st.session_state.questions_[-1]["question"]
390
+ st.session_state.answers_.pop()
391
+ st.session_state.questions_.pop()
392
+
393
+ handle_input()
394
+ with placeholder.container():
395
+ render_all()
396
+ # def show_maxsim():
397
+ # st.session_state.show_columns = True
398
+ # st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
399
+ # handle_input()
400
+ # with placeholder.container():
401
+ # render_all()
402
+ if("currentValue" in st.session_state):
403
+ del st.session_state["currentValue"]
404
+
405
+ try:
406
+ del regenerate
407
+ except:
408
+ pass
409
+ placeholder__ = st.empty()
410
+ placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
411
+
412
+
413
+ #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
414
+ def write_chat_message(md, q,index):
415
+ if(st.session_state.show_columns):
416
+ res_img = st.session_state.maxSimImages
417
+ else:
418
+ res_img = md['image']
419
+ chat = st.container()
420
+ with chat:
421
+ render_answer(q,md,index,res_img)
422
+
423
+ def render_all():
424
+ index = 0
425
+ for (q, a) in zip(st.session_state.questions_, st.session_state.answers_):
426
+ index = index +1
427
+
428
  write_user_message(q)
429
+ write_chat_message(a, q,index)
430
 
 
431
  placeholder = st.empty()
432
  with placeholder.container():
433
+ render_all()
434
 
435
+ st.markdown("")
436
+ col_2, col_3 = st.columns([75,20])
437
  with col_2:
438
+ #st.markdown("")
439
+ input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
440
  with col_3:
441
+ #hidden = st.button("RUN",disabled=True,key = "hidden")
442
+ play = st.button("GO",on_click=handle_input,key = "play")
 
443
  with st.sidebar:
444
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
445
  st.subheader(":blue[Sample Data]")
446
+ coln_1,coln_2 = st.columns([70,30])
447
  with coln_1:
448
+ index_select = st.radio("Choose one index",["UK Housing","Global Warming stats","Covid19 impacts on Ireland"],key="input_rad_index")
449
  with coln_2:
450
+ st.markdown("<p style='font-size:15px'>Preview file</p>",unsafe_allow_html=True)
451
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)")
452
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)")
453
  st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)")
454
+ st.markdown("""
455
+ <style>
456
+ [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
457
+ gap: 0rem;
458
+ }
459
+ [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
460
+ gap: 0rem;
461
+ }
462
+ </style>
463
+ """,unsafe_allow_html=True)
464
+ with st.expander("Sample questions:"):
465
+ st.markdown("<span style = 'color:#FF9900;'>UK Housing</span> - which city has the highest average housing price in UK ?",unsafe_allow_html=True)
466
+ st.markdown("<span style = 'color:#FF9900;'>Global Warming stats</span> - What is the projected energy percentage from renewable sources in future?",unsafe_allow_html=True)
467
+ st.markdown("<span style = 'color:#FF9900;'>Covid19 impacts</span> - How many aged above 85 years died due to covid ?",unsafe_allow_html=True)
468
+
469
+
470
+ #st.subheader(":blue[Your multi-modal documents]")
471
+ # pdf_doc_ = st.file_uploader(
472
+ # "Upload your PDFs here and click on 'Process'", accept_multiple_files=False)
473
+
474
+
475
+ # pdf_docs = [pdf_doc_]
476
+ # if st.button("Process"):
477
+ # with st.spinner("Processing"):
478
+ # if os.path.isdir(parent_dirname+"/pdfs") == False:
479
+ # os.mkdir(parent_dirname+"/pdfs")
480
+
481
+ # for pdf_doc in pdf_docs:
482
+ # print(type(pdf_doc))
483
+ # pdf_doc_name = (pdf_doc.name).replace(" ","_")
484
+ # with open(os.path.join(parent_dirname+"/pdfs",pdf_doc_name),"wb") as f:
485
+ # f.write(pdf_doc.getbuffer())
486
+
487
+ # request_ = { "bucket": s3_bucket_,"key": pdf_doc_name}
488
+ # # if(st.session_state.input_copali_rerank):
489
+ # # copali.process_doc(request_)
490
+ # # else:
491
+ # rag_DocumentLoader.load_docs(request_)
492
+ # print('lambda done')
493
+ # st.success('you can start searching on your PDF')
494
+
495
+ ############## haystach demo temporary addition ############
496
+ # st.subheader(":blue[Multimodality]")
497
+ # colu1,colu2 = st.columns([50,50])
498
+ # with colu1:
499
+ # in_images = st.toggle('Images', key = 'in_images', disabled = False)
500
+ # with colu2:
501
+ # in_tables = st.toggle('Tables', key = 'in_tables', disabled = False)
502
+ # if(in_tables):
503
+ # st.session_state.input_table_with_sql = True
504
+ # else:
505
+ # st.session_state.input_table_with_sql = False
506
+
507
+ ############## haystach demo temporary addition ############
508
+ #if(pdf_doc_ is None or pdf_doc_ == ""):
509
+ if(index_select == "Global Warming stats"):
510
+ st.session_state.input_index = "globalwarming"
511
+ if(index_select == "Covid19 impacts on Ireland"):
512
+ st.session_state.input_index = "covid19ie"#"choosetheknnalgorithmforyourbillionscaleusecasewithopensearchawsbigdatablog"
513
+ if(index_select == "BEIR"):
514
+ st.session_state.input_index = "2104"
515
+ if(index_select == "UK Housing"):
516
+ st.session_state.input_index = "hpijan2024hometrack"
517
+
518
+ # custom_index = st.text_input("If uploaded the file already, enter the original file name", value = "")
519
+ # if(custom_index!=""):
520
+ # st.session_state.input_index = re.sub('[^A-Za-z0-9]+', '', (custom_index.lower().replace(".pdf","").split("/")[-1].split(".")[0]).lower())
521
+
522
+
523
+
524
  st.subheader(":blue[Retriever]")
525
+ search_type = st.multiselect('Select the Retriever(s)',
526
+ ['Keyword Search',
527
+ 'Vector Search',
528
+ 'Sparse Search',
529
+ ],
530
+ ['Vector Search'],
531
+
532
+ key = 'input_rag_searchType',
533
+ help = "Select the type of Search, adding more than one search type will activate hybrid search"#\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers)"
534
+ )
535
+
536
+ re_rank = st.checkbox('Re-rank results', key = 'input_re_rank', disabled = False, value = True, help = "Checking this box will re-rank the results using a cross-encoder model")
537
+
538
+ if(re_rank):
539
+ st.session_state.input_is_rerank = True
540
+ else:
541
+ st.session_state.input_is_rerank = False
542
+
543
  st.subheader(":blue[Multi-vector retrieval]")
544
+
545
+ #st.write("Dataset indexed: https://huggingface.co/datasets/vespa-engine/gpfg-QA")
546
+ colpali_search_rerank = st.checkbox('Try Colpali multi-vector retrieval on the [sample dataset](https://huggingface.co/datasets/vespa-engine/gpfg-QA)', key = 'input_colpali', disabled = False, value = False, help = "Checking this box will use colpali as the embedding model and retrieval is performed using multi-vectors followed by re-ranking using MaxSim")
547
+
548
+ if(colpali_search_rerank):
 
 
 
549
  st.session_state.input_is_colpali = True
550
+ #st.session_state.input_query = ""
551
  else:
552
  st.session_state.input_is_colpali = False
553
+
554
  with st.expander("Sample questions for Colpali retriever:"):
555
+ st.write("1. Proportion of female new hires 2021-2023? \n\n 2. First-half 2021 return on unlisted real estate investments? \n\n 3. Trend of the fund's expected absolute volatility between January 2014 and January 2016? \n\n 4. Fund return percentage in 2017? \n\n 5. Annualized gross return of the fund from 1997 to 2008?")
556
+
557
+
558
+ # copali_rerank = st.checkbox("Search and Re-rank with Token level vectors",key = 'copali_rerank',help = "Enabling this option uses 'Copali' model's page level image embeddings to retrieve documents and MaxSim to re-rank the pages.\n\n Hugging Face Model: https://huggingface.co/vidore/colpali")
559
+
560
+ # if(copali_rerank):
561
+ # st.session_state.input_copali_rerank = True
562
+ # else:
563
+ # st.session_state.input_copali_rerank = False
564
+
565
 
566
+