kiyer commited on
Commit
ac72d36
·
1 Parent(s): ea7a22d

cleaning up files

Browse files
app.py CHANGED
@@ -11,10 +11,17 @@ from datetime import datetime, date
11
  from datasets import load_dataset, load_from_disk
12
  from collections import Counter
13
 
14
- import yaml, json, requests, sys, os, time, hickle
15
  import concurrent.futures
16
  ts = time.time()
17
 
 
 
 
 
 
 
 
18
  from nltk.corpus import stopwords
19
  import nltk
20
  from openai import OpenAI
@@ -39,8 +46,6 @@ from bokeh.plotting import figure
39
  from bokeh.models import ColumnDataSource
40
  from bokeh.palettes import Spectral10
41
 
42
- # try to load the data, if it doesn't work, pull from huggingface and make the pickle files
43
-
44
  st.image('local_files/pathfinder_logo.png')
45
 
46
  st.expander("About", expanded=False).write(
@@ -75,16 +80,21 @@ st.expander("About", expanded=False).write(
75
 
76
 
77
 
 
78
  if 'arxiv_corpus' not in st.session_state:
79
  with st.spinner('loading data...'):
80
  try:
81
  arxiv_corpus = load_from_disk('data/')
 
82
  except:
83
  st.write('downloading data')
84
  arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data',split='train')
 
85
  arxiv_corpus.save_to_disk('data/')
86
  st.session_state.arxiv_corpus = arxiv_corpus
87
  st.toast('loaded arxiv corpus')
 
 
88
 
89
  if 'ids' not in st.session_state:
90
  st.session_state.ids = arxiv_corpus['ads_id']
@@ -92,24 +102,452 @@ if 'ids' not in st.session_state:
92
  st.session_state.abstracts = arxiv_corpus['abstract']
93
  st.session_state.cites = arxiv_corpus['cites']
94
  st.session_state.years = arxiv_corpus['date']
95
- st.toast('done caching. time:taken: {}'.format(time.time()-ts))
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  else:
99
- arxiv_corpus = st.session_state.arxiv_corpus
 
 
 
100
  # Function to simulate question answering (replace with actual implementation)
101
  def answer_question(question, keywords, toggles, method, question_type):
102
  # Simulated answer (replace with actual logic)
103
- return f"Answer to '{question}' using method {method} for {question_type} question."
 
104
 
105
- # Function to simulate paper retrieval (replace with actual implementation)
106
- def get_papers():
107
- # Simulated paper data (replace with actual data retrieval)
 
 
 
 
 
 
108
  return pd.DataFrame({
109
- 'Title': ['Paper 1', 'Paper 2', 'Paper 3'],
110
- 'Relevance': [0.9, 0.7, 0.5]
 
111
  })
112
 
 
113
  # Function to create embedding plot (replace with actual implementation)
114
  def create_embedding_plot():
115
  # Simulated embedding data (replace with actual embedding calculation)
@@ -134,14 +572,24 @@ def estimate_consensus():
134
  # Simulated consensus estimation (replace with actual calculation)
135
  return 0.75
136
 
 
 
 
 
 
 
 
 
 
 
137
  # Streamlit app
138
  def main():
139
 
140
  # st.title("Question Answering App")
 
141
 
142
  # Sidebar (Inputs)
143
  st.sidebar.header("Inputs")
144
- question = st.sidebar.text_input("Enter your question:")
145
  extra_keywords = st.sidebar.text_input("Enter extra keywords (comma-separated):")
146
 
147
  st.sidebar.subheader("Toggles")
@@ -151,52 +599,65 @@ def main():
151
 
152
  method = st.sidebar.radio("Choose a method:", ["h1", "h2", "h3"])
153
  question_type = st.sidebar.selectbox("Select question type:", ["Type 1", "Type 2", "Type 3"])
154
- store_output = st.sidebar.checkbox("Store the output")
155
 
156
- submit_button = st.sidebar.button("Submit")
 
157
 
158
  # Main page (Outputs)
 
 
 
 
159
  if submit_button:
160
  # Process inputs
161
  keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else []
162
  toggles = {'A': toggle_a, 'B': toggle_b, 'C': toggle_c}
163
 
164
  # Generate outputs
165
- answer = answer_question(question, keywords, toggles, method, question_type)
166
- papers_df = get_papers()
167
  embedding_plot = create_embedding_plot()
168
  triggered_keywords = extract_keywords(question)
169
  consensus = estimate_consensus()
170
 
171
- # Display outputs
172
- st.header("Results")
 
 
 
 
 
 
173
 
174
  col1, col2 = st.columns(2)
175
 
176
  with col1:
177
- st.subheader("Answer")
178
- st.write(answer)
179
 
180
- st.subheader("Papers Used")
181
- st.dataframe(papers_df)
182
 
183
  st.subheader("Triggered Keywords")
184
  st.write(", ".join(triggered_keywords))
185
 
186
  with col2:
187
- st.subheader("Embedding Map")
188
- st.bokeh_chart(embedding_plot)
189
 
190
  st.subheader("Question Type")
191
  st.write(question_type)
192
 
193
  st.subheader("Consensus Estimate")
194
  st.write(f"{consensus:.2%}")
195
-
196
- if store_output:
197
- st.success("Output stored successfully!")
 
 
 
198
  else:
199
  st.info("Use the sidebar to input parameters and submit to see results.")
 
 
 
200
 
201
  if __name__ == "__main__":
202
  main()
 
11
  from datasets import load_dataset, load_from_disk
12
  from collections import Counter
13
 
14
+ import yaml, json, requests, sys, os, time
15
  import concurrent.futures
16
  ts = time.time()
17
 
18
+
19
+ anthropic_key = "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"
20
+ # anthropic_client = anthropic.Anthropic(api_key=anthropic_key)
21
+
22
+ openai_key = "sk-None-TMT98W6ksCIYY6w0UI66T3BlbkFJva1LamMQXbenkcnYqvs6"
23
+ # openai_client = EmbeddingClient(OpenAI(api_key=openai_key))
24
+
25
  from nltk.corpus import stopwords
26
  import nltk
27
  from openai import OpenAI
 
46
  from bokeh.models import ColumnDataSource
47
  from bokeh.palettes import Spectral10
48
 
 
 
49
  st.image('local_files/pathfinder_logo.png')
50
 
51
  st.expander("About", expanded=False).write(
 
80
 
81
 
82
 
83
+
84
  if 'arxiv_corpus' not in st.session_state:
85
  with st.spinner('loading data...'):
86
  try:
87
  arxiv_corpus = load_from_disk('data/')
88
+ arxiv_corpus.add_faiss_index('embed')
89
  except:
90
  st.write('downloading data')
91
  arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data',split='train')
92
+ arxiv_corpus.add_faiss_index('embed')
93
  arxiv_corpus.save_to_disk('data/')
94
  st.session_state.arxiv_corpus = arxiv_corpus
95
  st.toast('loaded arxiv corpus')
96
+ else:
97
+ arxiv_corpus = st.session_state.arxiv_corpus
98
 
99
  if 'ids' not in st.session_state:
100
  st.session_state.ids = arxiv_corpus['ads_id']
 
102
  st.session_state.abstracts = arxiv_corpus['abstract']
103
  st.session_state.cites = arxiv_corpus['cites']
104
  st.session_state.years = arxiv_corpus['date']
105
+ st.session_state.kws = arxiv_corpus['keywords']
106
+ st.toast('done caching. time taken: %.2f sec' %(time.time()-ts))
107
 
108
+
109
+ #----------------------------------------------------------------
110
+
111
+ class Filter():
112
+ def filter(self, query: str, arxiv_id: str) -> List[str]:
113
+ pass
114
+
115
+ class CitationFilter(Filter): # can do it with all metadata
116
+ def __init__(self, corpus):
117
+ self.corpus = corpus
118
+ ids = ids
119
+ cites = cites
120
+ self.citation_counts = {ids[i]: cites[i] for i in range(len(ids))}
121
+
122
+ def citation_weight(self, x, shift, scale):
123
+ return 1 / (1 + np.exp(-1 * (x - shift) / scale)) # sigmoid function
124
+
125
+ def filter(self, doc_scores, weight = 0.1): # additive weighting
126
+ citation_count = np.array([self.citation_counts[doc[0]] for doc in doc_scores])
127
+ cmean, cstd = np.median(citation_count), np.std(citation_count)
128
+ citation_score = self.citation_weight(citation_count, cmean, cstd)
129
+
130
+ for i, doc in enumerate(doc_scores):
131
+ doc_scores[i][2] += weight * citation_score[i]
132
+
133
+ class DateFilter(Filter): # include time weighting eventually
134
+ def __init__(self, document_dates):
135
+ self.document_dates = document_dates
136
+
137
+ def parse_date(self, arxiv_id: str) -> datetime: # only for documents
138
+ if arxiv_id.startswith('astro-ph'):
139
+ arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0]
140
+ try:
141
+ year = int("20" + arxiv_id[:2])
142
+ month = int(arxiv_id[2:4])
143
+ except:
144
+ year = 2023
145
+ month = 1
146
+ return date(year, month, 1)
147
+
148
+ def weight(self, time, shift, scale):
149
+ return 1 / (1 + np.exp((time - shift) / scale))
150
+
151
+ def evaluate_filter(self, year, filter_string):
152
+ try:
153
+ # Use ast.literal_eval to safely evaluate the expression
154
+ result = eval(filter_string, {"__builtins__": None}, {"year": year})
155
+ return result
156
+ except Exception as e:
157
+ print(f"Error evaluating filter: {e}")
158
+ return False
159
+
160
+ def filter(self, docs, boolean_date = None, min_date = None, max_date = None, time_score = 0):
161
+ filtered = []
162
+
163
+ if boolean_date is not None:
164
+ boolean_date = boolean_date.replace("AND", "and").replace("OR", "or")
165
+ for doc in docs:
166
+ if self.evaluate_filter(self.document_dates[doc[0]].year, boolean_date):
167
+ filtered.append(doc)
168
+
169
+ else:
170
+ if min_date == None: min_date = date(1990, 1, 1)
171
+ if max_date == None: max_date = date(2024, 7, 3)
172
+
173
+ for doc in docs:
174
+ if self.document_dates[doc[0]] >= min_date and self.document_dates[doc[0]] <= max_date:
175
+ filtered.append(doc)
176
+
177
+ if time_score is not None: # apply time weighting
178
+ for i, item in enumerate(filtered):
179
+ time_diff = (max_date - self.document_dates[filtered[i][0]]).days / 365
180
+ filtered[i][2] += time_score * 0.1 * self.weight(time_diff, 5, 5)
181
+
182
+ return filtered
183
+
184
+ class KeywordFilter(Filter):
185
+ def __init__(self, corpus,
186
+ remove_capitals: bool = True, metadata = None, ne_only = True, verbose = False):
187
+
188
+ self.index_path = 'keyword_index.json'
189
+ # self.metadata = metadata
190
+ self.remove_capitals = remove_capitals
191
+ self.ne_only = ne_only
192
+ self.stopwords = set(stopwords.words('english'))
193
+ self.verbose = verbose
194
+ self.index = None
195
+ self.kws = st.session_state.kws
196
+ self.ids = st.session_state.ids
197
+ self.titles = st.session_state.titles
198
+
199
+ self.load_or_build_index()
200
+
201
+ def preprocess_text(self, text: str) -> str:
202
+ text = ''.join(char for char in text if char.isalnum() or char.isspace())
203
+ if self.remove_capitals: text = text.lower()
204
+ return ' '.join(word for word in text.split() if word.lower() not in self.stopwords)
205
+
206
+ def build_index(self): # include the title in the index
207
+ print("Building index...")
208
+ self.index = {}
209
+
210
+ for i in range(len(self.kws)):
211
+ paper = self.ids[i]
212
+ title = self.titles[i]
213
+ title_keywords = set()
214
+ for keyword in set(self.kws[i]) | title_keywords:
215
+ term = ' '.join(word for word in keyword.lower().split() if word.lower() not in self.stopwords)
216
+ if term not in self.index:
217
+ self.index[term] = []
218
+ self.index[term].append(self.ids[i])
219
+
220
+ with open(self.index_path, 'w') as f:
221
+ json.dump(self.index, f)
222
+
223
+ def load_index(self):
224
+ print("Loading existing index...")
225
+ with open(self.index_path, 'rb') as f:
226
+ self.index = json.load(f)
227
+
228
+ print("Index loaded successfully.")
229
+
230
+ def load_or_build_index(self):
231
+ if os.path.exists(self.index_path):
232
+ self.load_index()
233
+ else:
234
+ self.build_index()
235
+
236
+ def parse_doc(self, doc):
237
+ local_kws = []
238
+
239
+ for phrase in doc._.phrases:
240
+ local_kws.append(phrase.text.lower())
241
+
242
+ return [self.preprocess_text(word) for word in local_kws]
243
+
244
+ def get_propn(self, doc):
245
+ result = []
246
+
247
+ working_str = ''
248
+ for token in doc:
249
+ if(token.text in nlp.Defaults.stop_words or token.text in punctuation):
250
+ if working_str != '':
251
+ result.append(working_str.strip())
252
+ working_str = ''
253
+
254
+ if(token.pos_ == "PROPN"):
255
+ working_str += token.text + ' '
256
+
257
+ if working_str != '': result.append(working_str.strip())
258
+
259
+ return [self.preprocess_text(word) for word in result]
260
+
261
+ def filter(self, query: str, doc_ids = None):
262
+ doc = nlp(query)
263
+ query_keywords = self.parse_doc(doc)
264
+ nouns = self.get_propn(doc)
265
+ if self.verbose: print('keywords:', query_keywords)
266
+ if self.verbose: print('proper nouns:', nouns)
267
+
268
+ filtered = set()
269
+ if len(query_keywords) > 0 and not self.ne_only:
270
+ for keyword in query_keywords:
271
+ if keyword != '' and keyword in self.index.keys(): filtered |= set(self.index[keyword])
272
+
273
+ if len(nouns) > 0:
274
+ ne_results = set()
275
+ for noun in nouns:
276
+ if noun in self.index.keys(): ne_results |= set(self.index[noun])
277
+
278
+ if self.ne_only: filtered = ne_results # keep only named entity results
279
+ else: filtered &= ne_results # take the intersection
280
+
281
+ if doc_ids is not None: filtered &= doc_ids # apply filter to results
282
+ return filtered
283
+
284
+ class EmbeddingRetrievalSystem():
285
+
286
+ def __init__(self, weight_citation = False, weight_date = False, weight_keywords = False):
287
+
288
+ self.ids = st.session_state.ids
289
+ self.years = st.session_state.years
290
+ self.abstract = st.session_state.abstracts
291
+ self.client = OpenAI(api_key = openai_key)
292
+ self.embed_model = "text-embedding-3-small"
293
+ self.dataset = arxiv_corpus
294
+ self.kws = st.session_state.kws
295
+
296
+ self.weight_citation = weight_citation
297
+ self.weight_date = weight_date
298
+ self.weight_keywords = weight_keywords
299
+ self.id_to_index = {self.ids[i]: i for i in range(len(self.ids))}
300
+
301
+ # self.citation_filter = CitationFilter(self.dataset)
302
+ # self.date_filter = DateFilter(self.dataset['date'])
303
+ self.keyword_filter = KeywordFilter(corpus=self.dataset, remove_capitals=True)
304
+
305
+ def parse_date(self, id):
306
+ # indexval = np.where(self.ids == id)[0][0]
307
+ indexval = id
308
+ return self.years[indexval]
309
+
310
+ def make_embedding(self, text):
311
+ str_embed = self.client.embeddings.create(input = [text], model = self.embed_model).data[0].embedding
312
+ return str_embed
313
+
314
+ def embed_batch(self, texts: List[str]) -> List[np.ndarray]:
315
+ embeddings = self.client.embeddings.create(input=texts, model=self.embed_model).data
316
+ return [np.array(embedding.embedding, dtype=np.float32) for embedding in embeddings]
317
+
318
+ def init_filters(self):
319
+
320
+ self.citation_filter = []
321
+ self.date_filter = []
322
+ self.keyword_filter = []
323
+
324
+ def get_query_embedding(self, query):
325
+ return self.make_embedding(query)
326
+
327
+ def analyze_temporal_query(self, query):
328
+ return
329
+
330
+ def calc_faiss(self, query_embedding, top_k = 100):
331
+ # xq = query_embedding.reshape(-1,1).T.astype('float32')
332
+ # D, I = self.index.search(xq, top_k)
333
+ # return I[0], D[0]
334
+ tmp = self.dataset.search('embed',query_embedding, k=top_k)
335
+ return [tmp.indices, tmp.scores]
336
+
337
+ def rank_and_filter(self, query, query_embedding, query_date, top_k = 10, return_scores=False, time_result=None):
338
+
339
+
340
+ topk_indices, similarities = self.calc_faiss(np.array(query_embedding), top_k = 300)
341
+
342
+ if self.weight_keywords:
343
+ keyword_matches = self.keyword_filter.filter(query)
344
+ kw_indices = np.zeros_like(similarities)
345
+ for s in keyword_matches:
346
+ if self.id_to_index[s] in topk_indices:
347
+ # print('yes', self.id_to_index[s], topk_indices[np.where(topk_indices == self.id_to_index[s])[0]])
348
+ similarities[np.where(topk_indices == self.id_to_index[s])[0]] = similarities[np.where(topk_indices == self.id_to_index[s])[0]] * 10.
349
+ similarities = similarities / 10.
350
+
351
+ filtered_results = [[topk_indices[i], similarities[i]] for i in range(len(similarities))]
352
+ top_results = sorted(filtered_results, key=lambda x: x[1], reverse=True)[:top_k]
353
+
354
+ if return_scores:
355
+ return {doc[0]: doc[1] for doc in top_results}
356
+
357
+ # Only keep the document IDs
358
+ top_results = [doc[0] for doc in top_results]
359
+ return top_results
360
+
361
+ def retrieve(self, query, top_k, time_result=None, query_date = None, return_scores = False):
362
+
363
+ query_embedding = self.get_query_embedding(query)
364
+
365
+ # Judge time relevance
366
+ if time_result is None:
367
+ if self.weight_date:
368
+ time_result, time_taken = self.analyze_temporal_query(query, self.anthropic_client)
369
+ else:
370
+ time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None}
371
+
372
+ top_results = self.rank_and_filter(query,
373
+ query_embedding,
374
+ query_date,
375
+ top_k,
376
+ return_scores = return_scores,
377
+ time_result = time_result)
378
+
379
+ return top_results
380
+
381
+ class HydeRetrievalSystem(EmbeddingRetrievalSystem):
382
+ def __init__(self, generation_model: str = "claude-3-haiku-20240307",
383
+ embedding_model: str = "text-embedding-3-small",
384
+ temperature: float = 0.5,
385
+ max_doclen: int = 500,
386
+ generate_n: int = 1,
387
+ embed_query = True,
388
+ conclusion = False, **kwargs):
389
+
390
+ # Handle the kwargs for the superclass init -- filters/citation weighting
391
+ super().__init__(**kwargs)
392
+
393
+ if max_doclen * generate_n > 8191:
394
+ raise ValueError("Too many tokens. Please reduce max_doclen or generate_n.")
395
+
396
+ self.embedding_model = embedding_model
397
+ self.generation_model = generation_model
398
+
399
+ # HYPERPARAMETERS
400
+ self.temperature = temperature # generation temperature
401
+ self.max_doclen = max_doclen # max tokens for generation
402
+ self.generate_n = generate_n # how many documents
403
+ self.embed_query = embed_query # embed the query vector?
404
+ self.conclusion = conclusion # generate conclusion as well?
405
+
406
+ self.anthropic_key = anthropic_key
407
+ self.generation_client = anthropic.Anthropic(api_key = self.anthropic_key)
408
+
409
+ def retrieve(self, query: str, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]:
410
+ if time_result is None:
411
+ if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client)
412
+ else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None}
413
+
414
+ docs = self.generate_docs(query)
415
+ doc_embeddings = self.embed_docs(docs)
416
+
417
+ if self.embed_query:
418
+ query_emb = self.embed_docs([query])[0]
419
+ doc_embeddings.append(query_emb)
420
+
421
+ embedding = np.mean(np.array(doc_embeddings), axis = 0)
422
+
423
+ top_results = self.rank_and_filter(query, embedding, query_date=None, top_k = top_k, return_scores = return_scores, time_result = time_result)
424
+
425
+ return top_results
426
+
427
+ def generate_doc(self, query: str):
428
+ prompt = """You are an expert astronomer. Given a scientific query, generate the abstract"""
429
+ if self.conclusion:
430
+ prompt += " and conclusion"
431
+ prompt += """ of an expert-level research paper
432
+ that answers the question. Stick to a maximum length of {} tokens and return just the text of the abstract and conclusion.
433
+ Do not include labels for any section. Use research-specific jargon.""".format(self.max_doclen)
434
+
435
+
436
+ message = self.generation_client.messages.create(
437
+ model = self.generation_model,
438
+ max_tokens = self.max_doclen,
439
+ temperature = self.temperature,
440
+ system = prompt,
441
+ messages=[{ "role": "user",
442
+ "content": [{"type": "text", "text": query,}] }]
443
+ )
444
+
445
+ return message.content[0].text
446
+
447
+ def generate_docs(self, query: str):
448
+ docs = []
449
+ with concurrent.futures.ThreadPoolExecutor() as executor:
450
+ future_to_query = {executor.submit(self.generate_doc, query): query for i in range(self.generate_n)}
451
+ for future in concurrent.futures.as_completed(future_to_query):
452
+ query = future_to_query[future]
453
+ try:
454
+ data = future.result()
455
+ docs.append(data)
456
+ except Exception as exc:
457
+ pass
458
+ return docs
459
+
460
+ def embed_docs(self, docs: List[str]):
461
+ return self.embed_batch(docs)
462
+
463
+ class HydeCohereRetrievalSystem(HydeRetrievalSystem):
464
+ def __init__(self, **kwargs):
465
+ super().__init__(**kwargs)
466
+
467
+ self.cohere_key = "Of1MjzFjGmvzBAqdvNHTQLkAjecPcOKpiIPAnFMn"
468
+ self.cohere_client = cohere.Client(self.cohere_key)
469
+
470
+ def retrieve(self, query: str,
471
+ top_k: int = 10,
472
+ rerank_top_k: int = 250,
473
+ return_scores = False, time_result = None,
474
+ reweight = False) -> List[Tuple[str, str, float]]:
475
+
476
+ if time_result is None:
477
+ if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client)
478
+ else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None}
479
+
480
+ top_results = super().retrieve(query, top_k = rerank_top_k, time_result = time_result)
481
+
482
+ # doc_texts = self.get_document_texts(top_results)
483
+ # docs_for_rerank = [f"Abstract: {doc['abstract']}\nConclusions: {doc['conclusions']}" for doc in doc_texts]
484
+ docs_for_rerank = [self.abstract[i] for i in top_results]
485
+
486
+ if len(docs_for_rerank) == 0:
487
+ return []
488
+
489
+ reranked_results = self.cohere_client.rerank(
490
+ query=query,
491
+ documents=docs_for_rerank,
492
+ model='rerank-english-v3.0',
493
+ top_n=top_k
494
+ )
495
+
496
+ final_results = []
497
+ for result in reranked_results.results:
498
+ doc_id = top_results[result.index]
499
+ doc_text = docs_for_rerank[result.index]
500
+ score = float(result.relevance_score)
501
+ final_results.append([doc_id, "", score])
502
+
503
+ if reweight:
504
+ if time_result['has_temporal_aspect']:
505
+ final_results = self.date_filter.filter(final_results, time_score = time_result['expected_recency_weight'])
506
+
507
+ if self.weight_citation: self.citation_filter.filter(final_results)
508
 
