prasadnu commited on
Commit
cccd8e6
·
1 Parent(s): 4ca2226

change container width param

Browse files
pages/Multimodal_Conversational_Search.py CHANGED
@@ -1,214 +1,566 @@
1
- # Streamlit app: Chat with PDFs using OpenSearch, RAG, and ColPali
2
-
3
  import streamlit as st
4
  import uuid
5
  import os
 
6
  import sys
7
- import warnings
 
 
8
  import boto3
 
 
 
9
  import json
10
  import random
11
  import string
 
 
12
  import pandas as pd
13
- from PIL import Image
 
 
 
 
 
 
14
  from requests.auth import HTTPBasicAuth
 
15
 
16
- # Suppress Streamlit deprecation warnings
17
  warnings.filterwarnings("ignore", category=DeprecationWarning)
18
 
19
- # Add necessary module paths
20
- base_path = "/".join(os.path.realpath(__file__).split("/")[:-2])
21
- sys.path.insert(1, f"{base_path}/semantic_search")
22
- sys.path.insert(1, f"{base_path}/RAG")
23
- sys.path.insert(1, f"{base_path}/utilities")
24
 
25
- # Local modules
26
- import rag_DocumentLoader
27
- import rag_DocumentSearcher
28
- import colpali
29
 
30
- # AWS & OpenSearch setup
31
- region = 'us-east-1'
32
- s3_bucket_ = "pdf-repo-uploads"
33
- bedrock_runtime_client = boto3.client('bedrock-runtime', region_name=region)
34
- polly_client = boto3.client(
35
- 'polly',
36
- aws_access_key_id=st.secrets['user_access_key'],
37
- aws_secret_access_key=st.secrets['user_secret_key'],
38
- region_name=region
39
- )
40
- credentials = boto3.Session().get_credentials()
41
- awsauth = HTTPBasicAuth('master', st.secrets['ml_search_demo_api_access'])
42
 
43
- # App configuration
44
- st.set_page_config(layout="wide", page_icon="images/opensearch_mark_default.png")
45
- parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[:-1])
 
 
 
46
  USER_ICON = "images/user.png"
47
  AI_ICON = "images/opensearch-twitter-card.png"
48
  REGENERATE_ICON = "images/regenerate.png"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- # Session state setup
