Update app.py
Browse files
app.py
CHANGED
@@ -106,10 +106,10 @@ def preprocess_plain_text(text, window_size=3):
|
|
106 |
return passages
|
107 |
|
108 |
|
109 |
-
def bi_encode(passages):
|
110 |
global bi_encoder
|
111 |
# We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
|
112 |
-
bi_encoder = SentenceTransformer(
|
113 |
|
114 |
|
115 |
# Compute the embeddings
|
@@ -139,6 +139,7 @@ def display_as_table(model, score='score'):
|
|
139 |
|
140 |
st.title("Search Your Query Here")
|
141 |
window_size = 3
|
|
|
142 |
|
143 |
# This will search articles for passages to answer the query
|
144 |
def search_func(query):
|
@@ -217,11 +218,11 @@ with col2:
|
|
217 |
clear = st.button("Clear Text Input", on_click=clear_text, key='clear',help='Click to clear the URL and query')
|
218 |
|
219 |
if search:
|
220 |
-
if bi_encoder_type:
|
221 |
with st.spinner(
|
222 |
text=f"Loading..........................."
|
223 |
):
|
224 |
-
bi_encoder, corpus_embeddings = bi_encode(passages)
|
225 |
cross_encoder = cross_encode()
|
226 |
|
227 |
with st.spinner(
|
|
|
106 |
return passages
|
107 |
|
108 |
|
109 |
+
def bi_encode(bi_enc,passages):
|
110 |
global bi_encoder
|
111 |
# We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
|
112 |
+
bi_encoder = SentenceTransformer(bi_enc)
|
113 |
|
114 |
|
115 |
# Compute the embeddings
|
|
|
139 |
|
140 |
st.title("Search Your Query Here")
|
141 |
window_size = 3
|
142 |
+
bi_enc_options = "multi-qa-mpnet-base-dot-v1"
|
143 |
|
144 |
# This will search articles for passages to answer the query
|
145 |
def search_func(query):
|
|
|
218 |
clear = st.button("Clear Text Input", on_click=clear_text, key='clear',help='Click to clear the URL and query')
|
219 |
|
220 |
if search:
|
221 |
+
if bi_encoder_type=="multi-qa-mpnet-base-dot-v1":
|
222 |
with st.spinner(
|
223 |
text=f"Loading..........................."
|
224 |
):
|
225 |
+
bi_encoder, corpus_embeddings = bi_encode(bi_encoder_type,passages)
|
226 |
cross_encoder = cross_encode()
|
227 |
|
228 |
with st.spinner(
|