kiyer commited on
Commit
60c8258
·
1 Parent(s): ac72d36

5pm update

Browse files
Files changed (2) hide show
  1. app.py +144 -20
  2. requirements.txt +2 -1
app.py CHANGED
@@ -13,6 +13,16 @@ from collections import Counter
13
 
14
  import yaml, json, requests, sys, os, time
15
  import concurrent.futures
 
 
 
 
 
 
 
 
 
 
16
  ts = time.time()
17
 
18
 
@@ -105,6 +115,39 @@ if 'ids' not in st.session_state:
105
  st.session_state.kws = arxiv_corpus['keywords']
106
  st.toast('done caching. time taken: %.2f sec' %(time.time()-ts))
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  #----------------------------------------------------------------
110
 
@@ -527,10 +570,10 @@ else:
527
 
528
 
529
  # Function to simulate question answering (replace with actual implementation)
530
- def answer_question(question, keywords, toggles, method, question_type):
531
  # Simulated answer (replace with actual logic)
532
  # return f"Answer to '{question}' using method {method} for {question_type} question."
533
- return run_ret(question, 10)
534
 
535
 
536
  def get_papers(ids):
@@ -577,19 +620,84 @@ def run_ret(query, top_k):
577
  output_str = ''
578
  for i in rs:
579
  if rs[i] > 0.5:
580
- output_str = output_str + '---> ' + st.session_state.titles[i] + '(score: %.2f) \n' %rs[i]
581
  else:
582
- output_str = output_str + '---> ' + st.session_state.titles[i] + '(score: %.2f) \n' %rs[i]
583
  return output_str, rs
584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  # Streamlit app
586
  def main():
587
 
588
  # st.title("Question Answering App")
589
-
590
 
591
  # Sidebar (Inputs)
592
- st.sidebar.header("Inputs")
 
593
  extra_keywords = st.sidebar.text_input("Enter extra keywords (comma-separated):")
594
 
595
  st.sidebar.subheader("Toggles")
@@ -597,8 +705,8 @@ def main():
597
  toggle_b = st.sidebar.checkbox("Toggle B")
598
  toggle_c = st.sidebar.checkbox("Toggle C")
599
 
600
- method = st.sidebar.radio("Choose a method:", ["h1", "h2", "h3"])
601
- question_type = st.sidebar.selectbox("Select question type:", ["Type 1", "Type 2", "Type 3"])
602
  # store_output = st.sidebar.checkbox("Store the output")
603
 
604
 
@@ -606,7 +714,7 @@ def main():
606
 
607
  # Main page (Outputs)
608
 
609
- question = st.text_input("Ask me anything:")
610
  submit_button = st.button("Submit")
611
 
612
  if submit_button:
@@ -615,36 +723,52 @@ def main():
615
  toggles = {'A': toggle_a, 'B': toggle_b, 'C': toggle_c}
616
 
617
  # Generate outputs
618
- answer, rs = answer_question(question, keywords, toggles, method, question_type)
619
  papers_df = get_papers(rs)
620
  embedding_plot = create_embedding_plot()
621
- triggered_keywords = extract_keywords(question)
622
  consensus = estimate_consensus()
623
 
624
- # Display outputs
625
-
626
- st.subheader("Answer")
627
  st.write(answer)
 
 
 
 
 
 
628
 
629
- with st.expander("Papers used", expanded=True):
 
 
 
 
 
 
 
 
 
 
 
630
  st.dataframe(papers_df)
631
 
 
 
632
 
633
  col1, col2 = st.columns(2)
634
 
635
  with col1:
636
 
637
- st.subheader("Embedding Map")
638
- st.bokeh_chart(embedding_plot)
639
 
640
  st.subheader("Triggered Keywords")
641
  st.write(", ".join(triggered_keywords))
642
 
643
  with col2:
644
 
645
- st.subheader("Question Type")
646
- st.write(question_type)
647
-
648
  st.subheader("Consensus Estimate")
649
  st.write(f"{consensus:.2%}")
650
 
 
13
 
14
  import yaml, json, requests, sys, os, time
15
  import concurrent.futures
16
+
17
+ from langchain_community.chat_models import ChatOpenAI as openai_llm
18
+ from langchain_core.runnables import RunnableConfig
19
+ from langchain_community.callbacks import StreamlitCallbackHandler
20
+
21
+ from langchain.agents import ZeroShotAgent, Tool, AgentExecutor
22
+ from langchain_community.utilities import DuckDuckGoSearchAPIWrapper
23
+ from langchain.chains import LLMChain
24
+
25
+
26
  ts = time.time()