509
+ if return_scores:
510
+ return {result[0]: result[2] for result in final_results}
511
+
512
+ return [doc[0] for doc in final_results]
513
+
514
+ def embed_docs(self, docs: List[str]):
515
+ return self.embed_batch(docs)
516
+
517
+ # ----------------------------------------------------------------
518
+
519
+
520
+ if 'ec' not in st.session_state:
521
+ ec = EmbeddingRetrievalSystem(weight_keywords=True)
522
+ st.session_state.ec = ec
523
+ st.toast('loaded retrieval system')
524
  else:
525
+ ec = st.session_state.ec
526
+
527
+
528
+
529
  # Function to simulate question answering (replace with actual implementation)
530
  def answer_question(question, keywords, toggles, method, question_type):
531
  # Simulated answer (replace with actual logic)
532
+ # return f"Answer to '{question}' using method {method} for {question_type} question."
533
+ return run_ret(question, 10)
534
 
535
+
536
+ def get_papers(ids):
537
+
538
+ papers, scores, links = [], [], []
539
+ for i in ids:
540
+ papers.append(st.session_state.titles[i])
541
+ scores.append(ids[i])
542
+ links.append('https://ui.adsabs.harvard.edu/abs/'+st.session_state.arxiv_corpus['bibcode'][i]+'/abstract')
543
+
544
  return pd.DataFrame({
545
+ 'Title': papers,
546
+ 'Relevance': scores,
547
+ 'Link': links
548
  })
