pathfinder / pages /3_qa_sources_v2.py
kiyer's picture
update qa sources
0d72411
raw
history blame
16.8 kB
# set the environment variables needed for openai package to know to reach out to azure
import os
import datetime
import faiss
import streamlit as st
import feedparser
import urllib
import cloudpickle as cp
import pickle
from urllib.request import urlopen
from summa import summarizer
import numpy as np
import matplotlib.pyplot as plt
import requests
import json
from langchain.document_loaders import TextLoader
from langchain.indexes import VectorstoreIndexCreator
from langchain_openai import AzureOpenAIEmbeddings
from langchain.llms import OpenAI
from langchain_openai import AzureChatOpenAI
from langchain import hub
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["AZURE_ENDPOINT"] = st.secrets["endpoint1"]
os.environ["OPENAI_API_KEY"] = st.secrets["key1"]
os.environ["OPENAI_API_VERSION"] = "2023-05-15"
embeddings = AzureOpenAIEmbeddings(
deployment="embedding",
model="text-embedding-ada-002",
azure_endpoint=st.secrets["endpoint1"],
)
llm = AzureChatOpenAI(
deployment_name="gpt4_small",
openai_api_version="2023-12-01-preview",
azure_endpoint=st.secrets["endpoint2"],
openai_api_key=st.secrets["key2"],
openai_api_type="azure",
temperature=0.
)
@st.cache_data
def get_feeds_data(url):
# data = cp.load(urlopen(url))
with open(url, "rb") as fp:
data = pickle.load(fp)
st.sidebar.success("Loaded data")
return data
# feeds_link = "https://drive.google.com/uc?export=download&id=1-IPk1voyUM9VqnghwyVrM1dY6rFnn1S_"
# embed_link = "https://dl.dropboxusercontent.com/s/ob2betm29qrtb8v/astro_ph_ga_feeds_ada_embedding_18-Apr-2023.pkl?dl=0"
dateval = "27-Jun-2023"
feeds_link = "local_files/astro_ph_ga_feeds_upto_"+dateval+".pkl"
embed_link = "local_files/astro_ph_ga_feeds_ada_embedding_"+dateval+".pkl"
gal_feeds = get_feeds_data(feeds_link)
arxiv_ada_embeddings = get_feeds_data(embed_link)
@st.cache_data
def get_embedding_data(url):
# data = cp.load(urlopen(url))
with open(url, "rb") as fp:
data = pickle.load(fp)
st.sidebar.success("Fetched data from API!")
return data
# url = "https://drive.google.com/uc?export=download&id=1133tynMwsfdR1wxbkFLhbES3FwDWTPjP"
url = "local_files/astro_ph_ga_embedding_"+dateval+".pkl"
e2d = get_embedding_data(url)
# e2d, _, _, _, _ = get_embedding_data(url)
ctr = -1
num_chunks = len(gal_feeds)
all_text, all_titles, all_arxivid, all_links, all_authors = [], [], [], [], []
for nc in range(num_chunks):
for i in range(len(gal_feeds[nc].entries)):
text = gal_feeds[nc].entries[i].summary
text = text.replace('\n', ' ')
text = text.replace('\\', '')
all_text.append(text)
all_titles.append(gal_feeds[nc].entries[i].title)
all_arxivid.append(gal_feeds[nc].entries[i].id.split('/')[-1][0:-2])
all_links.append(gal_feeds[nc].entries[i].links[1].href)
all_authors.append(gal_feeds[nc].entries[i].authors)
d = arxiv_ada_embeddings.shape[1] # dimension
nb = arxiv_ada_embeddings.shape[0] # database size
xb = arxiv_ada_embeddings.astype('float32')
index = faiss.IndexFlatL2(d)
index.add(xb)
def run_simple_query(search_query = 'all:sed+fitting', max_results = 10, start = 0, sort_by = 'lastUpdatedDate', sort_order = 'descending'):
"""
Query ArXiv to return search results for a particular query
Parameters
----------
query: str
query term. use prefixes ti, au, abs, co, jr, cat, m, id, all as applicable.
max_results: int, default = 10
number of results to return. numbers > 1000 generally lead to timeouts
start: int, default = 0
start index for results reported. use this if you're interested in running chunks.
Returns
-------
feed: dict
object containing requested results parsed with feedparser
Notes
-----
add functionality for chunk parsing, as well as storage and retreival
"""
base_url = 'http://export.arxiv.org/api/query?';
query = 'search_query=%s&start=%i&max_results=%i&sortBy=%s&sortOrder=%s' % (search_query,
start,
max_results,sort_by,sort_order)
response = urllib.request.urlopen(base_url+query).read()
feed = feedparser.parse(response)
return feed
def find_papers_by_author(auth_name):
doc_ids = []
for doc_id in range(len(all_authors)):
for auth_id in range(len(all_authors[doc_id])):
if auth_name.lower() in all_authors[doc_id][auth_id]['name'].lower():
print('Doc ID: ',doc_id, ' | arXiv: ', all_arxivid[doc_id], '| ', all_titles[doc_id],' | Author entry: ', all_authors[doc_id][auth_id]['name'])
doc_ids.append(doc_id)
return doc_ids
def faiss_based_indices(input_vector, nindex=10):
xq = input_vector.reshape(-1,1).T.astype('float32')
D, I = index.search(xq, nindex)
return I[0], D[0]
def list_similar_papers_v2(model_data,
doc_id = [], input_type = 'doc_id',
show_authors = False, show_summary = False,
return_n = 10):
arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data
if input_type == 'doc_id':
print('Doc ID: ',doc_id,', title: ',all_titles[doc_id])
# inferred_vector = model.infer_vector(train_corpus[doc_id].words)
inferred_vector = arxiv_ada_embeddings[doc_id,0:]
start_range = 1
elif input_type == 'arxiv_id':
print('ArXiv id: ',doc_id)
arxiv_query_feed = run_simple_query(search_query='id:'+str(doc_id))
if len(arxiv_query_feed.entries) == 0:
print('error: arxiv id not found.')
return
else:
print('Title: '+arxiv_query_feed.entries[0].title)
inferred_vector = np.array(embeddings.embed_query(arxiv_query_feed.entries[0].summary))
start_range = 0
elif input_type == 'keywords':
inferred_vector = np.array(embeddings.embed_query(doc_id))
start_range = 0
else:
print('unrecognized input type.')
return
sims, dists = faiss_based_indices(inferred_vector, return_n+2)
textstr = ''
abstracts_relevant = []
fhdrs = []
for i in range(start_range,start_range+return_n):
abstracts_relevant.append(all_text[sims[i]])
fhdr = all_authors[sims[i]][0]['name'].split()[-1] + all_arxivid[sims[i]][0:2] +'_'+ all_arxivid[sims[i]]
fhdrs.append(fhdr)
textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n'
textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n'
if show_authors == True:
textstr = textstr + '**Authors:** '
temp = all_authors[sims[i]]
for ak in range(len(temp)):
if ak < len(temp)-1:
textstr = textstr + temp[ak].name + ', '
else:
textstr = textstr + temp[ak].name + ' \n'
if show_summary == True:
textstr = textstr + '**Summary:** '
text = all_text[sims[i]]
text = text.replace('\n', ' ')
textstr = textstr + summarizer.summarize(text) + ' \n'
if show_authors == True or show_summary == True:
textstr = textstr + ' '
textstr = textstr + ' \n'
return textstr, abstracts_relevant, fhdrs, sims
def generate_chat_completion(messages, model="gpt-4", temperature=1, max_tokens=None):
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openai.api_key}",
}
data = {
"model": model,
"messages": messages,
"temperature": temperature,
}
if max_tokens is not None:
data["max_tokens"] = max_tokens
response = requests.post(API_ENDPOINT, headers=headers, data=json.dumps(data))
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"]
else:
raise Exception(f"Error {response.status_code}: {response.text}")
model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors]
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)
def get_textstr(i, show_authors=False, show_summary=False):
textstr = ''
textstr = '**'+ all_titles[i] +'** \n'
textstr = textstr + '**ArXiv:** ['+all_arxivid[i]+'](https://arxiv.org/abs/'+all_arxivid[i]+') \n'
if show_authors == True:
textstr = textstr + '**Authors:** '
temp = all_authors[i]
for ak in range(len(temp)):
if ak < len(temp)-1:
textstr = textstr + temp[ak].name + ', '
else:
textstr = textstr + temp[ak].name + ' \n'
if show_summary == True:
textstr = textstr + '**Summary:** '
text = all_text[i]
text = text.replace('\n', ' ')
textstr = textstr + summarizer.summarize(text) + ' \n'
if show_authors == True or show_summary == True:
textstr = textstr + ' '
textstr = textstr + ' \n'
return textstr
def run_rag(query, return_n = 10, show_authors = True, show_summary = True):
sims, absts, fhdrs, simids = list_similar_papers_v2(model_data,
doc_id = query,
input_type='keywords',
show_authors = show_authors, show_summary = show_summary,
return_n = return_n)
temp_abst = ''
loaders = []
for i in range(len(absts)):
temp_abst = absts[i]
try:
text_file = open("absts/"+fhdrs[i]+".txt", "w")
except:
os.mkdir('absts')
text_file = open("absts/"+fhdrs[i]+".txt", "w")
n = text_file.write(temp_abst)
text_file.close()
loader = TextLoader("absts/"+fhdrs[i]+".txt")
loaders.append(loader)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
splits = text_splitter.split_documents([loader.load()[0] for loader in loaders])
vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
retriever = vectorstore.as_retriever()
template = """You are an assistant with expertise in astrophysics for question-answering tasks.
Use the following pieces of retrieved context from the literature to answer the question.
If you don't know the answer, just say that you don't know.
Use six sentences maximum and keep the answer concise.
{context}
Question: {question}
Answer:"""
custom_rag_prompt = PromptTemplate.from_template(template)
rag_chain_from_docs = (
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
| custom_rag_prompt
| llm
| StrOutputParser()
)
rag_chain_with_source = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain_from_docs)
rag_answer = rag_chain_with_source.invoke(query)
st.markdown('### User query: '+query)
st.markdown(rag_answer['answer'])
opstr = '#### Primary sources: \n'
srcnames = []
for i in range(len(rag_answer['context'])):
srcnames.append(rag_answer['context'][0].metadata['source'])
srcnames = np.unique(srcnames)
srcindices = []
for i in range(len(srcnames)):
temp = srcnames[i].split('_')[1]
srcindices.append(int(srcnames[i].split('_')[0].split('/')[1]))
if int(temp[-2:]) < 40:
temp = temp[0:-2] + ' et al. 20' + temp[-2:]
else:
temp = temp[0:-2] + ' et al. 19' + temp[-2:]
temp = '['+temp+']('+all_links[int(srcnames[i].split('_')[0].split('/')[1])]+')'
st.markdown(temp)
simids = np.array(srcindices)
fig = plt.figure(figsize=(9,9))
plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
plt.scatter(e2d[simids,0], e2d[simids,1],s=30)
plt.scatter(e2d[abs_indices,0], e2d[abs_indices,1],s=100,color='k',marker='d')
st.pyplot(fig)
st.markdown('\n #### List of relevant papers:')
st.markdown(sims)
return rag_answer
def run_query(query, return_n = 3, show_pure_answer = False, show_all_sources = True):
show_authors = True
show_summary = True
sims, absts, fhdrs, simids = list_similar_papers_v2(model_data,
doc_id = query,
input_type='keywords',
show_authors = show_authors, show_summary = show_summary,
return_n = return_n)
temp_abst = ''
loaders = []
for i in range(len(absts)):
temp_abst = absts[i]
try:
text_file = open("absts/"+fhdrs[i]+".txt", "w")
except:
os.mkdir('absts')
text_file = open("absts/"+fhdrs[i]+".txt", "w")
n = text_file.write(temp_abst)
text_file.close()
loader = TextLoader("absts/"+fhdrs[i]+".txt")
loaders.append(loader)
lc_index = VectorstoreIndexCreator().from_loaders(loaders)
st.markdown('### User query: '+query)
if show_pure_answer == True:
st.markdown('pure answer:')
st.markdown(lc_index.query(query))
st.markdown(' ')
st.markdown('#### context-based answer from sources:')
output = lc_index.query_with_sources(query + ' Let\'s work this out in a step by step way to be sure we have the right answer.' ) #zero-shot in-context prompting from Zhou+22, Kojima+22
st.markdown(output['answer'])
opstr = '#### Primary sources: \n'
st.markdown(opstr)
# opstr = ''
# for i in range(len(output['sources'])):
# opstr = opstr +'\n'+ output['sources'][i]
textstr = ''
ng = len(output['sources'].split())
abs_indices = []
for i in range(ng):
if i == (ng-1):
tempid = output['sources'].split()[i].split('_')[1][0:-4]
else:
tempid = output['sources'].split()[i].split('_')[1][0:-5]
try:
abs_index = all_arxivid.index(tempid)
abs_indices.append(abs_index)
textstr = textstr + str(i+1)+'. **'+ all_titles[abs_index] +' \n'
textstr = textstr + '**ArXiv:** ['+all_arxivid[abs_index]+'](https://arxiv.org/abs/'+all_arxivid[abs_index]+') \n'
textstr = textstr + '**Authors:** '
temp = all_authors[abs_index]
for ak in range(4):
if ak < len(temp)-1:
textstr = textstr + temp[ak].name + ', '
else:
textstr = textstr + temp[ak].name + ' \n'
if len(temp) > 3:
textstr = textstr + ' et al. \n'
textstr = textstr + '**Summary:** '
text = all_text[abs_index]
text = text.replace('\n', ' ')
textstr = textstr + summarizer.summarize(text) + ' \n'
except:
textstr = textstr + output['sources'].split()[i]
# opstr = opstr + ' \n ' + output['sources'].split()[i][6:-5].split('_')[0]
# opstr = opstr + ' \n Arxiv id: ' + output['sources'].split()[i][6:-5].split('_')[1]
textstr = textstr + ' '
textstr = textstr + ' \n'
st.markdown(textstr)
fig = plt.figure(figsize=(9,9))
plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
plt.scatter(e2d[simids,0], e2d[simids,1],s=30)
plt.scatter(e2d[abs_indices,0], e2d[abs_indices,1],s=100,color='k',marker='d')
st.pyplot(fig)
if show_all_sources == True:
st.markdown('\n #### Other interesting papers:')
st.markdown(sims)
return output
st.title('ArXiv-based question answering')
st.markdown('[Includes papers up to: `'+dateval+'`]')
st.markdown('Concise answers for questions using arxiv abstracts + GPT-4. Please use sparingly because it costs me money right now. You might need to wait for a few seconds for the GPT-4 query to return an answer (check top right corner to see if it is still running).')
query = st.text_input('Your question here:', value="What sersic index does a disk galaxy have?")
return_n = st.slider('How many papers should I show?', 1, 20, 10)
sims = run_query(query, return_n = return_n)