51
- if 'user_id' not in st.session_state:
52
- st.session_state['user_id'] = str(uuid.uuid4())
53
-
54
- st.session_state.setdefault('session_id', "")
55
- st.session_state.setdefault('chats', [{'id': 0, 'question': '', 'answer': ''}])
56
- st.session_state.setdefault('questions_', [])
57
- st.session_state.setdefault('answers_', [])
58
- st.session_state.setdefault('show_columns', False)
59
- st.session_state.setdefault('input_index', "hpijan2024hometrack")
60
- st.session_state.setdefault('input_is_rerank', True)
61
- st.session_state.setdefault('input_is_colpali', False)
62
- st.session_state.setdefault('input_copali_rerank', False)
63
- st.session_state.setdefault('input_table_with_sql', False)
64
- st.session_state.setdefault('input_query', "which city has the highest average housing price in UK ?")
65
- st.session_state.setdefault('input_rag_searchType', ["Vector Search"])
66
-
67
- # Custom styling
68
  st.markdown("""
69
  <style>
70
- [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock],
71
- [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock] {
 
 
72
  gap: 0rem;
73
  }
74
  </style>
75
- """, unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # Top bar with app logo and clear button
78
  def write_top_bar():
79
- col1, col2 = st.columns([77, 23])
80
  with col1:
81
- st.header("Chat with your data", divider='rainbow')
 
 
 
 
82
  with col2:
 
 
83
  clear = st.button("Clear")
84
- st.write("") # spacing
 
85
  return clear
86
 
87
- # Reset inputs when Clear is clicked
88
- if write_top_bar():
 
89
  st.session_state.questions_ = []
90
  st.session_state.answers_ = []
91
- st.session_state.input_query = ""
 
 
 
 
 
92
 
93
- # Handle user query submission
94
- def handle_input():
95
- if st.session_state.input_query == '':
96
- return
97
 
98
- # Extract all input values from session state
99
- inputs = {key.removeprefix('input_'): st.session_state[key] for key in st.session_state if key.startswith('input_')}
 
 
 
 
 
 
 
 
100
  st.session_state.inputs_ = inputs
101
-
102
- # Save the question
103
- st.session_state.questions_.append({
 
 
 
104
  'question': inputs["query"],
105
  'id': len(st.session_state.questions_)
106
- })
107
-
108
- # Choose retrieval method
109
- if st.session_state.input_is_colpali:
110
  out_ = colpali.colpali_search_rerank(st.session_state.input_query)
 
111
  else:
112
- out_ = rag_DocumentSearcher.query_(
113
- awsauth,
114
- inputs,
115
- st.session_state['session_id'],
116
- st.session_state.input_rag_searchType
117
- )
118
-
119
- # Save the answer and clear input
120
  st.session_state.answers_.append({
121
  'answer': out_['text'],
122
- 'source': out_['source'],
123
  'id': len(st.session_state.questions_),
124
  'image': out_['image'],
125
- 'table': out_['table']
126
  })
127
- st.session_state.input_query = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
- # Display user message block
130
- def write_user_message(msg):
131
- col1, col2 = st.columns([3, 97])
 
 
 
132
  with col1:
133
- st.image(USER_ICON, use_container_width=True)
134
  with col2:
135
- st.markdown(
136
- f"<div style='color:#e28743;font-size:18px;padding:3px 7px;border-radius:10px;font-style:italic;'>{msg['question']}</div>",
137
- unsafe_allow_html=True
138
- )
139
 
140
- # Render assistant answer block with optional images and tables
141
- def write_chat_message(response, question, index):
142
- col1, col2, col3 = st.columns([4, 74, 22])
143
 
144
- with col1:
145
- st.image(AI_ICON, use_container_width=True)
146
 
 
 
 
 
 
 
147
  with col2:
148
- answer_text = response['answer']
149
- st.write(answer_text)
150
-
151
- # Add voice playback using AWS Polly
152
- polly_response = polly_client.synthesize_speech(
153
- VoiceId='Joanna', OutputFormat='ogg_vorbis', Text=answer_text, Engine='neural')
154
- st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
155
-
156
- # Optionally show similarity map if enabled
157
- if st.session_state.input_is_colpali:
158
- if st.button("Show similarity map", key=f"simmap_{index}"):
159
- st.session_state.show_columns = True
160
- st.session_state.maxSimImages = colpali.img_highlight(
161
- st.session_state.top_img,
162
- st.session_state.query_token_vectors,
163
- st.session_state.query_tokens
164
- )
165
- handle_input()
166
- with placeholder.container():
167
- render_all()
168
-
169
- with st.expander("Relevant Sources"):
170
- # Render related images
171
- for img in response.get('image', []):
172
- if isinstance(img, dict) and 'file' in img:
173
- st.image(img['file'])
174
-
175
- # Render related tables
176
- for tbl in response.get('table', []):
177
- try:
178
- df = pd.read_csv(tbl['name'], skipinitialspace=True, on_bad_lines='skip', delimiter='`')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  df.fillna(method='pad', inplace=True)
180
  st.table(df)
181
- except Exception as e:
182
- st.warning(f"Failed to load table: {e}")
 
 
 
 
183
 
184
- # Show source text
185
- st.write(response.get("source", ""))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
- # Render all Q&A pairs
188
- def render_all():
189
- for index, (q, a) in enumerate(zip(st.session_state.questions_, st.session_state.answers_), start=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  write_user_message(q)
191
- write_chat_message(a, q, index)
192
 
193
- # Placeholder for dynamic rendering
194
  placeholder = st.empty()
195
  with placeholder.container():
196
- render_all()
197
 
198
- # Input field for user question
199
- col_2, col_3 = st.columns([75, 20])
200
  with col_2:
201
- st.text_input("Ask here", label_visibility="collapsed", key="input_query")
 
202
  with col_3:
203
- st.button("GO", on_click=handle_input, key="play")
204
-
205
- # Sidebar configuration
206
  with st.sidebar:
207
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
208
  st.subheader(":blue[Sample Data]")
209
- st.radio("Choose one index", ["UK Housing", "Global Warming stats", "Covid19 impacts on Ireland"], key="input_rad_index")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  st.subheader(":blue[Retriever]")
211
- st.multiselect("Select the Retriever(s)", ["Keyword Search", "Vector Search", "Sparse Search"], default=["Vector Search"], key="input_rag_searchType")
212
- st.checkbox("Re-rank results", key="input_is_rerank", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  st.subheader(":blue[Multi-vector retrieval]")
214
- st.checkbox("Try Colpali multi-vector retrieval", key="input_is_colpali")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import uuid
3
  import os
4
+ import re
5
  import sys
6
+ sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/semantic_search")
7
+ sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/RAG")
8
+ sys.path.insert(1, "/".join(os.path.realpath(__file__).split("/")[0:-2])+"/utilities")
9
  import boto3
10
+ import requests
11
+ from boto3 import Session
12
+ import botocore.session
13
  import json
14
  import random
15
  import string
16
+ import rag_DocumentLoader
17
+ import rag_DocumentSearcher
18
  import pandas as pd
19
+ from PIL import Image
20
+ import shutil
21
+ import base64
22
+ import time
23
+ import botocore
24
+ from requests_aws4auth import AWS4Auth
25
+ import colpali
26
  from requests.auth import HTTPBasicAuth
27
+ import warnings
28
 
 
29
  warnings.filterwarnings("ignore", category=DeprecationWarning)
30
 
 
 
 
 
 
31
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ st.set_page_config(
35
+ #page_title="Semantic Search using OpenSearch",
36
+ layout="wide",
37
+ page_icon="images/opensearch_mark_default.png"
38
+ )
39
+ parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
40
  USER_ICON = "images/user.png"
41
  AI_ICON = "images/opensearch-twitter-card.png"
42
  REGENERATE_ICON = "images/regenerate.png"
43
+ s3_bucket_ = "pdf-repo-uploads"
44
+ #"pdf-repo-uploads"
45
+ polly_client = boto3.client('polly',aws_access_key_id=st.secrets['user_access_key'],
46
+ aws_secret_access_key=st.secrets['user_secret_key'], region_name = 'us-east-1')
47
+
48
+
49
+ # Check if the user ID is already stored in the session state
50
+ if 'user_id' in st.session_state:
51
+ user_id = st.session_state['user_id']
52
+ #print(f"User ID: {user_id}")
53
+
54
+ # If the user ID is not yet stored in the session state, generate a random UUID
55
+ else:
56
+ user_id = str(uuid.uuid4())
57
+ st.session_state['user_id'] = user_id
58
+
59
+
60
+ if 'session_id' not in st.session_state:
61
+ st.session_state['session_id'] = ""
62
+
63
+ if "chats" not in st.session_state:
64
+ st.session_state.chats = [
65
+ {
66
+ 'id': 0,
67
+ 'question': '',
68
+ 'answer': ''
69
+ }
70
+ ]
71
+
72
+ if "questions_" not in st.session_state:
73
+ st.session_state.questions_ = []
74
+
75
+ if "show_columns" not in st.session_state:
76
+ st.session_state.show_columns = False
77
+
78
+ if "answers_" not in st.session_state:
79
+ st.session_state.answers_ = []
80
+
81
+ if "input_index" not in st.session_state:
82
+ st.session_state.input_index = "hpijan2024hometrack"#"globalwarmingnew"#"hpijan2024hometrack_no_img_no_table"
83
+
84
+ if "input_is_rerank" not in st.session_state:
85
+ st.session_state.input_is_rerank = True
86
+
87
+ if "input_is_colpali" not in st.session_state:
88
+ st.session_state.input_is_colpali = False
89
+
90
+ if "input_copali_rerank" not in st.session_state:
91
+ st.session_state.input_copali_rerank = False
92
+
93
+ if "input_table_with_sql" not in st.session_state:
94
+ st.session_state.input_table_with_sql = False
95
+
96
+ if "input_query" not in st.session_state:
97
+ st.session_state.input_query="which city has the highest average housing price in UK ?"#"What is the projected energy percentage from renewable sources in future?"#"Which city in United Kingdom has the highest average housing price ?"#"How many aged above 85 years died due to covid ?"# What is the projected energy from renewable sources ?"
98
+
99
+ if "input_rag_searchType" not in st.session_state:
100
+ st.session_state.input_rag_searchType = ["Vector Search"]
101
+
102
+
103
+
104
+ region = 'us-east-1'
105
+ bedrock_runtime_client = boto3.client('bedrock-runtime',region_name=region)
106
+ output = []
107
+ service = 'es'
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  st.markdown("""
110
  <style>
111
+ [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
112
+ gap: 0rem;
113
+ }
114
+ [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
115
  gap: 0rem;
116
  }
117
  </style>
118
+ """,unsafe_allow_html=True)
119
+
120
+
121
+
122
+
123
+ credentials = boto3.Session().get_credentials()
124
+ awsauth = HTTPBasicAuth('master',st.secrets['ml_search_demo_api_access'])
125
+ service = 'es'
126
+
127
+
128
+
129
+
130
+ # if "input_searchType" not in st.session_state:
131
+ # st.session_state.input_searchType = "Conversational Search (RAG)"
132
+
133
+ # if "input_temperature" not in st.session_state:
134
+ # st.session_state.input_temperature = "0.001"
135
+
136
+ # if "input_topK" not in st.session_state:
137
+ # st.session_state.input_topK = 200
138
+
139
+ # if "input_topP" not in st.session_state:
140
+ # st.session_state.input_topP = 0.95
141
+
142
+ # if "input_maxTokens" not in st.session_state:
143
+ # st.session_state.input_maxTokens = 1024
144
+
145
+
146
+ def write_logo():
147
+ col1, col2, col3 = st.columns([5, 1, 5])
148
+ with col2:
149
+ st.image(AI_ICON, use_container_width='always')
150
 
 
151
  def write_top_bar():
152
+ col1, col2 = st.columns([77,23])
153
  with col1:
154
+ st.write("")
155
+ st.header("Chat with your data",divider='rainbow')
156
+
157
+ #st.image(AI_ICON, use_container_width='always')
158
+
159
  with col2:
160
+ st.write("")
161
+ st.write("")
162
  clear = st.button("Clear")
163
+ st.write("")
164
+ st.write("")
165
  return clear
166
 
167
+ clear = write_top_bar()
168
+
169
+ if clear:
170
  st.session_state.questions_ = []
171
  st.session_state.answers_ = []
172
+ st.session_state.input_query=""
173
+ # st.session_state.input_searchType="Conversational Search (RAG)"
174
+ # st.session_state.input_temperature = "0.001"
175
+ # st.session_state.input_topK = 200
176
+ # st.session_state.input_topP = 0.95
177
+ # st.session_state.input_maxTokens = 1024
178
 
 
 
 
 
179
 
180
+ def handle_input():
181
+ print("Question: "+st.session_state.input_query)
182
+ print("-----------")
183
+ print("\n\n")
184
+ if(st.session_state.input_query==''):
185
+ return ""
186
+ inputs = {}
187
+ for key in st.session_state:
188
+ if key.startswith('input_'):
189
+ inputs[key.removeprefix('input_')] = st.session_state[key]
190
  st.session_state.inputs_ = inputs
191
+
192
+ #######
193
+
194
+
195
+ #st.write(inputs)
196
+ question_with_id = {
197
  'question': inputs["query"],
198
  'id': len(st.session_state.questions_)
199
+ }
200
+ st.session_state.questions_.append(question_with_id)
201
+ if(st.session_state.input_is_colpali):
 
202
  out_ = colpali.colpali_search_rerank(st.session_state.input_query)
203
+ #print(out_)
204
  else:
205
+ out_ = rag_DocumentSearcher.query_(awsauth, inputs, st.session_state['session_id'],st.session_state.input_rag_searchType)
 
 
 
 
 
 
 
206
  st.session_state.answers_.append({
207
  'answer': out_['text'],
208
+ 'source':out_['source'],
209
  'id': len(st.session_state.questions_),
210
  'image': out_['image'],
211
+ 'table':out_['table']
212
  })
213
+ st.session_state.input_query=""
214
+
215
+
216
+
217
+ # search_type = st.selectbox('Select the Search type',
218
+ # ('Conversational Search (RAG)',
219
+ # 'OpenSearch vector search',
220
+ # 'LLM Text Generation'
221
+ # ),
222
+
223
+ # key = 'input_searchType',
224
+ # help = "Select the type of retriever\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers_)"
225
+ # )
226
+
227
+ # col1, col2, col3, col4 = st.columns(4)
228
+
229
+ # with col1:
230
+ # st.text_input('Temperature', value = "0.001", placeholder='LLM Temperature', key = 'input_temperature',help = "Set the temperature of the Large Language model. \n Note: 1. Set this to values lower to 1 in the order of 0.001, 0.0001, such low values reduces hallucination and creativity in the LLM response; 2. This applies only when LLM is a part of the retriever pipeline")
231
+ # with col2:
232
+ # st.number_input('Top K', value = 200, placeholder='Top K', key = 'input_topK', step = 50, help = "This limits the LLM's predictions to the top k most probable tokens at each step of generation, this applies only when LLM is a prt of the retriever pipeline")
233
+ # with col3:
234
+ # st.number_input('Top P', value = 0.95, placeholder='Top P', key = 'input_topP', step = 0.05, help = "This sets a threshold probability and selects the top tokens whose cumulative probability exceeds the threshold while the tokens are generated by the LLM")
235
+ # with col4:
236
+ # st.number_input('Max Output Tokens', value = 500, placeholder='Max Output Tokens', key = 'input_maxTokens', step = 100, help = "This decides the total number of tokens generated as the final response. Note: Values greater than 1000 takes longer response time")
237
 
238
+ # st.markdown('---')
239
+
240
+
241
+ def write_user_message(md):
242
+ col1, col2 = st.columns([3,97])
243
+
244
  with col1:
245
+ st.image(USER_ICON, use_container_width='always')
246
  with col2:
247
+ #st.warning(md['question'])
 
 
 
248
 
249
+ st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
250
+
 
251
 
 
 
252
 
253
+ def render_answer(question,answer,index,res_img):
254
+
255
+
256
+ col1, col2, col_3 = st.columns([4,74,22])
257
+ with col1:
258
+ st.image(AI_ICON, use_container_width='always')
259
  with col2:
260
+ ans_ = answer['answer']
261
+ st.write(ans_)
262
+
263
+
264
+
265
+ # def stream_():
266
+ # #use for streaming response on the client side
267
+ # for word in ans_.split(" "):
268
+ # yield word + " "
269
+ # time.sleep(0.04)
270
+ # #use for streaming response from Llm directly
271
+ # if(isinstance(ans_,botocore.eventstream.EventStream)):
272
+ # for event in ans_:
273
+ # chunk = event.get('chunk')
274
+
275
+ # if chunk:
276
+
277
+ # chunk_obj = json.loads(chunk.get('bytes').decode())
278
+
279
+ # if('content_block' in chunk_obj or ('delta' in chunk_obj and 'text' in chunk_obj['delta'])):
280
+ # key_ = list(chunk_obj.keys())[2]
281
+ # text = chunk_obj[key_]['text']
282
+
283
+ # clear_output(wait=True)
284
+ # output.append(text)
285
+ # yield text
286
+ # time.sleep(0.04)
287
+
288
+
289
+
290
+ # if(index == len(st.session_state.questions_)):
291
+ # st.write_stream(stream_)
292
+ # if(isinstance(st.session_state.answers_[index-1]['answer'],botocore.eventstream.EventStream)):
293
+ # st.session_state.answers_[index-1]['answer'] = "".join(output)
294
+ # else:
295
+ # st.write(ans_)
296
+
297
+
298
+ polly_response = polly_client.synthesize_speech(VoiceId='Joanna',
299
+ OutputFormat='ogg_vorbis',
300
+ Text = ans_,
301
+ Engine = 'neural')
302
+
303
+ audio_col1, audio_col2 = st.columns([50,50])
304
+ with audio_col1:
305
+ st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
306
+ rdn_key_1 = ''.join([random.choice(string.ascii_letters)
307
+ for _ in range(10)])
308
+ def show_maxsim():
309
+ st.session_state.show_columns = True
310
+ st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
311
+ handle_input()
312
+ with placeholder.container():
313
+ render_all()
314
+ if(st.session_state.input_is_colpali):
315
+ st.button("Show similarity map",key=rdn_key_1,on_click = show_maxsim)
316
+
317
+
318
+
319
+ #st.markdown("<div style='font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;border-radius: 10px;'>"+ans_+"</div>", unsafe_allow_html = True)
320
+ #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Relevant images from the document :</b></div>", unsafe_allow_html = True)
321
+ #st.write("")
322
+ colu1,colu2,colu3 = st.columns([4,82,20])
323
+ with colu2:
324
+ with st.expander("Relevant Sources:"):
325
+ with st.container():
326
+ if(len(res_img)>0):
327
+ #with st.expander("Images:"):
328
+
329
+ idx = 0
330
+ print(res_img)
331
+ for i in range(0,len(res_img)):
332
+
333
+ if(st.session_state.input_is_colpali):
334
+ if(st.session_state.show_columns == True):
335
+ cols_per_row = 3
336
+ st.session_state.image_placeholder=st.empty()
337
+ with st.session_state.image_placeholder.container():
338
+ row = st.columns(cols_per_row)
339
+ for j, item in enumerate(res_img[i:i+cols_per_row]):
340
+ with row[j]:
341
+ st.image(item['file'])
342
+
343
+ else:
344
+ st.session_state.image_placeholder = st.empty()
345
+ with st.session_state.image_placeholder.container():
346
+ col3_,col4_,col5_ = st.columns([33,33,33])
347
+ with col3_:
348
+ st.image(res_img[i]['file'])
349
+
350
+
351
+
352
+
353
+
354
+ else:
355
+ if(res_img[i]['file'].lower()!='none' and idx < 1):
356
+ col3,col4,col5 = st.columns([33,33,33])
357
+ cols = [col3,col4]
358
+ img = res_img[i]['file'].split(".")[0]
359
+ caption = res_img[i]['caption']
360
+
361
+ with cols[idx]:
362
+
363
+ st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
364
+ #st.write(caption)
365
+ idx = idx+1
366
+ if(st.session_state.show_columns == True):
367
+ st.session_state.show_columns = False
368
+ #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Sources from the document:</b></div>", unsafe_allow_html = True)
369
+ if(len(answer["table"] )>0):
370
+ #with st.expander("Table:"):
371
+ df = pd.read_csv(answer["table"][0]['name'],skipinitialspace = True, on_bad_lines='skip',delimiter='`')
372
  df.fillna(method='pad', inplace=True)
373
  st.table(df)
374
+ #with st.expander("Raw sources:"):
375
+ st.write(answer["source"])
376
+
377
+
378
+ with col_3:
379
+ if(index == len(st.session_state.questions_)):
380
 
381
+ rdn_key = ''.join([random.choice(string.ascii_letters)
382
+ for _ in range(10)])
383
+ # rdn_key_1 = ''.join([random.choice(string.ascii_letters)
384
+ # for _ in range(10)])
385
+ currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
386
+ oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
387
+ def on_button_click():
388
+ if(currentValue!=oldValue or 1==1):
389
+ st.session_state.input_query = st.session_state.questions_[-1]["question"]
390
+ st.session_state.answers_.pop()
391
+ st.session_state.questions_.pop()
392
+
393
+ handle_input()
394
+ with placeholder.container():
395
+ render_all()
396
+ # def show_maxsim():
397
+ # st.session_state.show_columns = True
398
+ # st.session_state.maxSimImages = colpali.img_highlight(st.session_state.top_img, st.session_state.query_token_vectors, st.session_state.query_tokens)
399
+ # handle_input()
400
+ # with placeholder.container():
401
+ # render_all()
402
+ if("currentValue" in st.session_state):
403
+ del st.session_state["currentValue"]
404
 
405
+ try:
406
+ del regenerate
407
+ except:
408
+ pass
409
+ placeholder__ = st.empty()
410
+ placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
411
+
412
+
413
+ #Each answer will have context of the question asked in order to associate the provided feedback with the respective question
414
+ def write_chat_message(md, q,index):
415
+ if(st.session_state.show_columns):
416
+ res_img = st.session_state.maxSimImages
417
+ else:
418
+ res_img = md['image']
419
+ chat = st.container()
420
+ with chat:
421
+ render_answer(q,md,index,res_img)
422
+
423
+ def render_all():
424
+ index = 0
425
+ for (q, a) in zip(st.session_state.questions_, st.session_state.answers_):
426
+ index = index +1
427
+
428
  write_user_message(q)
429
+ write_chat_message(a, q,index)
430
 
 
431
  placeholder = st.empty()
432
  with placeholder.container():
433
+ render_all()
434
 
435
+ st.markdown("")
436
+ col_2, col_3 = st.columns([75,20])
437
  with col_2:
438
+ #st.markdown("")
439
+ input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_query")
440
  with col_3:
441
+ #hidden = st.button("RUN",disabled=True,key = "hidden")
442
+ play = st.button("GO",on_click=handle_input,key = "play")
 
443
  with st.sidebar:
444
  st.page_link("app.py", label=":orange[Home]", icon="🏠")
445
  st.subheader(":blue[Sample Data]")
446
+ coln_1,coln_2 = st.columns([70,30])
447
+ with coln_1:
448
+ index_select = st.radio("Choose one index",["UK Housing","Global Warming stats","Covid19 impacts on Ireland"],key="input_rad_index")
449
+ with coln_2:
450
+ st.markdown("<p style='font-size:15px'>Preview file</p>",unsafe_allow_html=True)
451
+ st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)")
452
+ st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)")
453
+ st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)")
454
+ st.markdown("""
455
+ <style>
456
+ [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
457
+ gap: 0rem;
458
+ }
459
+ [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
460
+ gap: 0rem;
461
+ }
462
+ </style>
463
+ """,unsafe_allow_html=True)
464
+ with st.expander("Sample questions:"):
465
+ st.markdown("<span style = 'color:#FF9900;'>UK Housing</span> - which city has the highest average housing price in UK ?",unsafe_allow_html=True)
466
+ st.markdown("<span style = 'color:#FF9900;'>Global Warming stats</span> - What is the projected energy percentage from renewable sources in future?",unsafe_allow_html=True)
467
+ st.markdown("<span style = 'color:#FF9900;'>Covid19 impacts</span> - How many aged above 85 years died due to covid ?",unsafe_allow_html=True)
468
+
469
+
470
+ #st.subheader(":blue[Your multi-modal documents]")
471
+ # pdf_doc_ = st.file_uploader(
472
+ # "Upload your PDFs here and click on 'Process'", accept_multiple_files=False)
473
+
474
+
475
+ # pdf_docs = [pdf_doc_]
476
+ # if st.button("Process"):
477
+ # with st.spinner("Processing"):
478
+ # if os.path.isdir(parent_dirname+"/pdfs") == False:
479
+ # os.mkdir(parent_dirname+"/pdfs")
480
+
481
+ # for pdf_doc in pdf_docs:
482
+ # print(type(pdf_doc))
483
+ # pdf_doc_name = (pdf_doc.name).replace(" ","_")
484
+ # with open(os.path.join(parent_dirname+"/pdfs",pdf_doc_name),"wb") as f:
485
+ # f.write(pdf_doc.getbuffer())
486
+
487
+ # request_ = { "bucket": s3_bucket_,"key": pdf_doc_name}
488
+ # # if(st.session_state.input_copali_rerank):
489
+ # # copali.process_doc(request_)
490
+ # # else:
491
+ # rag_DocumentLoader.load_docs(request_)
492
+ # print('lambda done')
493
+ # st.success('you can start searching on your PDF')
494
+
495
+ ############## haystach demo temporary addition ############
496
+ # st.subheader(":blue[Multimodality]")
497
+ # colu1,colu2 = st.columns([50,50])
498
+ # with colu1:
499
+ # in_images = st.toggle('Images', key = 'in_images', disabled = False)
500
+ # with colu2:
501
+ # in_tables = st.toggle('Tables', key = 'in_tables', disabled = False)
502
+ # if(in_tables):
503
+ # st.session_state.input_table_with_sql = True
504
+ # else:
505
+ # st.session_state.input_table_with_sql = False
506
+
507
+ ############## haystach demo temporary addition ############
508
+ #if(pdf_doc_ is None or pdf_doc_ == ""):
509
+ if(index_select == "Global Warming stats"):
510
+ st.session_state.input_index = "globalwarming"
511
+ if(index_select == "Covid19 impacts on Ireland"):
512
+ st.session_state.input_index = "covid19ie"#"choosetheknnalgorithmforyourbillionscaleusecasewithopensearchawsbigdatablog"
513
+ if(index_select == "BEIR"):
514
+ st.session_state.input_index = "2104"
515
+ if(index_select == "UK Housing"):
516
+ st.session_state.input_index = "hpijan2024hometrack"
517
+
518
+ # custom_index = st.text_input("If uploaded the file already, enter the original file name", value = "")
519
+ # if(custom_index!=""):
520
+ # st.session_state.input_index = re.sub('[^A-Za-z0-9]+', '', (custom_index.lower().replace(".pdf","").split("/")[-1].split(".")[0]).lower())
521
+
522
+
523
+
524
  st.subheader(":blue[Retriever]")