27
 
28
 
 
115
  st.session_state.kws = arxiv_corpus['keywords']
116
  st.toast('done caching. time taken: %.2f sec' %(time.time()-ts))
117
 
118
+
119
+ #---------------------------------------------------------------
120
+
121
+
122
+ # A hack to "clear" the previous result when submitting a new prompt. This avoids
123
+ # the "previous run's text is grayed-out but visible during rerun" Streamlit behavior.
124
+ class DirtyState:
125
+ NOT_DIRTY = "NOT_DIRTY"
126
+ DIRTY = "DIRTY"
127
+ UNHANDLED_SUBMIT = "UNHANDLED_SUBMIT"
128
+
129
+
130
+ def get_dirty_state() -> str:
131
+ return st.session_state.get("dirty_state", DirtyState.NOT_DIRTY)
132
+
133
+
134
+ def set_dirty_state(state: str) -> None:
135
+ st.session_state["dirty_state"] = state
136
+
137
+
138
+ def with_clear_container(submit_clicked: bool) -> bool:
139
+ if get_dirty_state() == DirtyState.DIRTY:
140
+ if submit_clicked:
141
+ set_dirty_state(DirtyState.UNHANDLED_SUBMIT)
142
+ st.experimental_rerun()
143
+ else:
144
+ set_dirty_state(DirtyState.NOT_DIRTY)
145
+
146
+ if submit_clicked or get_dirty_state() == DirtyState.UNHANDLED_SUBMIT:
147
+ set_dirty_state(DirtyState.DIRTY)
148
+ return True
149
+
150
+ return False
151
 
152
  #----------------------------------------------------------------
153
 
 
570
 
571
 
572
  # Function to simulate question answering (replace with actual implementation)
573
+ def answer_question(question, top_k, keywords, toggles, method, question_type):
574
  # Simulated answer (replace with actual logic)
575
  # return f"Answer to '{question}' using method {method} for {question_type} question."
576
+ return run_ret(question, top_k)
577
 
578
 
579
  def get_papers(ids):
 
620
  output_str = ''
621
  for i in rs:
622
  if rs[i] > 0.5:
623
+ output_str = output_str + '---> ' + st.session_state.abstracts[i] + '(score: %.2f) \n' %rs[i]
624
  else:
625
+ output_str = output_str + st.session_state.abstracts[i] + '(score: %.2f) \n' %rs[i]
626
  return output_str, rs
627
 
628
+ def Library(query, top_k=3):
629
+ print('get called start')
630
+ rs = ec.retrieve(query, top_k, return_scores=True)
631
+ op_docs = ''
632
+ for i in rs:
633
+ # op_docs.append(abstracts[i])
634
+ op_docs = op_docs + st.session_state.abstracts[i] + '\n\n'
635
+ # st.write(op_docs)
636
+ print('get called end')
637
+ return op_docs
638
+
639
+ search = DuckDuckGoSearchAPIWrapper()
640
+ tools = [
641
+ Tool(
642
+ name="Library",
643
+ func=Library,
644
+ description="A source of information pertinent to your question. Do not answer a question without consulting this!"
645
+ ),
646
+ Tool(
647
+ name="Search",
648
+ func=search.run,
649
+ description="useful for when you need to look up knowledge about common topics or current events",
650
+ )
651
+ ]
652
+
653
+ if 'tools' not in st.session_state:
654
+ st.session_state.tools = tools
655
+
656
+ # for another question type:
657
+ # First, find the quotes from the document that are most relevant to answering the question, and then print them in numbered order.
658
+ # Quotes should be relatively short. If there are no relevant quotes, write “No relevant quotes” instead.
659
+
660
+ gen_llm = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key)
661
+
662
+ prefix = """You are an expert astronomer and cosmologist.
663
+ Answer the following question as best you can using information from the library, but speaking in a concise and factual manner.
664
+ If you can not come up with an answer, say you do not know.
665
+ Try to break the question down into smaller steps and solve it in a logical manner.
666
+
667
+ You have access to the following tools:"""
668
+ suffix = """Begin! Remember to speak in a pedagogical and factual manner."
669
+
670
+ Question: {input}
671
+ {agent_scratchpad}"""
672
+
673
+ prompt = ZeroShotAgent.create_prompt(
674
+ st.session_state.tools, prefix=prefix, suffix=suffix, input_variables=["input", "agent_scratchpad"]
675
+ )
676
+
677
+ llm_chain = LLMChain(llm=gen_llm, prompt=prompt)
678
+
679
+ tool_names = [tool.name for tool in st.session_state.tools]
680
+ if 'agent' not in st.session_state:
681
+ agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names)
682
+ st.session_state.agent = agent
683
+
684
+ if 'agent_executor' not in st.session_state:
685
+ agent_executor = AgentExecutor.from_agent_and_tools(
686
+ agent=st.session_state.agent, tools=st.session_state.tools, verbose=True, handle_parsing_errors=True
687
+ )
688
+ st.session_state.agent_executor = agent_executor
689
+
690
+
691
+
692
  # Streamlit app
