Jai12345 commited on
Commit
6ba0728
·
1 Parent(s): dcd92b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -106,10 +106,10 @@ def preprocess_plain_text(text, window_size=3):
106
  return passages
107
 
108
 
109
- def bi_encode(passages):
110
  global bi_encoder
111
  # We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
112
- bi_encoder = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
113
 
114
 
115
  # Compute the embeddings
@@ -139,6 +139,7 @@ def display_as_table(model, score='score'):
139
 
140
  st.title("Search Your Query Here")
141
  window_size = 3
 
142
 
143
  # This will search articles for passages to answer the query
144
  def search_func(query):
@@ -217,11 +218,11 @@ with col2:
217
  clear = st.button("Clear Text Input", on_click=clear_text, key='clear',help='Click to clear the URL and query')
218
 
219
  if search:
220
- if bi_encoder_type:
221
  with st.spinner(
222
  text=f"Loading..........................."
223
  ):
224
- bi_encoder, corpus_embeddings = bi_encode(passages)
225
  cross_encoder = cross_encode()
226
 
227
  with st.spinner(
 
106
  return passages
107
 
108
 
109
+ def bi_encode(bi_enc,passages):
110
  global bi_encoder
111
  # We use the Bi-Encoder to encode all passages, so that we can use it with sematic search
112
+ bi_encoder = SentenceTransformer(bi_enc)
113
 
114
 
115
  # Compute the embeddings
 
139
 
140
  st.title("Search Your Query Here")
141
  window_size = 3
142
+ bi_enc_options = "multi-qa-mpnet-base-dot-v1"
143
 
144
  # This will search articles for passages to answer the query
145
  def search_func(query):
 
218
  clear = st.button("Clear Text Input", on_click=clear_text, key='clear',help='Click to clear the URL and query')
219
 
220
  if search:
221
+ if bi_encoder_type=="multi-qa-mpnet-base-dot-v1":
222
  with st.spinner(
223
  text=f"Loading..........................."
224
  ):
225
+ bi_encoder, corpus_embeddings = bi_encode(bi_encoder_type,passages)
226
  cross_encoder = cross_encode()
227
 
228
  with st.spinner(