kiyer commited on
Commit
0d72411
·
1 Parent(s): 9a6132f

update qa sources

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
local_files/.DS_Store CHANGED
Binary files a/local_files/.DS_Store and b/local_files/.DS_Store differ
 
pages/{3_qa_sources.py → 3_qa_sources_v1.py} RENAMED
File without changes
pages/3_qa_sources_v2.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # set the environment variables needed for openai package to know to reach out to azure
2
+ import os
3
+ import datetime
4
+ import faiss
5
+ import streamlit as st
6
+ import feedparser
7
+ import urllib
8
+ import cloudpickle as cp
9
+ import pickle
10
+ from urllib.request import urlopen
11
+ from summa import summarizer
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ import requests
15
+ import json
16
+
17
+ from langchain.document_loaders import TextLoader
18
+ from langchain.indexes import VectorstoreIndexCreator
19
+ from langchain_openai import AzureOpenAIEmbeddings
20
+ from langchain.llms import OpenAI
21
+ from langchain_openai import AzureChatOpenAI
22
+ from langchain import hub
23
+ from langchain_core.prompts import PromptTemplate
24
+ from langchain_core.runnables import RunnablePassthrough
25
+ from langchain_core.output_parsers import StrOutputParser
26
+ from langchain_core.runnables import RunnableParallel
27
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
28
+ from langchain_community.vectorstores import Chroma
29
+
30
+ os.environ["OPENAI_API_TYPE"] = "azure"
31
+ os.environ["AZURE_ENDPOINT"] = st.secrets["endpoint1"]
32
+ os.environ["OPENAI_API_KEY"] = st.secrets["key1"]
33
+ os.environ["OPENAI_API_VERSION"] = "2023-05-15"
34
+
35
+ embeddings = AzureOpenAIEmbeddings(
36
+ deployment="embedding",
37
+ model="text-embedding-ada-002",
38
+ azure_endpoint=st.secrets["endpoint1"],
39
+ )
40
+
41
+ llm = AzureChatOpenAI(
42
+ deployment_name="gpt4_small",
43
+ openai_api_version="2023-12-01-preview",
44
+ azure_endpoint=st.secrets["endpoint2"],
45
+ openai_api_key=st.secrets["key2"],
46
+ openai_api_type="azure",
47
+ temperature=0.
48
+ )
49
+
50
+
51
+ @st.cache_data
52
+ def get_feeds_data(url):
53
+ # data = cp.load(urlopen(url))
54
+ with open(url, "rb") as fp:
55
+ data = pickle.load(fp)
56
+ st.sidebar.success("Loaded data")
57
+ return data
58
+
59
+ # feeds_link = "https://drive.google.com/uc?export=download&id=1-IPk1voyUM9VqnghwyVrM1dY6rFnn1S_"
60
+ # embed_link = "https://dl.dropboxusercontent.com/s/ob2betm29qrtb8v/astro_ph_ga_feeds_ada_embedding_18-Apr-2023.pkl?dl=0"
61
+ dateval = "27-Jun-2023"
62
+ feeds_link = "local_files/astro_ph_ga_feeds_upto_"+dateval+".pkl"
63
+ embed_link = "local_files/astro_ph_ga_feeds_ada_embedding_"+dateval+".pkl"
64
+ gal_feeds = get_feeds_data(feeds_link)
65
+ arxiv_ada_embeddings = get_feeds_data(embed_link)
66
+
67
+ @st.cache_data
68
+ def get_embedding_data(url):
69
+ # data = cp.load(urlopen(url))
70
+ with open(url, "rb") as fp:
71
+ data = pickle.load(fp)
72
+ st.sidebar.success("Fetched data from API!")
73
+ return data
74
+
75
+ # url = "https://drive.google.com/uc?export=download&id=1133tynMwsfdR1wxbkFLhbES3FwDWTPjP"
76
+ url = "local_files/astro_ph_ga_embedding_"+dateval+".pkl"
77
+ e2d = get_embedding_data(url)
78
+ # e2d, _, _, _, _ = get_embedding_data(url)
79
+
80
+ ctr = -1
81
+ num_chunks = len(gal_feeds)
82
+ all_text, all_titles, all_arxivid, all_links, all_authors = [], [], [], [], []
83
+
84
+ for nc in range(num_chunks):
85
+
86
+ for i in range(len(gal_feeds[nc].entries)):
87
+ text = gal_feeds[nc].entries[i].summary
88
+ text = text.replace('\n', ' ')
89
+ text = text.replace('\\', '')
90
+ all_text.append(text)
91
+ all_titles.append(gal_feeds[nc].entries[i].title)
92
+ all_arxivid.append(gal_feeds[nc].entries[i].id.split('/')[-1][0:-2])
93
+ all_links.append(gal_feeds[nc].entries[i].links[1].href)
94
+ all_authors.append(gal_feeds[nc].entries[i].authors)
95
+
96
+ d = arxiv_ada_embeddings.shape[1] # dimension
97
+ nb = arxiv_ada_embeddings.shape[0] # database size
98
+ xb = arxiv_ada_embeddings.astype('float32')
99
+ index = faiss.IndexFlatL2(d)
100
+ index.add(xb)
101
+
102
+ def run_simple_query(search_query = 'all:sed+fitting', max_results = 10, start = 0, sort_by = 'lastUpdatedDate', sort_order = 'descending'):
103
+ """
104
+ Query ArXiv to return search results for a particular query
105
+ Parameters
106
+ ----------
107
+ query: str
108
+ query term. use prefixes ti, au, abs, co, jr, cat, m, id, all as applicable.
109
+ max_results: int, default = 10
110
+ number of results to return. numbers > 1000 generally lead to timeouts
111
+ start: int, default = 0
112
+ start index for results reported. use this if you're interested in running chunks.
113
+ Returns
114
+ -------
115
+ feed: dict
116
+ object containing requested results parsed with feedparser
117
+ Notes
118
+ -----
119
+ add functionality for chunk parsing, as well as storage and retreival
120
+ """
121
+
122
+ base_url = 'http://export.arxiv.org/api/query?';
123
+ query = 'search_query=%s&start=%i&max_results=%i&sortBy=%s&sortOrder=%s' % (search_query,
124
+ start,
125
+ max_results,sort_by,sort_order)
126
+
127
+ response = urllib.request.urlopen(base_url+query).read()
128
+ feed = feedparser.parse(response)
129
+ return feed
130
+
131
+ def find_papers_by_author(auth_name):
132
+
133
+ doc_ids = []
134
+ for doc_id in range(len(all_authors)):
135
+ for auth_id in range(len(all_authors[doc_id])):
136
+ if auth_name.lower() in all_authors[doc_id][auth_id]['name'].lower():
137
+ print('Doc ID: ',doc_id, ' | arXiv: ', all_arxivid[doc_id], '| ', all_titles[doc_id],' | Author entry: ', all_authors[doc_id][auth_id]['name'])
138
+ doc_ids.append(doc_id)
139
+
140
+ return doc_ids
141
+
142
+ def faiss_based_indices(input_vector, nindex=10):
143
+ xq = input_vector.reshape(-1,1).T.astype('float32')
144
+ D, I = index.search(xq, nindex)
145
+ return I[0], D[0]
146
+
147
+ def list_similar_papers_v2(model_data,
148
+ doc_id = [], input_type = 'doc_id',
149
+ show_authors = False, show_summary = False,
150
+ return_n = 10):
151
+
152
+ arxiv_ada_embeddings, embeddings, all_titles, all_abstracts, all_authors = model_data
153
+
154
+ if input_type == 'doc_id':
155
+ print('Doc ID: ',doc_id,', title: ',all_titles[doc_id])
156
+ # inferred_vector = model.infer_vector(train_corpus[doc_id].words)
157
+ inferred_vector = arxiv_ada_embeddings[doc_id,0:]
158
+ start_range = 1
159
+ elif input_type == 'arxiv_id':
160
+ print('ArXiv id: ',doc_id)
161
+ arxiv_query_feed = run_simple_query(search_query='id:'+str(doc_id))
162
+ if len(arxiv_query_feed.entries) == 0:
163
+ print('error: arxiv id not found.')
164
+ return
165
+ else:
166
+ print('Title: '+arxiv_query_feed.entries[0].title)
167
+ inferred_vector = np.array(embeddings.embed_query(arxiv_query_feed.entries[0].summary))
168
+ start_range = 0
169
+ elif input_type == 'keywords':
170
+ inferred_vector = np.array(embeddings.embed_query(doc_id))
171
+ start_range = 0
172
+ else:
173
+ print('unrecognized input type.')
174
+ return
175
+
176
+ sims, dists = faiss_based_indices(inferred_vector, return_n+2)
177
+ textstr = ''
178
+ abstracts_relevant = []
179
+ fhdrs = []
180
+
181
+ for i in range(start_range,start_range+return_n):
182
+
183
+ abstracts_relevant.append(all_text[sims[i]])
184
+ fhdr = all_authors[sims[i]][0]['name'].split()[-1] + all_arxivid[sims[i]][0:2] +'_'+ all_arxivid[sims[i]]
185
+ fhdrs.append(fhdr)
186
+ textstr = textstr + str(i+1)+'. **'+ all_titles[sims[i]] +'** (Distance: %.2f' %dists[i]+') \n'
187
+ textstr = textstr + '**ArXiv:** ['+all_arxivid[sims[i]]+'](https://arxiv.org/abs/'+all_arxivid[sims[i]]+') \n'
188
+ if show_authors == True:
189
+ textstr = textstr + '**Authors:** '
190
+ temp = all_authors[sims[i]]
191
+ for ak in range(len(temp)):
192
+ if ak < len(temp)-1:
193
+ textstr = textstr + temp[ak].name + ', '
194
+ else:
195
+ textstr = textstr + temp[ak].name + ' \n'
196
+ if show_summary == True:
197
+ textstr = textstr + '**Summary:** '
198
+ text = all_text[sims[i]]
199
+ text = text.replace('\n', ' ')
200
+ textstr = textstr + summarizer.summarize(text) + ' \n'
201
+ if show_authors == True or show_summary == True:
202
+ textstr = textstr + ' '
203
+ textstr = textstr + ' \n'
204
+ return textstr, abstracts_relevant, fhdrs, sims
205
+
206
+
207
+ def generate_chat_completion(messages, model="gpt-4", temperature=1, max_tokens=None):
208
+ headers = {
209
+ "Content-Type": "application/json",
210
+ "Authorization": f"Bearer {openai.api_key}",
211
+ }
212
+
213
+ data = {
214
+ "model": model,
215
+ "messages": messages,
216
+ "temperature": temperature,
217
+ }
218
+
219
+ if max_tokens is not None:
220
+ data["max_tokens"] = max_tokens
221
+ response = requests.post(API_ENDPOINT, headers=headers, data=json.dumps(data))
222
+ if response.status_code == 200:
223
+ return response.json()["choices"][0]["message"]["content"]
224
+ else:
225
+ raise Exception(f"Error {response.status_code}: {response.text}")
226
+
227
+ model_data = [arxiv_ada_embeddings, embeddings, all_titles, all_text, all_authors]
228
+
229
+ def format_docs(docs):
230
+ return "\n\n".join(doc.page_content for doc in docs)
231
+
232
+ def get_textstr(i, show_authors=False, show_summary=False):
233
+ textstr = ''
234
+ textstr = '**'+ all_titles[i] +'** \n'
235
+ textstr = textstr + '**ArXiv:** ['+all_arxivid[i]+'](https://arxiv.org/abs/'+all_arxivid[i]+') \n'
236
+ if show_authors == True:
237
+ textstr = textstr + '**Authors:** '
238
+ temp = all_authors[i]
239
+ for ak in range(len(temp)):
240
+ if ak < len(temp)-1:
241
+ textstr = textstr + temp[ak].name + ', '
242
+ else:
243
+ textstr = textstr + temp[ak].name + ' \n'
244
+ if show_summary == True:
245
+ textstr = textstr + '**Summary:** '
246
+ text = all_text[i]
247
+ text = text.replace('\n', ' ')
248
+ textstr = textstr + summarizer.summarize(text) + ' \n'
249
+ if show_authors == True or show_summary == True:
250
+ textstr = textstr + ' '
251
+ textstr = textstr + ' \n'
252
+
253
+ return textstr
254
+
255
+
256
+ def run_rag(query, return_n = 10, show_authors = True, show_summary = True):
257
+
258
+ sims, absts, fhdrs, simids = list_similar_papers_v2(model_data,
259
+ doc_id = query,
260
+ input_type='keywords',
261
+ show_authors = show_authors, show_summary = show_summary,
262
+ return_n = return_n)
263
+
264
+ temp_abst = ''
265
+ loaders = []
266
+ for i in range(len(absts)):
267
+ temp_abst = absts[i]
268
+
269
+ try:
270
+ text_file = open("absts/"+fhdrs[i]+".txt", "w")
271
+ except:
272
+ os.mkdir('absts')
273
+ text_file = open("absts/"+fhdrs[i]+".txt", "w")
274
+ n = text_file.write(temp_abst)
275
+ text_file.close()
276
+ loader = TextLoader("absts/"+fhdrs[i]+".txt")
277
+ loaders.append(loader)
278
+
279
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
280
+ splits = text_splitter.split_documents([loader.load()[0] for loader in loaders])
281
+ vectorstore = Chroma.from_documents(documents=splits, embedding=embeddings)
282
+ retriever = vectorstore.as_retriever()
283
+
284
+ template = """You are an assistant with expertise in astrophysics for question-answering tasks.
285
+ Use the following pieces of retrieved context from the literature to answer the question.
286
+ If you don't know the answer, just say that you don't know.
287
+ Use six sentences maximum and keep the answer concise.
288
+
289
+ {context}
290
+
291
+ Question: {question}
292
+
293
+ Answer:"""
294
+ custom_rag_prompt = PromptTemplate.from_template(template)
295
+
296
+ rag_chain_from_docs = (
297
+ RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
298
+ | custom_rag_prompt
299
+ | llm
300
+ | StrOutputParser()
301
+ )
302
+
303
+ rag_chain_with_source = RunnableParallel(
304
+ {"context": retriever, "question": RunnablePassthrough()}
305
+ ).assign(answer=rag_chain_from_docs)
306
+
307
+ rag_answer = rag_chain_with_source.invoke(query)
308
+
309
+ st.markdown('### User query: '+query)
310
+
311
+ st.markdown(rag_answer['answer'])
312
+ opstr = '#### Primary sources: \n'
313
+ srcnames = []
314
+ for i in range(len(rag_answer['context'])):
315
+ srcnames.append(rag_answer['context'][0].metadata['source'])
316
+
317
+ srcnames = np.unique(srcnames)
318
+ srcindices = []
319
+ for i in range(len(srcnames)):
320
+ temp = srcnames[i].split('_')[1]
321
+ srcindices.append(int(srcnames[i].split('_')[0].split('/')[1]))
322
+ if int(temp[-2:]) < 40:
323
+ temp = temp[0:-2] + ' et al. 20' + temp[-2:]
324
+ else:
325
+ temp = temp[0:-2] + ' et al. 19' + temp[-2:]
326
+ temp = '['+temp+']('+all_links[int(srcnames[i].split('_')[0].split('/')[1])]+')'
327
+ st.markdown(temp)
328
+ simids = np.array(srcindices)
329
+
330
+ fig = plt.figure(figsize=(9,9))
331
+ plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
332
+ plt.scatter(e2d[simids,0], e2d[simids,1],s=30)
333
+ plt.scatter(e2d[abs_indices,0], e2d[abs_indices,1],s=100,color='k',marker='d')
334
+ st.pyplot(fig)
335
+
336
+ st.markdown('\n #### List of relevant papers:')
337
+ st.markdown(sims)
338
+
339
+ return rag_answer
340
+
341
+ def run_query(query, return_n = 3, show_pure_answer = False, show_all_sources = True):
342
+
343
+ show_authors = True
344
+ show_summary = True
345
+ sims, absts, fhdrs, simids = list_similar_papers_v2(model_data,
346
+ doc_id = query,
347
+ input_type='keywords',
348
+ show_authors = show_authors, show_summary = show_summary,
349
+ return_n = return_n)
350
+
351
+ temp_abst = ''
352
+ loaders = []
353
+ for i in range(len(absts)):
354
+ temp_abst = absts[i]
355
+
356
+ try:
357
+ text_file = open("absts/"+fhdrs[i]+".txt", "w")
358
+ except:
359
+ os.mkdir('absts')
360
+ text_file = open("absts/"+fhdrs[i]+".txt", "w")
361
+ n = text_file.write(temp_abst)
362
+ text_file.close()
363
+ loader = TextLoader("absts/"+fhdrs[i]+".txt")
364
+ loaders.append(loader)
365
+
366
+ lc_index = VectorstoreIndexCreator().from_loaders(loaders)
367
+
368
+ st.markdown('### User query: '+query)
369
+ if show_pure_answer == True:
370
+ st.markdown('pure answer:')
371
+ st.markdown(lc_index.query(query))
372
+ st.markdown(' ')
373
+ st.markdown('#### context-based answer from sources:')
374
+ output = lc_index.query_with_sources(query + ' Let\'s work this out in a step by step way to be sure we have the right answer.' ) #zero-shot in-context prompting from Zhou+22, Kojima+22
375
+ st.markdown(output['answer'])
376
+ opstr = '#### Primary sources: \n'
377
+ st.markdown(opstr)
378
+
379
+ # opstr = ''
380
+ # for i in range(len(output['sources'])):
381
+ # opstr = opstr +'\n'+ output['sources'][i]
382
+
383
+ textstr = ''
384
+ ng = len(output['sources'].split())
385
+ abs_indices = []
386
+
387
+ for i in range(ng):
388
+ if i == (ng-1):
389
+ tempid = output['sources'].split()[i].split('_')[1][0:-4]
390
+ else:
391
+ tempid = output['sources'].split()[i].split('_')[1][0:-5]
392
+ try:
393
+ abs_index = all_arxivid.index(tempid)
394
+ abs_indices.append(abs_index)
395
+ textstr = textstr + str(i+1)+'. **'+ all_titles[abs_index] +' \n'
396
+ textstr = textstr + '**ArXiv:** ['+all_arxivid[abs_index]+'](https://arxiv.org/abs/'+all_arxivid[abs_index]+') \n'
397
+ textstr = textstr + '**Authors:** '
398
+ temp = all_authors[abs_index]
399
+ for ak in range(4):
400
+ if ak < len(temp)-1:
401
+ textstr = textstr + temp[ak].name + ', '
402
+ else:
403
+ textstr = textstr + temp[ak].name + ' \n'
404
+ if len(temp) > 3:
405
+ textstr = textstr + ' et al. \n'
406
+ textstr = textstr + '**Summary:** '
407
+ text = all_text[abs_index]
408
+ text = text.replace('\n', ' ')
409
+ textstr = textstr + summarizer.summarize(text) + ' \n'
410
+ except:
411
+ textstr = textstr + output['sources'].split()[i]
412
+ # opstr = opstr + ' \n ' + output['sources'].split()[i][6:-5].split('_')[0]
413
+ # opstr = opstr + ' \n Arxiv id: ' + output['sources'].split()[i][6:-5].split('_')[1]
414
+
415
+ textstr = textstr + ' '
416
+ textstr = textstr + ' \n'
417
+ st.markdown(textstr)
418
+
419
+ fig = plt.figure(figsize=(9,9))
420
+ plt.scatter(e2d[0:,0], e2d[0:,1],s=2)
421
+ plt.scatter(e2d[simids,0], e2d[simids,1],s=30)
422
+ plt.scatter(e2d[abs_indices,0], e2d[abs_indices,1],s=100,color='k',marker='d')
423
+ st.pyplot(fig)
424
+
425
+ if show_all_sources == True:
426
+ st.markdown('\n #### Other interesting papers:')
427
+ st.markdown(sims)
428
+ return output
429
+
430
+ st.title('ArXiv-based question answering')
431
+ st.markdown('[Includes papers up to: `'+dateval+'`]')
432
+ st.markdown('Concise answers for questions using arxiv abstracts + GPT-4. Please use sparingly because it costs me money right now. You might need to wait for a few seconds for the GPT-4 query to return an answer (check top right corner to see if it is still running).')
433
+
434
+ query = st.text_input('Your question here:', value="What sersic index does a disk galaxy have?")
435
+ return_n = st.slider('How many papers should I show?', 1, 20, 10)
436
+
437
+ sims = run_query(query, return_n = return_n)