Spaces:
Running
on
T4
Running
on
T4
rerank model
Browse files- pages/Semantic_Search.py +66 -72
- semantic_search/query_rewrite.py +3 -87
pages/Semantic_Search.py
CHANGED
@@ -747,83 +747,77 @@ def render_answer(answer,index):
|
|
747 |
col_1, col_2,col_3 = st.columns([70,10,20])
|
748 |
i = 0
|
749 |
filter_out = 0
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
763 |
-
|
764 |
-
|
765 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
766 |
|
767 |
-
|
768 |
-
|
769 |
-
|
770 |
-
|
771 |
-
|
772 |
-
if("highlight" in ans and 'Keyword Search' in st.session_state.input_searchType):
|
773 |
-
test_strs = ans["highlight"]
|
774 |
-
tag = "em"
|
775 |
-
res__ = []
|
776 |
-
for test_str in test_strs:
|
777 |
-
start_idx = test_str.find("<" + tag + ">")
|
778 |
-
|
779 |
-
while start_idx != -1:
|
780 |
-
end_idx = test_str.find("</" + tag + ">", start_idx)
|
781 |
-
if end_idx == -1:
|
782 |
-
break
|
783 |
-
res__.append(test_str[start_idx+len(tag)+2:end_idx])
|
784 |
-
start_idx = test_str.find("<" + tag + ">", end_idx)
|
785 |
|
|
|
|
|
|
|
|
|
|
|
786 |
|
787 |
-
|
788 |
-
|
789 |
-
final_desc = "<p>"
|
790 |
-
|
791 |
-
for word in desc__:
|
792 |
-
if(re.sub('[^A-Za-z0-9]+', '', word) in res__):
|
793 |
-
final_desc += "<span style='color:#e28743;font-weight:bold'>"+word+"</span> "
|
794 |
-
else:
|
795 |
-
final_desc += word + " "
|
796 |
-
|
797 |
-
final_desc += "</p>"
|
798 |
|
799 |
-
|
800 |
-
|
801 |
-
|
802 |
-
|
803 |
-
|
804 |
-
|
805 |
-
|
806 |
-
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
818 |
-
if(st.session_state.input_evaluate == "enabled"):
|
819 |
-
with st.container(border = False):
|
820 |
-
if("relevant" in ans.keys()):
|
821 |
-
if(ans['relevant']==True):
|
822 |
-
st.write(":white_check_mark:")
|
823 |
-
else:
|
824 |
-
st.write(":x:")
|
825 |
|
826 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
827 |
|
828 |
with col_3:
|
829 |
if(index == len(st.session_state.questions)):
|
|
|
747 |
col_1, col_2,col_3 = st.columns([70,10,20])
|
748 |
i = 0
|
749 |
filter_out = 0
|
750 |
+
if len(answer) == 0:
|
751 |
+
st.write("No results found")
|
752 |
+
else:
|
753 |
+
for ans in answer:
|
754 |
+
if('b5/b5319e00' in ans['image_url'] ):
|
755 |
+
filter_out+=1
|
756 |
+
continue
|
757 |
+
format_ = ans['image_url'].split(".")[-1]
|
758 |
+
Image.MAX_IMAGE_PIXELS = 100000000
|
759 |
+
width = 500
|
760 |
+
height = 500
|
761 |
+
with col_1:
|
762 |
+
inner_col_1,inner_col_2 = st.columns([8,92])
|
763 |
+
with inner_col_2:
|
764 |
+
st.image(ans['image_url'].replace("/home/ec2-user/SageMaker/","/home/user/"))
|
765 |
+
|
766 |
+
if("highlight" in ans and 'Keyword Search' in st.session_state.input_searchType):
|
767 |
+
test_strs = ans["highlight"]
|
768 |
+
tag = "em"
|
769 |
+
res__ = []
|
770 |
+
for test_str in test_strs:
|
771 |
+
start_idx = test_str.find("<" + tag + ">")
|
772 |
+
|
773 |
+
while start_idx != -1:
|
774 |
+
end_idx = test_str.find("</" + tag + ">", start_idx)
|
775 |
+
if end_idx == -1:
|
776 |
+
break
|
777 |
+
res__.append(test_str[start_idx+len(tag)+2:end_idx])
|
778 |
+
start_idx = test_str.find("<" + tag + ">", end_idx)
|
779 |
|
780 |
+
|
781 |
+
desc__ = ans['desc'].split(" ")
|
782 |
+
|
783 |
+
final_desc = "<p>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
784 |
|
785 |
+
for word in desc__:
|
786 |
+
if(re.sub('[^A-Za-z0-9]+', '', word) in res__):
|
787 |
+
final_desc += "<span style='color:#e28743;font-weight:bold'>"+word+"</span> "
|
788 |
+
else:
|
789 |
+
final_desc += word + " "
|
790 |
|
791 |
+
final_desc += "</p>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
792 |
|
793 |
+
st.markdown(final_desc,unsafe_allow_html = True)
|
794 |
+
else:
|
795 |
+
st.write(ans['desc'])
|
796 |
+
if("sparse" in ans):
|
797 |
+
with st.expander("Expanded document:"):
|
798 |
+
sparse_ = dict(sorted(ans['sparse'].items(), key=lambda item: item[1],reverse=True))
|
799 |
+
filtered_sparse = dict()
|
800 |
+
for key in sparse_:
|
801 |
+
if(sparse_[key]>=1.0):
|
802 |
+
filtered_sparse[key] = round(sparse_[key], 2)
|
803 |
+
st.write(filtered_sparse)
|
804 |
+
with st.expander("Document Metadata:",expanded = False):
|
805 |
+
st.write(":green[default:]")
|
806 |
+
st.json({"category:":ans['category'],"price":str(ans['price']),"gender_affinity":ans['gender_affinity'],"style":ans['style']},expanded = True)
|
807 |
+
if("rekog" in ans):
|
808 |
+
st.write(":green[enriched:]")
|
809 |
+
st.json(ans['rekog'],expanded = True)
|
810 |
+
with inner_col_1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
811 |
|
812 |
+
if(st.session_state.input_evaluate == "enabled"):
|
813 |
+
with st.container(border = False):
|
814 |
+
if("relevant" in ans.keys()):
|
815 |
+
if(ans['relevant']==True):
|
816 |
+
st.write(":white_check_mark:")
|
817 |
+
else:
|
818 |
+
st.write(":x:")
|
819 |
+
|
820 |
+
i = i+1
|
821 |
|
822 |
with col_3:
|
823 |
if(index == len(st.session_state.questions)):
|
semantic_search/query_rewrite.py
CHANGED
@@ -252,16 +252,6 @@ def get_new_query_res(query):
|
|
252 |
query = st.session_state.input_rekog_label
|
253 |
if(st.session_state.input_is_rewrite_query == 'enabled'):
|
254 |
|
255 |
-
# query_struct = query_constructor.invoke(
|
256 |
-
# {
|
257 |
-
# "query": query
|
258 |
-
# }
|
259 |
-
# )
|
260 |
-
# print("***prompt****")
|
261 |
-
# print(prompt)
|
262 |
-
# print("******query_struct******")
|
263 |
-
# print(query_struct)
|
264 |
-
|
265 |
res = invoke_models.invoke_llm_model( prompt_.format(query=query,schema = schema) ,False)
|
266 |
inter_query = res[7:-3].replace('\\"',"'").replace("\n","")
|
267 |
print("inter_query")
|
@@ -294,43 +284,8 @@ def get_new_query_res(query):
|
|
294 |
draft_new_query['bool']['must'].append(q_dash)
|
295 |
else:
|
296 |
draft_new_query['bool']['should'].append(q_dash)
|
297 |
-
|
298 |
-
|
299 |
-
# q__dash = json.loads(json.dumps(q_).replace('term','match' ))
|
300 |
-
# clause = list(q__dash.keys())[0]category
|
301 |
-
# long_field = list(q__dash[clause].keys())[0]
|
302 |
-
# get_attr = long_field.split(".")[1]
|
303 |
-
# q__dash[clause][get_attr] = q__dash[clause][long_field]
|
304 |
-
# draft_new_query['bool']['should'].append(q__dash)
|
305 |
-
|
306 |
-
#print(draft_new_query)
|
307 |
-
query_ = draft_new_query#json.loads(json.dumps(opts.visit_structured_query(query_struct)[1]['filter']).replace("must","should"))#.replace("must","should")
|
308 |
-
|
309 |
-
# if('bool' in query_ and 'should' in query_['bool']):
|
310 |
-
# query_['bool']['should'].append({
|
311 |
-
# "match": {
|
312 |
-
|
313 |
-
# "rekog_description_plus_original_description": query
|
314 |
-
|
315 |
-
# }
|
316 |
-
# })
|
317 |
-
# else:
|
318 |
-
# query_['bool']['should'] = {
|
319 |
-
# "match": {
|
320 |
-
|
321 |
-
# "rekog_description_plus_original_description": query
|
322 |
-
|
323 |
-
# }
|
324 |
-
# }
|
325 |
-
|
326 |
-
# def find_by_key(data, target):
|
327 |
-
# for key, value in data.items():
|
328 |
-
# if isinstance(value, dict):
|
329 |
-
# yield from find_by_key(value, target)
|
330 |
-
# elif key == target:
|
331 |
-
# yield value
|
332 |
-
# for x in find_by_key(query_, "metadata.category.keyword"):
|
333 |
-
# imp_item = x
|
334 |
|
335 |
|
336 |
###### find the main subject of the query
|
@@ -405,46 +360,7 @@ def get_new_query_res(query):
|
|
405 |
|
406 |
st.session_state.input_rewritten_query = {"query":query_}
|
407 |
print(st.session_state.input_rewritten_query)
|
408 |
-
|
409 |
-
# amazon_rekognition.call(st.session_state.input_text,st.session_state.input_rekog_label)
|
410 |
-
|
411 |
-
|
412 |
-
# #return searchWithNewQuery(st.session_state.input_rewritten_query)
|
413 |
-
|
414 |
-
# def searchWithNewQuery(new_query):
|
415 |
-
# response = aos_client.search(
|
416 |
-
# body = new_query,
|
417 |
-
# index = "demo-retail-rekognition"#'self-query-rewrite-retail',
|
418 |
-
# #pipeline = 'RAG-Search-Pipeline'
|
419 |
-
# )
|
420 |
-
|
421 |
-
# hits = response['hits']['hits']
|
422 |
-
# print("rewrite-------------------------")
|
423 |
-
# arr = []
|
424 |
-
# for doc in hits:
|
425 |
-
# # if('b5/b5319e00' in doc['_source']['image_s3_url'] ):
|
426 |
-
# # filter_out +=1
|
427 |
-
# # continue
|
428 |
-
|
429 |
-
# res_ = {"desc":doc['_source']['text'],"image_url":doc['_source']['metadata']['image_s3_url']}
|
430 |
-
# if('highlight' in doc):
|
431 |
-
# res_['highlight'] = doc['highlight']['text']
|
432 |
-
# # if('caption_embedding' in doc['_source']):
|
433 |
-
# # res_['sparse'] = doc['_source']['caption_embedding']
|
434 |
-
# # if('query_sparse' in response_ and len(arr) ==0 ):
|
435 |
-
# # res_['query_sparse'] = response_["query_sparse"]
|
436 |
-
# res_['id'] = doc['_id']
|
437 |
-
# res_['score'] = doc['_score']
|
438 |
-
# res_['title'] = doc['_source']['text']
|
439 |
-
|
440 |
-
# arr.append(res_)
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
# return arr
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
|
449 |
|
450 |
|
|
|
252 |
query = st.session_state.input_rekog_label
|
253 |
if(st.session_state.input_is_rewrite_query == 'enabled'):
|
254 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
res = invoke_models.invoke_llm_model( prompt_.format(query=query,schema = schema) ,False)
|
256 |
inter_query = res[7:-3].replace('\\"',"'").replace("\n","")
|
257 |
print("inter_query")
|
|
|
284 |
draft_new_query['bool']['must'].append(q_dash)
|
285 |
else:
|
286 |
draft_new_query['bool']['should'].append(q_dash)
|
287 |
+
|
288 |
+
query_ = draft_new_query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
|
291 |
###### find the main subject of the query
|
|
|
360 |
|
361 |
st.session_state.input_rewritten_query = {"query":query_}
|
362 |
print(st.session_state.input_rewritten_query)
|
363 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
|
365 |
|
366 |
|