Update app.py
Browse files
app.py
CHANGED
@@ -14,9 +14,6 @@ nltk.download('punkt')
|
|
14 |
|
15 |
from nltk import sent_tokenize
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
def extract_text_from_url(url: str):
|
21 |
'''Extract text from url'''
|
22 |
|
@@ -109,7 +106,6 @@ def preprocess_plain_text(text, window_size=3):
|
|
109 |
return passages
|
110 |
|
111 |
|
112 |
-
#@st.experimental_memo(suppress_st_warning=True)
|
113 |
def bi_encode(bi_enc, passages):
|
114 |
global bi_encoder
|
115 |
# We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
|
@@ -122,12 +118,11 @@ def bi_encode(bi_enc, passages):
|
|
122 |
with st.spinner('Encoding passages into a vector space...'):
|
123 |
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
124 |
|
125 |
-
st.success(f"Embeddings computed.
|
126 |
|
127 |
return bi_encoder, corpus_embeddings
|
128 |
|
129 |
|
130 |
-
#@st.experimental_singleton(suppress_st_warning=True)
|
131 |
def cross_encode():
|
132 |
global cross_encoder
|
133 |
# We use a cross-encoder, to re-rank the results list to improve the quality
|
@@ -144,7 +139,7 @@ def display_as_table(model, top_k=2, score='score'):
|
|
144 |
return df
|
145 |
|
146 |
|
147 |
-
|
148 |
|
149 |
st.title("Search with Retrieve & Rerank")
|
150 |
window_size = 3
|
@@ -178,7 +173,7 @@ def search_func(query, top_k=2):
|
|
178 |
for idx in range(len(cross_scores)):
|
179 |
hits[idx]['cross-score'] = cross_scores[idx]
|
180 |
|
181 |
-
# Output of top
|
182 |
st.markdown("\n-------------------------\n")
|
183 |
st.subheader(f"Top-{top_k} Bi-Encoder Retrieval hits")
|
184 |
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
@@ -186,7 +181,7 @@ def search_func(query, top_k=2):
|
|
186 |
cross_df = display_as_table(hits, top_k)
|
187 |
st.write(cross_df.to_html(index=False), unsafe_allow_html=True)
|
188 |
|
189 |
-
# Output of top
|
190 |
st.markdown("\n-------------------------\n")
|
191 |
st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
|
192 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
@@ -204,9 +199,7 @@ def clear_search_text():
|
|
204 |
st.session_state["text_input"] = ""
|
205 |
|
206 |
|
207 |
-
url_text = st.text_input("Please Enter a url here",
|
208 |
-
value="https://en.wikipedia.org/wiki/Virat_Kohli",
|
209 |
-
key='text_url', on_change=clear_search_text)
|
210 |
|
211 |
st.markdown(
|
212 |
"<h3 style='text-align: center; color: red;'>OR</h3>",
|
@@ -234,8 +227,7 @@ with col1:
|
|
234 |
search = st.button("Search", key='search_but', help='Click to Search!!')
|
235 |
|
236 |
with col2:
|
237 |
-
clear = st.button("Clear Text Input", on_click=clear_text, key='clear',
|
238 |
-
help='Click to clear the URL input and search query')
|
239 |
|
240 |
if search:
|
241 |
if bi_encoder_type:
|
|
|
14 |
|
15 |
from nltk import sent_tokenize
|
16 |
|
|
|
|
|
|
|
17 |
def extract_text_from_url(url: str):
|
18 |
'''Extract text from url'''
|
19 |
|
|
|
106 |
return passages
|
107 |
|
108 |
|
|
|
109 |
def bi_encode(bi_enc, passages):
|
110 |
global bi_encoder
|
111 |
# We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
|
|
|
118 |
with st.spinner('Encoding passages into a vector space...'):
|
119 |
corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True)
|
120 |
|
121 |
+
st.success(f"Embeddings computed.")
|
122 |
|
123 |
return bi_encoder, corpus_embeddings
|
124 |
|
125 |
|
|
|
126 |
def cross_encode():
|
127 |
global cross_encoder
|
128 |
# We use a cross-encoder, to re-rank the results list to improve the quality
|
|
|
139 |
return df
|
140 |
|
141 |
|
142 |
+
|
143 |
|
144 |
st.title("Search with Retrieve & Rerank")
|
145 |
window_size = 3
|
|
|
173 |
for idx in range(len(cross_scores)):
|
174 |
hits[idx]['cross-score'] = cross_scores[idx]
|
175 |
|
176 |
+
# Output of top hits from bi-encoder
|
177 |
st.markdown("\n-------------------------\n")
|
178 |
st.subheader(f"Top-{top_k} Bi-Encoder Retrieval hits")
|
179 |
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
|
|
181 |
cross_df = display_as_table(hits, top_k)
|
182 |
st.write(cross_df.to_html(index=False), unsafe_allow_html=True)
|
183 |
|
184 |
+
# Output of top hits from cross encoder
|
185 |
st.markdown("\n-------------------------\n")
|
186 |
st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
|
187 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
|
|
199 |
st.session_state["text_input"] = ""
|
200 |
|
201 |
|
202 |
+
url_text = st.text_input("Please Enter a url here",value="https://en.wikipedia.org/wiki/Virat_Kohli",key='text_url', on_change=clear_search_text)
|
|
|
|
|
203 |
|
204 |
st.markdown(
|
205 |
"<h3 style='text-align: center; color: red;'>OR</h3>",
|
|
|
227 |
search = st.button("Search", key='search_but', help='Click to Search!!')
|
228 |
|
229 |
with col2:
|
230 |
+
clear = st.button("Clear Text Input", on_click=clear_text, key='clear',help='Click to clear the URL and query')
|
|
|
231 |
|
232 |
if search:
|
233 |
if bi_encoder_type:
|