Update app.py
Browse files
app.py
CHANGED
@@ -131,9 +131,9 @@ def cross_encode():
|
|
131 |
|
132 |
|
133 |
|
134 |
-
def display_as_table(model,
|
135 |
# Display the df with text and scores as a table
|
136 |
-
df = pd.DataFrame([(hit[score], passages[hit['corpus_id']]) for hit in model[0:
|
137 |
df['Score'] = round(df['Score'], 2)
|
138 |
|
139 |
return df
|
@@ -146,7 +146,7 @@ window_size = 3
|
|
146 |
|
147 |
bi_encoder_type="multi-qa-mpnet-base-dot-v1"
|
148 |
# This function will search all wikipedia articles for passages that answer the query
|
149 |
-
def search_func(query
|
150 |
global bi_encoder, cross_encoder
|
151 |
|
152 |
st.subheader(f"Search Query: {query}")
|
@@ -162,7 +162,7 @@ def search_func(query, top_k=2):
|
|
162 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
163 |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
164 |
question_embedding = question_embedding.cpu()
|
165 |
-
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=
|
166 |
hits = hits[0] # Get the hits for the first query
|
167 |
|
168 |
# Now, score all retrieved passages with the cross_encoder
|
@@ -175,18 +175,18 @@ def search_func(query, top_k=2):
|
|
175 |
|
176 |
# Output of top hits from bi-encoder
|
177 |
st.markdown("\n-------------------------\n")
|
178 |
-
st.subheader(f"Top
|
179 |
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
180 |
|
181 |
-
cross_df = display_as_table(hits,
|
182 |
st.write(cross_df.to_html(index=False), unsafe_allow_html=True)
|
183 |
|
184 |
# Output of top hits from cross encoder
|
185 |
st.markdown("\n-------------------------\n")
|
186 |
-
st.subheader(f"Top-
|
187 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
188 |
|
189 |
-
rerank_df = display_as_table(hits,
|
190 |
st.write(rerank_df.to_html(index=False), unsafe_allow_html=True)
|
191 |
|
192 |
|
@@ -239,7 +239,7 @@ if search:
|
|
239 |
|
240 |
with st.spinner(
|
241 |
text="Embedding completed, searching for relevant text for given query and hits..."):
|
242 |
-
search_func(search_query
|
243 |
|
244 |
st.markdown("""
|
245 |
""")
|
|
|
131 |
|
132 |
|
133 |
|
134 |
+
def display_as_table(model, score='score'):
|
135 |
# Display the df with text and scores as a table
|
136 |
+
df = pd.DataFrame([(hit[score], passages[hit['corpus_id']]) for hit in model[0:2]], columns=['Score', 'Text'])
|
137 |
df['Score'] = round(df['Score'], 2)
|
138 |
|
139 |
return df
|
|
|
146 |
|
147 |
bi_encoder_type="multi-qa-mpnet-base-dot-v1"
|
148 |
# This function will search all wikipedia articles for passages that answer the query
|
149 |
+
def search_func(query):
|
150 |
global bi_encoder, cross_encoder
|
151 |
|
152 |
st.subheader(f"Search Query: {query}")
|
|
|
162 |
# Encode the query using the bi-encoder and find potentially relevant passages
|
163 |
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
164 |
question_embedding = question_embedding.cpu()
|
165 |
+
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=2, score_function=util.dot_score)
|
166 |
hits = hits[0] # Get the hits for the first query
|
167 |
|
168 |
# Now, score all retrieved passages with the cross_encoder
|
|
|
175 |
|
176 |
# Output of top hits from bi-encoder
|
177 |
st.markdown("\n-------------------------\n")
|
178 |
+
st.subheader(f"Top 2 Bi-Encoder Retrieval hits")
|
179 |
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
180 |
|
181 |
+
cross_df = display_as_table(hits, )
|
182 |
st.write(cross_df.to_html(index=False), unsafe_allow_html=True)
|
183 |
|
184 |
# Output of top hits from cross encoder
|
185 |
st.markdown("\n-------------------------\n")
|
186 |
+
st.subheader(f"Top-2 Cross-Encoder Re-ranker hits")
|
187 |
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
188 |
|
189 |
+
rerank_df = display_as_table(hits, 'cross-score')
|
190 |
st.write(rerank_df.to_html(index=False), unsafe_allow_html=True)
|
191 |
|
192 |
|
|
|
239 |
|
240 |
with st.spinner(
|
241 |
text="Embedding completed, searching for relevant text for given query and hits..."):
|
242 |
+
search_func(search_query)
|
243 |
|
244 |
st.markdown("""
|
245 |
""")
|