Jai12345 commited on
Commit
5318dde
1 Parent(s): c6b1bdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -131,9 +131,9 @@ def cross_encode():
131
 
132
 
133
 
134
- def display_as_table(model, top_k=2, score='score'):
135
  # Display the df with text and scores as a table
136
- df = pd.DataFrame([(hit[score], passages[hit['corpus_id']]) for hit in model[0:top_k]], columns=['Score', 'Text'])
137
  df['Score'] = round(df['Score'], 2)
138
 
139
  return df
@@ -146,7 +146,7 @@ window_size = 3
146
 
147
  bi_encoder_type="multi-qa-mpnet-base-dot-v1"
148
  # This function will search all wikipedia articles for passages that answer the query
149
- def search_func(query, top_k=2):
150
  global bi_encoder, cross_encoder
151
 
152
  st.subheader(f"Search Query: {query}")
@@ -162,7 +162,7 @@ def search_func(query, top_k=2):
162
  # Encode the query using the bi-encoder and find potentially relevant passages
163
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
164
  question_embedding = question_embedding.cpu()
165
- hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k, score_function=util.dot_score)
166
  hits = hits[0] # Get the hits for the first query
167
 
168
  # Now, score all retrieved passages with the cross_encoder
@@ -175,18 +175,18 @@ def search_func(query, top_k=2):
175
 
176
  # Output of top hits from bi-encoder
177
  st.markdown("\n-------------------------\n")
178
- st.subheader(f"Top-{top_k} Bi-Encoder Retrieval hits")
179
  hits = sorted(hits, key=lambda x: x['score'], reverse=True)
180
 
181
- cross_df = display_as_table(hits, top_k)
182
  st.write(cross_df.to_html(index=False), unsafe_allow_html=True)
183
 
184
  # Output of top hits from cross encoder
185
  st.markdown("\n-------------------------\n")
186
- st.subheader(f"Top-{top_k} Cross-Encoder Re-ranker hits")
187
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
188
 
189
- rerank_df = display_as_table(hits, top_k, 'cross-score')
190
  st.write(rerank_df.to_html(index=False), unsafe_allow_html=True)
191
 
192
 
@@ -239,7 +239,7 @@ if search:
239
 
240
  with st.spinner(
241
  text="Embedding completed, searching for relevant text for given query and hits..."):
242
- search_func(search_query, top_k=2)
243
 
244
  st.markdown("""
245
  """)
 
131
 
132
 
133
 
134
+ def display_as_table(model, score='score'):
135
  # Display the df with text and scores as a table
136
+ df = pd.DataFrame([(hit[score], passages[hit['corpus_id']]) for hit in model[0:2]], columns=['Score', 'Text'])
137
  df['Score'] = round(df['Score'], 2)
138
 
139
  return df
 
146
 
147
  bi_encoder_type="multi-qa-mpnet-base-dot-v1"
148
  # This function will search all wikipedia articles for passages that answer the query
149
+ def search_func(query):
150
  global bi_encoder, cross_encoder
151
 
152
  st.subheader(f"Search Query: {query}")
 
162
  # Encode the query using the bi-encoder and find potentially relevant passages
163
  question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
164
  question_embedding = question_embedding.cpu()
165
+ hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=2, score_function=util.dot_score)
166
  hits = hits[0] # Get the hits for the first query
167
 
168
  # Now, score all retrieved passages with the cross_encoder
 
175
 
176
  # Output of top hits from bi-encoder
177
  st.markdown("\n-------------------------\n")
178
+ st.subheader(f"Top 2 Bi-Encoder Retrieval hits")
179
  hits = sorted(hits, key=lambda x: x['score'], reverse=True)
180
 
181
+ cross_df = display_as_table(hits, )
182
  st.write(cross_df.to_html(index=False), unsafe_allow_html=True)
183
 
184
  # Output of top hits from cross encoder
185
  st.markdown("\n-------------------------\n")
186
+ st.subheader(f"Top-2 Cross-Encoder Re-ranker hits")
187
  hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
188
 
189
+ rerank_df = display_as_table(hits, 'cross-score')
190
  st.write(rerank_df.to_html(index=False), unsafe_allow_html=True)
191
 
192
 
 
239
 
240
  with st.spinner(
241
  text="Embedding completed, searching for relevant text for given query and hits..."):
242
+ search_func(search_query)
243
 
244
  st.markdown("""
245
  """)