693
  def main():
694
 
695
  # st.title("Question Answering App")
696
+
697
 
698
  # Sidebar (Inputs)
699
+ st.sidebar.header("Fine-tune the search")
700
+ top_k = st.sidebar.slider("Number of papers to retrieve:", 3, 100, 10)
701
  extra_keywords = st.sidebar.text_input("Enter extra keywords (comma-separated):")
702
 
703
  st.sidebar.subheader("Toggles")
 
705
  toggle_b = st.sidebar.checkbox("Toggle B")
706
  toggle_c = st.sidebar.checkbox("Toggle C")
707
 
708
+ method = st.sidebar.radio("Choose a method:", ["Semantic search", "Semantic search + HyDE", "Semantic search + HyDE + CoHERE"])
709
+ question_type = st.sidebar.selectbox("Select question type:", ["Single paper", "Multi-paper", "Summary"])
710
  # store_output = st.sidebar.checkbox("Store the output")
711
 
712
 
 
714
 
715
  # Main page (Outputs)
716
 
717
+ query = st.text_input("Ask me anything:")
718
  submit_button = st.button("Submit")
719
 
720
  if submit_button:
 
723
  toggles = {'A': toggle_a, 'B': toggle_b, 'C': toggle_c}
724
 
725
  # Generate outputs
726
+ answer, rs = answer_question(query, top_k, keywords, toggles, method, question_type)
727
  papers_df = get_papers(rs)
728
  embedding_plot = create_embedding_plot()
729
+ triggered_keywords = extract_keywords(query)
730
  consensus = estimate_consensus()
731
 
732
+ # Display outputs
733
+ answer = st.session_state.agent_executor.run(input=query)
734
+ # st.write(answer["output"])
735
  st.write(answer)
736
+
737
+ # st.subheader("Answer")
738
+ # output_container = st.empty()
739
+ # if with_clear_container(submit_button):
740
+ # output_container = output_container.container()
741
+ # output_container.chat_message("user").write(query)
742
 
743
+ # answer_container = output_container.chat_message("pfdr", avatar="🦜")
744
+ # st_callback = StreamlitCallbackHandler(answer_container)
745
+ # # cfg = RunnableConfig()
746
+ # # cfg["callbacks"] = [st_callback]
747
+ # answer = st.session_state.agent_executor.run(input=query, callbacks=[st_callback])
748
+ # try:
749
+ # answer_container.write(answer["output"])
750
+ # except:
751
+ # answer_container.write('No final answer')
752
+ # st.write(answer)
753
+
754
+ with st.expander("Relevant papers", expanded=True):
755
  st.dataframe(papers_df)
756
 
757
+ with st.expander("Embedding map", expanded=True):
758
+ st.bokeh_chart(embedding_plot)
759
 
760
  col1, col2 = st.columns(2)
761
 
762
  with col1:
763
 
764
+ st.subheader("Question Type")
765
+ st.write(question_type)
766
 
767
  st.subheader("Triggered Keywords")
768
  st.write(", ".join(triggered_keywords))
769
 
770
  with col2:
771
 
 
 
 
772
  st.subheader("Consensus Estimate")
773
  st.write(f"{consensus:.2%}")
774
 
requirements.txt CHANGED
@@ -15,4 +15,5 @@ tiktoken
15
  chromadb
16
  streamlit-extras
17
  nltk
18
- hickle
 
 
15
  chromadb
16
  streamlit-extras
17
  nltk
18
+ cohere
19
+ duckduckgo-search