525
+ search_type = st.multiselect('Select the Retriever(s)',
526
+ ['Keyword Search',
527
+ 'Vector Search',
528
+ 'Sparse Search',
529
+ ],
530
+ ['Vector Search'],
531
+
532
+ key = 'input_rag_searchType',
533
+ help = "Select the type of Search, adding more than one search type will activate hybrid search"#\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers)"
534
+ )
535
+
536
+ re_rank = st.checkbox('Re-rank results', key = 'input_re_rank', disabled = False, value = True, help = "Checking this box will re-rank the results using a cross-encoder model")
537
+
538
+ if(re_rank):
539
+ st.session_state.input_is_rerank = True
540
+ else:
541
+ st.session_state.input_is_rerank = False
542
+
543
  st.subheader(":blue[Multi-vector retrieval]")
544
+
545
+ #st.write("Dataset indexed: https://huggingface.co/datasets/vespa-engine/gpfg-QA")
546
+ colpali_search_rerank = st.checkbox('Try Colpali multi-vector retrieval on the [sample dataset](https://huggingface.co/datasets/vespa-engine/gpfg-QA)', key = 'input_colpali', disabled = False, value = False, help = "Checking this box will use colpali as the embedding model and retrieval is performed using multi-vectors followed by re-ranking using MaxSim")
547
+
548
+ if(colpali_search_rerank):
549
+ st.session_state.input_is_colpali = True
550
+ #st.session_state.input_query = ""
551
+ else:
552
+ st.session_state.input_is_colpali = False
553
+
554
+ with st.expander("Sample questions for Colpali retriever:"):
555
+ st.write("1. Proportion of female new hires 2021-2023? \n\n 2. First-half 2021 return on unlisted real estate investments? \n\n 3. Trend of the fund's expected absolute volatility between January 2014 and January 2016? \n\n 4. Fund return percentage in 2017? \n\n 5. Annualized gross return of the fund from 1997 to 2008?")
556
+
557
+
558
+ # copali_rerank = st.checkbox("Search and Re-rank with Token level vectors",key = 'copali_rerank',help = "Enabling this option uses 'Copali' model's page level image embeddings to retrieve documents and MaxSim to re-rank the pages.\n\n Hugging Face Model: https://huggingface.co/vidore/colpali")
559
+
560
+ # if(copali_rerank):
561
+ # st.session_state.input_copali_rerank = True
562
+ # else:
563
+ # st.session_state.input_copali_rerank = False
564
+
565
+
566
+