549
 
550
+
551
  # Function to create embedding plot (replace with actual implementation)
552
  def create_embedding_plot():
553
  # Simulated embedding data (replace with actual embedding calculation)
 
572
  # Simulated consensus estimation (replace with actual calculation)
573
  return 0.75
574
 
575
+ def run_ret(query, top_k):
576
+ rs = ec.retrieve(query, top_k, return_scores=True)
577
+ output_str = ''
578
+ for i in rs:
579
+ if rs[i] > 0.5:
580
+ output_str = output_str + '---> ' + st.session_state.titles[i] + '(score: %.2f) \n' %rs[i]
581
+ else:
582
+ output_str = output_str + '---> ' + st.session_state.titles[i] + '(score: %.2f) \n' %rs[i]
583
+ return output_str, rs
584
+
585
  # Streamlit app
586
  def main():
587
 
588
  # st.title("Question Answering App")
589
+
590
 
591
  # Sidebar (Inputs)
592
  st.sidebar.header("Inputs")
 
593
  extra_keywords = st.sidebar.text_input("Enter extra keywords (comma-separated):")
594
 
595
  st.sidebar.subheader("Toggles")
 
599
 
600
  method = st.sidebar.radio("Choose a method:", ["h1", "h2", "h3"])
601
  question_type = st.sidebar.selectbox("Select question type:", ["Type 1", "Type 2", "Type 3"])
