Spaces:
Running
on
T4
Running
on
T4
rerank model
Browse files- RAG/bedrock_agent.py +1 -44
- RAG/rag_DocumentSearcher.py +0 -17
- app.py +0 -24
- pages/AI_Shopping_Assistant.py +7 -404
- pages/Semantic_Search.py +19 -342
- semantic_search/amazon_rekognition.py +2 -47
- utilities/invoke_models.py +2 -84
- utilities/re_ranker.py +0 -127
RAG/bedrock_agent.py
CHANGED
@@ -23,8 +23,6 @@ if "inputs_" not in st.session_state:
|
|
23 |
|
24 |
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
25 |
region = 'us-east-1'
|
26 |
-
print(region)
|
27 |
-
account_id = '445083327804'
|
28 |
# setting logger
|
29 |
logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
|
30 |
logger = logging.getLogger(__name__)
|
@@ -46,9 +44,6 @@ def delete_memory():
|
|
46 |
)
|
47 |
|
48 |
def query_(inputs):
|
49 |
-
## create a random id for session initiator id
|
50 |
-
|
51 |
-
|
52 |
# invoke the agent API
|
53 |
agentResponse = bedrock_agent_runtime_client.invoke_agent(
|
54 |
inputText=inputs['shopping_query'],
|
@@ -71,13 +66,6 @@ def query_(inputs):
|
|
71 |
for event in event_stream:
|
72 |
print("***event*********")
|
73 |
print(event)
|
74 |
-
# if 'chunk' in event:
|
75 |
-
# data = event['chunk']['bytes']
|
76 |
-
# print("***chunk*********")
|
77 |
-
# print(data)
|
78 |
-
# logger.info(f"Final answer ->\n{data.decode('utf8')}")
|
79 |
-
# agent_answer_ = data.decode('utf8')
|
80 |
-
# print(agent_answer_)
|
81 |
if 'trace' in event:
|
82 |
print("trace*****total*********")
|
83 |
print(event['trace'])
|
@@ -109,38 +97,7 @@ def query_(inputs):
|
|
109 |
print(total_context)
|
110 |
except botocore.exceptions.EventStreamError as error:
|
111 |
raise error
|
112 |
-
|
113 |
-
# query_(st.session_state.inputs_)
|
114 |
-
|
115 |
-
# if 'chunk' in event:
|
116 |
-
# data = event['chunk']['bytes']
|
117 |
-
# final_ans = data.decode('utf8')
|
118 |
-
# print(f"Final answer ->\n{final_ans}")
|
119 |
-
# logger.info(f"Final answer ->\n{final_ans}")
|
120 |
-
# agent_answer = final_ans
|
121 |
-
# end_event_received = True
|
122 |
-
# # End event indicates that the request finished successfully
|
123 |
-
# elif 'trace' in event:
|
124 |
-
# logger.info(json.dumps(event['trace'], indent=2))
|
125 |
-
# else:
|
126 |
-
# raise Exception("unexpected event.", event)
|
127 |
-
# except Exception as e:
|
128 |
-
# raise Exception("unexpected event.", e)
|
129 |
return {'text':agent_answer,'source':total_context,'last_tool':{'name':last_tool_name,'response':last_tool}}
|
130 |
|
131 |
-
####### Re-Rank ########
|
132 |
-
|
133 |
-
#print("re-rank")
|
134 |
-
|
135 |
-
# if(st.session_state.input_is_rerank == True and len(total_context)):
|
136 |
-
# ques = [{"question":question}]
|
137 |
-
# ans = [{"answer":total_context}]
|
138 |
-
|
139 |
-
# total_context = re_ranker.re_rank('rag','Cross Encoder',"",ques, ans)
|
140 |
|
141 |
-
# llm_prompt = prompt_template.format(context=total_context[0],question=question)
|
142 |
-
# output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
|
143 |
-
# #print(output)
|
144 |
-
# if(len(images_2)==0):
|
145 |
-
# images_2 = images
|
146 |
-
# return {'text':output,'source':total_context,'image':images_2,'table':df}
|
|
|
23 |
|
24 |
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
25 |
region = 'us-east-1'
|
|
|
|
|
26 |
# setting logger
|
27 |
logging.basicConfig(format='[%(asctime)s] p%(process)s {%(filename)s:%(lineno)d} %(levelname)s - %(message)s', level=logging.INFO)
|
28 |
logger = logging.getLogger(__name__)
|
|
|
44 |
)
|
45 |
|
46 |
def query_(inputs):
|
|
|
|
|
|
|
47 |
# invoke the agent API
|
48 |
agentResponse = bedrock_agent_runtime_client.invoke_agent(
|
49 |
inputText=inputs['shopping_query'],
|
|
|
66 |
for event in event_stream:
|
67 |
print("***event*********")
|
68 |
print(event)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
if 'trace' in event:
|
70 |
print("trace*****total*********")
|
71 |
print(event['trace'])
|
|
|
97 |
print(total_context)
|
98 |
except botocore.exceptions.EventStreamError as error:
|
99 |
raise error
|
100 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
return {'text':agent_answer,'source':total_context,'last_tool':{'name':last_tool_name,'response':last_tool}}
|
102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
RAG/rag_DocumentSearcher.py
CHANGED
@@ -49,7 +49,6 @@ def query_(awsauth,inputs, session_id,search_types):
|
|
49 |
images = []
|
50 |
|
51 |
for hit in hits:
|
52 |
-
#context.append(hit['_source']['caption'])
|
53 |
images.append({'file':hit['_source']['image'],'caption':hit['_source']['processed_element']})
|
54 |
|
55 |
####### SEARCH ########
|
@@ -102,10 +101,6 @@ def query_(awsauth,inputs, session_id,search_types):
|
|
102 |
}
|
103 |
]
|
104 |
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
SIZE = 5
|
110 |
|
111 |
hybrid_payload = {
|
@@ -159,7 +154,6 @@ def query_(awsauth,inputs, session_id,search_types):
|
|
159 |
|
160 |
if('Sparse Search' in search_types):
|
161 |
|
162 |
-
#print("text expansion is enabled")
|
163 |
sparse_payload = { "neural_sparse": {
|
164 |
"processed_element_embedding_sparse": {
|
165 |
"query_text": question,
|
@@ -301,7 +295,6 @@ def query_(awsauth,inputs, session_id,search_types):
|
|
301 |
images_2.append({'file':hit["_source"]["image"],'caption':hit["_source"]["processed_element"]})
|
302 |
|
303 |
idx = idx +1
|
304 |
-
#images.append(hit['_source']['image'])
|
305 |
|
306 |
# if(is_table_in_result == False):
|
307 |
# df = lazy_get_table()
|
@@ -315,19 +308,9 @@ def query_(awsauth,inputs, session_id,search_types):
|
|
315 |
|
316 |
total_context = context_tables + context
|
317 |
|
318 |
-
####### Re-Rank ########
|
319 |
-
|
320 |
-
#print("re-rank")
|
321 |
-
|
322 |
-
# if(st.session_state.input_is_rerank == True and len(total_context)):
|
323 |
-
# ques = [{"question":question}]
|
324 |
-
# ans = [{"answer":total_context}]
|
325 |
-
|
326 |
-
# total_context = re_ranker.re_rank('rag','Cross Encoder',"",ques, ans)
|
327 |
|
328 |
llm_prompt = prompt_template.format(context=total_context[0],question=question)
|
329 |
output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
|
330 |
-
#print(output)
|
331 |
if(len(images_2)==0):
|
332 |
images_2 = images
|
333 |
return {'text':output,'source':total_context,'image':images_2,'table':df}
|
|
|
49 |
images = []
|
50 |
|
51 |
for hit in hits:
|
|
|
52 |
images.append({'file':hit['_source']['image'],'caption':hit['_source']['processed_element']})
|
53 |
|
54 |
####### SEARCH ########
|
|
|
101 |
}
|
102 |
]
|
103 |
|
|
|
|
|
|
|
|
|
104 |
SIZE = 5
|
105 |
|
106 |
hybrid_payload = {
|
|
|
154 |
|
155 |
if('Sparse Search' in search_types):
|
156 |
|
|
|
157 |
sparse_payload = { "neural_sparse": {
|
158 |
"processed_element_embedding_sparse": {
|
159 |
"query_text": question,
|
|
|
295 |
images_2.append({'file':hit["_source"]["image"],'caption':hit["_source"]["processed_element"]})
|
296 |
|
297 |
idx = idx +1
|
|
|
298 |
|
299 |
# if(is_table_in_result == False):
|
300 |
# df = lazy_get_table()
|
|
|
308 |
|
309 |
total_context = context_tables + context
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
llm_prompt = prompt_template.format(context=total_context[0],question=question)
|
313 |
output = invoke_models.invoke_llm_model( "\n\nHuman: {input}\n\nAssistant:".format(input=llm_prompt) ,False)
|
|
|
314 |
if(len(images_2)==0):
|
315 |
images_2 = images
|
316 |
return {'text':output,'source':total_context,'image':images_2,'table':df}
|
app.py
CHANGED
@@ -152,28 +152,6 @@ spacer_col = st.columns(1)[0]
|
|
152 |
with spacer_col:
|
153 |
st.markdown("<div style='height: 120px;'></div>", unsafe_allow_html=True)
|
154 |
|
155 |
-
#st.image("/home/ubuntu/images/OS_AI_1.png", use_column_width=True)
|
156 |
-
# with col_title:
|
157 |
-
# st.write("")
|
158 |
-
# st.markdown('<div class="title">OpenSearch AI demos</div>', unsafe_allow_html=True)
|
159 |
-
|
160 |
-
# def demo_link_block(icon, title, target_page):
|
161 |
-
# st.markdown(f"""
|
162 |
-
# <a href="/{target_page}" target="_self" style="text-decoration: none;">
|
163 |
-
# <div class="demo-card">
|
164 |
-
# <div class="demo-text">
|
165 |
-
# <span>{icon} {title}</span>
|
166 |
-
# <span class="demo-arrow">→</span>
|
167 |
-
# </div>
|
168 |
-
# </div>
|
169 |
-
# </a>
|
170 |
-
# """, unsafe_allow_html=True)
|
171 |
-
|
172 |
-
|
173 |
-
# st.write("")
|
174 |
-
# demo_link_block("🔍", "AI Search", "Semantic_Search")
|
175 |
-
# demo_link_block("💬","Multimodal Conversational Search", "Multimodal_Conversational_Search")
|
176 |
-
# demo_link_block("🛍️","Agentic Shopping Assistant", "AI_Shopping_Assistant")
|
177 |
|
178 |
|
179 |
col1, col2, col3 = st.columns(3)
|
@@ -225,5 +203,3 @@ st.markdown("""
|
|
225 |
</style>
|
226 |
""", unsafe_allow_html=True)
|
227 |
|
228 |
-
# <div class="card-arrow"></div>
|
229 |
-
|
|
|
152 |
with spacer_col:
|
153 |
st.markdown("<div style='height: 120px;'></div>", unsafe_allow_html=True)
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
|
156 |
|
157 |
col1, col2, col3 = st.columns(3)
|
|
|
203 |
</style>
|
204 |
""", unsafe_allow_html=True)
|
205 |
|
|
|
|
pages/AI_Shopping_Assistant.py
CHANGED
@@ -33,12 +33,7 @@ import bedrock_agent
|
|
33 |
import warnings
|
34 |
|
35 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
st.set_page_config(
|
41 |
-
#page_title="Semantic Search using OpenSearch",
|
42 |
layout="wide",
|
43 |
page_icon="images/opensearch_mark_default.png"
|
44 |
)
|
@@ -47,15 +42,14 @@ USER_ICON = "images/user.png"
|
|
47 |
AI_ICON = "images/opensearch-twitter-card.png"
|
48 |
REGENERATE_ICON = "images/regenerate.png"
|
49 |
s3_bucket_ = "pdf-repo-uploads"
|
50 |
-
|
51 |
polly_client = boto3.Session(
|
52 |
region_name='us-east-1').client('polly')
|
53 |
|
54 |
# Check if the user ID is already stored in the session state
|
55 |
if 'user_id' in st.session_state:
|
56 |
user_id = st.session_state['user_id']
|
57 |
-
|
58 |
-
|
59 |
# If the user ID is not yet stored in the session state, generate a random UUID
|
60 |
else:
|
61 |
user_id = str(uuid.uuid4())
|
@@ -79,9 +73,6 @@ if "questions__" not in st.session_state:
|
|
79 |
|
80 |
if "answers__" not in st.session_state:
|
81 |
st.session_state.answers__ = []
|
82 |
-
|
83 |
-
if "input_index" not in st.session_state:
|
84 |
-
st.session_state.input_index = "hpijan2024hometrack"#"globalwarmingnew"#"hpijan2024hometrack_no_img_no_table"
|
85 |
|
86 |
if "input_is_rerank" not in st.session_state:
|
87 |
st.session_state.input_is_rerank = True
|
@@ -92,22 +83,17 @@ if "input_copali_rerank" not in st.session_state:
|
|
92 |
if "input_table_with_sql" not in st.session_state:
|
93 |
st.session_state.input_table_with_sql = False
|
94 |
|
95 |
-
|
96 |
if "inputs_" not in st.session_state:
|
97 |
st.session_state.inputs_ = {}
|
98 |
|
99 |
if "input_shopping_query" not in st.session_state:
|
100 |
-
st.session_state.input_shopping_query="get me shoes suitable for trekking"
|
101 |
|
102 |
|
103 |
if "input_rag_searchType" not in st.session_state:
|
104 |
st.session_state.input_rag_searchType = ["Sparse Search"]
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
region = 'us-east-1'
|
110 |
-
#bedrock_runtime_client = boto3.client('bedrock-runtime',region_name=region)
|
111 |
output = []
|
112 |
service = 'es'
|
113 |
|
@@ -122,48 +108,6 @@ st.markdown("""
|
|
122 |
</style>
|
123 |
""",unsafe_allow_html=True)
|
124 |
|
125 |
-
################ OpenSearch Py client #####################
|
126 |
-
|
127 |
-
# credentials = boto3.Session().get_credentials()
|
128 |
-
# awsauth = AWSV4SignerAuth(credentials, region, service)
|
129 |
-
|
130 |
-
# ospy_client = OpenSearch(
|
131 |
-
# hosts = [{'host': 'search-opensearchservi-75ucark0bqob-bzk6r6h2t33dlnpgx2pdeg22gi.us-east-1.es.amazonaws.com', 'port': 443}],
|
132 |
-
# http_auth = awsauth,
|
133 |
-
# use_ssl = True,
|
134 |
-
# verify_certs = True,
|
135 |
-
# connection_class = RequestsHttpConnection,
|
136 |
-
# pool_maxsize = 20
|
137 |
-
# )
|
138 |
-
|
139 |
-
################# using boto3 credentials ###################
|
140 |
-
|
141 |
-
|
142 |
-
# credentials = boto3.Session().get_credentials()
|
143 |
-
# awsauth = AWS4Auth(credentials.access_key, credentials.secret_key, region, service, session_token=credentials.token)
|
144 |
-
# service = 'es'
|
145 |
-
|
146 |
-
|
147 |
-
################# using boto3 credentials ####################
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
# if "input_searchType" not in st.session_state:
|
152 |
-
# st.session_state.input_searchType = "Conversational Search (RAG)"
|
153 |
-
|
154 |
-
# if "input_temperature" not in st.session_state:
|
155 |
-
# st.session_state.input_temperature = "0.001"
|
156 |
-
|
157 |
-
# if "input_topK" not in st.session_state:
|
158 |
-
# st.session_state.input_topK = 200
|
159 |
-
|
160 |
-
# if "input_topP" not in st.session_state:
|
161 |
-
# st.session_state.input_topP = 0.95
|
162 |
-
|
163 |
-
# if "input_maxTokens" not in st.session_state:
|
164 |
-
# st.session_state.input_maxTokens = 1024
|
165 |
-
|
166 |
-
|
167 |
def write_logo():
|
168 |
col1, col2, col3 = st.columns([5, 1, 5])
|
169 |
with col2:
|
@@ -175,8 +119,6 @@ def write_top_bar():
|
|
175 |
st.page_link("app.py", label=":orange[Home]", icon="🏠")
|
176 |
st.header("AI Shopping assistant",divider='rainbow')
|
177 |
|
178 |
-
#st.image(AI_ICON, use_column_width='always')
|
179 |
-
|
180 |
with col2:
|
181 |
st.write("")
|
182 |
st.write("")
|
@@ -193,17 +135,10 @@ if clear:
|
|
193 |
st.session_state.input_shopping_query=""
|
194 |
st.session_state.session_id_ = str(uuid.uuid1())
|
195 |
bedrock_agent.delete_memory()
|
196 |
-
|
197 |
-
# st.session_state.input_temperature = "0.001"
|
198 |
-
# st.session_state.input_topK = 200
|
199 |
-
# st.session_state.input_topP = 0.95
|
200 |
-
# st.session_state.input_maxTokens = 1024
|
201 |
|
202 |
|
203 |
def handle_input():
|
204 |
-
print("Question: "+st.session_state.input_shopping_query)
|
205 |
-
print("-----------")
|
206 |
-
print("\n\n")
|
207 |
if(st.session_state.input_shopping_query==''):
|
208 |
return ""
|
209 |
inputs = {}
|
@@ -212,10 +147,6 @@ def handle_input():
|
|
212 |
inputs[key.removeprefix('input_')] = st.session_state[key]
|
213 |
st.session_state.inputs_ = inputs
|
214 |
|
215 |
-
#######
|
216 |
-
|
217 |
-
|
218 |
-
#st.write(inputs)
|
219 |
question_with_id = {
|
220 |
'question': inputs["shopping_query"],
|
221 |
'id': len(st.session_state.questions__)
|
@@ -234,30 +165,6 @@ def handle_input():
|
|
234 |
st.session_state.input_shopping_query=""
|
235 |
|
236 |
|
237 |
-
|
238 |
-
# search_type = st.selectbox('Select the Search type',
|
239 |
-
# ('Conversational Search (RAG)',
|
240 |
-
# 'OpenSearch vector search',
|
241 |
-
# 'LLM Text Generation'
|
242 |
-
# ),
|
243 |
-
|
244 |
-
# key = 'input_searchType',
|
245 |
-
# help = "Select the type of retriever\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers_)"
|
246 |
-
# )
|
247 |
-
|
248 |
-
# col1, col2, col3, col4 = st.columns(4)
|
249 |
-
|
250 |
-
# with col1:
|
251 |
-
# st.text_input('Temperature', value = "0.001", placeholder='LLM Temperature', key = 'input_temperature',help = "Set the temperature of the Large Language model. \n Note: 1. Set this to values lower to 1 in the order of 0.001, 0.0001, such low values reduces hallucination and creativity in the LLM response; 2. This applies only when LLM is a part of the retriever pipeline")
|
252 |
-
# with col2:
|
253 |
-
# st.number_input('Top K', value = 200, placeholder='Top K', key = 'input_topK', step = 50, help = "This limits the LLM's predictions to the top k most probable tokens at each step of generation, this applies only when LLM is a prt of the retriever pipeline")
|
254 |
-
# with col3:
|
255 |
-
# st.number_input('Top P', value = 0.95, placeholder='Top P', key = 'input_topP', step = 0.05, help = "This sets a threshold probability and selects the top tokens whose cumulative probability exceeds the threshold while the tokens are generated by the LLM")
|
256 |
-
# with col4:
|
257 |
-
# st.number_input('Max Output Tokens', value = 500, placeholder='Max Output Tokens', key = 'input_maxTokens', step = 100, help = "This decides the total number of tokens generated as the final response. Note: Values greater than 1000 takes longer response time")
|
258 |
-
|
259 |
-
# st.markdown('---')
|
260 |
-
|
261 |
|
262 |
def write_user_message(md):
|
263 |
col1, col2 = st.columns([3,97])
|
@@ -265,8 +172,6 @@ def write_user_message(md):
|
|
265 |
with col1:
|
266 |
st.image(USER_ICON, use_column_width='always')
|
267 |
with col2:
|
268 |
-
#st.warning(md['question'])
|
269 |
-
|
270 |
st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
|
271 |
|
272 |
|
@@ -283,18 +188,9 @@ def render_answer(question,answer,index):
|
|
283 |
ans_ = answer['answer']
|
284 |
span_ans = ans_.replace('<question>',"<span style='fontSize:18px;color:#f37709;fontStyle:italic;'>").replace("</question>","</span>")
|
285 |
st.markdown("<p>"+span_ans+"</p>",unsafe_allow_html = True)
|
286 |
-
print("answer['source']")
|
287 |
-
print("-------------")
|
288 |
-
print(answer['source'])
|
289 |
-
print("-------------")
|
290 |
-
print(answer['last_tool'])
|
291 |
if(answer['last_tool']['name'] in ["generate_images","get_relevant_items_for_image","get_relevant_items_for_text","retrieve_with_hybrid_search","retrieve_with_keyword_search","get_any_general_recommendation"]):
|
292 |
use_interim_results = True
|
293 |
src_dict =json.loads(answer['last_tool']['response'].replace("'",'"'))
|
294 |
-
print("src_dict")
|
295 |
-
print("-------------")
|
296 |
-
print(src_dict)
|
297 |
-
#if("get_relevant_items_for_text" in src_dict):
|
298 |
if(use_interim_results and answer['last_tool']['name']!= 'generate_images' and answer['last_tool']['name']!= 'get_any_general_recommendation'):
|
299 |
key_ = answer['last_tool']['name']
|
300 |
|
@@ -310,9 +206,7 @@ def render_answer(question,answer,index):
|
|
310 |
if(index ==1):
|
311 |
with img_col2:
|
312 |
st.image(resizedImg,use_column_width = True,caption = item['title'])
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
if(answer['last_tool']['name'] == "generate_images" or answer['last_tool']['name'] == "get_any_general_recommendation"):
|
317 |
st.write("<br>",unsafe_allow_html = True)
|
318 |
gen_img_col1, gen_img_col2,gen_img_col2 = st.columns([30,30,30])
|
@@ -328,143 +222,17 @@ def render_answer(question,answer,index):
|
|
328 |
with gen_img_col1:
|
329 |
st.image(resizedImg,caption = "Generated image for "+key.split(".")[0],use_column_width = True)
|
330 |
st.write("<br>",unsafe_allow_html = True)
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
# def stream_():
|
339 |
-
# #use for streaming response on the client side
|
340 |
-
# for word in ans_.split(" "):
|
341 |
-
# yield word + " "
|
342 |
-
# time.sleep(0.04)
|
343 |
-
# #use for streaming response from Llm directly
|
344 |
-
# if(isinstance(ans_,botocore.eventstream.EventStream)):
|
345 |
-
# for event in ans_:
|
346 |
-
# chunk = event.get('chunk')
|
347 |
-
|
348 |
-
# if chunk:
|
349 |
-
|
350 |
-
# chunk_obj = json.loads(chunk.get('bytes').decode())
|
351 |
-
|
352 |
-
# if('content_block' in chunk_obj or ('delta' in chunk_obj and 'text' in chunk_obj['delta'])):
|
353 |
-
# key_ = list(chunk_obj.keys())[2]
|
354 |
-
# text = chunk_obj[key_]['text']
|
355 |
-
|
356 |
-
# clear_output(wait=True)
|
357 |
-
# output.append(text)
|
358 |
-
# yield text
|
359 |
-
# time.sleep(0.04)
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
# if(index == len(st.session_state.questions_)):
|
364 |
-
# st.write_stream(stream_)
|
365 |
-
# if(isinstance(st.session_state.answers_[index-1]['answer'],botocore.eventstream.EventStream)):
|
366 |
-
# st.session_state.answers_[index-1]['answer'] = "".join(output)
|
367 |
-
# else:
|
368 |
-
# st.write(ans_)
|
369 |
-
|
370 |
-
|
371 |
-
# polly_response = polly_client.synthesize_speech(VoiceId='Joanna',
|
372 |
-
# OutputFormat='ogg_vorbis',
|
373 |
-
# Text = ans_,
|
374 |
-
# Engine = 'neural')
|
375 |
-
|
376 |
-
# audio_col1, audio_col2 = st.columns([50,50])
|
377 |
-
# with audio_col1:
|
378 |
-
# st.audio(polly_response['AudioStream'].read(), format="audio/ogg")
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
#st.markdown("<div style='font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;border-radius: 10px;'>"+ans_+"</div>", unsafe_allow_html = True)
|
383 |
-
#st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Relevant images from the document :</b></div>", unsafe_allow_html = True)
|
384 |
-
#st.write("")
|
385 |
colu1,colu2,colu3 = st.columns([4,82,20])
|
386 |
if(answer['source']!={}):
|
387 |
with colu2:
|
388 |
with st.expander("Agent Traces:"):
|
389 |
st.write(answer['source'])
|
390 |
-
|
391 |
-
# if(len(res_img)>0):
|
392 |
-
# with st.expander("Images:"):
|
393 |
-
# col3,col4,col5 = st.columns([33,33,33])
|
394 |
-
# cols = [col3,col4]
|
395 |
-
# idx = 0
|
396 |
-
# #print(res_img)
|
397 |
-
# for img_ in res_img:
|
398 |
-
# if(img_['file'].lower()!='none' and idx < 2):
|
399 |
-
# img = img_['file'].split(".")[0]
|
400 |
-
# caption = img_['caption']
|
401 |
-
|
402 |
-
# with cols[idx]:
|
403 |
-
|
404 |
-
# st.image(parent_dirname+"/figures/"+st.session_state.input_index+"/"+img+".jpg")
|
405 |
-
# #st.write(caption)
|
406 |
-
# idx = idx+1
|
407 |
-
# #st.markdown("<div style='color:#e28743';padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'><b>Sources from the document:</b></div>", unsafe_allow_html = True)
|
408 |
-
# if(len(answer["table"] )>0):
|
409 |
-
# with st.expander("Table:"):
|
410 |
-
# df = pd.read_csv(answer["table"][0]['name'],skipinitialspace = True, on_bad_lines='skip',delimiter='`')
|
411 |
-
# df.fillna(method='pad', inplace=True)
|
412 |
-
# st.table(df)
|
413 |
-
# with st.expander("Raw sources:"):
|
414 |
-
# st.write(answer["source"])
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
# with col_3:
|
419 |
-
|
420 |
-
# #st.markdown("<div style='color:#e28743;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 5px;'><b>"+",".join(st.session_state.input_rag_searchType)+"</b></div>", unsafe_allow_html = True)
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
# if(index == len(st.session_state.questions_)):
|
425 |
-
|
426 |
-
# rdn_key = ''.join([random.choice(string.ascii_letters)
|
427 |
-
# for _ in range(10)])
|
428 |
-
# currentValue = ''.join(st.session_state.input_rag_searchType)+str(st.session_state.input_is_rerank)+str(st.session_state.input_table_with_sql)+st.session_state.input_index
|
429 |
-
# oldValue = ''.join(st.session_state.inputs_["rag_searchType"])+str(st.session_state.inputs_["is_rerank"])+str(st.session_state.inputs_["table_with_sql"])+str(st.session_state.inputs_["index"])
|
430 |
-
# #print("changing values-----------------")
|
431 |
-
# def on_button_click():
|
432 |
-
# # print("button clicked---------------")
|
433 |
-
# # print(currentValue)
|
434 |
-
# # print(oldValue)
|
435 |
-
# if(currentValue!=oldValue or 1==1):
|
436 |
-
# #print("----------regenerate----------------")
|
437 |
-
# st.session_state.input_query = st.session_state.questions_[-1]["question"]
|
438 |
-
# st.session_state.answers_.pop()
|
439 |
-
# st.session_state.questions_.pop()
|
440 |
-
|
441 |
-
# handle_input()
|
442 |
-
# with placeholder.container():
|
443 |
-
# render_all()
|
444 |
-
|
445 |
-
# if("currentValue" in st.session_state):
|
446 |
-
# del st.session_state["currentValue"]
|
447 |
-
|
448 |
-
# try:
|
449 |
-
# del regenerate
|
450 |
-
# except:
|
451 |
-
# pass
|
452 |
-
|
453 |
-
# #print("------------------------")
|
454 |
-
# #print(st.session_state)
|
455 |
-
|
456 |
-
# placeholder__ = st.empty()
|
457 |
-
|
458 |
-
# placeholder__.button("🔄",key=rdn_key,on_click=on_button_click)
|
459 |
|
460 |
#Each answer will have context of the question asked in order to associate the provided feedback with the respective question
|
461 |
def write_chat_message(md, q,index):
|
462 |
-
#res_img = md['image']
|
463 |
-
#st.session_state['session_id'] = res['session_id'] to be added in memory
|
464 |
chat = st.container()
|
465 |
with chat:
|
466 |
-
#print("st.session_state.input_index------------------")
|
467 |
-
#print(st.session_state.input_index)
|
468 |
render_answer(q,md,index)
|
469 |
|
470 |
def render_all():
|
@@ -480,173 +248,8 @@ with placeholder.container():
|
|
480 |
|
481 |
st.markdown("")
|
482 |
col_2, col_3 = st.columns([75,20])
|
483 |
-
|
484 |
-
# with col_1:
|
485 |
-
# st.markdown("<p style='padding:0px 0px 0px 0px; color:#FF9900;font-size:120%'><b>Ask:</b></p>",unsafe_allow_html=True, help = 'Enter the questions and click on "GO"')
|
486 |
-
|
487 |
with col_2:
|
488 |
-
#st.markdown("")
|
489 |
input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_shopping_query")
|
490 |
with col_3:
|
491 |
-
#hidden = st.button("RUN",disabled=True,key = "hidden")
|
492 |
-
# audio_value = st.audio_input("Record a voice message")
|
493 |
-
# print(audio_value)
|
494 |
play = st.button("Go",on_click=handle_input,key = "play")
|
495 |
-
#with st.sidebar:
|
496 |
-
# st.page_link("/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp/app.py", label=":orange[Home]", icon="🏠")
|
497 |
-
# st.subheader(":blue[Sample Data]")
|
498 |
-
# coln_1,coln_2 = st.columns([70,30])
|
499 |
-
# # index_select = st.radio("Choose one index",["UK Housing","Covid19 impacts on Ireland","Environmental Global Warming","BEIR Research"],
|
500 |
-
# # captions = ['[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)',
|
501 |
-
# # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)',
|
502 |
-
# # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)',
|
503 |
-
# # '[preview](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/BEIR.pdf)'],
|
504 |
-
# # key="input_rad_index")
|
505 |
-
# with coln_1:
|
506 |
-
# index_select = st.radio("Choose one index",["UK Housing","Global Warming stats","Covid19 impacts on Ireland"],key="input_rad_index")
|
507 |
-
# with coln_2:
|
508 |
-
# st.markdown("<p style='font-size:15px'>Preview file</p>",unsafe_allow_html=True)
|
509 |
-
# st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/HPI-Jan-2024-Hometrack.pdf)")
|
510 |
-
# st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/global_warming.pdf)")
|
511 |
-
# st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/covid19_ie.pdf)")
|
512 |
-
# #st.write("[:eyes:](https://github.com/aws-samples/AI-search-with-amazon-opensearch-service/blob/b559f82c07dfcca973f457c0a15d6444752553ab/rag/sample_pdfs/BEIR.pdf)")
|
513 |
-
# st.markdown("""
|
514 |
-
# <style>
|
515 |
-
# [data-testid=column]:nth-of-type(2) [data-testid=stVerticalBlock]{
|
516 |
-
# gap: 0rem;
|
517 |
-
# }
|
518 |
-
# [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
|
519 |
-
# gap: 0rem;
|
520 |
-
# }
|
521 |
-
# </style>
|
522 |
-
# """,unsafe_allow_html=True)
|
523 |
-
# # Initialize boto3 to use the S3 client.
|
524 |
-
# s3_client = boto3.resource('s3')
|
525 |
-
# bucket=s3_client.Bucket(s3_bucket_)
|
526 |
-
|
527 |
-
# objects = bucket.objects.filter(Prefix="sample_pdfs/")
|
528 |
-
# urls = []
|
529 |
-
|
530 |
-
# client = boto3.client('s3')
|
531 |
-
|
532 |
-
# for obj in objects:
|
533 |
-
# if obj.key.endswith('.pdf'):
|
534 |
-
|
535 |
-
# # Generate the S3 presigned URL
|
536 |
-
# s3_presigned_url = client.generate_presigned_url(
|
537 |
-
# ClientMethod='get_object',
|
538 |
-
# Params={
|
539 |
-
# 'Bucket': s3_bucket_,
|
540 |
-
# 'Key': obj.key
|
541 |
-
# },
|
542 |
-
# ExpiresIn=3600
|
543 |
-
# )
|
544 |
-
|
545 |
-
# # Print the created S3 presigned URL
|
546 |
-
# print(s3_presigned_url)
|
547 |
-
# urls.append(s3_presigned_url)
|
548 |
-
# #st.write("["+obj.key.split('/')[1]+"]("+s3_presigned_url+")")
|
549 |
-
# st.link_button(obj.key.split('/')[1], s3_presigned_url)
|
550 |
-
|
551 |
-
|
552 |
-
# st.subheader(":blue[Your multi-modal documents]")
|
553 |
-
# pdf_doc_ = st.file_uploader(
|
554 |
-
# "Upload your PDFs here and click on 'Process'", accept_multiple_files=False)
|
555 |
-
|
556 |
-
|
557 |
-
# pdf_docs = [pdf_doc_]
|
558 |
-
# if st.button("Process"):
|
559 |
-
# with st.spinner("Processing"):
|
560 |
-
# if os.path.isdir(parent_dirname+"/pdfs") == False:
|
561 |
-
# os.mkdir(parent_dirname+"/pdfs")
|
562 |
-
|
563 |
-
# for pdf_doc in pdf_docs:
|
564 |
-
# print(type(pdf_doc))
|
565 |
-
# pdf_doc_name = (pdf_doc.name).replace(" ","_")
|
566 |
-
# with open(os.path.join(parent_dirname+"/pdfs",pdf_doc_name),"wb") as f:
|
567 |
-
# f.write(pdf_doc.getbuffer())
|
568 |
-
|
569 |
-
# request_ = { "bucket": s3_bucket_,"key": pdf_doc_name}
|
570 |
-
# # if(st.session_state.input_copali_rerank):
|
571 |
-
# # copali.process_doc(request_)
|
572 |
-
# # else:
|
573 |
-
# rag_DocumentLoader.load_docs(request_)
|
574 |
-
# print('lambda done')
|
575 |
-
# st.success('you can start searching on your PDF')
|
576 |
-
|
577 |
-
# ############## haystach demo temporary addition ############
|
578 |
-
# # st.subheader(":blue[Multimodality]")
|
579 |
-
# # colu1,colu2 = st.columns([50,50])
|
580 |
-
# # with colu1:
|
581 |
-
# # in_images = st.toggle('Images', key = 'in_images', disabled = False)
|
582 |
-
# # with colu2:
|
583 |
-
# # in_tables = st.toggle('Tables', key = 'in_tables', disabled = False)
|
584 |
-
# # if(in_tables):
|
585 |
-
# # st.session_state.input_table_with_sql = True
|
586 |
-
# # else:
|
587 |
-
# # st.session_state.input_table_with_sql = False
|
588 |
-
|
589 |
-
# ############## haystach demo temporary addition ############
|
590 |
-
# if(pdf_doc_ is None or pdf_doc_ == ""):
|
591 |
-
# if(index_select == "Global Warming stats"):
|
592 |
-
# st.session_state.input_index = "globalwarmingnew"
|
593 |
-
# if(index_select == "Covid19 impacts on Ireland"):
|
594 |
-
# st.session_state.input_index = "covid19ie"#"choosetheknnalgorithmforyourbillionscaleusecasewithopensearchawsbigdatablog"
|
595 |
-
# if(index_select == "BEIR"):
|
596 |
-
# st.session_state.input_index = "2104"
|
597 |
-
# if(index_select == "UK Housing"):
|
598 |
-
# st.session_state.input_index = "hpijan2024hometrack"
|
599 |
-
# # if(in_images == True and in_tables == True):
|
600 |
-
# # st.session_state.input_index = "hpijan2024hometrack"
|
601 |
-
# # else:
|
602 |
-
# # if(in_images == True and in_tables == False):
|
603 |
-
# # st.session_state.input_index = "hpijan2024hometrackno_table"
|
604 |
-
# # else:
|
605 |
-
# # if(in_images == False and in_tables == True):
|
606 |
-
# # st.session_state.input_index = "hpijan2024hometrackno_images"
|
607 |
-
# # else:
|
608 |
-
# # st.session_state.input_index = "hpijan2024hometrack_no_img_no_table"
|
609 |
-
|
610 |
-
|
611 |
-
# # if(in_images):
|
612 |
-
# # st.session_state.input_include_images = True
|
613 |
-
# # else:
|
614 |
-
# # st.session_state.input_include_images = False
|
615 |
-
# # if(in_tables):
|
616 |
-
# # st.session_state.input_include_tables = True
|
617 |
-
# # else:
|
618 |
-
# # st.session_state.input_include_tables = False
|
619 |
-
|
620 |
-
# custom_index = st.text_input("If uploaded the file already, enter the original file name", value = "")
|
621 |
-
# if(custom_index!=""):
|
622 |
-
# st.session_state.input_index = re.sub('[^A-Za-z0-9]+', '', (custom_index.lower().replace(".pdf","").split("/")[-1].split(".")[0]).lower())
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
# st.subheader(":blue[Retriever]")
|
627 |
-
# search_type = st.multiselect('Select the Retriever(s)',
|
628 |
-
# ['Keyword Search',
|
629 |
-
# 'Vector Search',
|
630 |
-
# 'Sparse Search',
|
631 |
-
# ],
|
632 |
-
# ['Sparse Search'],
|
633 |
-
|
634 |
-
# key = 'input_rag_searchType',
|
635 |
-
# help = "Select the type of Search, adding more than one search type will activate hybrid search"#\n1. Conversational Search (Recommended) - This will include both the OpenSearch and LLM in the retrieval pipeline \n (note: This will put opensearch response as context to LLM to answer) \n2. OpenSearch vector search - This will put only OpenSearch's vector search in the pipeline, \n(Warning: this will lead to unformatted results )\n3. LLM Text Generation - This will include only LLM in the pipeline, \n(Warning: This will give hallucinated and out of context answers)"
|
636 |
-
# )
|
637 |
-
|
638 |
-
# re_rank = st.checkbox('Re-rank results', key = 'input_re_rank', disabled = False, value = True, help = "Checking this box will re-rank the results using a cross-encoder model")
|
639 |
-
|
640 |
-
# if(re_rank):
|
641 |
-
# st.session_state.input_is_rerank = True
|
642 |
-
# else:
|
643 |
-
# st.session_state.input_is_rerank = False
|
644 |
-
|
645 |
-
# # copali_rerank = st.checkbox("Search and Re-rank with Token level vectors",key = 'copali_rerank',help = "Enabling this option uses 'Copali' model's page level image embeddings to retrieve documents and MaxSim to re-rank the pages.\n\n Hugging Face Model: https://huggingface.co/vidore/colpali")
|
646 |
-
|
647 |
-
# # if(copali_rerank):
|
648 |
-
# # st.session_state.input_copali_rerank = True
|
649 |
-
# # else:
|
650 |
-
# # st.session_state.input_copali_rerank = False
|
651 |
-
|
652 |
-
|
|
|
33 |
import warnings
|
34 |
|
35 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
|
|
|
|
|
|
|
36 |
st.set_page_config(
|
|
|
37 |
layout="wide",
|
38 |
page_icon="images/opensearch_mark_default.png"
|
39 |
)
|
|
|
42 |
AI_ICON = "images/opensearch-twitter-card.png"
|
43 |
REGENERATE_ICON = "images/regenerate.png"
|
44 |
s3_bucket_ = "pdf-repo-uploads"
|
45 |
+
|
46 |
polly_client = boto3.Session(
|
47 |
region_name='us-east-1').client('polly')
|
48 |
|
49 |
# Check if the user ID is already stored in the session state
|
50 |
if 'user_id' in st.session_state:
|
51 |
user_id = st.session_state['user_id']
|
52 |
+
|
|
|
53 |
# If the user ID is not yet stored in the session state, generate a random UUID
|
54 |
else:
|
55 |
user_id = str(uuid.uuid4())
|
|
|
73 |
|
74 |
if "answers__" not in st.session_state:
|
75 |
st.session_state.answers__ = []
|
|
|
|
|
|
|
76 |
|
77 |
if "input_is_rerank" not in st.session_state:
|
78 |
st.session_state.input_is_rerank = True
|
|
|
83 |
if "input_table_with_sql" not in st.session_state:
|
84 |
st.session_state.input_table_with_sql = False
|
85 |
|
|
|
86 |
if "inputs_" not in st.session_state:
|
87 |
st.session_state.inputs_ = {}
|
88 |
|
89 |
if "input_shopping_query" not in st.session_state:
|
90 |
+
st.session_state.input_shopping_query="get me shoes suitable for trekking"
|
91 |
|
92 |
|
93 |
if "input_rag_searchType" not in st.session_state:
|
94 |
st.session_state.input_rag_searchType = ["Sparse Search"]
|
95 |
|
|
|
|
|
|
|
96 |
region = 'us-east-1'
|
|
|
97 |
output = []
|
98 |
service = 'es'
|
99 |
|
|
|
108 |
</style>
|
109 |
""",unsafe_allow_html=True)
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
def write_logo():
|
112 |
col1, col2, col3 = st.columns([5, 1, 5])
|
113 |
with col2:
|
|
|
119 |
st.page_link("app.py", label=":orange[Home]", icon="🏠")
|
120 |
st.header("AI Shopping assistant",divider='rainbow')
|
121 |
|
|
|
|
|
122 |
with col2:
|
123 |
st.write("")
|
124 |
st.write("")
|
|
|
135 |
st.session_state.input_shopping_query=""
|
136 |
st.session_state.session_id_ = str(uuid.uuid1())
|
137 |
bedrock_agent.delete_memory()
|
138 |
+
|
|
|
|
|
|
|
|
|
139 |
|
140 |
|
141 |
def handle_input():
|
|
|
|
|
|
|
142 |
if(st.session_state.input_shopping_query==''):
|
143 |
return ""
|
144 |
inputs = {}
|
|
|
147 |
inputs[key.removeprefix('input_')] = st.session_state[key]
|
148 |
st.session_state.inputs_ = inputs
|
149 |
|
|
|
|
|
|
|
|
|
150 |
question_with_id = {
|
151 |
'question': inputs["shopping_query"],
|
152 |
'id': len(st.session_state.questions__)
|
|
|
165 |
st.session_state.input_shopping_query=""
|
166 |
|
167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
def write_user_message(md):
|
170 |
col1, col2 = st.columns([3,97])
|
|
|
172 |
with col1:
|
173 |
st.image(USER_ICON, use_column_width='always')
|
174 |
with col2:
|
|
|
|
|
175 |
st.markdown("<div style='color:#e28743';font-size:18px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;'>"+md['question']+"</div>", unsafe_allow_html = True)
|
176 |
|
177 |
|
|
|
188 |
ans_ = answer['answer']
|
189 |
span_ans = ans_.replace('<question>',"<span style='fontSize:18px;color:#f37709;fontStyle:italic;'>").replace("</question>","</span>")
|
190 |
st.markdown("<p>"+span_ans+"</p>",unsafe_allow_html = True)
|
|
|
|
|
|
|
|
|
|
|
191 |
if(answer['last_tool']['name'] in ["generate_images","get_relevant_items_for_image","get_relevant_items_for_text","retrieve_with_hybrid_search","retrieve_with_keyword_search","get_any_general_recommendation"]):
|
192 |
use_interim_results = True
|
193 |
src_dict =json.loads(answer['last_tool']['response'].replace("'",'"'))
|
|
|
|
|
|
|
|
|
194 |
if(use_interim_results and answer['last_tool']['name']!= 'generate_images' and answer['last_tool']['name']!= 'get_any_general_recommendation'):
|
195 |
key_ = answer['last_tool']['name']
|
196 |
|
|
|
206 |
if(index ==1):
|
207 |
with img_col2:
|
208 |
st.image(resizedImg,use_column_width = True,caption = item['title'])
|
209 |
+
|
|
|
|
|
210 |
if(answer['last_tool']['name'] == "generate_images" or answer['last_tool']['name'] == "get_any_general_recommendation"):
|
211 |
st.write("<br>",unsafe_allow_html = True)
|
212 |
gen_img_col1, gen_img_col2,gen_img_col2 = st.columns([30,30,30])
|
|
|
222 |
with gen_img_col1:
|
223 |
st.image(resizedImg,caption = "Generated image for "+key.split(".")[0],use_column_width = True)
|
224 |
st.write("<br>",unsafe_allow_html = True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
colu1,colu2,colu3 = st.columns([4,82,20])
|
226 |
if(answer['source']!={}):
|
227 |
with colu2:
|
228 |
with st.expander("Agent Traces:"):
|
229 |
st.write(answer['source'])
|
230 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
#Each answer will have context of the question asked in order to associate the provided feedback with the respective question
|
233 |
def write_chat_message(md, q,index):
|
|
|
|
|
234 |
chat = st.container()
|
235 |
with chat:
|
|
|
|
|
236 |
render_answer(q,md,index)
|
237 |
|
238 |
def render_all():
|
|
|
248 |
|
249 |
st.markdown("")
|
250 |
col_2, col_3 = st.columns([75,20])
|
251 |
+
|
|
|
|
|
|
|
252 |
with col_2:
|
|
|
253 |
input = st.text_input( "Ask here",label_visibility = "collapsed",key="input_shopping_query")
|
254 |
with col_3:
|
|
|
|
|
|
|
255 |
play = st.button("Go",on_click=handle_input,key = "play")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pages/Semantic_Search.py
CHANGED
@@ -24,24 +24,18 @@ import base64
|
|
24 |
import shutil
|
25 |
import re
|
26 |
from requests.auth import HTTPBasicAuth
|
27 |
-
#import utilities.re_ranker as re_ranker
|
28 |
# from nltk.stem import PorterStemmer
|
29 |
# from nltk.tokenize import word_tokenize
|
30 |
import query_rewrite
|
31 |
import amazon_rekognition
|
|
|
32 |
#from st_click_detector import click_detector
|
33 |
import llm_eval
|
34 |
import all_search_execute
|
35 |
import warnings
|
36 |
|
37 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
st.set_page_config(
|
43 |
-
#page_title="Semantic Search using OpenSearch",
|
44 |
-
#layout="wide",
|
45 |
page_icon="images/opensearch_mark_default.png"
|
46 |
)
|
47 |
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
@@ -58,11 +52,6 @@ st.markdown("""
|
|
58 |
#ps = PorterStemmer()
|
59 |
|
60 |
st.session_state.REGION = 'us-east-1'
|
61 |
-
|
62 |
-
|
63 |
-
#from langchain.callbacks.base import BaseCallbackHandler
|
64 |
-
|
65 |
-
|
66 |
USER_ICON = "images/user.png"
|
67 |
AI_ICON = "images/opensearch-twitter-card.png"
|
68 |
REGENERATE_ICON = "images/regenerate.png"
|
@@ -170,12 +159,6 @@ if "input_ndcg" not in st.session_state:
|
|
170 |
if "gen_image_str" not in st.session_state:
|
171 |
st.session_state.gen_image_str=""
|
172 |
|
173 |
-
# if "input_searchType" not in st.session_state:
|
174 |
-
# st.session_state.input_searchType = ['Keyword Search']
|
175 |
-
|
176 |
-
# if "input_must" not in st.session_state:
|
177 |
-
# st.session_state.input_must = ["Category","Price","Gender","Style"]
|
178 |
-
|
179 |
if "input_NormType" not in st.session_state:
|
180 |
st.session_state.input_NormType = "min_max"
|
181 |
|
@@ -261,25 +244,8 @@ if(search_all_type==True):
|
|
261 |
'Multimodal Search',
|
262 |
'NeuralSparse Search',
|
263 |
]
|
264 |
-
|
265 |
-
|
266 |
-
# html("""
|
267 |
-
# <script>
|
268 |
-
# // Locate elements
|
269 |
-
# var decoration = window.parent.document.querySelectorAll('[data-testid="stDecoration"]')[0];
|
270 |
-
# decoration.style.height = "3.0rem";
|
271 |
-
# decoration.style.right = "45px";
|
272 |
-
# // Adjust text decorations
|
273 |
-
# decoration.innerText = "Semantic Search with OpenSearch!"; // Replace with your desired text
|
274 |
-
# decoration.style.fontWeight = "bold";
|
275 |
-
# decoration.style.display = "flex";
|
276 |
-
# decoration.style.justifyContent = "center";
|
277 |
-
# decoration.style.alignItems = "center";
|
278 |
-
# decoration.style.fontWeight = "bold";
|
279 |
-
# decoration.style.backgroundImage = url('/home/ubuntu/AI-search-with-amazon-opensearch-service/OpenSearchApp/images/service_logo.png'); // Remove background image
|
280 |
-
# decoration.style.backgroundSize = "unset"; // Remove background size
|
281 |
-
# </script>
|
282 |
-
# """, width=0, height=0)
|
283 |
|
284 |
|
285 |
|
@@ -448,31 +414,12 @@ def handle_input():
|
|
448 |
|
449 |
|
450 |
inputs = {}
|
451 |
-
# if(st.session_state.input_imageUpload == 'yes'):
|
452 |
-
# st.session_state.input_searchType = 'Multi-modal Search'
|
453 |
-
# if(st.session_state.input_sparse == 'enabled' or st.session_state.input_is_rewrite_query == 'enabled'):
|
454 |
-
# st.session_state.input_searchType = 'Keyword Search'
|
455 |
if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType):
|
456 |
old_rekog_label = st.session_state.input_rekog_label
|
457 |
st.session_state.input_rekog_label = amazon_rekognition.extract_image_metadata(st.session_state.bytes_for_rekog)
|
458 |
if(st.session_state.input_text == ""):
|
459 |
st.session_state.input_text = st.session_state.input_rekog_label
|
460 |
|
461 |
-
# if(st.session_state.input_imageUpload == 'yes'):
|
462 |
-
# if(st.session_state.input_searchType!='Multi-modal Search'):
|
463 |
-
# if(st.session_state.input_searchType=='Keyword Search'):
|
464 |
-
# if(st.session_state.input_rekognition != 'enabled'):
|
465 |
-
# st.error('For Keyword Search using images, enable "Enrich metadata for Images" in the left panel',icon = "🚨")
|
466 |
-
# #st.session_state.input_rekognition = 'enabled'
|
467 |
-
# st.switch_page('pages/1_Semantic_Search.py')
|
468 |
-
# #st.stop()
|
469 |
-
|
470 |
-
# else:
|
471 |
-
# st.error('Please set the search type as "Keyword Search (enabling Enrich metadata for Images) or Multi-modal Search"',icon = "🚨")
|
472 |
-
# #st.session_state.input_searchType='Multi-modal Search'
|
473 |
-
# st.switch_page('pages/1_Semantic_Search.py')
|
474 |
-
# #st.stop()
|
475 |
-
|
476 |
|
477 |
weightage = {}
|
478 |
st.session_state.weights_ = []
|
@@ -511,44 +458,13 @@ def handle_input():
|
|
511 |
else:
|
512 |
weightage[original_key] = 0.0
|
513 |
st.session_state[key] = 0.0
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
inputs['weightage']=weightage
|
528 |
st.session_state.input_weightage = weightage
|
529 |
|
530 |
-
print("====================")
|
531 |
-
print(st.session_state.weights_)
|
532 |
-
print(st.session_state.input_weightage )
|
533 |
-
print("====================")
|
534 |
-
#print("***************************")
|
535 |
-
#print(sum(weights_))
|
536 |
-
# if(sum(st.session_state.weights_)!=100):
|
537 |
-
# st.warning('The total weight of selected search type(s) should be equal to 100',icon = "🚨")
|
538 |
-
# refresh = st.button("Re-Enter")
|
539 |
-
# if(refresh):
|
540 |
-
# st.switch_page('pages/1_Semantic_Search.py')
|
541 |
-
# st.stop()
|
542 |
-
|
543 |
-
|
544 |
-
# #st.session_state.input_rekognition = 'enabled'
|
545 |
-
# st.rerun()
|
546 |
-
|
547 |
-
|
548 |
|
549 |
st.session_state.inputs_ = inputs
|
550 |
|
551 |
-
#st.write(inputs)
|
552 |
question_with_id = {
|
553 |
'question': inputs["text"],
|
554 |
'id': len(st.session_state.questions)
|
@@ -567,19 +483,15 @@ def handle_input():
|
|
567 |
|
568 |
if(st.session_state.input_is_rewrite_query == 'enabled' or (st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType)):
|
569 |
query_rewrite.get_new_query_res(st.session_state.input_text)
|
570 |
-
|
571 |
-
print(st.session_state.input_rewritten_query)
|
572 |
-
print("-------------------")
|
573 |
else:
|
574 |
st.session_state.input_rewritten_query = ""
|
575 |
|
576 |
-
|
577 |
-
# ans__ = amazon_rekognition.call(st.session_state.input_text,st.session_state.input_rekog_label)
|
578 |
-
# else:
|
579 |
ans__ = all_search_execute.handler(inputs, st.session_state['session_id'])
|
580 |
|
581 |
st.session_state.answers.append({
|
582 |
-
'answer': ans__
|
583 |
'search_type':inputs['searchType'],
|
584 |
'id': len(st.session_state.questions)
|
585 |
})
|
@@ -587,21 +499,8 @@ def handle_input():
|
|
587 |
st.session_state.answers_none_rank = st.session_state.answers
|
588 |
if(st.session_state.input_evaluate == "enabled"):
|
589 |
llm_eval.eval(st.session_state.questions, st.session_state.answers)
|
590 |
-
|
591 |
-
#st.session_state.input_searchType=st.session_state.input_searchType
|
592 |
-
|
593 |
def write_top_bar():
|
594 |
-
# st.markdown("""
|
595 |
-
# <style>
|
596 |
-
# [data-testid=column]:nth-of-type(1) [data-testid=stVerticalBlock]{
|
597 |
-
# gap: 0rem;
|
598 |
-
# }
|
599 |
-
# </style>
|
600 |
-
# """,unsafe_allow_html=True)
|
601 |
-
#print("top bar")
|
602 |
-
# st.title(':mag: AI powered OpenSearch')
|
603 |
-
# st.write("")
|
604 |
-
# st.write("")
|
605 |
col1, col2,col3,col4 = st.columns([2.5,35,8,7])
|
606 |
with col1:
|
607 |
st.image(TEXT_ICON, use_column_width='always')
|
@@ -630,9 +529,6 @@ def write_top_bar():
|
|
630 |
st.markdown("<div style = 'height:43px'></div>",unsafe_allow_html=True)
|
631 |
st.button("Generate",disabled=False,key = "generate",on_click = generate_images, args=(tab1,"default_img"))
|
632 |
|
633 |
-
# image_select = st.select_slider(
|
634 |
-
# "Select a image",
|
635 |
-
# options=["Image 1","Image 2","Image 3"], value = None, disabled = st.session_state.radio_disabled,key = "image_select")
|
636 |
image_select = st.radio("Choose one image", ["Image 1","Image 2","Image 3"],index=None, horizontal = True,key = 'image_select',disabled = st.session_state.radio_disabled)
|
637 |
st.markdown("""
|
638 |
<style>
|
@@ -642,25 +538,10 @@ def write_top_bar():
|
|
642 |
</style>
|
643 |
""",unsafe_allow_html=True)
|
644 |
if(st.session_state.image_select is not None and st.session_state.image_select !="" and len(st.session_state.img_gen)!=0):
|
645 |
-
print("image_select")
|
646 |
-
print("------------")
|
647 |
-
print(st.session_state.image_select)
|
648 |
st.session_state.input_rad_1 = st.session_state.image_select.split(" ")[1]
|
649 |
else:
|
650 |
st.session_state.input_rad_1 = ""
|
651 |
-
|
652 |
-
# with rad1:
|
653 |
-
# btn1 = st.button("choose image 1", disabled = st.session_state.radio_disabled)
|
654 |
-
# with rad2:
|
655 |
-
# btn2 = st.button("choose image 2", disabled = st.session_state.radio_disabled)
|
656 |
-
# with rad3:
|
657 |
-
# btn3 = st.button("choose image 3", disabled = st.session_state.radio_disabled)
|
658 |
-
# if(btn1):
|
659 |
-
# st.session_state.input_rad_1 = "1"
|
660 |
-
# if(btn2):
|
661 |
-
# st.session_state.input_rad_1 = "2"
|
662 |
-
# if(btn3):
|
663 |
-
# st.session_state.input_rad_1 = "3"
|
664 |
|
665 |
|
666 |
generate_images(tab1,gen_images)
|
@@ -669,19 +550,11 @@ def write_top_bar():
|
|
669 |
with tab2:
|
670 |
st.session_state.img_doc = st.file_uploader(
|
671 |
"Upload image", accept_multiple_files=False,type = ['png', 'jpg'])
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
return clear,tab1
|
678 |
|
679 |
clear,tab_ = write_top_bar()
|
680 |
|
681 |
if clear:
|
682 |
-
|
683 |
-
|
684 |
-
print("clear1")
|
685 |
st.session_state.questions = []
|
686 |
st.session_state.answers = []
|
687 |
|
@@ -697,18 +570,7 @@ if clear:
|
|
697 |
st.session_state.input_rad_1 = ""
|
698 |
|
699 |
|
700 |
-
|
701 |
-
# with placeholder1.container():
|
702 |
-
# generate_images(tab_,st.session_state.image_prompt)
|
703 |
-
|
704 |
-
|
705 |
-
#st.session_state.input_text=""
|
706 |
-
# st.session_state.input_searchType="Conversational Search (RAG)"
|
707 |
-
# st.session_state.input_temperature = "0.001"
|
708 |
-
# st.session_state.input_topK = 200
|
709 |
-
# st.session_state.input_topP = 0.95
|
710 |
-
# st.session_state.input_maxTokens = 1024
|
711 |
-
|
712 |
col1, col3, col4 = st.columns([70,18,12])
|
713 |
|
714 |
with col1:
|
@@ -732,7 +594,7 @@ with col4:
|
|
732 |
evaluate = st.toggle(' ', key = 'evaluate', disabled = False) #help = "Checking this box will use LLM to evaluate results as relevant and irrelevant. \n\n This option increases the latency")
|
733 |
if(evaluate):
|
734 |
st.session_state.input_evaluate = "enabled"
|
735 |
-
|
736 |
else:
|
737 |
st.session_state.input_evaluate = "disabled"
|
738 |
|
@@ -740,11 +602,7 @@ with col4:
|
|
740 |
if(search_all_type == True or 1==1):
|
741 |
with st.sidebar:
|
742 |
st.page_link("app.py", label=":orange[Home]", icon="🏠")
|
743 |
-
|
744 |
-
#st.warning('Note: After changing any of the below settings, click "SEARCH" button or 🔄 to apply the changes', icon="⚠️")
|
745 |
-
#st.header(' :gear: :orange[Fine-tune Search]')
|
746 |
-
#st.write("Note: After changing any of the below settings, click 'SEARCH' button or '🔄' to apply the changes")
|
747 |
-
#st.subheader(':blue[Keyword Search]')
|
748 |
|
749 |
########################## enable for query_rewrite ########################
|
750 |
rewrite_query = st.checkbox('Auto-apply filters', key = 'query_rewrite', disabled = False, help = "Checking this box will use LLM to rewrite your query. \n\n Here your natural language query is transformed into OpenSearch query with added filters and attributes")
|
@@ -754,6 +612,8 @@ if(search_all_type == True or 1==1):
|
|
754 |
key = 'input_must',
|
755 |
)
|
756 |
########################## enable for query_rewrite ########################
|
|
|
|
|
757 |
####### Filters #########
|
758 |
|
759 |
st.subheader(':blue[Filters]')
|
@@ -776,25 +636,6 @@ if(search_all_type == True or 1==1):
|
|
776 |
|
777 |
|
778 |
clear_filter = st.button("Clear Filters",on_click=clear_filter)
|
779 |
-
|
780 |
-
|
781 |
-
# filter_place_holder = st.container()
|
782 |
-
# with filter_place_holder:
|
783 |
-
# st.selectbox("Select one Category", ("accessories", "books","floral","furniture","hot_dispensed","jewelry","tools","apparel","cold_dispensed","food_service","groceries","housewares","outdoors","salty_snacks","videos","beauty","electronics","footwear","homedecor","instruments","seasonal"),index = None,key = "input_category")
|
784 |
-
# st.selectbox("Select one Gender", ("male","female"),index = None,key = "input_gender")
|
785 |
-
# st.slider("Select a range of price", 0, 2000, (0, 0),50, key = "input_price")
|
786 |
-
|
787 |
-
# st.session_state.input_category=None
|
788 |
-
# st.session_state.input_gender=None
|
789 |
-
# st.session_state.input_price=(0,0)
|
790 |
-
|
791 |
-
print("--------------------filters---------------")
|
792 |
-
print(st.session_state.input_gender)
|
793 |
-
print(st.session_state.input_manual_filter)
|
794 |
-
print("--------------------filters---------------")
|
795 |
-
|
796 |
-
|
797 |
-
|
798 |
####### Filters #########
|
799 |
|
800 |
if('NeuralSparse Search' in st.session_state.search_types):
|
@@ -802,111 +643,21 @@ if(search_all_type == True or 1==1):
|
|
802 |
sparse_filter = st.slider('Keep only sparse tokens with weight >=', 0.0, 1.0, 0.5,0.1,key = 'input_sparse_filter', help = 'Use this slider to set the minimum weight that the sparse vector token weights should meet, rest are filtered out')
|
803 |
|
804 |
|
805 |
-
#sql_query = st.checkbox('Re-write as SQL query', key = 'sql_rewrite', disabled = True, help = "In Progress")
|
806 |
st.session_state.input_is_rewrite_query = 'disabled'
|
807 |
st.session_state.input_is_sql_query = 'disabled'
|
808 |
|
809 |
########################## enable for query_rewrite ########################
|
810 |
if rewrite_query:
|
811 |
-
#st.write(st.session_state.inputs_)
|
812 |
st.session_state.input_is_rewrite_query = 'enabled'
|
813 |
-
|
814 |
-
# #st.write(st.session_state.inputs_)
|
815 |
-
# st.session_state.input_is_sql_query = 'enabled'
|
816 |
-
########################## enable for sql conversion ########################
|
817 |
-
|
818 |
-
|
819 |
-
#st.markdown('---')
|
820 |
-
#st.header('Fine-tune keyword Search', divider='rainbow')
|
821 |
-
#st.subheader('Note: The below selection applies only when the Search type is set to Keyword Search')
|
822 |
-
|
823 |
-
|
824 |
-
# st.markdown("<u>Enrich metadata for :</u>",unsafe_allow_html=True)
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
# c3,c4 = st.columns([10,90])
|
829 |
-
# with c4:
|
830 |
-
# rekognition = st.checkbox('Images', key = 'rekognition', help = "Checking this box will use AI to extract metadata for images that are present in query and documents")
|
831 |
-
# if rekognition:
|
832 |
-
# #st.write(st.session_state.inputs_)
|
833 |
-
# st.session_state.input_rekognition = 'enabled'
|
834 |
-
# else:
|
835 |
-
# st.session_state.input_rekognition = "disabled"
|
836 |
-
|
837 |
-
#st.markdown('---')
|
838 |
-
#st.header('Fine-tune Hybrid Search', divider='rainbow')
|
839 |
-
#st.subheader('Note: The below parameters apply only when the Search type is set to Hybrid Search')
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
847 |
-
#st.write("---")
|
848 |
-
#if(st.session_state.max_selections == "None"):
|
849 |
st.subheader(':blue[Hybrid Search]')
|
850 |
-
# st.selectbox('Select the Hybrid Search type',
|
851 |
-
# ("OpenSearch Hybrid Query","Reciprocal Rank Fusion"),key = 'input_hybridType')
|
852 |
-
# equal_weight = st.button("Give equal weights to selected searches")
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
#st.warning('Weight of each of the selected search type should be greater than 0 and the total weight of all the selected search type(s) should be equal to 100',icon = "⚠️")
|
860 |
-
|
861 |
-
|
862 |
-
#st.markdown("<p style = 'font-size:14.5px;font-style:italic;'>Set Weights</p>",unsafe_allow_html=True)
|
863 |
-
|
864 |
with st.expander("Set query Weightage:"):
|
865 |
st.number_input("Keyword %", min_value=0, max_value=100, value=100, step=5, key='input_Keyword-weight', help=None)
|
866 |
st.number_input("Vector %", min_value=0, max_value=100, value=0, step=5, key='input_Vector-weight', help=None)
|
867 |
st.number_input("Multimodal %", min_value=0, max_value=100, value=0, step=5, key='input_Multimodal-weight', help=None)
|
868 |
st.number_input("NeuralSparse %", min_value=0, max_value=100, value=0, step=5, key='input_NeuralSparse-weight', help=None)
|
869 |
|
870 |
-
|
871 |
-
# counter = 0
|
872 |
-
# num_search = len(st.session_state.input_searchType)
|
873 |
-
# weight_type = ["input_Keyword-weight","input_Vector-weight","input_Multimodal-weight","input_NeuralSparse-weight"]
|
874 |
-
# for type in weight_type:
|
875 |
-
# if(type.split("-")[0].replace("input_","")+ " Search" in st.session_state.input_searchType):
|
876 |
-
# print("ssssssssssss")
|
877 |
-
# counter = counter +1
|
878 |
-
# extra_weight = 100%num_search
|
879 |
-
# if(counter == num_search):
|
880 |
-
# cal_weight = math.trunc(100/num_search)+extra_weight
|
881 |
-
# else:
|
882 |
-
# cal_weight = math.trunc(100/num_search)
|
883 |
-
# st.session_state[weight_type] = cal_weight
|
884 |
-
# else:
|
885 |
-
# st.session_state[weight_type] = 0
|
886 |
-
#weight = st.slider('Weight for Vector Search', 0.0, 1.0, 0.5,0.1,key = 'input_weight', help = 'Use this slider to set the weightage for keyword and vector search, higher values of the slider indicate the increased weightage for semantic search.\n\n This applies only when the search type is set to Hybrid Search')
|
887 |
-
# st.selectbox('Select the Normalisation type',
|
888 |
-
# ('min_max',
|
889 |
-
# 'l2'
|
890 |
-
# ),
|
891 |
-
#st.write("---")
|
892 |
-
# key = 'input_NormType',
|
893 |
-
# disabled = True,
|
894 |
-
# help = "Select the type of Normalisation to be applied on the two sets of scores"
|
895 |
-
# )
|
896 |
-
|
897 |
-
# st.selectbox('Select the Score Combination type',
|
898 |
-
# ('arithmetic_mean','geometric_mean','harmonic_mean'
|
899 |
-
# ),
|
900 |
-
|
901 |
-
# key = 'input_CombineType',
|
902 |
-
# disabled = True,
|
903 |
-
# help = "Select the Combination strategy to be used while combining the two scores of the two search queries for every document"
|
904 |
-
# )
|
905 |
-
|
906 |
-
#st.markdown('---')
|
907 |
-
|
908 |
-
#st.header('Select the ML Model for text embedding', divider='rainbow')
|
909 |
-
#st.subheader('Note: The below selection applies only when the Search type is set to Vector or Hybrid Search')
|
910 |
if(st.session_state.re_ranker == "true"):
|
911 |
st.subheader(':blue[Re-ranking]')
|
912 |
reranker = st.selectbox('Choose a Re-Ranker',
|
@@ -916,41 +667,19 @@ if(search_all_type == True or 1==1):
|
|
916 |
|
917 |
key = 'input_reranker',
|
918 |
help = 'Select the Re-Ranker type, select "None" to apply no re-ranking of the results',
|
919 |
-
#on_change = re_ranker.re_rank,
|
920 |
args=(st.session_state.questions, st.session_state.answers)
|
921 |
|
922 |
)
|
923 |
-
|
924 |
-
# st.subheader('Text Embeddings Model')
|
925 |
-
# model_type = st.selectbox('Select the Text Embeddings Model',
|
926 |
-
# ('Titan-Embed-Text-v1','GPT-J-6B'
|
927 |
-
|
928 |
-
# ),
|
929 |
-
|
930 |
-
# key = 'input_modelType',
|
931 |
-
# help = "Select the Text embedding model, this applies only for the vector and hybrid search"
|
932 |
-
# )
|
933 |
-
|
934 |
-
#st.markdown('---')
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
#st.markdown('---')
|
943 |
|
944 |
|
945 |
def write_user_message(md,ans):
|
946 |
-
#print(ans)
|
947 |
ans = ans["answer"][0]
|
948 |
col1, col2, col3 = st.columns([3,40,20])
|
949 |
|
950 |
with col1:
|
951 |
st.image(USER_ICON, use_column_width='always')
|
952 |
with col2:
|
953 |
-
#st.warning(md['question'])
|
954 |
st.markdown("<div style='fontSize:15px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'>Input Text: </div><div style='fontSize:25px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;color:#e28743'>"+md['question']+"</div>", unsafe_allow_html = True)
|
955 |
if('query_sparse' in ans):
|
956 |
with st.expander("Expanded Query:"):
|
@@ -1011,10 +740,7 @@ def render_answer(answer,index):
|
|
1011 |
span_color = "red"
|
1012 |
st.markdown("<span style='fontSize:20px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 20px;font-family:Courier New;color:#e28743'>Relevance:" +str('%.3f'%(st.session_state.input_ndcg)) + "</span><span style='font-size:30px;font-weight:bold;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-weight:bold;font-family:Courier New;color:"+span_color+"'> "+st.session_state.ndcg_increase.split("~")[1]+"</span>", unsafe_allow_html = True)
|
1013 |
|
1014 |
-
|
1015 |
-
#st.markdown("<span style='font-size:30px;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-family:Courier New;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[1]+"</span>",unsafe_allow_html = True)
|
1016 |
-
|
1017 |
-
|
1018 |
|
1019 |
placeholder_no_results = st.empty()
|
1020 |
|
@@ -1030,12 +756,7 @@ def render_answer(answer,index):
|
|
1030 |
continue
|
1031 |
|
1032 |
|
1033 |
-
# imgdata = base64.b64decode(ans['image_binary'])
|
1034 |
format_ = ans['image_url'].split(".")[-1]
|
1035 |
-
|
1036 |
-
#urllib.request.urlretrieve(ans['image_url'], "/home/ubuntu/res_images/"+str(i)+"_."+format_)
|
1037 |
-
|
1038 |
-
|
1039 |
Image.MAX_IMAGE_PIXELS = 100000000
|
1040 |
|
1041 |
width = 500
|
@@ -1066,23 +787,6 @@ def render_answer(answer,index):
|
|
1066 |
desc__ = ans['desc'].split(" ")
|
1067 |
|
1068 |
final_desc = "<p>"
|
1069 |
-
|
1070 |
-
###### stemming and highlighting
|
1071 |
-
|
1072 |
-
# ans_text = ans['desc']
|
1073 |
-
# query_text = st.session_state.input_text
|
1074 |
-
|
1075 |
-
# ans_text_stemmed = set(stem_(ans_text))
|
1076 |
-
# query_text_stemmed = set(stem_(query_text))
|
1077 |
-
|
1078 |
-
# common = ans_text_stemmed.intersection( query_text_stemmed)
|
1079 |
-
# #unique = set(document_1_words).symmetric_difference( )
|
1080 |
-
|
1081 |
-
# desc__stemmed = stem_(desc__)
|
1082 |
-
|
1083 |
-
# for word_ in desc__stemmed:
|
1084 |
-
# if(word_ in common):
|
1085 |
-
|
1086 |
|
1087 |
for word in desc__:
|
1088 |
if(re.sub('[^A-Za-z0-9]+', '', word) in res__):
|
@@ -1104,16 +808,8 @@ def render_answer(answer,index):
|
|
1104 |
filtered_sparse[key] = round(sparse_[key], 2)
|
1105 |
st.write(filtered_sparse)
|
1106 |
with st.expander("Document Metadata:",expanded = False):
|
1107 |
-
# if("rekog" in ans):
|
1108 |
-
# div_size = [50,50]
|
1109 |
-
# else:
|
1110 |
-
# div_size = [99,1]
|
1111 |
-
# div1,div2 = st.columns(div_size)
|
1112 |
-
# with div1:
|
1113 |
-
|
1114 |
st.write(":green[default:]")
|
1115 |
st.json({"category:":ans['category'],"price":str(ans['price']),"gender_affinity":ans['gender_affinity'],"style":ans['style']},expanded = True)
|
1116 |
-
#with div2:
|
1117 |
if("rekog" in ans):
|
1118 |
st.write(":green[enriched:]")
|
1119 |
st.json(ans['rekog'],expanded = True)
|
@@ -1128,18 +824,7 @@ def render_answer(answer,index):
|
|
1128 |
st.write(":x:")
|
1129 |
|
1130 |
i = i+1
|
1131 |
-
|
1132 |
-
# if(st.session_state.input_evaluate == "enabled"):
|
1133 |
-
# st.markdown("<div style='fontSize:12px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;font-weight:bold;height: fit-content;border-radius: 20px;font-family:Courier New;color:#e28743'>DCG: " +str('%.3f'%(st.session_state.input_ndcg)) + "</div>", unsafe_allow_html = True)
|
1134 |
-
# with col_2_b:
|
1135 |
-
# span_color = "white"
|
1136 |
-
# if("↑" in st.session_state.ndcg_increase):
|
1137 |
-
# span_color = "green"
|
1138 |
-
# if("↓" in st.session_state.ndcg_increase):
|
1139 |
-
# span_color = "red"
|
1140 |
-
# st.markdown("<span style='font-size:30px;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-family:Courier New;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[1]+"</span>",unsafe_allow_html = True)
|
1141 |
-
|
1142 |
-
|
1143 |
with col_3:
|
1144 |
if(index == len(st.session_state.questions)):
|
1145 |
|
@@ -1155,7 +840,6 @@ def render_answer(answer,index):
|
|
1155 |
st.session_state.questions.pop()
|
1156 |
|
1157 |
handle_input()
|
1158 |
-
#re_ranker.re_rank(st.session_state.questions, st.session_state.answers)
|
1159 |
with placeholder.container():
|
1160 |
render_all()
|
1161 |
|
@@ -1169,9 +853,6 @@ def render_answer(answer,index):
|
|
1169 |
except:
|
1170 |
pass
|
1171 |
|
1172 |
-
print("------------------------")
|
1173 |
-
#print(st.session_state)
|
1174 |
-
|
1175 |
placeholder__ = st.empty()
|
1176 |
|
1177 |
placeholder__.button("🔄",key=rdn_key,on_click=on_button_click, help = "This will regenerate the responses with new settings that you entered, Note: To see difference in responses, you should change any of the applicable settings")#,type="primary",use_column_width=True)
|
@@ -1196,8 +877,6 @@ def render_all():
|
|
1196 |
index = 0
|
1197 |
for (q, a) in zip(st.session_state.questions, st.session_state.answers):
|
1198 |
index = index +1
|
1199 |
-
#print("answers----")
|
1200 |
-
#print(a)
|
1201 |
ans_ = st.session_state.answers[0]
|
1202 |
write_user_message(q,ans_)
|
1203 |
write_chat_message(a, q,index)
|
@@ -1206,6 +885,4 @@ placeholder = st.empty()
|
|
1206 |
with placeholder.container():
|
1207 |
render_all()
|
1208 |
|
1209 |
-
#generate_images("",st.session_state.image_prompt)
|
1210 |
-
|
1211 |
st.markdown("")
|
|
|
24 |
import shutil
|
25 |
import re
|
26 |
from requests.auth import HTTPBasicAuth
|
|
|
27 |
# from nltk.stem import PorterStemmer
|
28 |
# from nltk.tokenize import word_tokenize
|
29 |
import query_rewrite
|
30 |
import amazon_rekognition
|
31 |
+
from streamlit.components.v1 import html
|
32 |
#from st_click_detector import click_detector
|
33 |
import llm_eval
|
34 |
import all_search_execute
|
35 |
import warnings
|
36 |
|
37 |
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
|
|
|
|
|
|
|
38 |
st.set_page_config(
|
|
|
|
|
39 |
page_icon="images/opensearch_mark_default.png"
|
40 |
)
|
41 |
parent_dirname = "/".join((os.path.dirname(__file__)).split("/")[0:-1])
|
|
|
52 |
#ps = PorterStemmer()
|
53 |
|
54 |
st.session_state.REGION = 'us-east-1'
|
|
|
|
|
|
|
|
|
|
|
55 |
USER_ICON = "images/user.png"
|
56 |
AI_ICON = "images/opensearch-twitter-card.png"
|
57 |
REGENERATE_ICON = "images/regenerate.png"
|
|
|
159 |
if "gen_image_str" not in st.session_state:
|
160 |
st.session_state.gen_image_str=""
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
if "input_NormType" not in st.session_state:
|
163 |
st.session_state.input_NormType = "min_max"
|
164 |
|
|
|
244 |
'Multimodal Search',
|
245 |
'NeuralSparse Search',
|
246 |
]
|
247 |
+
|
248 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
|
251 |
|
|
|
414 |
|
415 |
|
416 |
inputs = {}
|
|
|
|
|
|
|
|
|
417 |
if(st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType):
|
418 |
old_rekog_label = st.session_state.input_rekog_label
|
419 |
st.session_state.input_rekog_label = amazon_rekognition.extract_image_metadata(st.session_state.bytes_for_rekog)
|
420 |
if(st.session_state.input_text == ""):
|
421 |
st.session_state.input_text = st.session_state.input_rekog_label
|
422 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
423 |
|
424 |
weightage = {}
|
425 |
st.session_state.weights_ = []
|
|
|
458 |
else:
|
459 |
weightage[original_key] = 0.0
|
460 |
st.session_state[key] = 0.0
|
461 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
inputs['weightage']=weightage
|
463 |
st.session_state.input_weightage = weightage
|
464 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
465 |
|
466 |
st.session_state.inputs_ = inputs
|
467 |
|
|
|
468 |
question_with_id = {
|
469 |
'question': inputs["text"],
|
470 |
'id': len(st.session_state.questions)
|
|
|
483 |
|
484 |
if(st.session_state.input_is_rewrite_query == 'enabled' or (st.session_state.input_imageUpload == 'yes' and 'Keyword Search' in st.session_state.input_searchType)):
|
485 |
query_rewrite.get_new_query_res(st.session_state.input_text)
|
486 |
+
|
|
|
|
|
487 |
else:
|
488 |
st.session_state.input_rewritten_query = ""
|
489 |
|
490 |
+
|
|
|
|
|
491 |
ans__ = all_search_execute.handler(inputs, st.session_state['session_id'])
|
492 |
|
493 |
st.session_state.answers.append({
|
494 |
+
'answer': ans__,
|
495 |
'search_type':inputs['searchType'],
|
496 |
'id': len(st.session_state.questions)
|
497 |
})
|
|
|
499 |
st.session_state.answers_none_rank = st.session_state.answers
|
500 |
if(st.session_state.input_evaluate == "enabled"):
|
501 |
llm_eval.eval(st.session_state.questions, st.session_state.answers)
|
502 |
+
|
|
|
|
|
503 |
def write_top_bar():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
col1, col2,col3,col4 = st.columns([2.5,35,8,7])
|
505 |
with col1:
|
506 |
st.image(TEXT_ICON, use_column_width='always')
|
|
|
529 |
st.markdown("<div style = 'height:43px'></div>",unsafe_allow_html=True)
|
530 |
st.button("Generate",disabled=False,key = "generate",on_click = generate_images, args=(tab1,"default_img"))
|
531 |
|
|
|
|
|
|
|
532 |
image_select = st.radio("Choose one image", ["Image 1","Image 2","Image 3"],index=None, horizontal = True,key = 'image_select',disabled = st.session_state.radio_disabled)
|
533 |
st.markdown("""
|
534 |
<style>
|
|
|
538 |
</style>
|
539 |
""",unsafe_allow_html=True)
|
540 |
if(st.session_state.image_select is not None and st.session_state.image_select !="" and len(st.session_state.img_gen)!=0):
|
|
|
|
|
|
|
541 |
st.session_state.input_rad_1 = st.session_state.image_select.split(" ")[1]
|
542 |
else:
|
543 |
st.session_state.input_rad_1 = ""
|
544 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
545 |
|
546 |
|
547 |
generate_images(tab1,gen_images)
|
|
|
550 |
with tab2:
|
551 |
st.session_state.img_doc = st.file_uploader(
|
552 |
"Upload image", accept_multiple_files=False,type = ['png', 'jpg'])
|
|
|
|
|
|
|
|
|
|
|
553 |
return clear,tab1
|
554 |
|
555 |
clear,tab_ = write_top_bar()
|
556 |
|
557 |
if clear:
|
|
|
|
|
|
|
558 |
st.session_state.questions = []
|
559 |
st.session_state.answers = []
|
560 |
|
|
|
570 |
st.session_state.input_rad_1 = ""
|
571 |
|
572 |
|
573 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
574 |
col1, col3, col4 = st.columns([70,18,12])
|
575 |
|
576 |
with col1:
|
|
|
594 |
evaluate = st.toggle(' ', key = 'evaluate', disabled = False) #help = "Checking this box will use LLM to evaluate results as relevant and irrelevant. \n\n This option increases the latency")
|
595 |
if(evaluate):
|
596 |
st.session_state.input_evaluate = "enabled"
|
597 |
+
|
598 |
else:
|
599 |
st.session_state.input_evaluate = "disabled"
|
600 |
|
|
|
602 |
if(search_all_type == True or 1==1):
|
603 |
with st.sidebar:
|
604 |
st.page_link("app.py", label=":orange[Home]", icon="🏠")
|
605 |
+
|
|
|
|
|
|
|
|
|
606 |
|
607 |
########################## enable for query_rewrite ########################
|
608 |
rewrite_query = st.checkbox('Auto-apply filters', key = 'query_rewrite', disabled = False, help = "Checking this box will use LLM to rewrite your query. \n\n Here your natural language query is transformed into OpenSearch query with added filters and attributes")
|
|
|
612 |
key = 'input_must',
|
613 |
)
|
614 |
########################## enable for query_rewrite ########################
|
615 |
+
|
616 |
+
|
617 |
####### Filters #########
|
618 |
|
619 |
st.subheader(':blue[Filters]')
|
|
|
636 |
|
637 |
|
638 |
clear_filter = st.button("Clear Filters",on_click=clear_filter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
####### Filters #########
|
640 |
|
641 |
if('NeuralSparse Search' in st.session_state.search_types):
|
|
|
643 |
sparse_filter = st.slider('Keep only sparse tokens with weight >=', 0.0, 1.0, 0.5,0.1,key = 'input_sparse_filter', help = 'Use this slider to set the minimum weight that the sparse vector token weights should meet, rest are filtered out')
|
644 |
|
645 |
|
|
|
646 |
st.session_state.input_is_rewrite_query = 'disabled'
|
647 |
st.session_state.input_is_sql_query = 'disabled'
|
648 |
|
649 |
########################## enable for query_rewrite ########################
|
650 |
if rewrite_query:
|
|
|
651 |
st.session_state.input_is_rewrite_query = 'enabled'
|
652 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
653 |
st.subheader(':blue[Hybrid Search]')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
654 |
with st.expander("Set query Weightage:"):
|
655 |
st.number_input("Keyword %", min_value=0, max_value=100, value=100, step=5, key='input_Keyword-weight', help=None)
|
656 |
st.number_input("Vector %", min_value=0, max_value=100, value=0, step=5, key='input_Vector-weight', help=None)
|
657 |
st.number_input("Multimodal %", min_value=0, max_value=100, value=0, step=5, key='input_Multimodal-weight', help=None)
|
658 |
st.number_input("NeuralSparse %", min_value=0, max_value=100, value=0, step=5, key='input_NeuralSparse-weight', help=None)
|
659 |
|
660 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
661 |
if(st.session_state.re_ranker == "true"):
|
662 |
st.subheader(':blue[Re-ranking]')
|
663 |
reranker = st.selectbox('Choose a Re-Ranker',
|
|
|
667 |
|
668 |
key = 'input_reranker',
|
669 |
help = 'Select the Re-Ranker type, select "None" to apply no re-ranking of the results',
|
|
|
670 |
args=(st.session_state.questions, st.session_state.answers)
|
671 |
|
672 |
)
|
673 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
674 |
|
675 |
|
676 |
def write_user_message(md,ans):
|
|
|
677 |
ans = ans["answer"][0]
|
678 |
col1, col2, col3 = st.columns([3,40,20])
|
679 |
|
680 |
with col1:
|
681 |
st.image(USER_ICON, use_column_width='always')
|
682 |
with col2:
|
|
|
683 |
st.markdown("<div style='fontSize:15px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;'>Input Text: </div><div style='fontSize:25px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 10px;font-style: italic;color:#e28743'>"+md['question']+"</div>", unsafe_allow_html = True)
|
684 |
if('query_sparse' in ans):
|
685 |
with st.expander("Expanded Query:"):
|
|
|
740 |
span_color = "red"
|
741 |
st.markdown("<span style='fontSize:20px;padding:3px 7px 3px 7px;borderWidth: 0px;borderColor: red;borderStyle: solid;width: fit-content;height: fit-content;border-radius: 20px;font-family:Courier New;color:#e28743'>Relevance:" +str('%.3f'%(st.session_state.input_ndcg)) + "</span><span style='font-size:30px;font-weight:bold;color:"+span_color+"'>"+st.session_state.ndcg_increase.split("~")[0] +"</span><span style='font-size:15px;font-weight:bold;font-family:Courier New;color:"+span_color+"'> "+st.session_state.ndcg_increase.split("~")[1]+"</span>", unsafe_allow_html = True)
|
742 |
|
743 |
+
|
|
|
|
|
|
|
744 |
|
745 |
placeholder_no_results = st.empty()
|
746 |
|
|
|
756 |
continue
|
757 |
|
758 |
|
|
|
759 |
format_ = ans['image_url'].split(".")[-1]
|
|
|
|
|
|
|
|
|
760 |
Image.MAX_IMAGE_PIXELS = 100000000
|
761 |
|
762 |
width = 500
|
|
|
787 |
desc__ = ans['desc'].split(" ")
|
788 |
|
789 |
final_desc = "<p>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
790 |
|
791 |
for word in desc__:
|
792 |
if(re.sub('[^A-Za-z0-9]+', '', word) in res__):
|
|
|
808 |
filtered_sparse[key] = round(sparse_[key], 2)
|
809 |
st.write(filtered_sparse)
|
810 |
with st.expander("Document Metadata:",expanded = False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
811 |
st.write(":green[default:]")
|
812 |
st.json({"category:":ans['category'],"price":str(ans['price']),"gender_affinity":ans['gender_affinity'],"style":ans['style']},expanded = True)
|
|
|
813 |
if("rekog" in ans):
|
814 |
st.write(":green[enriched:]")
|
815 |
st.json(ans['rekog'],expanded = True)
|
|
|
824 |
st.write(":x:")
|
825 |
|
826 |
i = i+1
|
827 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
828 |
with col_3:
|
829 |
if(index == len(st.session_state.questions)):
|
830 |
|
|
|
840 |
st.session_state.questions.pop()
|
841 |
|
842 |
handle_input()
|
|
|
843 |
with placeholder.container():
|
844 |
render_all()
|
845 |
|
|
|
853 |
except:
|
854 |
pass
|
855 |
|
|
|
|
|
|
|
856 |
placeholder__ = st.empty()
|
857 |
|
858 |
placeholder__.button("🔄",key=rdn_key,on_click=on_button_click, help = "This will regenerate the responses with new settings that you entered, Note: To see difference in responses, you should change any of the applicable settings")#,type="primary",use_column_width=True)
|
|
|
877 |
index = 0
|
878 |
for (q, a) in zip(st.session_state.questions, st.session_state.answers):
|
879 |
index = index +1
|
|
|
|
|
880 |
ans_ = st.session_state.answers[0]
|
881 |
write_user_message(q,ans_)
|
882 |
write_chat_message(a, q,index)
|
|
|
885 |
with placeholder.container():
|
886 |
render_all()
|
887 |
|
|
|
|
|
888 |
st.markdown("")
|
semantic_search/amazon_rekognition.py
CHANGED
@@ -24,12 +24,7 @@ def extract_image_metadata(img):
|
|
24 |
MaxLabels = 10,
|
25 |
MinConfidence = 80.0,
|
26 |
Settings = {
|
27 |
-
|
28 |
-
# "LabelCategoryExclusionFilters": [ "string" ],
|
29 |
-
# "LabelCategoryInclusionFilters": [ "string" ],
|
30 |
-
# "LabelExclusionFilters": [ "string" ],
|
31 |
-
# "LabelInclusionFilters": [ "string" ]
|
32 |
-
# },
|
33 |
"ImageProperties": {
|
34 |
"MaxDominantColors": 5
|
35 |
}
|
@@ -76,20 +71,12 @@ def extract_image_metadata(img):
|
|
76 |
objects = " ".join(set(objects))
|
77 |
categories = " ".join(set(categories))
|
78 |
colors = " ".join(set(colors))
|
79 |
-
|
80 |
-
print("^^^^^^^^^^^^^^^^^^")
|
81 |
-
print(colors+ " " + objects + " " + categories)
|
82 |
-
|
83 |
return colors+ " " + objects + " " + categories
|
84 |
|
85 |
def call(a,b):
|
86 |
-
print("'''''''''''''''''''''''")
|
87 |
-
print(b)
|
88 |
-
|
89 |
if(st.session_state.input_is_rewrite_query == 'enabled' and st.session_state.input_rewritten_query!=""):
|
90 |
|
91 |
|
92 |
-
#st.session_state.input_rewritten_query['query']['bool']['should'].pop()
|
93 |
st.session_state.input_rewritten_query['query']['bool']['should'].append( {
|
94 |
"simple_query_string": {
|
95 |
|
@@ -112,36 +99,4 @@ def call(a,b):
|
|
112 |
}
|
113 |
st.session_state.input_rewritten_query = rekog_query
|
114 |
|
115 |
-
|
116 |
-
# body = rekog_query,
|
117 |
-
# index = 'demo-retail-rekognition'
|
118 |
-
# #pipeline = 'RAG-Search-Pipeline'
|
119 |
-
# )
|
120 |
-
|
121 |
-
|
122 |
-
# hits = response['hits']['hits']
|
123 |
-
# print("rewrite-------------------------")
|
124 |
-
# arr = []
|
125 |
-
# for doc in hits:
|
126 |
-
# # if('b5/b5319e00' in doc['_source']['image_s3_url'] ):
|
127 |
-
# # filter_out +=1
|
128 |
-
# # continue
|
129 |
-
|
130 |
-
# res_ = {"desc":doc['_source']['text'].replace(doc['_source']['metadata']['rekog_all']," ^^^ " +doc['_source']['metadata']['rekog_all']),
|
131 |
-
# "image_url":doc['_source']['metadata']['image_s3_url']}
|
132 |
-
# if('highlight' in doc):
|
133 |
-
# res_['highlight'] = doc['highlight']['text']
|
134 |
-
# # if('caption_embedding' in doc['_source']):
|
135 |
-
# # res_['sparse'] = doc['_source']['caption_embedding']
|
136 |
-
# # if('query_sparse' in response_ and len(arr) ==0 ):
|
137 |
-
# # res_['query_sparse'] = response_["query_sparse"]
|
138 |
-
# res_['id'] = doc['_id']
|
139 |
-
# res_['score'] = doc['_score']
|
140 |
-
# res_['title'] = doc['_source']['text']
|
141 |
-
# res_['rekog'] = {'color':doc['_source']['metadata']['rekog_color'],'category': doc['_source']['metadata']['rekog_categories'],'objects':doc['_source']['metadata']['rekog_objects']}
|
142 |
-
|
143 |
-
# arr.append(res_)
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
# return arr
|
|
|
24 |
MaxLabels = 10,
|
25 |
MinConfidence = 80.0,
|
26 |
Settings = {
|
27 |
+
|
|
|
|
|
|
|
|
|
|
|
28 |
"ImageProperties": {
|
29 |
"MaxDominantColors": 5
|
30 |
}
|
|
|
71 |
objects = " ".join(set(objects))
|
72 |
categories = " ".join(set(categories))
|
73 |
colors = " ".join(set(colors))
|
|
|
|
|
|
|
|
|
74 |
return colors+ " " + objects + " " + categories
|
75 |
|
76 |
def call(a,b):
|
|
|
|
|
|
|
77 |
if(st.session_state.input_is_rewrite_query == 'enabled' and st.session_state.input_rewritten_query!=""):
|
78 |
|
79 |
|
|
|
80 |
st.session_state.input_rewritten_query['query']['bool']['should'].append( {
|
81 |
"simple_query_string": {
|
82 |
|
|
|
99 |
}
|
100 |
st.session_state.input_rewritten_query = rekog_query
|
101 |
|
102 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utilities/invoke_models.py
CHANGED
@@ -24,17 +24,6 @@ bedrock_runtime_client = get_bedrock_client()
|
|
24 |
|
25 |
|
26 |
|
27 |
-
# def generate_image_captions_ml():
|
28 |
-
# model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
29 |
-
# feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
30 |
-
# tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
31 |
-
|
32 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
33 |
-
# model.to(device)
|
34 |
-
# max_length = 16
|
35 |
-
# num_beams = 4
|
36 |
-
# gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
|
37 |
-
|
38 |
def invoke_model(input):
|
39 |
response = bedrock_runtime_client.invoke_model(
|
40 |
body=json.dumps({
|
@@ -100,56 +89,7 @@ def invoke_llm_model(input,is_stream):
|
|
100 |
|
101 |
return (json.loads(res))['content'][0]['text']
|
102 |
|
103 |
-
|
104 |
-
# body=json.dumps({
|
105 |
-
# "prompt": input,
|
106 |
-
# "max_tokens_to_sample": 300,
|
107 |
-
# "temperature": 0.5,
|
108 |
-
# "top_k": 250,
|
109 |
-
# "top_p": 1,
|
110 |
-
# "stop_sequences": [
|
111 |
-
# "\n\nHuman:"
|
112 |
-
# ],
|
113 |
-
# # "anthropic_version": "bedrock-2023-05-31"
|
114 |
-
# }),
|
115 |
-
# modelId="anthropic.claude-v2:1",
|
116 |
-
# accept="application/json",
|
117 |
-
# contentType="application/json",
|
118 |
-
# )
|
119 |
-
# stream = response.get('body')
|
120 |
-
|
121 |
-
# return stream
|
122 |
-
|
123 |
-
# else:
|
124 |
-
# response = bedrock_runtime_client.invoke_model_with_response_stream(
|
125 |
-
# modelId= "anthropic.claude-3-sonnet-20240229-v1:0",
|
126 |
-
# contentType = "application/json",
|
127 |
-
# accept = "application/json",
|
128 |
-
|
129 |
-
# body = json.dumps({
|
130 |
-
# "anthropic_version": "bedrock-2023-05-31",
|
131 |
-
# "max_tokens": 1024,
|
132 |
-
# "temperature": 0.0001,
|
133 |
-
# "top_k": 150,
|
134 |
-
# "top_p": 0.7,
|
135 |
-
# "stop_sequences": [
|
136 |
-
# "\n\nHuman:"
|
137 |
-
# ],
|
138 |
-
# "messages": [
|
139 |
-
# {
|
140 |
-
# "role": "user",
|
141 |
-
# "content":input
|
142 |
-
# }
|
143 |
-
# ]
|
144 |
-
# }
|
145 |
-
|
146 |
-
# )
|
147 |
-
# )
|
148 |
-
|
149 |
-
# stream = response.get('body')
|
150 |
-
|
151 |
-
# return stream
|
152 |
-
|
153 |
def read_from_table(file,question):
|
154 |
print("started table analysis:")
|
155 |
print("-----------------------")
|
@@ -175,7 +115,6 @@ def read_from_table(file,question):
|
|
175 |
df = pd.read_csv(file,skipinitialspace = True, on_bad_lines='skip',delimiter = "`")
|
176 |
else:
|
177 |
df = file
|
178 |
-
#df.fillna(method='pad', inplace=True)
|
179 |
agent = create_pandas_dataframe_agent(
|
180 |
model,
|
181 |
df,
|
@@ -188,24 +127,7 @@ def read_from_table(file,question):
|
|
188 |
|
189 |
def generate_image_captions_llm(base64_string,question):
|
190 |
|
191 |
-
|
192 |
-
# MODEL_NAME = "claude-3-opus-20240229"
|
193 |
-
|
194 |
-
# message_list = [
|
195 |
-
# {
|
196 |
-
# "role": 'user',
|
197 |
-
# "content": [
|
198 |
-
# {"type": "image", "source": {"type": "base64", "media_type": "image/jpeg", "data": base64_string}},
|
199 |
-
# {"type": "text", "text": "What is in the image ?"}
|
200 |
-
# ]
|
201 |
-
# }
|
202 |
-
# ]
|
203 |
-
|
204 |
-
# response = ant_client.messages.create(
|
205 |
-
# model=MODEL_NAME,
|
206 |
-
# max_tokens=2048,
|
207 |
-
# messages=message_list
|
208 |
-
# )
|
209 |
response = bedrock_runtime_client.invoke_model(
|
210 |
modelId= "anthropic.claude-3-haiku-20240307-v1:0",
|
211 |
contentType = "application/json",
|
@@ -234,9 +156,5 @@ def generate_image_captions_llm(base64_string,question):
|
|
234 |
}
|
235 |
]
|
236 |
}))
|
237 |
-
#print(response)
|
238 |
response_body = json.loads(response.get("body").read())['content'][0]['text']
|
239 |
-
|
240 |
-
#print(response_body)
|
241 |
-
|
242 |
return response_body
|
|
|
24 |
|
25 |
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def invoke_model(input):
|
28 |
response = bedrock_runtime_client.invoke_model(
|
29 |
body=json.dumps({
|
|
|
89 |
|
90 |
return (json.loads(res))['content'][0]['text']
|
91 |
|
92 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
def read_from_table(file,question):
|
94 |
print("started table analysis:")
|
95 |
print("-----------------------")
|
|
|
115 |
df = pd.read_csv(file,skipinitialspace = True, on_bad_lines='skip',delimiter = "`")
|
116 |
else:
|
117 |
df = file
|
|
|
118 |
agent = create_pandas_dataframe_agent(
|
119 |
model,
|
120 |
df,
|
|
|
127 |
|
128 |
def generate_image_captions_llm(base64_string,question):
|
129 |
|
130 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
response = bedrock_runtime_client.invoke_model(
|
132 |
modelId= "anthropic.claude-3-haiku-20240307-v1:0",
|
133 |
contentType = "application/json",
|
|
|
156 |
}
|
157 |
]
|
158 |
}))
|
|
|
159 |
response_body = json.loads(response.get("body").read())['content'][0]['text']
|
|
|
|
|
|
|
160 |
return response_body
|
utilities/re_ranker.py
DELETED
@@ -1,127 +0,0 @@
|
|
1 |
-
import boto3
|
2 |
-
from botocore.exceptions import ClientError
|
3 |
-
import pprint
|
4 |
-
import time
|
5 |
-
import streamlit as st
|
6 |
-
from sentence_transformers import CrossEncoder
|
7 |
-
|
8 |
-
#model = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2", max_length=512)
|
9 |
-
####### Add this Kendra Rescore ranking
|
10 |
-
#kendra_ranking = boto3.client("kendra-ranking",region_name = 'us-east-1')
|
11 |
-
#print("Create a rescore execution plan.")
|
12 |
-
|
13 |
-
# Provide a name for the rescore execution plan
|
14 |
-
#name = "MyRescoreExecutionPlan"
|
15 |
-
# Set your required additional capacity units
|
16 |
-
# Don't set capacity units if you don't require more than 1 unit given by default
|
17 |
-
#capacity_units = 2
|
18 |
-
|
19 |
-
# try:
|
20 |
-
# rescore_execution_plan_response = kendra_ranking.create_rescore_execution_plan(
|
21 |
-
# Name = name,
|
22 |
-
# CapacityUnits = {"RescoreCapacityUnits":capacity_units}
|
23 |
-
# )
|
24 |
-
|
25 |
-
# pprint.pprint(rescore_execution_plan_response)
|
26 |
-
|
27 |
-
# rescore_execution_plan_id = rescore_execution_plan_response["Id"]
|
28 |
-
|
29 |
-
# print("Wait for Amazon Kendra to create the rescore execution plan.")
|
30 |
-
|
31 |
-
# while True:
|
32 |
-
# # Get the details of the rescore execution plan, such as the status
|
33 |
-
# rescore_execution_plan_description = kendra_ranking.describe_rescore_execution_plan(
|
34 |
-
# Id = rescore_execution_plan_id
|
35 |
-
# )
|
36 |
-
# # When status is not CREATING quit.
|
37 |
-
# status = rescore_execution_plan_description["Status"]
|
38 |
-
# print(" Creating rescore execution plan. Status: "+status)
|
39 |
-
# time.sleep(60)
|
40 |
-
# if status != "CREATING":
|
41 |
-
# break
|
42 |
-
|
43 |
-
# except ClientError as e:
|
44 |
-
# print("%s" % e)
|
45 |
-
|
46 |
-
# print("Program ends.")
|
47 |
-
#########################
|
48 |
-
|
49 |
-
@st.cache_resource
|
50 |
-
def re_rank(self_, rerank_type, search_type, question, answers):
|
51 |
-
|
52 |
-
ans = []
|
53 |
-
ids = []
|
54 |
-
ques_ans = []
|
55 |
-
query = question[0]['question']
|
56 |
-
for i in answers[0]['answer']:
|
57 |
-
if(self_ == "search"):
|
58 |
-
|
59 |
-
ans.append({
|
60 |
-
"Id": i['id'],
|
61 |
-
"Body": i["desc"],
|
62 |
-
"OriginalScore": i['score'],
|
63 |
-
"Title":i["desc"]
|
64 |
-
})
|
65 |
-
ids.append(i['id'])
|
66 |
-
ques_ans.append((query,i["desc"]))
|
67 |
-
|
68 |
-
else:
|
69 |
-
ans.append({'text':i})
|
70 |
-
|
71 |
-
ques_ans.append((query,i))
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
re_ranked = [{}]
|
76 |
-
####### Add this Kendra Rescore ranking
|
77 |
-
# if(rerank_type == 'Kendra Rescore'):
|
78 |
-
# rescore_response = kendra_ranking.rescore(
|
79 |
-
# RescoreExecutionPlanId = 'b2a4d4f3-98ff-4e17-8b69-4c61ed7d91eb',
|
80 |
-
# SearchQuery = query,
|
81 |
-
# Documents = ans
|
82 |
-
# )
|
83 |
-
# re_ranked[0]['answer']=[]
|
84 |
-
# for result in rescore_response["ResultItems"]:
|
85 |
-
|
86 |
-
# pos_ = ids.index(result['DocumentId'])
|
87 |
-
|
88 |
-
# re_ranked[0]['answer'].append(answers[0]['answer'][pos_])
|
89 |
-
# re_ranked[0]['search_type']=search_type,
|
90 |
-
# re_ranked[0]['id'] = len(question)
|
91 |
-
# return re_ranked
|
92 |
-
|
93 |
-
# if(rerank_type == 'Cross Encoder'):
|
94 |
-
|
95 |
-
# scores = model.predict(
|
96 |
-
# ques_ans
|
97 |
-
# )
|
98 |
-
|
99 |
-
# index__ = 0
|
100 |
-
# for i in ans:
|
101 |
-
# i['new_score'] = scores[index__]
|
102 |
-
# index__ = index__+1
|
103 |
-
|
104 |
-
# ans_sorted = sorted(ans, key=lambda d: d['new_score'],reverse=True)
|
105 |
-
|
106 |
-
|
107 |
-
# def retreive_only_text(item):
|
108 |
-
# return item['text']
|
109 |
-
|
110 |
-
# if(self_ == 'rag'):
|
111 |
-
# return list(map(retreive_only_text, ans_sorted))
|
112 |
-
|
113 |
-
|
114 |
-
# re_ranked[0]['answer']=[]
|
115 |
-
# for j in ans_sorted:
|
116 |
-
# pos_ = ids.index(j['Id'])
|
117 |
-
# re_ranked[0]['answer'].append(answers[0]['answer'][pos_])
|
118 |
-
# re_ranked[0]['search_type']= search_type,
|
119 |
-
# re_ranked[0]['id'] = len(question)
|
120 |
-
# return re_ranked
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|