kiyer commited on
Commit
6931cbb
·
verified ·
1 Parent(s): 5f2a5af

basic files and codebase

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ .ipynb_checkpoints/
.streamlit/config.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [theme]
2
+ backgroundColor="#C4C4C4"
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from fns import *
3
+
4
+ st.set_page_config(
5
+ page_title="Synthesist",
6
+ page_icon="👋",
7
+ )
8
+
9
+ # st.write("# Welcome to Pathfinder! 👋")
10
+ st.image('local_files/synth_logo.png')
11
+
12
+ st.sidebar.success("Select a function above.")
13
+ st.sidebar.markdown("Current functions include visualizing papers in the arxiv embedding, searching for similar papers to an input paper or prompt phrase, or answering quick questions.")
14
+
15
+
16
+ st.markdown("")
17
+ st.markdown(
18
+ """
19
+ **Synthesist** (from Peter Watt's [Blindsight](https://scalar.usc.edu/works/network-ecologies/on-peter-watts-blindsight)) is a framework for searching and visualizing papers on the [arXiv](https://arxiv.org/) using the context
20
+ sensitivity from modern large language models (LLMs) to better parse patterns in paper contexts.
21
+
22
+ This tool was built during the [JSALT workshop](https://www.clsp.jhu.edu/2024-jelinek-summer-workshop-on-speech-and-language-technology/) to do awesome things.
23
+
24
+ **👈 Select a tool from the sidebar** to see some examples
25
+ of what this framework can do!
26
+
27
+ ### Tool summary:
28
+ - Please wait while the initial data loads and compiles, this takes about a minute initially.
29
+ - `Paper search` looks for relevant papers given an arxiv id or a question.
30
+
31
+ This is not meant to be a replacement to existing tools like the
32
+ [ADS](https://ui.adsabs.harvard.edu/),
33
+ [arxivsorter](https://www.arxivsorter.org/), semantic search or google scholar, but rather a supplement to find papers
34
+ that otherwise might be missed during a literature survey.
35
+ It is trained on astro-ph (astrophysics of galaxies) papers up to last-year-ish mined from arxiv and supplemented with ADS metadata,
36
+ if you are interested in extending it please reach out!
37
+
38
+
39
+ Also add: more pages, actual generation, diff. toggles for retrieval/gen, feedback form, socials, literature, contact us, copyright, collaboration, etc.
40
+
41
+ The image below shows a representation of all the astro-ph.GA papers that can be explored in more detail
42
+ using the `Arxiv embedding` page. The papers tend to cluster together by similarity, and result in an
43
+ atlas that shows well studied (forests) and currently uncharted areas (water).
44
+ """
45
+ )
46
+
47
+
48
+ s = time.time()
49
+ st.markdown(f'Loading data for retrieval system, please wait before jumping to one of the pages....')
50
+ st.session_state.retrieval_system = EmbeddingRetrievalSystem()
51
+ st.session_state.dataset = load_dataset('arxiv_corpus/', split = "train")
52
+ st.markdown(f'Loaded retrieval system, time taken: %.1f sec' %(time.time()-s))
fns.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ s2 = time.time()
3
+ import numpy as np
4
+ import streamlit as st
5
+
6
+ import json
7
+ from abc import ABC, abstractmethod
8
+ from typing import List, Dict, Any, Tuple
9
+ from collections import defaultdict
10
+ # import wandb
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from datetime import datetime, date
14
+ import pickle
15
+ from datasets import load_dataset
16
+ import os
17
+ from nltk.corpus import stopwords
18
+ import nltk
19
+ from openai import OpenAI
20
+ import anthropic
21
+ import time
22
+ from collections import Counter
23
+
24
+ try:
25
+ stopwords.words('english')
26
+ except:
27
+ nltk.download('stopwords')
28
+ stopwords.words('english')
29
+
30
+
31
+ openai_key = st.secrets['openai_key']
32
+ anthropic_key = st.secrets['anthropic_key']
33
+ # anthropic_key = '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'
34
+
35
+ @st.cache_data
36
+ def load_astro_meta():
37
+ print('load astro meta')
38
+ return load_dataset('arxiv_corpus/', split = "train")
39
+
40
+ @st.cache_data
41
+ def load_index_mapping(index_mapping_path):
42
+ print("Loading index mapping...")
43
+ with open(index_mapping_path, 'rb') as f:
44
+ temp = pickle.load(f)
45
+ return temp
46
+
47
+ @st.cache_data
48
+ def load_embeddings(embeddings_path):
49
+ print("Loading embedding")
50
+ return np.load(embeddings_path)
51
+
52
+ @st.cache_data
53
+ def load_metadata(meta_path):
54
+ print("Loading metadata...")
55
+ with open(meta_path, 'r') as f:
56
+ metadata = json.load(f)
57
+ return metadata
58
+
59
+ # @st.cache_data
60
+ def load_umapcoords(umap_path):
61
+ print('loading umap coords')
62
+ with open(umap_path, "rb") as fp: #Pickling
63
+ umap = pickle.load(fp)
64
+ return umap
65
+
66
+
67
+ class EmbeddingClient:
68
+ def __init__(self, client: OpenAI, model: str = "text-embedding-3-small"):
69
+ self.client = client
70
+ self.model = model
71
+
72
+ def embed(self, text: str) -> np.ndarray:
73
+ embedding = self.client.embeddings.create(input=[text], model=self.model).data[0].embedding
74
+ return np.array(embedding, dtype=np.float32)
75
+
76
+ def embed_batch(self, texts: List[str]) -> List[np.ndarray]:
77
+ embeddings = self.client.embeddings.create(input=texts, model=self.model).data
78
+ return [np.array(embedding.embedding, dtype=np.float32) for embedding in embeddings]
79
+
80
+ class RetrievalSystem(ABC):
81
+ @abstractmethod
82
+ def retrieve(self, query: str, arxiv_id: str, top_k: int = 100) -> List[str]:
83
+ pass
84
+
85
+ def parse_date(self, arxiv_id: str) -> datetime:
86
+ if arxiv_id is None:
87
+ return date.today()
88
+
89
+ if arxiv_id.startswith('astro-ph'):
90
+ arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0]
91
+ try:
92
+ year = int("20" + arxiv_id[:2])
93
+ month = int(arxiv_id[2:4])
94
+ except:
95
+ year = 2023
96
+ month = 1
97
+ return date(year, month, 1)
98
+
99
+ class EmbeddingRetrievalSystem(RetrievalSystem):
100
+ def __init__(self, embeddings_path: str = "local_files/embeddings_matrix.npy",
101
+ documents_path: str = "local_files/documents.pkl",
102
+ index_mapping_path: str = "local_files/index_mapping.pkl",
103
+ metadata_path: str = "local_files/metadata.json",
104
+ weight_citation = False, weight_date = False, weight_keywords = False):
105
+
106
+ self.embeddings_path = embeddings_path
107
+ self.documents_path = documents_path
108
+ self.index_mapping_path = index_mapping_path
109
+ self.metadata_path = metadata_path
110
+ self.weight_citation = weight_citation
111
+ self.weight_date = weight_date
112
+ self.weight_keywords = weight_keywords
113
+
114
+ self.embeddings = None
115
+ self.documents = None
116
+ self.index_mapping = None
117
+ self.metadata = None
118
+ self.document_dates = []
119
+
120
+ self.load_data()
121
+ self.init_filters()
122
+
123
+ # config = yaml.safe_load(open('../config.yaml', 'r'))
124
+ self.client = EmbeddingClient(OpenAI(api_key=openai_key))
125
+ self.anthropic_client = anthropic.Anthropic(api_key=anthropic_key)
126
+
127
+ def generate_metadata(self):
128
+ astro_meta = load_astro_meta()
129
+ # dataset = load_dataset('arxiv_corpus/')
130
+ keys = list(astro_meta[0].keys())
131
+ keys.remove('abstract')
132
+ keys.remove('introduction')
133
+ keys.remove('conclusions')
134
+
135
+ self.metadata = {}
136
+ for paper in astro_meta:
137
+ id_str = paper['arxiv_id']
138
+ self.metadata[id_str] = {key: paper[key] for key in keys}
139
+
140
+ with open(self.metadata_path, 'w') as f:
141
+ json.dump(self.metadata, f)
142
+ st.markdown("Wrote metadata to {}".format(self.metadata_path))
143
+ #
144
+
145
+ def load_data(self):
146
+ # print("Loading embeddings...")
147
+
148
+ # self.embeddings = np.load(self.embeddings_path)
149
+ self.embeddings = load_embeddings(self.embeddings_path)
150
+ st.sidebar.success("Loaded embeddings")
151
+
152
+ # with open(self.index_mapping_path, 'rb') as f:
153
+ # self.index_mapping = pickle.load(f)
154
+ self.index_mapping = load_index_mapping(self.index_mapping_path)
155
+ st.sidebar.success("Loaded index mapping")
156
+
157
+ # print("Loading documents...")
158
+ # with open(self.documents_path, 'rb') as f:
159
+ # self.documents = pickle.load(f)
160
+ dataset = load_astro_meta()
161
+ st.sidebar.success("Loaded documents")
162
+
163
+
164
+ print("Processing document dates...")
165
+ # self.document_dates = {doc.id: self.parse_date(doc.arxiv_id) for doc in self.documents}
166
+ aids = dataset['arxiv_id']
167
+ adsids = dataset['id']
168
+ self.document_dates = {adsids[i]: self.parse_date(aids[i]) for i in range(len(aids))}
169
+
170
+ if os.path.exists(self.metadata_path):
171
+ self.metadata = load_metadata(self.metadata_path)
172
+ print("Loaded metadata.")
173
+ else:
174
+ print("Could not find path; generating metadata.")
175
+ self.generate_metadata()
176
+
177
+ print("Data loaded successfully.")
178
+
179
+ def init_filters(self):
180
+ print("Loading filters...")
181
+ self.citation_filter = CitationFilter(metadata = self.metadata)
182
+
183
+ self.date_filter = DateFilter(document_dates = self.document_dates)
184
+
185
+ self.keyword_filter = KeywordFilter(index_path = "local_files/keyword_index.json", metadata = self.metadata, remove_capitals = True)
186
+
187
+ def retrieve(self, query: str, arxiv_id: str = None, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]:
188
+ query_date = self.parse_date(arxiv_id)
189
+ query_embedding = self.get_query_embedding(query)
190
+
191
+ # Judge time relevance
192
+ if time_result is None:
193
+ if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client)
194
+ else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None}
195
+
196
+ top_results = self.rank_and_filter(query, query_embedding, query_date, top_k, return_scores = return_scores, time_result = time_result)
197
+
198
+ return top_results
199
+
200
+ def rank_and_filter(self, query, query_embedding: np.ndarray, query_date, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]:
201
+ # Calculate similarities
202
+ similarities = np.dot(self.embeddings, query_embedding)
203
+
204
+ # Filter and rank results
205
+ if self.weight_keywords: keyword_matches = self.keyword_filter.filter(query)
206
+
207
+ results = []
208
+ for doc_id, mappings in self.index_mapping.items():
209
+ if not self.weight_keywords or doc_id in keyword_matches:
210
+ abstract_sim = similarities[mappings['abstract']] if 'abstract' in mappings else -np.inf
211
+ conclusions_sim = similarities[mappings['conclusions']] if 'conclusions' in mappings else -np.inf
212
+
213
+ if abstract_sim > conclusions_sim:
214
+ results.append([doc_id, "abstract", abstract_sim])
215
+ else:
216
+ results.append([doc_id, "conclusions", conclusions_sim])
217
+
218
+
219
+ # Sort and weight and get top-k results
220
+ if time_result['has_temporal_aspect']:
221
+ filtered_results = self.date_filter.filter(results, boolean_date = time_result['expected_year_filter'], time_score = time_result['expected_recency_weight'], max_date = query_date)
222
+ else:
223
+ filtered_results = self.date_filter.filter(results, max_date = query_date)
224
+
225
+ if self.weight_citation: self.citation_filter.filter(filtered_results)
226
+
227
+ top_results = sorted(filtered_results, key=lambda x: x[2], reverse=True)[:top_k]
228
+
229
+ if return_scores:
230
+ return {doc[0]: doc[2] for doc in top_results}
231
+
232
+ # Only keep the document IDs
233
+ top_results = [doc[0] for doc in top_results]
234
+ return top_results
235
+
236
+ def get_query_embedding(self, query: str) -> np.ndarray:
237
+ embedding = self.client.embed(query)
238
+ return np.array(embedding, dtype = np.float32)
239
+
240
+ def get_document_texts(self, doc_ids: List[str]) -> List[Dict[str, str]]:
241
+ results = []
242
+ for doc_id in doc_ids:
243
+ doc = next((d for d in self.documents if d.id == doc_id), None)
244
+ if doc:
245
+ results.append({
246
+ 'id': doc.id,
247
+ 'abstract': doc.abstract,
248
+ 'conclusions': doc.conclusions
249
+ })
250
+ else:
251
+ print(f"Warning: Document with ID {doc_id} not found.")
252
+ return results
253
+
254
+ def retrieve_context(self, query, top_k, sections = ["abstract", "conclusions"], **kwargs):
255
+ docs = self.retrieve(query, top_k = top_k, return_scores = True, **kwargs)
256
+ docids = docs.keys()
257
+ doctexts = self.get_document_texts(docids) # avoid having to do this repetitively?
258
+ context_str = ""
259
+ doclist = []
260
+
261
+ for docid, doctext in zip(docids, doctexts):
262
+ for section in sections:
263
+ context_str += f"{docid}: {doctext[section]}\n"
264
+
265
+ meta_row = self.metadata[docid]
266
+ doclist.append(Document(docid, doctext['abstract'], doctext['conclusions'], docid, title = meta_row['title'],
267
+ score = docs[docid], n_citation = meta_row['citation_count'], keywords = meta_row['keyword_search']))
268
+
269
+ return context_str, doclist
270
+
271
+
272
+ class Filter():
273
+ def filter(self, query: str, arxiv_id: str) -> List[str]:
274
+ pass
275
+
276
+ class CitationFilter(Filter): # can do it with all metadata
277
+ def __init__(self, metadata):
278
+ self.metadata = metadata
279
+ self.citation_counts = {doc_id: self.metadata[doc_id]['citation_count'] for doc_id in self.metadata}
280
+
281
+ def citation_weight(self, x, shift, scale):
282
+ return 1 / (1 + np.exp(-1 * (x - shift) / scale)) # sigmoid function
283
+
284
+ def filter(self, doc_scores, weight = 0.1): # additive weighting
285
+ citation_count = np.array([self.citation_counts[doc[0]] for doc in doc_scores])
286
+ cmean, cstd = np.median(citation_count), np.std(citation_count)
287
+ citation_score = self.citation_weight(citation_count, cmean, cstd)
288
+
289
+ for i, doc in enumerate(doc_scores):
290
+ doc_scores[i][2] += weight * citation_score[i]
291
+
292
+ class DateFilter(Filter): # include time weighting eventually
293
+ def __init__(self, document_dates):
294
+ self.document_dates = document_dates
295
+
296
+ def parse_date(self, arxiv_id: str) -> datetime: # only for documents
297
+ if arxiv_id.startswith('astro-ph'):
298
+ arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0]
299
+ try:
300
+ year = int("20" + arxiv_id[:2])
301
+ month = int(arxiv_id[2:4])
302
+ except:
303
+ year = 2023
304
+ month = 1
305
+ return date(year, month, 1)
306
+
307
+ def weight(self, time, shift, scale):
308
+ return 1 / (1 + np.exp((time - shift) / scale))
309
+
310
+ def evaluate_filter(self, year, filter_string):
311
+ try:
312
+ # Use ast.literal_eval to safely evaluate the expression
313
+ result = eval(filter_string, {"__builtins__": None}, {"year": year})
314
+ return result
315
+ except Exception as e:
316
+ print(f"Error evaluating filter: {e}")
317
+ return False
318
+
319
+ def filter(self, docs, boolean_date = None, min_date = None, max_date = None, time_score = 0):
320
+ filtered = []
321
+
322
+ if boolean_date is not None:
323
+ boolean_date = boolean_date.replace("AND", "and").replace("OR", "or")
324
+ for doc in docs:
325
+ if self.evaluate_filter(self.document_dates[doc[0]].year, boolean_date):
326
+ filtered.append(doc)
327
+
328
+ else:
329
+ if min_date == None: min_date = date(1990, 1, 1)
330
+ if max_date == None: max_date = date(2024, 7, 3)
331
+
332
+ for doc in docs:
333
+ if self.document_dates[doc[0]] >= min_date and self.document_dates[doc[0]] <= max_date:
334
+ filtered.append(doc)
335
+
336
+ if time_score is not None: # apply time weighting
337
+ for i, item in enumerate(filtered):
338
+ time_diff = (max_date - self.document_dates[filtered[i][0]]).days / 365
339
+ filtered[i][2] += time_score * 0.1 * self.weight(time_diff, 5, 5)
340
+
341
+ return filtered
342
+
343
+ class KeywordFilter(Filter):
344
+ def __init__(self, index_path: str = "local_files/keyword_index.json",
345
+ remove_capitals: bool = True, metadata = None, ne_only = True, verbose = False):
346
+
347
+ self.index_path = index_path
348
+ self.metadata = metadata
349
+ self.remove_capitals = remove_capitals
350
+ self.ne_only = ne_only
351
+ self.stopwords = set(stopwords.words('english'))
352
+ self.verbose = verbose
353
+ self.index = None
354
+
355
+ self.load_or_build_index()
356
+
357
+ def preprocess_text(self, text: str) -> str:
358
+ text = ''.join(char for char in text if char.isalnum() or char.isspace())
359
+ if self.remove_capitals: text = text.lower()
360
+ return ' '.join(word for word in text.split() if word.lower() not in self.stopwords)
361
+
362
+ def build_index(self): # include the title in the index
363
+ print("Building index...")
364
+ self.index = {}
365
+
366
+ for i, index in tqdm(enumerate(self.metadata)):
367
+ paper = self.metadata[index]
368
+ title = paper['title'][0]
369
+ title_keywords = set() #set(self.parse_doc(title) + self.get_propn(title))
370
+ for keyword in set(paper['keyword_search']) | title_keywords:
371
+ term = ' '.join(word for word in keyword.lower().split() if word.lower() not in self.stopwords)
372
+ if term not in self.index:
373
+ self.index[term] = []
374
+
375
+ self.index[term].append(paper['arxiv_id'])
376
+
377
+ with open(self.index_path, 'w') as f:
378
+ json.dump(self.index, f)
379
+
380
+ def load_index(self):
381
+ print("Loading existing index...")
382
+ with open(self.index_path, 'rb') as f:
383
+ self.index = json.load(f)
384
+
385
+ print("Index loaded successfully.")
386
+
387
+ def load_or_build_index(self):
388
+ if os.path.exists(self.index_path):
389
+ self.load_index()
390
+ else:
391
+ self.build_index()
392
+
393
+ def parse_doc(self, doc):
394
+ local_kws = []
395
+
396
+ for phrase in doc._.phrases:
397
+ local_kws.append(phrase.text.lower())
398
+
399
+ return [self.preprocess_text(word) for word in local_kws]
400
+
401
+ def get_propn(self, doc):
402
+ result = []
403
+
404
+ working_str = ''
405
+ for token in doc:
406
+ if(token.text in nlp.Defaults.stop_words or token.text in punctuation):
407
+ if working_str != '':
408
+ result.append(working_str.strip())
409
+ working_str = ''
410
+
411
+ if(token.pos_ == "PROPN"):
412
+ working_str += token.text + ' '
413
+
414
+ if working_str != '': result.append(working_str.strip())
415
+
416
+ return [self.preprocess_text(word) for word in result]
417
+
418
+ def filter(self, query: str, doc_ids = None):
419
+ doc = nlp(query)
420
+ query_keywords = self.parse_doc(doc)
421
+ nouns = self.get_propn(doc)
422
+ if self.verbose: print('keywords:', query_keywords)
423
+ if self.verbose: print('proper nouns:', nouns)
424
+
425
+ filtered = set()
426
+ if len(query_keywords) > 0 and not self.ne_only:
427
+ for keyword in query_keywords:
428
+ if keyword != '' and keyword in self.index.keys(): filtered |= set(self.index[keyword])
429
+
430
+ if len(nouns) > 0:
431
+ ne_results = set()
432
+ for noun in nouns:
433
+ if noun in self.index.keys(): ne_results |= set(self.index[noun])
434
+
435
+ if self.ne_only: filtered = ne_results # keep only named entity results
436
+ else: filtered &= ne_results # take the intersection
437
+
438
+ if doc_ids is not None: filtered &= doc_ids # apply filter to results
439
+ return filtered
440
+
441
+ def get_cluster_keywords(clust_ids, all_keywords):
442
+
443
+ tagstr = ''
444
+ clust_tags = []
445
+ for i in range(len(clust_ids)):
446
+ clust_paper_kw = []
447
+ for j in range(len(all_keywords[clust_ids[i]])):
448
+ clust_tags.append(all_keywords[clust_ids[i]][j])
449
+ tags = Counter(clust_tags).most_common(30)
450
+ for i in range(len(tags)):
451
+ # print(tags[i][0])
452
+ if len(tags[i][0]) > 2:
453
+ tagstr = tagstr + tags[i][0]+ ', '
454
+ return tagstr
455
+
456
+ def get_keywords(query, ret_indices, all_keywords):
457
+
458
+ kws = get_cluster_keywords(ret_indices, all_keywords)
459
+
460
+ kw_prompt = """You are an expert research assistant. Here are a list of keywords corresponding to the topics that a query and its answer are about that you need to synthesize into a succinct summary:
461
+ ["""+kws+"""]
462
+
463
+ First, find the keywords that are most relevant to answering the question, and then print them in numbered order. Keywords should be a few words at most. Do not list more than five keywords.
464
+
465
+ If there are no relevant quotes, write “No relevant keywords” instead.
466
+
467
+ Thus, the format of your overall response should look like what’s shown between the tags. Make sure to follow the formatting and spacing exactly.
468
+
469
+ Keywords:
470
+ [1] Milky Way galaxy
471
+ [2] Good agreement
472
+ [3] Bayesian
473
+ [4] Observational constraints
474
+ [5] Globular clusters
475
+ [6] Kinematic data
476
+
477
+ If the question cannot be answered by the document, say so."""
478
+
479
+ client = anthropic.Anthropic(api_key=anthropic_key,)
480
+ message = client.messages.create(model="claude-3-haiku-20240307",max_tokens=200,temperature=0,system=kw_prompt,
481
+ messages=[{"role": "user","content": [{"type": "text","text": query}]}])
482
+
483
+ return message.content[0].text
pages/.ipynb_checkpoints/1 retrieval-checkpoint.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ s = time.time()
3
+
4
+ import os
5
+ import datetime
6
+ import faiss
7
+ import streamlit as st
8
+ import feedparser
9
+ import urllib
10
+ import cloudpickle as cp
11
+ import pickle
12
+ from urllib.request import urlopen
13
+ from summa import summarizer
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ import requests
17
+ import json
18
+ from scipy import ndimage
19
+
20
+ from langchain_openai import AzureOpenAIEmbeddings
21
+ # from langchain.llms import OpenAI
22
+ from langchain_community.llms import OpenAI
23
+ from langchain_openai import AzureChatOpenAI
24
+
25
+ from fns import *
26
+
27
+ st.image('local_files/synth_logo.png')
28
+ st.markdown("")
29
+
30
+ query = st.text_input('Ask me anything:',
31
+ value="What causes galaxy quenching at high redshifts?")
32
+
33
+ arxiv_id = None
34
+ top_k = st.slider('How many papers should I show?', 1, 30, 6)
35
+
36
+ retrieval_system = st.session_state.retrieval_system
37
+ results = retrieval_system.retrieve(query, arxiv_id, top_k)
38
+
39
+ aids = st.session_state.dataset['id']
40
+ titles = st.session_state.dataset['title']
41
+ auths = st.session_state.dataset['author']
42
+ bibcodes = st.session_state.dataset['bibcode']
43
+ all_keywords = st.session_state.dataset['keyword_search']
44
+ allyrs = st.session_state.dataset['year']
45
+ ret_indices = np.array([aids.index(results[i]) for i in range(top_k)])
46
+ yrs = []
47
+ for i in range(len(ret_indices)):
48
+ yr = allyrs[ret_indices[i]]
49
+ if yr < 50:
50
+ yr = yr + 2000
51
+ else:
52
+ yr = yr + 1900
53
+ yrs.append(yr)
54
+ print_titles = [titles[ret_indices[i]][0] for i in range(len(ret_indices))]
55
+ print_auths = [auths[ret_indices[i]][0]+' et al. '+str(yrs[i]) for i in range(len(ret_indices))]
56
+ print_links = ['['+bibcodes[ret_indices[i]]+'](https://ui.adsabs.harvard.edu/abs/'+bibcodes[ret_indices[i]]+'/abstract)' for i in range(len(ret_indices))]
57
+
58
+ st.divider()
59
+ st.header('top-k papers:')
60
+
61
+ for i in range(len(ret_indices)):
62
+ st.subheader(str(i+1)+'. '+print_titles[i])
63
+ st.write(print_auths[i]+' '+print_links[i])
64
+
65
+
66
+ st.divider()
67
+ st.header('top-k papers in context:')
68
+
69
+ gtkws = get_keywords(query, ret_indices, all_keywords)
70
+
71
+ umap, clbls, all_kws = load_umapcoords('local_files/arxiv_ads_corpus_coordsonly_v3.pkl')
72
+
73
+ fig = plt.figure(figsize=(12*1.8*1.2,9*2.*1.2))
74
+ im = plt.imread('local_files/astro_worldmap.png')
75
+ implot = plt.imshow(im,)
76
+
77
+ xax = (umap[0:,1]-np.amin(umap[0:,1]))+.0
78
+ xax = xax / np.amax(xax)
79
+ xax = xax * 1580 + 170
80
+ yax = (umap[0:,0]-np.amin(umap[0:,0]))+.0
81
+ yax = yax / np.amax(yax)
82
+ yax = (np.amax(yax)-yax) * 1700 + 30
83
+ # plt.scatter(xax, yax,s=2,alpha=0.7,c='k')
84
+
85
+ for i in range(np.amax(clbls)):
86
+
87
+ clust_ids = np.arange(len(clbls))[clbls == i]
88
+ clust_centroid = (np.median(xax[clust_ids]),np.median(yax[clust_ids]))
89
+ # plt.text(clust_centroid[1], clust_centroid[0], all_kws[i],fontsize=9,ha="center", va="center",
90
+ # bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3',alpha=0.3))
91
+ plt.text(clust_centroid[0], clust_centroid[1], all_kws[i],fontsize=9,ha="center", va="center",
92
+ fontfamily='serif',color='w',
93
+ bbox=dict(facecolor='k', edgecolor='none', boxstyle='round,pad=0.1',alpha=0.3))
94
+
95
+ plt.scatter(xax[ret_indices], yax[ret_indices], c='k',s=300,zorder=100)
96
+ plt.scatter(xax[ret_indices], yax[ret_indices], c='firebrick',s=100,zorder=101)
97
+ plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='k',s=300,zorder=101)
98
+ plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='w',s=100,zorder=101)
99
+
100
+ tempx = plt.xlim(); tempy = plt.ylim()
101
+ plt.text(0.012*tempx[1], (0.012+0.03)*tempy[0], 'The world of astronomy literature',fontsize=36, fontfamily='serif')
102
+ plt.text(0.012*tempx[1], (0.012+0.06)*tempy[0], 'Query: '+query,fontsize=18, fontfamily='serif')
103
+ plt.text(0.012*tempx[1], (0.012+0.08)*tempy[0], gtkws,fontsize=18, fontfamily='serif', va='top')
104
+ plt.axis('off')
105
+ st.pyplot(fig, transparent = True, bbox_inches='tight')
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
pages/1 retrieval.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ s = time.time()
3
+
4
+ import os
5
+ import datetime
6
+ import faiss
7
+ import streamlit as st
8
+ import feedparser
9
+ import urllib
10
+ import cloudpickle as cp
11
+ import pickle
12
+ from urllib.request import urlopen
13
+ from summa import summarizer
14
+ import numpy as np
15
+ import matplotlib.pyplot as plt
16
+ import requests
17
+ import json
18
+ from scipy import ndimage
19
+
20
+ from langchain_openai import AzureOpenAIEmbeddings
21
+ # from langchain.llms import OpenAI
22
+ from langchain_community.llms import OpenAI
23
+ from langchain_openai import AzureChatOpenAI
24
+
25
+ from fns import *
26
+
27
+ st.image('local_files/synth_logo.png')
28
+ st.markdown("")
29
+
30
+ query = st.text_input('Ask me anything:',
31
+ value="What causes galaxy quenching at high redshifts?")
32
+
33
+ arxiv_id = None
34
+ top_k = st.slider('How many papers should I show?', 1, 30, 6)
35
+
36
+ retrieval_system = st.session_state.retrieval_system
37
+ results = retrieval_system.retrieve(query, arxiv_id, top_k)
38
+
39
+ aids = st.session_state.dataset['id']
40
+ titles = st.session_state.dataset['title']
41
+ auths = st.session_state.dataset['author']
42
+ bibcodes = st.session_state.dataset['bibcode']
43
+ all_keywords = st.session_state.dataset['keyword_search']
44
+ allyrs = st.session_state.dataset['year']
45
+ ret_indices = np.array([aids.index(results[i]) for i in range(top_k)])
46
+ yrs = []
47
+ for i in range(len(ret_indices)):
48
+ yr = allyrs[ret_indices[i]]
49
+ if yr < 50:
50
+ yr = yr + 2000
51
+ else:
52
+ yr = yr + 1900
53
+ yrs.append(yr)
54
+ print_titles = [titles[ret_indices[i]][0] for i in range(len(ret_indices))]
55
+ print_auths = [auths[ret_indices[i]][0]+' et al. '+str(yrs[i]) for i in range(len(ret_indices))]
56
+ print_links = ['['+bibcodes[ret_indices[i]]+'](https://ui.adsabs.harvard.edu/abs/'+bibcodes[ret_indices[i]]+'/abstract)' for i in range(len(ret_indices))]
57
+
58
+ st.divider()
59
+ st.header('top-k papers:')
60
+
61
+ for i in range(len(ret_indices)):
62
+ st.subheader(str(i+1)+'. '+print_titles[i])
63
+ st.write(print_auths[i]+' '+print_links[i])
64
+
65
+
66
+ st.divider()
67
+ st.header('top-k papers in context:')
68
+
69
+ gtkws = get_keywords(query, ret_indices, all_keywords)
70
+
71
+ umap, clbls, all_kws = load_umapcoords('local_files/arxiv_ads_corpus_coordsonly_v3.pkl')
72
+
73
+ fig = plt.figure(figsize=(12*1.8*1.2,9*2.*1.2))
74
+ im = plt.imread('local_files/astro_worldmap.png')
75
+ implot = plt.imshow(im,)
76
+
77
+ xax = (umap[0:,1]-np.amin(umap[0:,1]))+.0
78
+ xax = xax / np.amax(xax)
79
+ xax = xax * 1580 + 170
80
+ yax = (umap[0:,0]-np.amin(umap[0:,0]))+.0
81
+ yax = yax / np.amax(yax)
82
+ yax = (np.amax(yax)-yax) * 1700 + 30
83
+ # plt.scatter(xax, yax,s=2,alpha=0.7,c='k')
84
+
85
+ for i in range(np.amax(clbls)):
86
+
87
+ clust_ids = np.arange(len(clbls))[clbls == i]
88
+ clust_centroid = (np.median(xax[clust_ids]),np.median(yax[clust_ids]))
89
+ # plt.text(clust_centroid[1], clust_centroid[0], all_kws[i],fontsize=9,ha="center", va="center",
90
+ # bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3',alpha=0.3))
91
+ plt.text(clust_centroid[0], clust_centroid[1], all_kws[i],fontsize=9,ha="center", va="center",
92
+ fontfamily='serif',color='w',
93
+ bbox=dict(facecolor='k', edgecolor='none', boxstyle='round,pad=0.1',alpha=0.3))
94
+
95
+ plt.scatter(xax[ret_indices], yax[ret_indices], c='k',s=300,zorder=100)
96
+ plt.scatter(xax[ret_indices], yax[ret_indices], c='firebrick',s=100,zorder=101)
97
+ plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='k',s=300,zorder=101)
98
+ plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='w',s=100,zorder=101)
99
+
100
+ tempx = plt.xlim(); tempy = plt.ylim()
101
+ plt.text(0.012*tempx[1], (0.012+0.03)*tempy[0], 'The world of astronomy literature',fontsize=36, fontfamily='serif')
102
+ plt.text(0.012*tempx[1], (0.012+0.06)*tempy[0], 'Query: '+query,fontsize=18, fontfamily='serif')
103
+ plt.text(0.012*tempx[1], (0.012+0.08)*tempy[0], gtkws,fontsize=18, fontfamily='serif', va='top')
104
+ plt.axis('off')
105
+ st.pyplot(fig, transparent = True, bbox_inches='tight')
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ matplotlib
2
+ bokeh==2.4.3
3
+ cloudpickle
4
+ scipy
5
+ summa
6
+ faiss-cpu
7
+ langchain
8
+ langchain_openai
9
+ langchain_community
10
+ langchain_core
11
+ openai
12
+ feedparser
13
+ tiktoken
14
+ chromadb
15
+ streamlit-extras
16
+ nltk