602
+ # store_output = st.sidebar.checkbox("Store the output")
603
 
604
+
605
+ store_output = st.sidebar.button("Save output")
606
 
607
  # Main page (Outputs)
608
+
609
+ question = st.text_input("Ask me anything:")
610
+ submit_button = st.button("Submit")
611
+
612
  if submit_button:
613
  # Process inputs
614
  keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else []
615
  toggles = {'A': toggle_a, 'B': toggle_b, 'C': toggle_c}
616
 
617
  # Generate outputs
618
+ answer, rs = answer_question(question, keywords, toggles, method, question_type)
619
+ papers_df = get_papers(rs)
620
  embedding_plot = create_embedding_plot()
621
  triggered_keywords = extract_keywords(question)
622
  consensus = estimate_consensus()
623
 
624
+ # Display outputs
625
+
626
+ st.subheader("Answer")
627
+ st.write(answer)
628
+
629
+ with st.expander("Papers used", expanded=True):
630
+ st.dataframe(papers_df)
631
+
632
 
633
  col1, col2 = st.columns(2)
634
 
635
  with col1:
 
 
636
 
637
+ st.subheader("Embedding Map")
638
+ st.bokeh_chart(embedding_plot)
639
 
640
  st.subheader("Triggered Keywords")
641
  st.write(", ".join(triggered_keywords))
