Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -200,10 +200,87 @@ def search(query):
|
|
200 |
|
201 |
return show_out
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
def reranking():
|
204 |
rerank_list = []
|
205 |
-
rerank_list =
|
206 |
-
|
207 |
random.shuffle(rerank_list[0:maxtags_sidebar])
|
208 |
for i in rerank_list[0:maxtags_sidebar]:
|
209 |
st.write(i)
|
|
|
200 |
|
201 |
return show_out
|
202 |
|
203 |
+
def search_nolog(query):
|
204 |
+
total_qe = []
|
205 |
+
##### BM25 search (lexical search) #####
|
206 |
+
bm25_scores = bm25.get_scores(bm25_tokenizer(query))
|
207 |
+
top_n = np.argpartition(bm25_scores, -5)[-5:]
|
208 |
+
bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
|
209 |
+
bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
|
210 |
+
|
211 |
+
qe_string = []
|
212 |
+
for hit in bm25_hits[0:1000]:
|
213 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
214 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
215 |
+
|
216 |
+
sub_string = []
|
217 |
+
for item in qe_string:
|
218 |
+
for sub_item in item.split(","):
|
219 |
+
sub_string.append(sub_item)
|
220 |
+
total_qe.append(sub_string)
|
221 |
+
|
222 |
+
##### Sematic Search #####
|
223 |
+
# Encode the query using the bi-encoder and find potentially relevant passages
|
224 |
+
query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
225 |
+
hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
|
226 |
+
hits = hits[0] # Get the hits for the first query
|
227 |
+
|
228 |
+
##### Re-Ranking #####
|
229 |
+
# Now, score all retrieved passages with the cross_encoder
|
230 |
+
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
|
231 |
+
cross_scores = cross_encoder.predict(cross_inp)
|
232 |
+
|
233 |
+
# Sort results by the cross-encoder scores
|
234 |
+
for idx in range(len(cross_scores)):
|
235 |
+
hits[idx]['cross-score'] = cross_scores[idx]
|
236 |
+
|
237 |
+
# Output of top-10 hits from bi-encoder
|
238 |
+
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
239 |
+
qe_string = []
|
240 |
+
for hit in hits[0:1000]:
|
241 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
242 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
243 |
+
total_qe.append(qe_string)
|
244 |
+
|
245 |
+
# Output of top-10 hits from re-ranker
|
246 |
+
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
247 |
+
qe_string = []
|
248 |
+
for hit in hits[0:1000]:
|
249 |
+
if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
|
250 |
+
qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
|
251 |
+
total_qe.append(qe_string)
|
252 |
+
|
253 |
+
# Total Results
|
254 |
+
total_qe.append(qe_string)
|
255 |
+
|
256 |
+
res = []
|
257 |
+
for sub_list in total_qe:
|
258 |
+
for i in sub_list:
|
259 |
+
rs = re.sub("([^\u0030-\u0039\u0041-\u007a])", ' ', i)
|
260 |
+
rs_final = re.sub("\x20\x20", "\n", rs)
|
261 |
+
res.append(rs_final.strip())
|
262 |
+
|
263 |
+
res_clean = []
|
264 |
+
for out in res:
|
265 |
+
if len(out) > 20:
|
266 |
+
keywords = custom_kw_extractor.extract_keywords(out)
|
267 |
+
for key in keywords:
|
268 |
+
res_clean.append(key[0])
|
269 |
+
else:
|
270 |
+
res_clean.append(out)
|
271 |
+
|
272 |
+
show_out = []
|
273 |
+
for i in res_clean:
|
274 |
+
num = word_len(i)
|
275 |
+
if num > 1:
|
276 |
+
show_out.append(i)
|
277 |
+
|
278 |
+
return show_out
|
279 |
+
|
280 |
def reranking():
|
281 |
rerank_list = []
|
282 |
+
rerank_list = search_nolog(query = user_query)
|
283 |
+
random.seed(7)
|
284 |
random.shuffle(rerank_list[0:maxtags_sidebar])
|
285 |
for i in rerank_list[0:maxtags_sidebar]:
|
286 |
st.write(i)
|