domenicrosati commited on
Commit
5cc7b84
Β·
1 Parent(s): f1fd3e1

remove summarization

Browse files
Files changed (1) hide show
  1. app.py +3 -83
app.py CHANGED
@@ -78,7 +78,6 @@ def search(term, limit=10, clean=True, strict=True, all_mode=True, abstracts=Tru
78
  except:
79
  pass
80
 
81
-
82
  return (
83
  contexts,
84
  docs
@@ -149,11 +148,9 @@ def init_models():
149
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
150
  # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
151
  # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
152
- summ_tok = AutoTokenizer.from_pretrained('allenai/led-base-16384-ms2')
153
- summ_mdl = LEDForConditionalGeneration.from_pretrained('allenai/led-base-16384-ms2')
154
- return question_answerer, reranker, stop, device, summ_mdl, summ_tok
155
 
156
- qa_model, reranker, stop, device, summ_mdl, summ_tok = init_models() # queryexp_model, queryexp_tokenizer
157
 
158
 
159
  def clean_query(query, strict=True, clean=True):
@@ -214,9 +211,6 @@ st.markdown("""
214
  """, unsafe_allow_html=True)
215
 
216
  with st.expander("Settings (strictness, context limit, top hits)"):
217
- use_mds = st.radio(
218
- "Use multi-document summarization to summarize answer?",
219
- ('yes', 'no'))
220
  support_all = st.radio(
221
  "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
222
  ('yes', 'no'))
@@ -271,77 +265,6 @@ def matched_context(start_i, end_i, contexts_string, seperator='---'):
271
  return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
272
  return None
273
 
274
- def process_document(documents, tokenizer, docsep_token_id, pad_token_id, device=device):
275
- input_ids_all=[]
276
- for data in documents:
277
- all_docs = data.split("|||||")
278
- for i, doc in enumerate(all_docs):
279
- doc = doc.replace("\n", " ")
280
- doc = " ".join(doc.split())
281
- all_docs[i] = doc
282
-
283
- #### concat with global attention on doc-sep
284
- input_ids = []
285
- for doc in all_docs:
286
- input_ids.extend(
287
- tokenizer.encode(
288
- doc,
289
- truncation=True,
290
- max_length=4096 // len(all_docs),
291
- )[1:-1]
292
- )
293
- input_ids.append(docsep_token_id)
294
- input_ids = (
295
- [tokenizer.bos_token_id]
296
- + input_ids
297
- + [tokenizer.eos_token_id]
298
- )
299
- input_ids_all.append(torch.tensor(input_ids))
300
- input_ids = torch.nn.utils.rnn.pad_sequence(
301
- input_ids_all, batch_first=True, padding_value=pad_token_id
302
- )
303
- return input_ids
304
-
305
-
306
- def batch_process(batch, model, tokenizer, docsep_token_id, pad_token_id, device=device):
307
- input_ids=process_document(batch['document'], tokenizer, docsep_token_id, pad_token_id)
308
- # get the input ids and attention masks together
309
- global_attention_mask = torch.zeros_like(input_ids).to(device)
310
- input_ids = input_ids.to(device)
311
- # put global attention on <s> token
312
-
313
- global_attention_mask[:, 0] = 1
314
- global_attention_mask[input_ids == docsep_token_id] = 1
315
- generated_ids = model.generate(
316
- input_ids=input_ids,
317
- global_attention_mask=global_attention_mask,
318
- use_cache=True,
319
- max_length=1024,
320
- num_beams=5,
321
- )
322
- generated_str = tokenizer.batch_decode(
323
- generated_ids.tolist(), skip_special_tokens=True
324
- )
325
- result={}
326
- result['generated_summaries'] = generated_str
327
- return result
328
-
329
-
330
- def gen_summary(query, sorted_result):
331
- pad_token_id = summ_tok.pad_token_id
332
- docsep_token_id = summ_tok.convert_tokens_to_ids("</s>")
333
- out = batch_process({ 'document': [f'||||'.join([f'{query} '.join(r['texts']) + r['context'] for r in sorted_result])]}, summ_mdl, summ_tok, docsep_token_id, pad_token_id)
334
- st.markdown(f"""
335
- <div class="container-fluid">
336
- <div class="row align-items-start">
337
- <div class="col-md-12 col-sm-12">
338
- <strong>Answer:</strong> {out['generated_summaries'][0]}
339
- </div>
340
- </div>
341
- </div>
342
- """, unsafe_allow_html=True)
343
- st.markdown("<br /><br /><h5>Sources:</h5>", unsafe_allow_html=True)
344
-
345
 
346
  def run_query(query):
347
  # if use_query_exp == 'yes':
@@ -395,7 +318,7 @@ def run_query(query):
395
  context = '\n---'.join(contexts[:context_limit])
396
 
397
  results = []
398
- model_results = qa_model(question=query, context=context, top_k=10)
399
  for result in model_results:
400
  matched = matched_context(result['start'], result['end'], context)
401
  support = find_source(result['answer'], orig_docs, matched)
@@ -423,9 +346,6 @@ def run_query(query):
423
  sorted_result
424
  ))
425
 
426
- if use_mds == 'yes':
427
- gen_summary(query, sorted_result)
428
-
429
  for r in sorted_result:
430
  ctx = remove_html(r["context"])
431
  for answer in r['texts']:
 
78
  except:
79
  pass
80
 
 
81
  return (
82
  contexts,
83
  docs
 
148
  reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-2-v2', device=device)
149
  # queryexp_tokenizer = AutoTokenizer.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
150
  # queryexp_model = AutoModelWithLMHead.from_pretrained("doc2query/all-with_prefix-t5-base-v1")
151
+ return question_answerer, reranker, stop, device
 
 
152
 
153
+ qa_model, reranker, stop, device = init_models() # queryexp_model, queryexp_tokenizer
154
 
155
 
156
  def clean_query(query, strict=True, clean=True):
 
211
  """, unsafe_allow_html=True)
212
 
213
  with st.expander("Settings (strictness, context limit, top hits)"):
 
 
 
214
  support_all = st.radio(
215
  "Use abstracts and titles as a ranking signal (if the words are matched in the abstract then the document is more relevant)?",
216
  ('yes', 'no'))
 
265
  return contexts_string[doc_starts[i]:doc_starts[i+1]].replace(seperator, '')
266
  return None
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  def run_query(query):
270
  # if use_query_exp == 'yes':
 
318
  context = '\n---'.join(contexts[:context_limit])
319
 
320
  results = []
321
+ model_results = qa_model(question=query, context=query+'---'+context, top_k=10)
322
  for result in model_results:
323
  matched = matched_context(result['start'], result['end'], context)
324
  support = find_source(result['answer'], orig_docs, matched)
 
346
  sorted_result
347
  ))
348
 
 
 
 
349
  for r in sorted_result:
350
  ctx = remove_html(r["context"])
351
  for answer in r['texts']: