Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import streamlit as st | |
st.set_page_config(layout="wide") | |
openai_key = st.secrets["openai_key"] | |
cohere_key = st.secrets['cohere_key'] | |
import numpy as np | |
from abc import ABC, abstractmethod | |
from typing import List, Dict, Any, Tuple | |
from collections import defaultdict | |
from tqdm import tqdm | |
import pandas as pd | |
from datetime import datetime, date | |
from datasets import load_dataset, load_from_disk | |
from collections import Counter | |
import yaml, json, requests, sys, os, time | |
import concurrent.futures | |
from langchain import hub | |
from langchain_openai import ChatOpenAI as openai_llm | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_core.runnables import RunnableConfig, RunnablePassthrough, RunnableParallel | |
from langchain_core.prompts import PromptTemplate | |
from langchain_community.callbacks import StreamlitCallbackHandler | |
from langchain_community.utilities import DuckDuckGoSearchAPIWrapper | |
from langchain_community.vectorstores import Chroma | |
from langchain_community.document_loaders import TextLoader | |
from langchain.agents import create_react_agent, Tool, AgentExecutor | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain.callbacks import FileCallbackHandler | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.schema import Document | |
import instructor | |
from pydantic import BaseModel, Field | |
from typing import List, Literal | |
from nltk.corpus import stopwords | |
import nltk | |
from openai import OpenAI | |
# import anthropic | |
import cohere | |
import faiss | |
import spacy | |
from string import punctuation | |
import pytextrank | |
from prompts import * | |
ts = time.time() | |
def load_nlp(): | |
nlp = spacy.load("en_core_web_sm") | |
nlp.add_pipe("textrank") | |
try: | |
stopwords.words('english') | |
except: | |
nltk.download('stopwords') | |
stopwords.words('english') | |
return nlp | |
# @st.cache_resource | |
# def load_embeddings(): | |
# return OpenAIEmbeddings(model="text-embedding-3-small", api_key=st.secrets["openai_key"]) | |
# | |
# @st.cache_resource | |
# def load_llm(): | |
# return ChatOpenAI(temperature=0, model_name='gpt-4o-mini', openai_api_key=st.secrets["openai_key"]) | |
st.session_state.gen_llm = openai_llm(temperature=0, | |
model_name='gpt-4o-mini', | |
openai_api_key = openai_key) | |
st.session_state.consensus_client = instructor.patch(OpenAI(api_key=openai_key)) | |
st.session_state.embed_client = OpenAI(api_key = openai_key) | |
embed_model = "text-embedding-3-small" | |
st.session_state.embeddings = OpenAIEmbeddings(model = embed_model, api_key = openai_key) | |
# @st.cache_data | |
def load_arxiv_corpus(): | |
with st.spinner('loading astro-ph corpus'): | |
arxiv_corpus = load_from_disk('data/') | |
arxiv_corpus.load_faiss_index('embed', 'data/astrophindex.faiss') | |
st.toast('loaded data. time taken: %.2f sec' %(time.time()-ts)) | |
return arxiv_corpus | |
def get_keywords(text): | |
result = [] | |
pos_tag = ['PROPN', 'ADJ', 'NOUN'] | |
if 'nlp' not in st.session_state: | |
st.session_state.nlp = load_nlp() | |
doc = st.session_state.nlp(text.lower()) | |
for token in doc: | |
if(token.text in st.session_state.nlp.Defaults.stop_words or token.text in punctuation): | |
continue | |
if(token.pos_ in pos_tag): | |
result.append(token.text) | |
return result | |
class RetrievalSystem(): | |
def __init__(self): | |
self.dataset = st.session_state.arxiv_corpus | |
self.client = OpenAI(api_key = openai_key) | |
self.embed_model = "text-embedding-3-small" | |
self.generation_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key) | |
self.hyde_client = openai_llm(temperature=0.5,model_name='gpt-4o-mini', openai_api_key = openai_key) | |
self.cohere_client = cohere.Client(cohere_key) | |
def make_embedding(self, text): | |
str_embed = self.client.embeddings.create(input = [text], model = self.embed_model).data[0].embedding | |
return str_embed | |
def embed_batch(self, texts: List[str]) -> List[np.ndarray]: | |
embeddings = self.client.embeddings.create(input=texts, model=self.embed_model).data | |
return [np.array(embedding.embedding, dtype=np.float32) for embedding in embeddings] | |
def get_query_embedding(self, query): | |
return self.make_embedding(query) | |
def calc_faiss(self, query_embedding, top_k = 100): | |
# xq = query_embedding.reshape(-1,1).T.astype('float32') | |
# D, I = self.index.search(xq, top_k) | |
# return I[0], D[0] | |
tmp = self.dataset.search('embed', query_embedding, k=top_k) | |
return [tmp.indices, tmp.scores, self.dataset[tmp.indices]] | |
def rank_and_filter(self, query, query_embedding, top_k = 10, top_k_internal = 1000, return_scores=False): | |
self.weight_keywords = self.toggles["Keyword weighting"] | |
self.weight_date = self.toggles["Time weighting"] | |
self.weight_citation = self.toggles["Citation weighting"] | |
topk_indices, similarities, small_corpus = self.calc_faiss(np.array(query_embedding), top_k = top_k_internal) | |
similarities = 1/similarities # converting from a distance (less is better) to a similarity (more is better) | |
if self.weight_keywords == True: | |
query_kws = get_keywords(query) | |
input_kws = self.query_input_keywords | |
query_kws = query_kws + input_kws | |
self.query_kws = query_kws | |
sub_kws = [small_corpus['keywords'][i] for i in range(top_k_internal)] | |
kw_weight = np.zeros((len(topk_indices),)) + 0.1 | |
for k in query_kws: | |
for i in (range(len(topk_indices))): | |
for j in range(len(sub_kws[i])): | |
if k.lower() in sub_kws[i][j].lower(): | |
kw_weight[i] = kw_weight[i] + 0.1 | |
# print(i, k, sub_kws[i][j]) | |
# kw_weight = kw_weight**0.36 / np.amax(kw_weight**0.36) | |
kw_weight = kw_weight / np.amax(kw_weight) | |
else: | |
kw_weight = np.ones((len(topk_indices),)) | |
if self.weight_date == True: | |
sub_dates = [small_corpus['date'][i] for i in range(top_k_internal)] | |
date = datetime.now().date() | |
date_diff = np.array([((date - i).days / 365.) for i in sub_dates]) | |
# age_weight = (1 + np.exp(date_diff/2.1))**(-1) + 0.5 | |
age_weight = (1 + np.exp(date_diff/0.7))**(-1) | |
age_weight = age_weight / np.amax(age_weight) | |
else: | |
age_weight = np.ones((len(topk_indices),)) | |
if self.weight_citation == True: | |
# st.write('weighting by citations') | |
sub_cites = np.array([small_corpus['cites'][i] for i in range(top_k_internal)]) | |
temp = sub_cites.copy() | |
temp[sub_cites > 300] = 300. | |
cite_weight = (1 + np.exp((300-temp)/42.0))**(-1.) | |
cite_weight = cite_weight / np.amax(cite_weight) | |
else: | |
cite_weight = np.ones((len(topk_indices),)) | |
similarities = similarities * (kw_weight) * (age_weight) * (cite_weight) | |
filtered_results = [[topk_indices[i], similarities[i]] for i in range(len(similarities))] | |
top_results = sorted(filtered_results, key=lambda x: x[1], reverse=True)[:top_k] | |
top_scores = [doc[1] for doc in top_results] | |
top_indices = [doc[0] for doc in top_results] | |
small_df = self.dataset[top_indices] | |
if return_scores: | |
return {doc[0]: doc[1] for doc in top_results}, small_df | |
# Only keep the document IDs | |
top_results = [doc[0] for doc in top_results] | |
return top_results, small_df | |
def generate_doc(self, query: str): | |
prompt = """You are an expert astronomer. Given a scientific query, generate the abstract of an expert-level research paper | |
that answers the question. Stick to a maximum length of {} tokens and return just the text of the abstract and conclusion. | |
Do not include labels for any section. Use research-specific jargon.""".format(self.max_doclen) | |
messages = [("system",prompt,),("human", query),] | |
return self.hyde_client.invoke(messages).content | |
def generate_docs(self, query: str): | |
docs = [] | |
for i in range(self.generate_n): | |
docs.append(self.generate_doc(query)) | |
return docs | |
def embed_docs(self, docs: List[str]): | |
return self.embed_batch(docs) | |
def retrieve(self, query, top_k, return_scores = False, | |
embed_query=True, max_doclen=250, | |
generate_n=1, temperature=0.5, | |
rerank_top_k = 250): | |
if max_doclen * generate_n > 8191: | |
raise ValueError("Too many tokens. Please reduce max_doclen or generate_n.") | |
query_embedding = self.get_query_embedding(query) | |
if self.hyde == True: | |
self.max_doclen = max_doclen | |
self.generate_n = generate_n | |
self.hyde_client.temperature = temperature | |
self.embed_query = embed_query | |
docs = self.generate_docs(query) | |
st.expander('Abstract generated with hyde', expanded=False).write(docs) | |
doc_embeddings = self.embed_docs(docs) | |
if self.embed_query: | |
query_emb = self.embed_docs([query])[0] | |
doc_embeddings.append(query_emb) | |
query_embedding = np.mean(np.array(doc_embeddings), axis = 0) | |
if self.rerank == True: | |
top_results, small_df = self.rank_and_filter(query, | |
query_embedding, | |
rerank_top_k, | |
return_scores = False) | |
try: | |
docs_for_rerank = [small_df['abstract'][i] for i in range(rerank_top_k)] | |
if len(docs_for_rerank) == 0: | |
return [] | |
reranked_results = self.cohere_client.rerank( | |
query=query, | |
documents=docs_for_rerank, | |
model='rerank-english-v3.0', | |
top_n=top_k | |
) | |
final_results = [] | |
for result in reranked_results.results: | |
doc_id = top_results[result.index] | |
doc_text = docs_for_rerank[result.index] | |
score = float(result.relevance_score) | |
final_results.append([doc_id, "", score]) | |
final_indices = [doc[0] for doc in final_results] | |
if return_scores: | |
return {result[0]: result[2] for result in final_results}, self.dataset[final_indices] | |
return [doc[0] for doc in final_results], self.dataset[final_indices] | |
except: | |
print('heavy load, please wait 10s and try again.') | |
else: | |
top_results, small_df = self.rank_and_filter(query, | |
query_embedding, | |
top_k, | |
return_scores = return_scores) | |
return top_results, small_df | |
def return_formatted_df(self, top_results, small_df): | |
df = pd.DataFrame(small_df) | |
df = df.drop(columns=['umap_x','umap_y','cite_bibcodes','ref_bibcodes']) | |
links = ['https://ui.adsabs.harvard.edu/abs/'+i+'/abstract' for i in small_df['bibcode']] | |
scores = [top_results[i] for i in top_results] | |
indices = [i for i in top_results] | |
df.insert(1,'ADS Link',links,True) | |
df.insert(2,'Relevance',scores,True) | |
df.insert(3,'Indices',indices,True) | |
df = df[['ADS Link','Relevance','date','cites','title','authors','abstract','keywords','ads_id','Indices','embed']] | |
df.index += 1 | |
return df | |
# @st.cache_resource | |
def load_ret_system(): | |
with st.spinner('loading retrieval system...'): | |
ec = RetrievalSystem() | |
st.toast('loaded retrieval system. time taken: %.2f sec' %(time.time()-ts)) | |
return ec | |
st.image('local_files/pathfinder_logo.png') | |
st.expander("What is Pathfinder / How do I use it?", expanded=False).write( | |
""" | |
Pathfinder v2.0 is a framework for searching and visualizing astronomy papers on the [arXiv](https://arxiv.org/) and [ADS](https://ui.adsabs.harvard.edu/) using the context | |
sensitivity from modern large language models (LLMs) to better parse patterns in paper contexts. | |
This tool was built during the [JSALT workshop](https://www.clsp.jhu.edu/2024-jelinek-summer-workshop-on-speech-and-language-technology/) to do awesome things. | |
**👈 Use the sidebar to tweak the search parameters to get better results**. | |
### Tool summary: | |
- Please wait while the initial data loads and compiles, this takes about a minute initially. | |
This is not meant to be a replacement to existing tools like the | |
[ADS](https://ui.adsabs.harvard.edu/), | |
[arxivsorter](https://www.arxivsorter.org/), semantic search or google scholar, but rather a supplement to find papers | |
that otherwise might be missed during a literature survey. | |
It is trained on astro-ph (astrophysics of galaxies) papers up to last-year-ish mined from arxiv and supplemented with ADS metadata, | |
if you are interested in extending it please reach out! | |
Also add: feedback form, socials, literature, contact us, copyright, collaboration, etc. | |
The image below shows a representation of all the astro-ph.GA papers that can be explored in more detail | |
using the `Arxiv embedding` page. The papers tend to cluster together by similarity, and result in an | |
atlas that shows well studied (forests) and currently uncharted areas (water). | |
""" | |
) | |
st.sidebar.header("Fine-tune the search") | |
top_k = st.sidebar.slider("Number of papers to retrieve:", 1, 30, 10) | |
extra_keywords = st.sidebar.text_input("Enter extra keywords (comma-separated):") | |
keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else [] | |
st.sidebar.subheader("Toggles") | |
toggle_a = st.sidebar.toggle("Weight by keywords", value = False) | |
toggle_b = st.sidebar.toggle("Weight by date", value = False) | |
toggle_c = st.sidebar.toggle("Weight by citations", value = False) | |
toggles = {'Keyword weighting': toggle_a, 'Time weighting': toggle_b, 'Citation weighting': toggle_c} | |
method = st.sidebar.radio("Retrieval method:", ["Semantic search", "Semantic search + HyDE", "Semantic search + HyDE + CoHERE"], index=2) | |
method2 = st.sidebar.radio("Generation complexity:", ["Basic RAG","ReAct Agent"]) | |
st.session_state.top_k = top_k | |
st.session_state.keywords = keywords | |
st.session_state.toggles = toggles | |
st.session_state.method = method | |
st.session_state.method2 = method2 | |
if (method == "Semantic search"): | |
st.session_state.hyde = False | |
st.session_state.cohere = False | |
elif (method == "Semantic search + HyDE"): | |
st.session_state.hyde = True | |
st.session_state.cohere = False | |
elif (method == "Semantic search + HyDE + CoHERE"): | |
st.session_state.hyde = True | |
st.session_state.cohere = True | |
if method2 == "Basic RAG": | |
st.session_state.gen_method = 'rag' | |
elif method2 == "ReAct Agent": | |
st.session_state.gen_method = 'agent' | |
question_type = st.sidebar.selectbox("Prompt specialization:", ["Multi-paper (Default)", "Single-paper", "Bibliometric", "Broad but nuanced"]) | |
st.session_state.question_type = question_type | |
# store_output = st.sidebar.button("Save output") | |
query = st.text_input("Ask me anything:") | |
st.session_state.query = query | |
st.write(query) | |
submit_button = st.button("Run pathfinder!", key='runpfdr') | |
search_text_list = ['rooting around in the paper pile...','looking for clarity...','scanning the event horizon...','peering into the abyss...','potatoes power this ongoing search...'] | |
gen_text_list = ['making the LLM talk to the papers...','invoking arcane rituals...','gone to library, please wait...','is there really an answer to this...'] | |
if 'arxiv_corpus' not in st.session_state: | |
st.session_state.arxiv_corpus = load_arxiv_corpus() | |
# @st.fragment() | |
def run_query_ret(query): | |
tr = time.time() | |
ec = load_ret_system() | |
ec.query_input_keywords = st.session_state.keywords | |
ec.toggles = st.session_state.toggles | |
ec.hyde = st.session_state.hyde | |
ec.rerank = st.session_state.cohere | |
rs, small_df = ec.retrieve(query, top_k = st.session_state.top_k, return_scores=True) | |
formatted_df = ec.return_formatted_df(rs, small_df) | |
st.toast('got top-k papers. time taken: %.2f sec' %(time.time()-tr)) | |
return formatted_df | |
def Library(query): | |
papers_df = run_query_ret(st.session_state.query) | |
op_docs = '' | |
for i in range(len(papers_df)): | |
op_docs = op_docs + 'Paper %.0f:' %(i+1) + papers_df['title'][i] + '\n' + papers_df['abstract'][i] + '\n\n' | |
return op_docs | |
def run_agent_qa(query): | |
search = DuckDuckGoSearchAPIWrapper() | |
tools = [ | |
Tool( | |
name="Library", | |
func=Library, | |
description="A source of information pertinent to your question. Do not answer a question without consulting this!" | |
), | |
Tool( | |
name="Search", | |
func=search.run, | |
description="useful for when you need to look up knowledge about common topics or current events", | |
) | |
] | |
if 'tools' not in st.session_state: | |
st.session_state.tools = tools | |
prompt = hub.pull("hwchase17/react") | |
prompt.template = react_prompt | |
file_path = "agent_trace.txt" | |
try: | |
os.remove(file_path) | |
except: | |
pass | |
file_handler = FileCallbackHandler(file_path) | |
callback_manager=CallbackManager([file_handler]) | |
tool_names = [tool.name for tool in st.session_state.tools] | |
if 'agent' not in st.session_state: | |
# agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names) | |
agent = create_react_agent(llm=st.session_state.gen_llm, tools=tools, prompt=prompt) | |
st.session_state.agent = agent | |
if 'agent_executor' not in st.session_state: | |
agent_executor = AgentExecutor(agent=st.session_state.agent, tools=st.session_state.tools, verbose=True, handle_parsing_errors=True, callbacks=CallbackManager([file_handler])) | |
st.session_state.agent_executor = agent_executor | |
answer = st.session_state.agent_executor.invoke({"input": query,}) | |
return answer | |
def run_rag_qa(query, papers_df): | |
try: | |
loaders = [] | |
documents = [] | |
my_bar = st.progress(0, text='adding documents to LLM context') | |
for i, row in papers_df.iterrows(): | |
content = f"Paper {i+1}: {row['title']}\n{row['abstract']}\n\n" | |
metadata = {"source": row['ads_id']} | |
doc = Document(page_content=content, metadata=metadata) | |
documents.append(doc) | |
my_bar.progress((i+1)/len(papers_df), text='adding documents to LLM context') | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=150, chunk_overlap=50, add_start_index=True) | |
splits = text_splitter.split_documents(documents) | |
vectorstore = Chroma.from_documents(documents=splits, embedding=st.session_state.embeddings, collection_name='retdoc4') | |
# retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6, "fetch_k": len(splits)}) | |
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 6}) | |
if st.session_state.question_type == 'Bibliometric': | |
template = bibliometric_prompt | |
elif st.session_state.question_type == 'Single-paper': | |
template = single_paper_prompt | |
elif st.session_state.question_type == 'Broad but nuanced': | |
template = deep_knowledge_prompt | |
else: | |
template = regular_prompt | |
prompt = PromptTemplate.from_template(template) | |
def format_docs(docs): | |
return "\n\n".join(doc.page_content for doc in docs) | |
rag_chain_from_docs = ( | |
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) | |
| prompt | |
| st.session_state.gen_llm | |
| StrOutputParser() | |
) | |
rag_chain_with_source = RunnableParallel( | |
{"context": retriever, "question": RunnablePassthrough()} | |
).assign(answer=rag_chain_from_docs) | |
rag_answer = rag_chain_with_source.invoke(query, ) | |
vectorstore.delete_collection() | |
except: | |
st.subheader('heavy load! please wait 10 seconds and try again.') | |
return rag_answer | |
def guess_question_type(query: str): | |
gen_client = openai_llm(temperature=0,model_name='gpt-4o-mini', openai_api_key = openai_key) | |
messages = [("system",question_categorization_prompt,),("human", query),] | |
return gen_client.invoke(messages).content | |
class OverallConsensusEvaluation(BaseModel): | |
consensus: Literal["Strong Agreement", "Moderate Agreement", "Weak Agreement", "No Clear Consensus", "Weak Disagreement", "Moderate Disagreement", "Strong Disagreement"] = Field( | |
..., | |
description="The overall level of consensus between the query and the abstracts" | |
) | |
explanation: str = Field( | |
..., | |
description="A detailed explanation of the consensus evaluation" | |
) | |
relevance_score: float = Field( | |
..., | |
description="A score from 0 to 1 indicating how relevant the abstracts are to the query overall", | |
ge=0, | |
le=1 | |
) | |
def evaluate_overall_consensus(query: str, abstracts: List[str]) -> OverallConsensusEvaluation: | |
""" | |
Evaluates the overall consensus of the abstracts in relation to the query in a single LLM call. | |
""" | |
prompt = f""" | |
Query: {query} | |
You will be provided with {len(abstracts)} scientific abstracts. Your task is to: | |
1. Evaluate the overall consensus between the query and the abstracts. | |
2. Provide a detailed explanation of your consensus evaluation. | |
3. Assign an overall relevance score from 0 to 1, where 0 means completely irrelevant and 1 means highly relevant. | |
For the consensus evaluation, use one of the following levels: | |
Strong Agreement, Moderate Agreement, Weak Agreement, No Clear Consensus, Weak Disagreement, Moderate Disagreement, Strong Disagreement | |
Here are the abstracts: | |
{' '.join([f"Abstract {i+1}: {abstract}" for i, abstract in enumerate(abstracts)])} | |
Provide your evaluation in a structured format. | |
""" | |
response = st.session_state.consensus_client.chat.completions.create( | |
model="gpt-4o-mini", # used to be "gpt-4", | |
response_model=OverallConsensusEvaluation, | |
messages=[ | |
{"role": "system", "content": """You are an assistant with expertise in astrophysics for question-answering tasks. | |
Evaluate the overall consensus of the retrieved scientific abstracts in relation to a given query. | |
If you don't know the answer, just say that you don't know. | |
Use six sentences maximum and keep the answer concise."""}, | |
{"role": "user", "content": prompt} | |
], | |
temperature=0 | |
) | |
return response | |
def calc_outlier_flag(papers_df, top_k, cutoff_adjust = 0.1): | |
cut_dist = np.load('pfdr_arxiv_cutoff_distances.npy') - cutoff_adjust | |
pts = np.array(papers_df['embed'].tolist()) | |
centroid = np.mean(pts,0) | |
dists = np.sqrt(np.sum((pts-centroid)**2,1)) | |
outlier_flag = (dists > cut_dist[top_k-1]) | |
return outlier_flag | |
def make_embedding_plot(papers_df, consensus_answer): | |
plt_indices = np.array(papers_df['Indices'].tolist()) | |
if 'arxiv_corpus' not in st.session_state: | |
st.session_state.arxiv_corpus = load_arxiv_corpus() | |
xax = np.array(st.session_state.arxiv_corpus['umap_x']) | |
yax = np.array(st.session_state.arxiv_corpus['umap_y']) | |
outlier_flag = calc_outlier_flag(papers_df, top_k, cutoff_adjust=0.25) | |
alphas = np.ones((len(plt_indices),)) * 0.9 | |
alphas[outlier_flag] = 0.5 | |
fig = plt.figure(figsize=(9,12)) | |
plt.scatter(xax,yax, s=1, alpha=0.01, c='k') | |
plt.scatter(xax[plt_indices], yax[plt_indices], s=300*alphas**2, alpha=alphas, c='w') | |
plt.scatter(xax[plt_indices], yax[plt_indices], s=100*alphas**2, alpha=alphas, c='dodgerblue') | |
# plt.scatter(xax[plt_indices][outlier_flag], yax[plt_indices][outlier_flag], s=100, alpha=1., c='firebrick') | |
plt.axis([0,20,-4.2,18]) | |
plt.axis('off') | |
plt.title('Query: '+st.session_state.query+'\n'+r'N$_{\rm outliers}: %.0f/%.0f$, Consensus: ' %(np.sum(outlier_flag), len(outlier_flag)) + consensus_answer.consensus + ' (%.1f)' %consensus_answer.relevance_score) | |
st.pyplot(fig) | |
# --------------------------------------- | |
if st.session_state.get('runpfdr'): | |
with st.spinner(search_text_list[np.random.choice(len(search_text_list))]): | |
st.write('Settings: [Kw:',toggle_a, 'Time:',toggle_b, 'Cite:',toggle_c, '] top_k:',top_k, 'retrieval:',method) | |
papers_df = run_query_ret(st.session_state.query) | |
st.header(st.session_state.query) | |
st.subheader('top-k relevant papers:') | |
st.data_editor(papers_df, column_config = {'ADS Link':st.column_config.LinkColumn(display_text= 'https://ui.adsabs.harvard.edu/abs/(.*?)/abstract')}) | |
with st.spinner(gen_text_list[np.random.choice(len(gen_text_list))]): | |
if st.session_state.gen_method == 'agent': | |
answer = run_agent_qa(st.session_state.query) | |
answer_text = answer['output'] | |
st.subheader('Answer with '+method2) | |
st.write(answer_text) | |
file_path = "agent_trace.txt" | |
with open(file_path, 'r') as file: | |
intermediate_steps = file.read() | |
st.expander('Intermediate steps', expanded=False).write(intermediate_steps) | |
elif st.session_state.gen_method == 'rag': | |
answer = run_rag_qa(query, papers_df) | |
st.subheader('Answer with '+method2) | |
answer_text = answer['answer'] | |
st.write(answer_text) | |
query_kws = get_keywords(query) | |
input_kws = st.session_state.keywords | |
query_kws = query_kws + input_kws | |
triggered_keywords = query_kws + input_kws | |
st.write('**Triggered keywords:** `'+ "`, `".join(triggered_keywords)+'`') | |
col1, col2 = st.columns(2) | |
with col1: | |
with st.spinner("Evaluating question type"): | |
with st.expander("Question type", expanded=True): | |
st.subheader("Question type suggestion") | |
question_type_gen = guess_question_type(query) | |
if '<categorization>' in question_type_gen: | |
question_type_gen = question_type_gen.split('<categorization>')[1] | |
if '</categorization>' in question_type_gen: | |
question_type_gen = question_type_gen.split('</categorization>')[0] | |
question_type_gen = question_type_gen.replace('\n',' \n') | |
st.markdown(question_type_gen) | |
with st.spinner("Evaluating abstract consensus"): | |
with st.expander("Abstract consensus", expanded=True): | |
consensus_answer = evaluate_overall_consensus(query, [papers_df['abstract'][i] for i in range(len(papers_df))]) | |
st.subheader("Consensus: "+consensus_answer.consensus) | |
st.markdown(consensus_answer.explanation) | |
st.markdown('Relevance of retrieved papers to answer: %.1f' %consensus_answer.relevance_score) | |
with col2: | |
make_embedding_plot(papers_df, consensus_answer) | |
session_vars = { | |
"runtime": "pathfinder_v1_online", | |
"query": query, | |
"question_type": question_type, | |
'Keyword weighting': toggle_a, | |
'Time weighting': toggle_b, | |
'Citation weighting': toggle_c, | |
"rag_method" : method, | |
"gen_method" : method2, | |
"answer" : answer_text, | |
"question_type": question_type_gen, | |
"consensus": consensus_answer.explanation, | |
"topk" : list(papers_df['ads_id']), | |
"topk_scores" : list(papers_df['Relevance']), | |
"topk_papers": list(papers_df['ADS Link']), | |
} | |
def download_op(data): | |
json_string = json.dumps(data) | |
st.download_button( | |
label='Download output', | |
file_name="pathfinder_data.json", | |
mime="application/json", | |
data=json_string,) | |
# with st.sidebar: | |
download_op(session_vars) | |
else: | |
st.info("Use the sidebar to tweak the search parameters to get better results.") | |