Spaces:
Sleeping
Sleeping
import time | |
s2 = time.time() | |
import numpy as np | |
import streamlit as st | |
import json | |
from abc import ABC, abstractmethod | |
from typing import List, Dict, Any, Tuple | |
from collections import defaultdict | |
# import wandb | |
import numpy as np | |
from tqdm import tqdm | |
from datetime import datetime, date | |
import pickle | |
from datasets import load_dataset | |
import os | |
from nltk.corpus import stopwords | |
import nltk | |
from openai import OpenAI | |
import anthropic | |
import time | |
from collections import Counter | |
try: | |
stopwords.words('english') | |
except: | |
nltk.download('stopwords') | |
stopwords.words('english') | |
openai_key = st.secrets['openai_key'] | |
anthropic_key = st.secrets['anthropic_key'] | |
# anthropic_key = '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' | |
def load_astro_meta(): | |
print('load astro meta') | |
return load_dataset('arxiv_corpus/', split = "train") | |
def load_index_mapping(index_mapping_path): | |
print("Loading index mapping...") | |
with open(index_mapping_path, 'rb') as f: | |
temp = pickle.load(f) | |
return temp | |
def load_embeddings(embeddings_path): | |
print("Loading embedding") | |
return np.load(embeddings_path) | |
def load_metadata(meta_path): | |
print("Loading metadata...") | |
with open(meta_path, 'r') as f: | |
metadata = json.load(f) | |
return metadata | |
# @st.cache_data | |
def load_umapcoords(umap_path): | |
print('loading umap coords') | |
with open(umap_path, "rb") as fp: #Pickling | |
umap = pickle.load(fp) | |
return umap | |
class EmbeddingClient: | |
def __init__(self, client: OpenAI, model: str = "text-embedding-3-small"): | |
self.client = client | |
self.model = model | |
def embed(self, text: str) -> np.ndarray: | |
embedding = self.client.embeddings.create(input=[text], model=self.model).data[0].embedding | |
return np.array(embedding, dtype=np.float32) | |
def embed_batch(self, texts: List[str]) -> List[np.ndarray]: | |
embeddings = self.client.embeddings.create(input=texts, model=self.model).data | |
return [np.array(embedding.embedding, dtype=np.float32) for embedding in embeddings] | |
class RetrievalSystem(ABC): | |
def retrieve(self, query: str, arxiv_id: str, top_k: int = 100) -> List[str]: | |
pass | |
def parse_date(self, arxiv_id: str) -> datetime: | |
if arxiv_id is None: | |
return date.today() | |
if arxiv_id.startswith('astro-ph'): | |
arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0] | |
try: | |
year = int("20" + arxiv_id[:2]) | |
month = int(arxiv_id[2:4]) | |
except: | |
year = 2023 | |
month = 1 | |
return date(year, month, 1) | |
class EmbeddingRetrievalSystem(RetrievalSystem): | |
def __init__(self, embeddings_path: str = "local_files/embeddings_matrix.npy", | |
documents_path: str = "local_files/documents.pkl", | |
index_mapping_path: str = "local_files/index_mapping.pkl", | |
metadata_path: str = "local_files/metadata.json", | |
weight_citation = False, weight_date = False, weight_keywords = False): | |
self.embeddings_path = embeddings_path | |
self.documents_path = documents_path | |
self.index_mapping_path = index_mapping_path | |
self.metadata_path = metadata_path | |
self.weight_citation = weight_citation | |
self.weight_date = weight_date | |
self.weight_keywords = weight_keywords | |
self.embeddings = None | |
self.documents = None | |
self.index_mapping = None | |
self.metadata = None | |
self.document_dates = [] | |
self.load_data() | |
self.init_filters() | |
# config = yaml.safe_load(open('../config.yaml', 'r')) | |
self.client = EmbeddingClient(OpenAI(api_key=openai_key)) | |
self.anthropic_client = anthropic.Anthropic(api_key=anthropic_key) | |
def generate_metadata(self): | |
astro_meta = load_astro_meta() | |
# dataset = load_dataset('arxiv_corpus/') | |
keys = list(astro_meta[0].keys()) | |
keys.remove('abstract') | |
keys.remove('introduction') | |
keys.remove('conclusions') | |
self.metadata = {} | |
for paper in astro_meta: | |
id_str = paper['arxiv_id'] | |
self.metadata[id_str] = {key: paper[key] for key in keys} | |
with open(self.metadata_path, 'w') as f: | |
json.dump(self.metadata, f) | |
st.markdown("Wrote metadata to {}".format(self.metadata_path)) | |
# | |
def load_data(self): | |
# print("Loading embeddings...") | |
# self.embeddings = np.load(self.embeddings_path) | |
self.embeddings = load_embeddings(self.embeddings_path) | |
st.sidebar.success("Loaded embeddings") | |
# with open(self.index_mapping_path, 'rb') as f: | |
# self.index_mapping = pickle.load(f) | |
self.index_mapping = load_index_mapping(self.index_mapping_path) | |
st.sidebar.success("Loaded index mapping") | |
# print("Loading documents...") | |
# with open(self.documents_path, 'rb') as f: | |
# self.documents = pickle.load(f) | |
dataset = load_astro_meta() | |
st.sidebar.success("Loaded documents") | |
print("Processing document dates...") | |
# self.document_dates = {doc.id: self.parse_date(doc.arxiv_id) for doc in self.documents} | |
aids = dataset['arxiv_id'] | |
adsids = dataset['id'] | |
self.document_dates = {adsids[i]: self.parse_date(aids[i]) for i in range(len(aids))} | |
if os.path.exists(self.metadata_path): | |
self.metadata = load_metadata(self.metadata_path) | |
print("Loaded metadata.") | |
else: | |
print("Could not find path; generating metadata.") | |
self.generate_metadata() | |
print("Data loaded successfully.") | |
def init_filters(self): | |
print("Loading filters...") | |
self.citation_filter = CitationFilter(metadata = self.metadata) | |
self.date_filter = DateFilter(document_dates = self.document_dates) | |
self.keyword_filter = KeywordFilter(index_path = "local_files/keyword_index.json", metadata = self.metadata, remove_capitals = True) | |
def retrieve(self, query: str, arxiv_id: str = None, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]: | |
query_date = self.parse_date(arxiv_id) | |
query_embedding = self.get_query_embedding(query) | |
# Judge time relevance | |
if time_result is None: | |
if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client) | |
else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None} | |
top_results = self.rank_and_filter(query, query_embedding, query_date, top_k, return_scores = return_scores, time_result = time_result) | |
return top_results | |
def rank_and_filter(self, query, query_embedding: np.ndarray, query_date, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]: | |
# Calculate similarities | |
similarities = np.dot(self.embeddings, query_embedding) | |
# Filter and rank results | |
if self.weight_keywords: keyword_matches = self.keyword_filter.filter(query) | |
results = [] | |
for doc_id, mappings in self.index_mapping.items(): | |
if not self.weight_keywords or doc_id in keyword_matches: | |
abstract_sim = similarities[mappings['abstract']] if 'abstract' in mappings else -np.inf | |
conclusions_sim = similarities[mappings['conclusions']] if 'conclusions' in mappings else -np.inf | |
if abstract_sim > conclusions_sim: | |
results.append([doc_id, "abstract", abstract_sim]) | |
else: | |
results.append([doc_id, "conclusions", conclusions_sim]) | |
# Sort and weight and get top-k results | |
if time_result['has_temporal_aspect']: | |
filtered_results = self.date_filter.filter(results, boolean_date = time_result['expected_year_filter'], time_score = time_result['expected_recency_weight'], max_date = query_date) | |
else: | |
filtered_results = self.date_filter.filter(results, max_date = query_date) | |
if self.weight_citation: self.citation_filter.filter(filtered_results) | |
top_results = sorted(filtered_results, key=lambda x: x[2], reverse=True)[:top_k] | |
if return_scores: | |
return {doc[0]: doc[2] for doc in top_results} | |
# Only keep the document IDs | |
top_results = [doc[0] for doc in top_results] | |
return top_results | |
def get_query_embedding(self, query: str) -> np.ndarray: | |
embedding = self.client.embed(query) | |
return np.array(embedding, dtype = np.float32) | |
def get_document_texts(self, doc_ids: List[str]) -> List[Dict[str, str]]: | |
results = [] | |
for doc_id in doc_ids: | |
doc = next((d for d in self.documents if d.id == doc_id), None) | |
if doc: | |
results.append({ | |
'id': doc.id, | |
'abstract': doc.abstract, | |
'conclusions': doc.conclusions | |
}) | |
else: | |
print(f"Warning: Document with ID {doc_id} not found.") | |
return results | |
def retrieve_context(self, query, top_k, sections = ["abstract", "conclusions"], **kwargs): | |
docs = self.retrieve(query, top_k = top_k, return_scores = True, **kwargs) | |
docids = docs.keys() | |
doctexts = self.get_document_texts(docids) # avoid having to do this repetitively? | |
context_str = "" | |
doclist = [] | |
for docid, doctext in zip(docids, doctexts): | |
for section in sections: | |
context_str += f"{docid}: {doctext[section]}\n" | |
meta_row = self.metadata[docid] | |
doclist.append(Document(docid, doctext['abstract'], doctext['conclusions'], docid, title = meta_row['title'], | |
score = docs[docid], n_citation = meta_row['citation_count'], keywords = meta_row['keyword_search'])) | |
return context_str, doclist | |
class Filter(): | |
def filter(self, query: str, arxiv_id: str) -> List[str]: | |
pass | |
class CitationFilter(Filter): # can do it with all metadata | |
def __init__(self, metadata): | |
self.metadata = metadata | |
self.citation_counts = {doc_id: self.metadata[doc_id]['citation_count'] for doc_id in self.metadata} | |
def citation_weight(self, x, shift, scale): | |
return 1 / (1 + np.exp(-1 * (x - shift) / scale)) # sigmoid function | |
def filter(self, doc_scores, weight = 0.1): # additive weighting | |
citation_count = np.array([self.citation_counts[doc[0]] for doc in doc_scores]) | |
cmean, cstd = np.median(citation_count), np.std(citation_count) | |
citation_score = self.citation_weight(citation_count, cmean, cstd) | |
for i, doc in enumerate(doc_scores): | |
doc_scores[i][2] += weight * citation_score[i] | |
class DateFilter(Filter): # include time weighting eventually | |
def __init__(self, document_dates): | |
self.document_dates = document_dates | |
def parse_date(self, arxiv_id: str) -> datetime: # only for documents | |
if arxiv_id.startswith('astro-ph'): | |
arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0] | |
try: | |
year = int("20" + arxiv_id[:2]) | |
month = int(arxiv_id[2:4]) | |
except: | |
year = 2023 | |
month = 1 | |
return date(year, month, 1) | |
def weight(self, time, shift, scale): | |
return 1 / (1 + np.exp((time - shift) / scale)) | |
def evaluate_filter(self, year, filter_string): | |
try: | |
# Use ast.literal_eval to safely evaluate the expression | |
result = eval(filter_string, {"__builtins__": None}, {"year": year}) | |
return result | |
except Exception as e: | |
print(f"Error evaluating filter: {e}") | |
return False | |
def filter(self, docs, boolean_date = None, min_date = None, max_date = None, time_score = 0): | |
filtered = [] | |
if boolean_date is not None: | |
boolean_date = boolean_date.replace("AND", "and").replace("OR", "or") | |
for doc in docs: | |
if self.evaluate_filter(self.document_dates[doc[0]].year, boolean_date): | |
filtered.append(doc) | |
else: | |
if min_date == None: min_date = date(1990, 1, 1) | |
if max_date == None: max_date = date(2024, 7, 3) | |
for doc in docs: | |
if self.document_dates[doc[0]] >= min_date and self.document_dates[doc[0]] <= max_date: | |
filtered.append(doc) | |
if time_score is not None: # apply time weighting | |
for i, item in enumerate(filtered): | |
time_diff = (max_date - self.document_dates[filtered[i][0]]).days / 365 | |
filtered[i][2] += time_score * 0.1 * self.weight(time_diff, 5, 5) | |
return filtered | |
class KeywordFilter(Filter): | |
def __init__(self, index_path: str = "local_files/keyword_index.json", | |
remove_capitals: bool = True, metadata = None, ne_only = True, verbose = False): | |
self.index_path = index_path | |
self.metadata = metadata | |
self.remove_capitals = remove_capitals | |
self.ne_only = ne_only | |
self.stopwords = set(stopwords.words('english')) | |
self.verbose = verbose | |
self.index = None | |
self.load_or_build_index() | |
def preprocess_text(self, text: str) -> str: | |
text = ''.join(char for char in text if char.isalnum() or char.isspace()) | |
if self.remove_capitals: text = text.lower() | |
return ' '.join(word for word in text.split() if word.lower() not in self.stopwords) | |
def build_index(self): # include the title in the index | |
print("Building index...") | |
self.index = {} | |
for i, index in tqdm(enumerate(self.metadata)): | |
paper = self.metadata[index] | |
title = paper['title'][0] | |
title_keywords = set() #set(self.parse_doc(title) + self.get_propn(title)) | |
for keyword in set(paper['keyword_search']) | title_keywords: | |
term = ' '.join(word for word in keyword.lower().split() if word.lower() not in self.stopwords) | |
if term not in self.index: | |
self.index[term] = [] | |
self.index[term].append(paper['arxiv_id']) | |
with open(self.index_path, 'w') as f: | |
json.dump(self.index, f) | |
def load_index(self): | |
print("Loading existing index...") | |
with open(self.index_path, 'rb') as f: | |
self.index = json.load(f) | |
print("Index loaded successfully.") | |
def load_or_build_index(self): | |
if os.path.exists(self.index_path): | |
self.load_index() | |
else: | |
self.build_index() | |
def parse_doc(self, doc): | |
local_kws = [] | |
for phrase in doc._.phrases: | |
local_kws.append(phrase.text.lower()) | |
return [self.preprocess_text(word) for word in local_kws] | |
def get_propn(self, doc): | |
result = [] | |
working_str = '' | |
for token in doc: | |
if(token.text in nlp.Defaults.stop_words or token.text in punctuation): | |
if working_str != '': | |
result.append(working_str.strip()) | |
working_str = '' | |
if(token.pos_ == "PROPN"): | |
working_str += token.text + ' ' | |
if working_str != '': result.append(working_str.strip()) | |
return [self.preprocess_text(word) for word in result] | |
def filter(self, query: str, doc_ids = None): | |
doc = nlp(query) | |
query_keywords = self.parse_doc(doc) | |
nouns = self.get_propn(doc) | |
if self.verbose: print('keywords:', query_keywords) | |
if self.verbose: print('proper nouns:', nouns) | |
filtered = set() | |
if len(query_keywords) > 0 and not self.ne_only: | |
for keyword in query_keywords: | |
if keyword != '' and keyword in self.index.keys(): filtered |= set(self.index[keyword]) | |
if len(nouns) > 0: | |
ne_results = set() | |
for noun in nouns: | |
if noun in self.index.keys(): ne_results |= set(self.index[noun]) | |
if self.ne_only: filtered = ne_results # keep only named entity results | |
else: filtered &= ne_results # take the intersection | |
if doc_ids is not None: filtered &= doc_ids # apply filter to results | |
return filtered | |
def get_cluster_keywords(clust_ids, all_keywords): | |
tagstr = '' | |
clust_tags = [] | |
for i in range(len(clust_ids)): | |
clust_paper_kw = [] | |
for j in range(len(all_keywords[clust_ids[i]])): | |
clust_tags.append(all_keywords[clust_ids[i]][j]) | |
tags = Counter(clust_tags).most_common(30) | |
for i in range(len(tags)): | |
# print(tags[i][0]) | |
if len(tags[i][0]) > 2: | |
tagstr = tagstr + tags[i][0]+ ', ' | |
return tagstr | |
def get_keywords(query, ret_indices, all_keywords): | |
kws = get_cluster_keywords(ret_indices, all_keywords) | |
kw_prompt = """You are an expert research assistant. Here are a list of keywords corresponding to the topics that a query and its answer are about that you need to synthesize into a succinct summary: | |
["""+kws+"""] | |
First, find the keywords that are most relevant to answering the question, and then print them in numbered order. Keywords should be a few words at most. Do not list more than five keywords. | |
If there are no relevant quotes, write “No relevant keywords” instead. | |
Thus, the format of your overall response should look like what’s shown between the tags. Make sure to follow the formatting and spacing exactly. | |
Keywords: | |
[1] Milky Way galaxy | |
[2] Good agreement | |
[3] Bayesian | |
[4] Observational constraints | |
[5] Globular clusters | |
[6] Kinematic data | |
If the question cannot be answered by the document, say so.""" | |
client = anthropic.Anthropic(api_key=anthropic_key,) | |
message = client.messages.create(model="claude-3-haiku-20240307",max_tokens=200,temperature=0,system=kw_prompt, | |
messages=[{"role": "user","content": [{"type": "text","text": query}]}]) | |
return message.content[0].text |