kiyer commited on
Commit
2007c08
·
verified ·
1 Parent(s): 2c3e220

added beta deep research mode

Browse files
Files changed (1) hide show
  1. app_gradio.py +112 -12
app_gradio.py CHANGED
@@ -480,6 +480,100 @@ def make_embedding_plot(papers_df, top_k, consensus_answer, arxiv_corpus=arxiv_c
480
  plt.axis('off')
481
  return fig
482
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
483
  def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type, ec=ec, progress=gr.Progress()):
484
 
485
  yield None, None, None, None, None
@@ -507,21 +601,26 @@ def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type,
507
  ec.hyde = True
508
  ec.rerank = True
509
 
510
- progress(0.2, desc=search_text_list[np.random.choice(len(search_text_list))])
511
- rs, small_df = ec.retrieve(query, top_k = top_k, return_scores=True)
512
- formatted_df = ec.return_formatted_df(rs, small_df)
513
- yield formatted_df, None, None, None, None
514
-
515
- progress(0.4, desc=gen_text_list[np.random.choice(len(gen_text_list))])
516
- rag_answer = run_rag_qa(query, formatted_df, prompt_type)
517
- yield formatted_df, rag_answer['answer'], None, None, None
 
 
 
 
 
518
 
519
- progress(0.6, desc="Generating consensus")
520
  consensus_answer = evaluate_overall_consensus(query, [formatted_df['abstract'][i+1] for i in range(len(formatted_df))])
521
  consensus = '## Consensus \n'+consensus_answer.consensus + '\n\n'+consensus_answer.explanation + '\n\n > Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score
522
  yield formatted_df, rag_answer['answer'], consensus, None, None
523
 
524
- progress(0.8, desc="Analyzing question type")
525
  question_type_gen = guess_question_type(query)
526
  if '<categorization>' in question_type_gen:
527
  question_type_gen = question_type_gen.split('<categorization>')[1]
@@ -531,7 +630,7 @@ def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type,
531
  qn_type = question_type_gen
532
  yield formatted_df, rag_answer['answer'], consensus, qn_type, None
533
 
534
- progress(1.0, desc="Visualizing embeddings")
535
  fig = make_embedding_plot(formatted_df, top_k, consensus_answer)
536
 
537
  yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
@@ -551,6 +650,7 @@ def create_interface():
551
  with gr.Tab("pathfinder"):
552
  with gr.Accordion("What is Pathfinder? / How do I use it?", open=False):
553
  gr.Markdown(pathfinder_text)
 
554
 
555
  with gr.Row():
556
  query = gr.Textbox(label="Ask me anything")
@@ -559,7 +659,7 @@ def create_interface():
559
  top_k = gr.Slider(1, 30, step=1, value=10, label="top-k", info="Number of papers to retrieve")
560
  keywords = gr.Textbox(label="Optional Keywords (comma-separated)",value="")
561
  toggles = gr.CheckboxGroup(["Keywords", "Time", "Citations"], label="Weight by", info="weighting retrieved papers",value=['Keywords'])
562
- prompt_type = gr.Radio(choices=["Single-paper", "Multi-paper", "Bibliometric", "Broad but nuanced"], label="Prompt Specialization", value='Multi-paper')
563
  rag_type = gr.Radio(choices=["Semantic Search", "Semantic + HyDE", "Semantic + CoHERE", "Semantic + HyDE + CoHERE"], label="RAG Method",value='Semantic + HyDE + CoHERE')
564
  with gr.Column(scale=2, min_width=300):
565
  img1 = gr.Image("local_files/pathfinder_logo.png")
 
480
  plt.axis('off')
481
  return fig
482
 
