import time s2 = time.time() import numpy as np import streamlit as st import json from abc import ABC, abstractmethod from typing import List, Dict, Any, Tuple from collections import defaultdict # import wandb import numpy as np from tqdm import tqdm from datetime import datetime, date import pickle from datasets import load_dataset import os from nltk.corpus import stopwords import nltk from openai import OpenAI import anthropic import time from collections import Counter try: stopwords.words('english') except: nltk.download('stopwords') stopwords.words('english') openai_key = st.secrets['openai_key'] anthropic_key = st.secrets['anthropic_key'] # anthropic_key = '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' @st.cache_data def load_astro_meta(): print('load astro meta') return load_dataset('arxiv_corpus/', split = "train") @st.cache_data def load_index_mapping(index_mapping_path): print("Loading index mapping...") with open(index_mapping_path, 'rb') as f: temp = pickle.load(f) return temp @st.cache_data def load_embeddings(embeddings_path): print("Loading embedding") return np.load(embeddings_path) @st.cache_data def load_metadata(meta_path): print("Loading metadata...") with open(meta_path, 'r') as f: metadata = json.load(f) return metadata # @st.cache_data def load_umapcoords(umap_path): print('loading umap coords') with open(umap_path, "rb") as fp: #Pickling umap = pickle.load(fp) return umap class EmbeddingClient: def __init__(self, client: OpenAI, model: str = "text-embedding-3-small"): self.client = client self.model = model def embed(self, text: str) -> np.ndarray: embedding = self.client.embeddings.create(input=[text], model=self.model).data[0].embedding return np.array(embedding, dtype=np.float32) def embed_batch(self, texts: List[str]) -> List[np.ndarray]: embeddings = self.client.embeddings.create(input=texts, model=self.model).data return [np.array(embedding.embedding, dtype=np.float32) for embedding in embeddings] class RetrievalSystem(ABC): @abstractmethod def retrieve(self, query: str, arxiv_id: str, top_k: int = 100) -> List[str]: pass def parse_date(self, arxiv_id: str) -> datetime: if arxiv_id is None: return date.today() if arxiv_id.startswith('astro-ph'): arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0] try: year = int("20" + arxiv_id[:2]) month = int(arxiv_id[2:4]) except: year = 2023 month = 1 return date(year, month, 1) class EmbeddingRetrievalSystem(RetrievalSystem): def __init__(self, embeddings_path: str = "local_files/embeddings_matrix.npy", documents_path: str = "local_files/documents.pkl", index_mapping_path: str = "local_files/index_mapping.pkl", metadata_path: str = "local_files/metadata.json", weight_citation = False, weight_date = False, weight_keywords = False): self.embeddings_path = embeddings_path self.documents_path = documents_path self.index_mapping_path = index_mapping_path self.metadata_path = metadata_path self.weight_citation = weight_citation self.weight_date = weight_date self.weight_keywords = weight_keywords self.embeddings = None self.documents = None self.index_mapping = None self.metadata = None self.document_dates = [] self.load_data() self.init_filters() # config = yaml.safe_load(open('../config.yaml', 'r')) self.client = EmbeddingClient(OpenAI(api_key=openai_key)) self.anthropic_client = anthropic.Anthropic(api_key=anthropic_key) def generate_metadata(self): astro_meta = load_astro_meta() # dataset = load_dataset('arxiv_corpus/') keys = list(astro_meta[0].keys()) keys.remove('abstract') keys.remove('introduction') keys.remove('conclusions') self.metadata = {} for paper in astro_meta: id_str = paper['arxiv_id'] self.metadata[id_str] = {key: paper[key] for key in keys} with open(self.metadata_path, 'w') as f: json.dump(self.metadata, f) st.markdown("Wrote metadata to {}".format(self.metadata_path)) # def load_data(self): # print("Loading embeddings...") # self.embeddings = np.load(self.embeddings_path) self.embeddings = load_embeddings(self.embeddings_path) st.sidebar.success("Loaded embeddings") # with open(self.index_mapping_path, 'rb') as f: # self.index_mapping = pickle.load(f) self.index_mapping = load_index_mapping(self.index_mapping_path) st.sidebar.success("Loaded index mapping") # print("Loading documents...") # with open(self.documents_path, 'rb') as f: # self.documents = pickle.load(f) dataset = load_astro_meta() st.sidebar.success("Loaded documents") print("Processing document dates...") # self.document_dates = {doc.id: self.parse_date(doc.arxiv_id) for doc in self.documents} aids = dataset['arxiv_id'] adsids = dataset['id'] self.document_dates = {adsids[i]: self.parse_date(aids[i]) for i in range(len(aids))} if os.path.exists(self.metadata_path): self.metadata = load_metadata(self.metadata_path) print("Loaded metadata.") else: print("Could not find path; generating metadata.") self.generate_metadata() print("Data loaded successfully.") def init_filters(self): print("Loading filters...") self.citation_filter = CitationFilter(metadata = self.metadata) self.date_filter = DateFilter(document_dates = self.document_dates) self.keyword_filter = KeywordFilter(index_path = "local_files/keyword_index.json", metadata = self.metadata, remove_capitals = True) def retrieve(self, query: str, arxiv_id: str = None, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]: query_date = self.parse_date(arxiv_id) query_embedding = self.get_query_embedding(query) # Judge time relevance if time_result is None: if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client) else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None} top_results = self.rank_and_filter(query, query_embedding, query_date, top_k, return_scores = return_scores, time_result = time_result) return top_results def rank_and_filter(self, query, query_embedding: np.ndarray, query_date, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]: # Calculate similarities similarities = np.dot(self.embeddings, query_embedding) # Filter and rank results if self.weight_keywords: keyword_matches = self.keyword_filter.filter(query) results = [] for doc_id, mappings in self.index_mapping.items(): if not self.weight_keywords or doc_id in keyword_matches: abstract_sim = similarities[mappings['abstract']] if 'abstract' in mappings else -np.inf conclusions_sim = similarities[mappings['conclusions']] if 'conclusions' in mappings else -np.inf if abstract_sim > conclusions_sim: results.append([doc_id, "abstract", abstract_sim]) else: results.append([doc_id, "conclusions", conclusions_sim]) # Sort and weight and get top-k results if time_result['has_temporal_aspect']: filtered_results = self.date_filter.filter(results, boolean_date = time_result['expected_year_filter'], time_score = time_result['expected_recency_weight'], max_date = query_date) else: filtered_results = self.date_filter.filter(results, max_date = query_date) if self.weight_citation: self.citation_filter.filter(filtered_results) top_results = sorted(filtered_results, key=lambda x: x[2], reverse=True)[:top_k] if return_scores: return {doc[0]: doc[2] for doc in top_results} # Only keep the document IDs top_results = [doc[0] for doc in top_results] return top_results def get_query_embedding(self, query: str) -> np.ndarray: embedding = self.client.embed(query) return np.array(embedding, dtype = np.float32) def get_document_texts(self, doc_ids: List[str]) -> List[Dict[str, str]]: results = [] for doc_id in doc_ids: doc = next((d for d in self.documents if d.id == doc_id), None) if doc: results.append({ 'id': doc.id, 'abstract': doc.abstract, 'conclusions': doc.conclusions }) else: print(f"Warning: Document with ID {doc_id} not found.") return results def retrieve_context(self, query, top_k, sections = ["abstract", "conclusions"], **kwargs): docs = self.retrieve(query, top_k = top_k, return_scores = True, **kwargs) docids = docs.keys() doctexts = self.get_document_texts(docids) # avoid having to do this repetitively? context_str = "" doclist = [] for docid, doctext in zip(docids, doctexts): for section in sections: context_str += f"{docid}: {doctext[section]}\n" meta_row = self.metadata[docid] doclist.append(Document(docid, doctext['abstract'], doctext['conclusions'], docid, title = meta_row['title'], score = docs[docid], n_citation = meta_row['citation_count'], keywords = meta_row['keyword_search'])) return context_str, doclist class Filter(): def filter(self, query: str, arxiv_id: str) -> List[str]: pass class CitationFilter(Filter): # can do it with all metadata def __init__(self, metadata): self.metadata = metadata self.citation_counts = {doc_id: self.metadata[doc_id]['citation_count'] for doc_id in self.metadata} def citation_weight(self, x, shift, scale): return 1 / (1 + np.exp(-1 * (x - shift) / scale)) # sigmoid function def filter(self, doc_scores, weight = 0.1): # additive weighting citation_count = np.array([self.citation_counts[doc[0]] for doc in doc_scores]) cmean, cstd = np.median(citation_count), np.std(citation_count) citation_score = self.citation_weight(citation_count, cmean, cstd) for i, doc in enumerate(doc_scores): doc_scores[i][2] += weight * citation_score[i] class DateFilter(Filter): # include time weighting eventually def __init__(self, document_dates): self.document_dates = document_dates def parse_date(self, arxiv_id: str) -> datetime: # only for documents if arxiv_id.startswith('astro-ph'): arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0] try: year = int("20" + arxiv_id[:2]) month = int(arxiv_id[2:4]) except: year = 2023 month = 1 return date(year, month, 1) def weight(self, time, shift, scale): return 1 / (1 + np.exp((time - shift) / scale)) def evaluate_filter(self, year, filter_string): try: # Use ast.literal_eval to safely evaluate the expression result = eval(filter_string, {"__builtins__": None}, {"year": year}) return result except Exception as e: print(f"Error evaluating filter: {e}") return False def filter(self, docs, boolean_date = None, min_date = None, max_date = None, time_score = 0): filtered = [] if boolean_date is not None: boolean_date = boolean_date.replace("AND", "and").replace("OR", "or") for doc in docs: if self.evaluate_filter(self.document_dates[doc[0]].year, boolean_date): filtered.append(doc) else: if min_date == None: min_date = date(1990, 1, 1) if max_date == None: max_date = date(2024, 7, 3) for doc in docs: if self.document_dates[doc[0]] >= min_date and self.document_dates[doc[0]] <= max_date: filtered.append(doc) if time_score is not None: # apply time weighting for i, item in enumerate(filtered): time_diff = (max_date - self.document_dates[filtered[i][0]]).days / 365 filtered[i][2] += time_score * 0.1 * self.weight(time_diff, 5, 5) return filtered class KeywordFilter(Filter): def __init__(self, index_path: str = "local_files/keyword_index.json", remove_capitals: bool = True, metadata = None, ne_only = True, verbose = False): self.index_path = index_path self.metadata = metadata self.remove_capitals = remove_capitals self.ne_only = ne_only self.stopwords = set(stopwords.words('english')) self.verbose = verbose self.index = None self.load_or_build_index() def preprocess_text(self, text: str) -> str: text = ''.join(char for char in text if char.isalnum() or char.isspace()) if self.remove_capitals: text = text.lower() return ' '.join(word for word in text.split() if word.lower() not in self.stopwords) def build_index(self): # include the title in the index print("Building index...") self.index = {} for i, index in tqdm(enumerate(self.metadata)): paper = self.metadata[index] title = paper['title'][0] title_keywords = set() #set(self.parse_doc(title) + self.get_propn(title)) for keyword in set(paper['keyword_search']) | title_keywords: term = ' '.join(word for word in keyword.lower().split() if word.lower() not in self.stopwords) if term not in self.index: self.index[term] = [] self.index[term].append(paper['arxiv_id']) with open(self.index_path, 'w') as f: json.dump(self.index, f) def load_index(self): print("Loading existing index...") with open(self.index_path, 'rb') as f: self.index = json.load(f) print("Index loaded successfully.") def load_or_build_index(self): if os.path.exists(self.index_path): self.load_index() else: self.build_index() def parse_doc(self, doc): local_kws = [] for phrase in doc._.phrases: local_kws.append(phrase.text.lower()) return [self.preprocess_text(word) for word in local_kws] def get_propn(self, doc): result = [] working_str = '' for token in doc: if(token.text in nlp.Defaults.stop_words or token.text in punctuation): if working_str != '': result.append(working_str.strip()) working_str = '' if(token.pos_ == "PROPN"): working_str += token.text + ' ' if working_str != '': result.append(working_str.strip()) return [self.preprocess_text(word) for word in result] def filter(self, query: str, doc_ids = None): doc = nlp(query) query_keywords = self.parse_doc(doc) nouns = self.get_propn(doc) if self.verbose: print('keywords:', query_keywords) if self.verbose: print('proper nouns:', nouns) filtered = set() if len(query_keywords) > 0 and not self.ne_only: for keyword in query_keywords: if keyword != '' and keyword in self.index.keys(): filtered |= set(self.index[keyword]) if len(nouns) > 0: ne_results = set() for noun in nouns: if noun in self.index.keys(): ne_results |= set(self.index[noun]) if self.ne_only: filtered = ne_results # keep only named entity results else: filtered &= ne_results # take the intersection if doc_ids is not None: filtered &= doc_ids # apply filter to results return filtered def get_cluster_keywords(clust_ids, all_keywords): tagstr = '' clust_tags = [] for i in range(len(clust_ids)): clust_paper_kw = [] for j in range(len(all_keywords[clust_ids[i]])): clust_tags.append(all_keywords[clust_ids[i]][j]) tags = Counter(clust_tags).most_common(30) for i in range(len(tags)): # print(tags[i][0]) if len(tags[i][0]) > 2: tagstr = tagstr + tags[i][0]+ ', ' return tagstr def get_keywords(query, ret_indices, all_keywords): kws = get_cluster_keywords(ret_indices, all_keywords) kw_prompt = """You are an expert research assistant. Here are a list of keywords corresponding to the topics that a query and its answer are about that you need to synthesize into a succinct summary: ["""+kws+"""] First, find the keywords that are most relevant to answering the question, and then print them in numbered order. Keywords should be a few words at most. Do not list more than five keywords. If there are no relevant quotes, write “No relevant keywords” instead. Thus, the format of your overall response should look like what’s shown between the tags. Make sure to follow the formatting and spacing exactly. Keywords: [1] Milky Way galaxy [2] Good agreement [3] Bayesian [4] Observational constraints [5] Globular clusters [6] Kinematic data If the question cannot be answered by the document, say so.""" client = anthropic.Anthropic(api_key=anthropic_key,) message = client.messages.create(model="claude-3-haiku-20240307",max_tokens=200,temperature=0,system=kw_prompt, messages=[{"role": "user","content": [{"type": "text","text": query}]}]) return message.content[0].text