HarryLee commited on
Commit
5e1fd6b
·
1 Parent(s): 048a704

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -2
app.py CHANGED
@@ -200,10 +200,87 @@ def search(query):
200
 
201
  return show_out
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def reranking():
204
  rerank_list = []
205
- rerank_list = search(query = user_query)
206
- st.write(rerank_list[0:maxtags_sidebar])
207
  random.shuffle(rerank_list[0:maxtags_sidebar])
208
  for i in rerank_list[0:maxtags_sidebar]:
209
  st.write(i)
 
200
 
201
  return show_out
202
 
203
+ def search_nolog(query):
204
+ total_qe = []
205
+ ##### BM25 search (lexical search) #####
206
+ bm25_scores = bm25.get_scores(bm25_tokenizer(query))
207
+ top_n = np.argpartition(bm25_scores, -5)[-5:]
208
+ bm25_hits = [{'corpus_id': idx, 'score': bm25_scores[idx]} for idx in top_n]
209
+ bm25_hits = sorted(bm25_hits, key=lambda x: x['score'], reverse=True)
210
+
211
+ qe_string = []
212
+ for hit in bm25_hits[0:1000]:
213
+ if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
214
+ qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
215
+
216
+ sub_string = []
217
+ for item in qe_string:
218
+ for sub_item in item.split(","):
219
+ sub_string.append(sub_item)
220
+ total_qe.append(sub_string)
221
+
222
+ ##### Sematic Search #####
223
+ # Encode the query using the bi-encoder and find potentially relevant passages
224
+ query_embedding = bi_encoder.encode(query, convert_to_tensor=True)
225
+ hits = util.semantic_search(query_embedding, corpus_embeddings, top_k=top_k)
226
+ hits = hits[0] # Get the hits for the first query
227
+
228
+ ##### Re-Ranking #####
229
+ # Now, score all retrieved passages with the cross_encoder
230
+ cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
231
+ cross_scores = cross_encoder.predict(cross_inp)
232
+
233
+ # Sort results by the cross-encoder scores
234
+ for idx in range(len(cross_scores)):
235
+ hits[idx]['cross-score'] = cross_scores[idx]
236
+
237
+ # Output of top-10 hits from bi-encoder
238
+ hits = sorted(hits, key=lambda x: x['score'], reverse=True)
239
+ qe_string = []
240
+ for hit in hits[0:1000]:
241
+ if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
242
+ qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
243
+ total_qe.append(qe_string)
244
+
245
+ # Output of top-10 hits from re-ranker
246
+ hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
247
+ qe_string = []
248
+ for hit in hits[0:1000]:
249
+ if passages[hit['corpus_id']].replace("\n", " ") not in qe_string:
250
+ qe_string.append(passages[hit['corpus_id']].replace("\n", ""))
251
+ total_qe.append(qe_string)
252
+
253
+ # Total Results
254
+ total_qe.append(qe_string)
255
+
256
+ res = []
257
+ for sub_list in total_qe:
258
+ for i in sub_list:
259
+ rs = re.sub("([^\u0030-\u0039\u0041-\u007a])", ' ', i)
260
+ rs_final = re.sub("\x20\x20", "\n", rs)
261
+ res.append(rs_final.strip())
262
+
263
+ res_clean = []
264
+ for out in res:
265
+ if len(out) > 20:
266
+ keywords = custom_kw_extractor.extract_keywords(out)
267
+ for key in keywords:
268
+ res_clean.append(key[0])
269
+ else:
270
+ res_clean.append(out)
271
+
272
+ show_out = []
273
+ for i in res_clean:
274
+ num = word_len(i)
275
+ if num > 1:
276
+ show_out.append(i)
277
+
278
+ return show_out
279
+
280
  def reranking():
281
  rerank_list = []
282
+ rerank_list = search_nolog(query = user_query)
283
+ random.seed(7)
284
  random.shuffle(rerank_list[0:maxtags_sidebar])
285
  for i in rerank_list[0:maxtags_sidebar]:
286
  st.write(i)