642
 
643
  with col2:
 
 
644
 
645
  st.subheader("Question Type")
646
  st.write(question_type)
647
 
648
  st.subheader("Consensus Estimate")
649
  st.write(f"{consensus:.2%}")
650
+
651
+ # st.subheader("Papers Used")
652
+ # st.dataframe(papers_df)
653
+
654
+
655
+
656
  else:
657
  st.info("Use the sidebar to input parameters and submit to see results.")
658
+
659
+ if store_output:
660
+ st.toast("Output stored successfully!")
661
 
662
  if __name__ == "__main__":
663
  main()
arxiv_corpus/dataset_dict.json CHANGED
@@ -1 +1,3 @@
1
- {"splits": ["train"]}
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c172eebfc28c1400d6be4338ce7d00191507ffb4ae64c315f039585c894df5b7
3
+ size 21
arxiv_corpus/train/dataset_info.json CHANGED
@@ -1,204 +1,3 @@
1
- {
2
- "builder_name": "parquet",
3
- "citation": "",
4
- "config_name": "default",
5
- "dataset_name": "astro_paper_corpus",
6
- "dataset_size": 4128813829,
7
- "description": "",
8
- "download_checksums": {
9
- "hf://datasets/JSALT2024-Astro-LLMs/astro_paper_corpus@b957a28700badb3b5f5c7af06ea77a2560ab6e46/data/train-00000-of-00009.parquet": {
10
- "num_bytes": 240072323,
11
- "checksum": null
12
- },
13
- "hf://datasets/JSALT2024-Astro-LLMs/astro_paper_corpus@b957a28700badb3b5f5c7af06ea77a2560ab6e46/data/train-00001-of-00009.parquet": {
14
- "num_bytes": 235851056,
15
- "checksum": null
16
- },
17
- "hf://datasets/JSALT2024-Astro-LLMs/astro_paper_corpus@b957a28700badb3b5f5c7af06ea77a2560ab6e46/data/train-00002-of-00009.parquet": {
18
- "num_bytes": 236413937,
19
- "checksum": null
20
- },
21
- "hf://datasets/JSALT2024-Astro-LLMs/astro_paper_corpus@b957a28700badb3b5f5c7af06ea77a2560ab6e46/data/train-00003-of-00009.parquet": {
22
- "num_bytes": 237728419,
23
- "checksum": null
24
- },
25
- "hf://datasets/JSALT2024-Astro-LLMs/astro_paper_corpus@b957a28700badb3b5f5c7af06ea77a2560ab6e46/data/train-00004-of-00009.parquet": {
26
- "num_bytes": 236710419,
27
- "checksum": null
28
- },
29
- "hf://datasets/JSALT2024-Astro-LLMs/astro_paper_corpus@b957a28700badb3b5f5c7af06ea77a2560ab6e46/data/train-00005-of-00009.parquet": {
30
- "num_bytes": 239567004,
31
- "checksum": null
32
- },
33
- "hf://datasets/JSALT2024-Astro-LLMs/astro_paper_corpus@b957a28700badb3b5f5c7af06ea77a2560ab6e46/data/train-00006-of-00009.parquet": {
34
- "num_bytes": 234863979,
35
- "checksum": null
36
- },
37
- "hf://datasets/JSALT2024-Astro-LLMs/astro_paper_corpus@b957a28700badb3b5f5c7af06ea77a2560ab6e46/data/train-00007-of-00009.parquet": {
38
- "num_bytes": 232662046,
39
- "checksum": null
40
- },
41
- "hf://datasets/JSALT2024-Astro-LLMs/astro_paper_corpus@b957a28700badb3b5f5c7af06ea77a2560ab6e46/data/train-00008-of-00009.parquet": {
42
- "num_bytes": 237444927,
43
- "checksum": null
44
- }
45
- },
46
- "download_size": 2131314110,
47
- "features": {
48
- "id": {
49
- "dtype": "string",
50
- "_type": "Value"
51
- },
52
- "author": {
53
- "feature": {
54
- "dtype": "string",
55
- "_type": "Value"
56
- },
57
- "_type": "Sequence"
58
- },
59
- "bibcode": {
60
- "dtype": "string",
61
- "_type": "Value"
62
- },
63
- "title": {
64
- "feature": {
65
- "dtype": "string",
66
- "_type": "Value"
67
- },
68
- "_type": "Sequence"
69
- },
70
- "citation_count": {
71
- "dtype": "int64",
72
- "_type": "Value"
73
- },
74
- "aff": {
75
- "feature": {
76
- "dtype": "string",
77
- "_type": "Value"
78
- },
79
- "_type": "Sequence"
80
- },
81
- "citation": {
82
- "feature": {
83
- "dtype": "string",
84
- "_type": "Value"
85
- },
86
- "_type": "Sequence"
87
- },
88
- "database": {
89
- "feature": {
90
- "dtype": "string",
91
- "_type": "Value"
92
- },
93
- "_type": "Sequence"
94
- },
95
- "read_count": {
96
- "dtype": "int64",
97
- "_type": "Value"
98
- },
99
- "keyword": {
100
- "feature": {
101
- "dtype": "string",
102
- "_type": "Value"
103
- },
104
- "_type": "Sequence"
105
- },
106
- "reference": {
107
- "feature": {
108
- "dtype": "string",
109
- "_type": "Value"
110
- },
111
- "_type": "Sequence"
112
- },
113
- "doi": {
114
- "feature": {
115
- "dtype": "string",
116
- "_type": "Value"
117
- },
118
- "_type": "Sequence"
119
- },
120
- "subfolder": {
121
- "dtype": "string",
122
- "_type": "Value"
123
- },
124
- "filename": {
125
- "dtype": "string",
126
- "_type": "Value"
127
- },
128
- "introduction": {
129
- "dtype": "string",
130
- "_type": "Value"
131
- },
132
- "conclusions": {
133
- "dtype": "string",
134
- "_type": "Value"
135
- },
136
- "year": {
137
- "dtype": "int64",
138
- "_type": "Value"
139
- },
140
- "month": {
141
- "dtype": "int64",
142
- "_type": "Value"
143
- },
144
- "arxiv_id": {
145
- "dtype": "string",
146
- "_type": "Value"
147
- },
148
- "abstract": {
149
- "dtype": "string",
150
- "_type": "Value"
151
- },
152
- "failed_ids": {
153
- "dtype": "bool",
154
- "_type": "Value"
155
- },
156
- "keyword_search": {
157
- "feature": {
158
- "dtype": "string",
159
- "_type": "Value"
160
- },
161
- "_type": "Sequence"
162
- },
163
- "umap_x": {
164
- "dtype": "float32",
165
- "_type": "Value"
166
- },
167
- "umap_y": {
168
- "dtype": "float32",
169
- "_type": "Value"
170
- },
171
- "clust_id": {
172
- "dtype": "int64",
173
- "_type": "Value"
174
- }
175
- },
176
- "homepage": "",
177
- "license": "",
178
- "size_in_bytes": 6260127939,
179
- "splits": {
180
- "train": {
181
- "name": "train",
182
- "num_bytes": 4128813829,
183
- "num_examples": 271544,
184
- "shard_lengths": [
185
- 33172,
186
- 33172,
187
- 33172,
188
- 33172,
189
- 33172,
190
- 33171,
191
- 34171,
192
- 34171,
193
- 4171
194
- ],
195
- "dataset_name": "astro_paper_corpus"
196
- }
197
- },
198
- "version": {
199
- "version_str": "0.0.0",
200
- "major": 0,
201
- "minor": 0,
202
- "patch": 0
203
- }
204
- }
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cfcc45f82340c62a77d97d2d1ba131e629a71a885273309a668a281e59745e90
3
+ size 4859
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
arxiv_corpus/train/state.json CHANGED
@@ -1,37 +1,3 @@
1
- {
2
- "_data_files": [
3
- {
4
- "filename": "data-00000-of-00009.arrow"
5
- },
6
- {
7
- "filename": "data-00001-of-00009.arrow"
8
- },
9
- {
10
- "filename": "data-00002-of-00009.arrow"
11
- },
12
- {
13
- "filename": "data-00003-of-00009.arrow"
14
- },
15
- {
16
- "filename": "data-00004-of-00009.arrow"
17
- },
18
- {
19
- "filename": "data-00005-of-00009.arrow"
20
- },
21
- {
22
- "filename": "data-00006-of-00009.arrow"
23
- },
24
- {
25
- "filename": "data-00007-of-00009.arrow"
26
- },
27
- {
28
- "filename": "data-00008-of-00009.arrow"
29
- }
30
- ],
31
- "_fingerprint": "b9db3ec46232aa87",
32
- "_format_columns": null,
33
- "_format_kwargs": {},
34
- "_format_type": null,
35
- "_output_all_columns": false,
36
- "_split": "train"
37
- }
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2f976bdcd9df0c87937fede4a771287e62e3dc62f1dec9ee12f066e3540043d
3
+ size 722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
keyword_index.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7dce94a21caa4aafe87e4996ae2ffa24250b82884e8ad9cf4fa2b5f50e7329e1
3
+ size 140727900
local_files/pathfinder_logo.png ADDED
requirements.txt CHANGED
@@ -14,4 +14,5 @@ feedparser
14
  tiktoken
15
  chromadb
16
  streamlit-extras
17
- nltk
 
 
14
  tiktoken
15
  chromadb
16
  streamlit-extras
17
+ nltk
18
+ hickle