domenicrosati commited on
Commit
1f1e9bd
Β·
1 Parent(s): 5cc7b84

don't use concat

Browse files
Files changed (1) hide show
  1. app.py +21 -8
app.py CHANGED
@@ -145,7 +145,7 @@ def init_models():
145
  "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
146
  device=device
147
  )
148
- reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
149
  # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
150
  # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
151
  return question_answerer, reranker, stop, device
@@ -211,6 +211,9 @@ st.markdown("""
211
  """, unsafe_allow_html=True)
212
 
213
  with st.expander("Settings (strictness, context limit, top hits)"):
 
 
 
214
  support_all = st.radio(
215
  "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
216
  ('yes', 'no'))
@@ -224,8 +227,8 @@ with st.expander("Settings (strictness, context limit, top hits)"):
224
  use_reranking = st.radio(
225
  "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
226
  ('yes', 'no'))
227
- top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 50)
228
- context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 10)
229
 
230
  # def paraphrase(text, max_length=128):
231
  # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
@@ -313,14 +316,24 @@ def run_query(query):
313
  scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
314
  hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
315
  sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
316
- context = '\n---'.join(sorted_contexts[:context_limit])
317
  else:
318
- context = '\n---'.join(contexts[:context_limit])
 
 
 
 
 
 
 
319
 
320
  results = []
321
- model_results = qa_model(question=query, context=query+'---'+context, top_k=10)
322
- for result in model_results:
323
- matched = matched_context(result['start'], result['end'], context)
 
 
 
324
  support = find_source(result['answer'], orig_docs, matched)
325
  if not support:
326
  continue
 
145
  "question-answering", model='sultan/BioM-ELECTRA-Large-SQuAD2-BioASQ8B',
146
  device=device
147
  )
148
+ reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device=device)
149
  # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
150
  # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
151
  return question_answerer, reranker, stop, device
 
211
  """, unsafe_allow_html=True)
212
 
213
  with st.expander("Settings (strictness, context limit, top hits)"):
214
+ concat_passages = st.radio(
215
+ "Concatenate passages as one long context?",
216
+ ('no', 'yes'))
217
  support_all = st.radio(
218
  "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
219
  ('yes', 'no'))
 
227
  use_reranking = st.radio(
228
  "Use Reranking? Reranking will rerank the top hits using semantic similarity of document and query.",
229
  ('yes', 'no'))
230
+ top_hits_limit = st.slider('Top hits? How many documents to use for reranking. Larger is slower but higher quality', 10, 300, 10)
231
+ context_lim = st.slider('Context limit? How many documents to use for answering from. Larger is slower but higher quality', 10, 300, 5)
232
 
233
  # def paraphrase(text, max_length=128):
234
  # input_ids = queryexp_tokenizer.encode(text, return_tensors="pt", add_special_tokens=True)
 
316
  scores = reranker.predict(sentence_pairs, batch_size=len(sentence_pairs), show_progress_bar=False)
317
  hits = {contexts[idx]: scores[idx] for idx in range(len(scores))}
318
  sorted_contexts = [k for k,v in sorted(hits.items(), key=lambda x: x[0], reverse=True)]
319
+ contexts = sorted_contexts[:context_limit]
320
  else:
321
+ contexts = contexts[:context_limit]
322
+
323
+ if concat_passages == 'yes':
324
+ context = '\n---'.join(contexts)
325
+ model_results = qa_model(question=query, context=context, top_k=10)
326
+ else:
327
+ context = ['\n---\n'+ctx for ctx in contexts]
328
+ model_results = qa_model(question=[query]*len(contexts), context=context)
329
 
330
  results = []
331
+
332
+ for i, result in enumerate(model_results):
333
+ if concat_passages == 'yes':
334
+ matched = matched_context(result['start'], result['end'], context)
335
+ else:
336
+ matched = matched_context(result['start'], result['end'], context[i])
337
  support = find_source(result['answer'], orig_docs, matched)
338
  if not support:
339
  continue