Spaces:
Running
on
T4
Running
on
T4
mvectors
Browse files- pages/Semantic_Search.py +58 -10
- semantic_search/all_search_execute.py +2 -0
- utilities/mvectors.py +1 -1
pages/Semantic_Search.py
CHANGED
@@ -24,8 +24,8 @@ import base64
|
|
24 |
import shutil
|
25 |
import re
|
26 |
from requests.auth import HTTPBasicAuth
|
27 |
-
|
28 |
-
|
29 |
import query_rewrite
|
30 |
import amazon_rekognition
|
31 |
from streamlit.components.v1 import html
|
@@ -71,7 +71,7 @@ st.markdown("""
|
|
71 |
|
72 |
|
73 |
|
74 |
-
|
75 |
|
76 |
st.session_state.REGION = 'us-east-1'
|
77 |
USER_ICON = "images/user.png"
|
@@ -113,6 +113,9 @@ if "chats" not in st.session_state:
|
|
113 |
|
114 |
if "questions" not in st.session_state:
|
115 |
st.session_state.questions = []
|
|
|
|
|
|
|
116 |
|
117 |
if "clear_" not in st.session_state:
|
118 |
st.session_state.clear_ = False
|
@@ -744,14 +747,14 @@ def write_user_message(md,ans):
|
|
744 |
st.markdown('---')
|
745 |
|
746 |
|
747 |
-
|
748 |
-
|
749 |
|
750 |
-
|
751 |
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
|
756 |
def render_answer(answer,index):
|
757 |
column1, column2 = st.columns([6,90])
|
@@ -790,7 +793,52 @@ def render_answer(answer,index):
|
|
790 |
with inner_col_2:
|
791 |
st.image(ans['image_url'].replace("/home/ec2-user/SageMaker/","/home/user/"))
|
792 |
|
793 |
-
if(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
794 |
test_strs = ans["highlight"]
|
795 |
tag = "em"
|
796 |
res__ = []
|
|
|
24 |
import shutil
|
25 |
import re
|
26 |
from requests.auth import HTTPBasicAuth
|
27 |
+
from nltk.stem import PorterStemmer
|
28 |
+
from nltk.tokenize import word_tokenize
|
29 |
import query_rewrite
|
30 |
import amazon_rekognition
|
31 |
from streamlit.components.v1 import html
|
|
|
71 |
|
72 |
|
73 |
|
74 |
+
ps = PorterStemmer()
|
75 |
|
76 |
st.session_state.REGION = 'us-east-1'
|
77 |
USER_ICON = "images/user.png"
|
|
|
113 |
|
114 |
if "questions" not in st.session_state:
|
115 |
st.session_state.questions = []
|
116 |
+
|
117 |
+
if "input_mvector_rerank" not in st.session_state:
|
118 |
+
st.session_state.input_colBert_rerank = False
|
119 |
|
120 |
if "clear_" not in st.session_state:
|
121 |
st.session_state.clear_ = False
|
|
|
747 |
st.markdown('---')
|
748 |
|
749 |
|
750 |
+
def stem_(sentence):
|
751 |
+
words = word_tokenize(sentence)
|
752 |
|
753 |
+
words_stem = []
|
754 |
|
755 |
+
for w in words:
|
756 |
+
words_stem.append( ps.stem(w))
|
757 |
+
return words_stem
|
758 |
|
759 |
def render_answer(answer,index):
|
760 |
column1, column2 = st.columns([6,90])
|
|
|
793 |
with inner_col_2:
|
794 |
st.image(ans['image_url'].replace("/home/ec2-user/SageMaker/","/home/user/"))
|
795 |
|
796 |
+
if('max_score_dict_list_sorted' in ans and 'Vector Search' in st.session_state.input_searchType):
|
797 |
+
desc___ = ans['desc'].split(" ")
|
798 |
+
res___ = []
|
799 |
+
for o in ans['max_score_dict_list_sorted']:
|
800 |
+
res___.append(o['doc_token'])
|
801 |
+
final_desc_ = "<p></p><p>"
|
802 |
+
for word_ in desc___:
|
803 |
+
str_=re.sub('[^A-Za-z0-9]+', '', word_).lower()
|
804 |
+
###### stemming and highlighting
|
805 |
+
|
806 |
+
# ans_text = ans['desc']
|
807 |
+
# query_text = st.session_state.input_text
|
808 |
+
|
809 |
+
stemmed_word = next(iter(set(stem_(str_))))
|
810 |
+
# print("stemmed_word-------------------")
|
811 |
+
# print(stemmed_word)
|
812 |
+
|
813 |
+
|
814 |
+
# common = ans_text_stemmed.intersection( query_text_stemmed)
|
815 |
+
# #unique = set(document_1_words).symmetric_difference( )
|
816 |
+
|
817 |
+
# desc__stemmed = stem_(desc__)
|
818 |
+
|
819 |
+
#print(str_)
|
820 |
+
if(stemmed_word in res___ or str_ in res___):
|
821 |
+
if(stemmed_word in res___):
|
822 |
+
mod_word = stemmed_word
|
823 |
+
else:
|
824 |
+
mod_word = str_
|
825 |
+
#print(str_)
|
826 |
+
if(res___.index(mod_word)==0):
|
827 |
+
#print(str_)
|
828 |
+
final_desc_ += "<span style='color:#ffffff;background-color:#8B0001;font-weight:bold'>"+word_+"</span> "
|
829 |
+
elif(res___.index(mod_word)==1):
|
830 |
+
#print(str_)
|
831 |
+
final_desc_ += "<span style='color:#ffffff;background-color:#C34632;font-weight:bold'>"+word_+"</span> "
|
832 |
+
else:
|
833 |
+
#print(str_)
|
834 |
+
final_desc_ += "<span style='color:#ffffff;background-color:#E97452;font-weight:bold'>"+word_+"</span> "
|
835 |
+
else:
|
836 |
+
final_desc_ += word_ + " "
|
837 |
+
|
838 |
+
final_desc_ += "</p><br>"
|
839 |
+
#print(final_desc_)
|
840 |
+
st.markdown(final_desc_,unsafe_allow_html = True)
|
841 |
+
elif("highlight" in ans and 'Keyword Search' in st.session_state.input_searchType):
|
842 |
test_strs = ans["highlight"]
|
843 |
tag = "em"
|
844 |
res__ = []
|
semantic_search/all_search_execute.py
CHANGED
@@ -512,6 +512,8 @@ def handler(input_,session_id):
|
|
512 |
"style":doc['_source']['style'],
|
513 |
|
514 |
}
|
|
|
|
|
515 |
if('highlight' in doc):
|
516 |
res_['highlight'] = doc['highlight']['product_description']
|
517 |
if('NeuralSparse Search' in search_types):
|
|
|
512 |
"style":doc['_source']['style'],
|
513 |
|
514 |
}
|
515 |
+
if('max_score_dict_list_sorted' in doc):
|
516 |
+
res_['max_score_dict_list_sorted'] = doc['max_score_dict_list_sorted']
|
517 |
if('highlight' in doc):
|
518 |
res_['highlight'] = doc['highlight']['product_description']
|
519 |
if('NeuralSparse Search' in search_types):
|
utilities/mvectors.py
CHANGED
@@ -56,7 +56,7 @@ def search(hits):
|
|
56 |
doc={"_source":
|
57 |
{
|
58 |
"description":j["_source"]["description"],"caption":j["_source"]["title"],
|
59 |
-
"
|
60 |
"style":j["_source"]["style"],"category":j["_source"]["category"]},"_id":j["_id"],"_score":j["_score"]}
|
61 |
|
62 |
if("gender_affinity" in j["_source"]):
|
|
|
56 |
doc={"_source":
|
57 |
{
|
58 |
"description":j["_source"]["description"],"caption":j["_source"]["title"],
|
59 |
+
"image_url":j["_source"]["image_s3_url"],"price":j["_source"]["price"],
|
60 |
"style":j["_source"]["style"],"category":j["_source"]["category"]},"_id":j["_id"],"_score":j["_score"]}
|
61 |
|
62 |
if("gender_affinity" in j["_source"]):
|