483
+
484
+ def getsmallans(query, df):
485
+
486
+ allcontent = dr_smallans_prompt
487
+
488
+ smallauth = ''
489
+ linkstr = ''
490
+ for i, row in df.iterrows():
491
+ # content = f"Paper {i+1}: {row['title'].replace('\n',' ')}\n{row['abstract'].replace('\n',' ')}\n\n"
492
+ content = f"Paper ({row['authors'][0].split(',')[0]} et al. {row['date'].year}): {row['title']}\n{row['abstract']}\n\n"
493
+ smallauth = smallauth + f"({row['authors'][0].split(',')[0]} et al. {row['date'].year}) "
494
+ linkstr = linkstr + f"[{row['authors'][0].split(',')[0]} et al. {row['date'].year}](" + row['ADS Link'].split('](')[1] + ' \n\n'
495
+ allcontent = allcontent + content
496
+
497
+ # allcontent = allcontent + '\n Question: '+query
498
+
499
+ gen_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
500
+
501
+ messages = [("system",allcontent,),("human", query),]
502
+ smallans = gen_client.invoke(messages).content
503
+
504
+ tmplnk = linkstr.split(' \n\n')
505
+ linkdict = {}
506
+ for i in range(len(tmplnk)-1):
507
+ linkdict[tmplnk[i].split('](')[0][1:]] = tmplnk[i]
508
+
509
+ for key in linkdict.keys():
510
+ try:
511
+ smallans = smallans.replace(key, linkdict[key])
512
+ key2 = key[0:-4]+'('+key[-4:]+')'
513
+ smallans = smallans.replace(key2, linkdict[key])
514
+ except:
515
+ print('key not found', key)
516
+
517
+ return smallans, smallauth, linkstr
518
+
519
+ def compileinfo(query, atom_qns, atom_qn_ans, atom_qn_strs):
520
+
521
+ tmp = dr_compileinfo_prompt
522
+ links = ''
523
+ for i in range(len(atom_qn_ans)):
524
+ tmp = tmp + atom_qns[i] + '\n\n' + atom_qn_ans[i] + '\n\n'
525
+ links = links + atom_qn_strs[i] + '\n\n'
526
+
527
+ gen_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
528
+
529
+ messages = [("system",tmp,),("human", query),]
530
+ smallans = gen_client.invoke(messages).content
531
+ return smallans, links
532
+
533
+ def deep_research(question, top_k, ec):
534
+
535
+ full_answer = '## ' + question
536
+
537
+ gen_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
538
+ messages = [("system",prompt_qdec2,),("human", question),]
539
+ rscope_text = gen_client.invoke(messages).content
540
+
541
+ full_answer = full_answer +' \n'+ rscope_text
542
+
543
+ rscope_messages = [("system","""In the given text, what are the main atomic questions being asked? Please answer as a concise list.""",),("human", rscope_text),]
544
+ rscope_qns = gen_client.invoke(rscope_messages).content
545
+
546
+ atom_qns = []
547
+
548
+ temp = rscope_qns.split('\n')
549
+ for i in temp:
550
+ if i != '':
551
+ atom_qns.append(i)
552
+
553
+ atom_qn_dfs = []
554
+ atom_qn_ans = []
555
+ atom_qn_strs = []
556
+ for i in range(len(atom_qns)):
557
+ rs, small_df = ec.retrieve(atom_qns[i], top_k = top_k, return_scores=True)
558
+ formatted_df = ec.return_formatted_df(rs, small_df)
559
+ atom_qn_dfs.append(formatted_df)
560
+ smallans, smallauth, linkstr = getsmallans(atom_qns[i], atom_qn_dfs[i])
561
+
562
+ atom_qn_ans.append(smallans)
563
+ atom_qn_strs.append(linkstr)
564
+ full_answer = full_answer +' \n### '+atom_qns[i]
565
+ full_answer = full_answer +' \n'+smallans
566
+
567
+ finalans, finallinks = compileinfo(question, atom_qns, atom_qn_ans, atom_qn_strs)
568
+ full_answer = full_answer +' \n'+'### Summary:\n'+finalans
569
+
570
+ full_df = pd.concat(atom_qn_dfs)
571
+
572
+ rag_answer = {}
573
+ rag_answer['answer'] = full_answer
574
+
575
+ return full_df, rag_answer
576
+
577
  def run_pathfinder(query, top_k, extra_keywords, toggles, prompt_type, rag_type, ec=ec, progress=gr.Progress()):
578
 
579
  yield None, None, None, None, None
 
601
  ec.hyde = True
602
  ec.rerank = True
603
 
604
+ if prompt_type == "Deep Research (BETA)":
605
+ formatted_df, rag_answer = deep_research(query, top_k = top_k, ec=ec)
606
+ yield formatted_df, rag_answer['answer'], None, None, None
607
+
608
+ else:
609
+ # progress(0.2, desc=search_text_list[np.random.choice(len(search_text_list))])
610
+ rs, small_df = ec.retrieve(query, top_k = top_k, return_scores=True)
611
+ formatted_df = ec.return_formatted_df(rs, small_df)
612
+ yield formatted_df, None, None, None, None
613
+
614
+ # progress(0.4, desc=gen_text_list[np.random.choice(len(gen_text_list))])
615
+ rag_answer = run_rag_qa(query, formatted_df, prompt_type)
616
+ yield formatted_df, rag_answer['answer'], None, None, None
617
 
618
+ # progress(0.6, desc="Generating consensus")
619
  consensus_answer = evaluate_overall_consensus(query, [formatted_df['abstract'][i+1] for i in range(len(formatted_df))])
620
  consensus = '## Consensus \n'+consensus_answer.consensus + '\n\n'+consensus_answer.explanation + '\n\n > Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score
621
  yield formatted_df, rag_answer['answer'], consensus, None, None
622
 
623
+ # progress(0.8, desc="Analyzing question type")
624
  question_type_gen = guess_question_type(query)
625
  if '<categorization>' in question_type_gen:
626
  question_type_gen = question_type_gen.split('<categorization>')[1]
 
630
  qn_type = question_type_gen
631
  yield formatted_df, rag_answer['answer'], consensus, qn_type, None
632
 
633
+ # progress(1.0, desc="Visualizing embeddings")
634
  fig = make_embedding_plot(formatted_df, top_k, consensus_answer)
635
 
636
  yield formatted_df, rag_answer['answer'], consensus, qn_type, fig
 
650
  with gr.Tab("pathfinder"):
651
  with gr.Accordion("What is Pathfinder? / How do I use it?", open=False):
652
  gr.Markdown(pathfinder_text)
653
+ img2 = gr.Image("local_files/galaxy_worldmap_kiyer-min.png")
654
 
655
  with gr.Row():
656
  query = gr.Textbox(label="Ask me anything")
 
659
  top_k = gr.Slider(1, 30, step=1, value=10, label="top-k", info="Number of papers to retrieve")
660
  keywords = gr.Textbox(label="Optional Keywords (comma-separated)",value="")
661
  toggles = gr.CheckboxGroup(["Keywords", "Time", "Citations"], label="Weight by", info="weighting retrieved papers",value=['Keywords'])
662
+ prompt_type = gr.Radio(choices=["Single-paper", "Multi-paper", "Bibliometric", "Broad but nuanced","Deep Research (BETA)"], label="Prompt Specialization", value='Multi-paper')
663
  rag_type = gr.Radio(choices=["Semantic Search", "Semantic + HyDE", "Semantic + CoHERE", "Semantic + HyDE + CoHERE"], label="RAG Method",value='Semantic + HyDE + CoHERE')
664
  with gr.Column(scale=2, min_width=300):
665
  img1 = gr.Image("local_files/pathfinder_logo.png")