Spaces:
Sleeping
Sleeping
import streamlit as st | |
st.set_page_config(layout="wide") | |
import numpy as np | |
from abc import ABC, abstractmethod | |
from typing import List, Dict, Any, Tuple | |
from collections import defaultdict | |
from tqdm import tqdm | |
import pandas as pd | |
from datetime import datetime, date | |
from datasets import load_dataset, load_from_disk | |
from collections import Counter | |
import yaml, json, requests, sys, os, time | |
import concurrent.futures | |
ts = time.time() | |
anthropic_key = "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" | |
# anthropic_client = anthropic.Anthropic(api_key=anthropic_key) | |
openai_key = "sk-None-TMT98W6ksCIYY6w0UI66T3BlbkFJva1LamMQXbenkcnYqvs6" | |
# openai_client = EmbeddingClient(OpenAI(api_key=openai_key)) | |
from nltk.corpus import stopwords | |
import nltk | |
from openai import OpenAI | |
import anthropic | |
import cohere | |
import faiss | |
import spacy | |
from string import punctuation | |
import pytextrank | |
nlp = spacy.load("en_core_web_sm") | |
nlp.add_pipe("textrank") | |
try: | |
stopwords.words('english') | |
except: | |
nltk.download('stopwords') | |
stopwords.words('english') | |
from bokeh.plotting import figure | |
from bokeh.models import ColumnDataSource | |
from bokeh.palettes import Spectral10 | |
st.image('local_files/pathfinder_logo.png') | |
st.expander("About", expanded=False).write( | |
""" | |
Pathfinder v2.0 is a framework for searching and visualizing astronomy papers on the [arXiv](https://arxiv.org/) and [ADS](https://ui.adsabs.harvard.edu/) using the context | |
sensitivity from modern large language models (LLMs) to better parse patterns in paper contexts. | |
This tool was built during the [JSALT workshop](https://www.clsp.jhu.edu/2024-jelinek-summer-workshop-on-speech-and-language-technology/) to do awesome things. | |
**π Select a tool from the sidebar** to see some examples | |
of what this framework can do! | |
### Tool summary: | |
- Please wait while the initial data loads and compiles, this takes about a minute initially. | |
- `Paper search` looks for relevant papers given an arxiv id or a question. | |
This is not meant to be a replacement to existing tools like the | |
[ADS](https://ui.adsabs.harvard.edu/), | |
[arxivsorter](https://www.arxivsorter.org/), semantic search or google scholar, but rather a supplement to find papers | |
that otherwise might be missed during a literature survey. | |
It is trained on astro-ph (astrophysics of galaxies) papers up to last-year-ish mined from arxiv and supplemented with ADS metadata, | |
if you are interested in extending it please reach out! | |
Also add: more pages, actual generation, diff. toggles for retrieval/gen, feedback form, socials, literature, contact us, copyright, collaboration, etc. | |
The image below shows a representation of all the astro-ph.GA papers that can be explored in more detail | |
using the `Arxiv embedding` page. The papers tend to cluster together by similarity, and result in an | |
atlas that shows well studied (forests) and currently uncharted areas (water). | |
""" | |
) | |
if 'arxiv_corpus' not in st.session_state: | |
with st.spinner('loading data...'): | |
try: | |
arxiv_corpus = load_from_disk('data/') | |
arxiv_corpus.add_faiss_index('embed') | |
except: | |
st.write('downloading data') | |
arxiv_corpus = load_dataset('kiyer/pathfinder_arxiv_data',split='train') | |
arxiv_corpus.add_faiss_index('embed') | |
arxiv_corpus.save_to_disk('data/') | |
st.session_state.arxiv_corpus = arxiv_corpus | |
st.toast('loaded arxiv corpus') | |
else: | |
arxiv_corpus = st.session_state.arxiv_corpus | |
if 'ids' not in st.session_state: | |
st.session_state.ids = arxiv_corpus['ads_id'] | |
st.session_state.titles = arxiv_corpus['title'] | |
st.session_state.abstracts = arxiv_corpus['abstract'] | |
st.session_state.cites = arxiv_corpus['cites'] | |
st.session_state.years = arxiv_corpus['date'] | |
st.session_state.kws = arxiv_corpus['keywords'] | |
st.toast('done caching. time taken: %.2f sec' %(time.time()-ts)) | |
#---------------------------------------------------------------- | |
class Filter(): | |
def filter(self, query: str, arxiv_id: str) -> List[str]: | |
pass | |
class CitationFilter(Filter): # can do it with all metadata | |
def __init__(self, corpus): | |
self.corpus = corpus | |
ids = ids | |
cites = cites | |
self.citation_counts = {ids[i]: cites[i] for i in range(len(ids))} | |
def citation_weight(self, x, shift, scale): | |
return 1 / (1 + np.exp(-1 * (x - shift) / scale)) # sigmoid function | |
def filter(self, doc_scores, weight = 0.1): # additive weighting | |
citation_count = np.array([self.citation_counts[doc[0]] for doc in doc_scores]) | |
cmean, cstd = np.median(citation_count), np.std(citation_count) | |
citation_score = self.citation_weight(citation_count, cmean, cstd) | |
for i, doc in enumerate(doc_scores): | |
doc_scores[i][2] += weight * citation_score[i] | |
class DateFilter(Filter): # include time weighting eventually | |
def __init__(self, document_dates): | |
self.document_dates = document_dates | |
def parse_date(self, arxiv_id: str) -> datetime: # only for documents | |
if arxiv_id.startswith('astro-ph'): | |
arxiv_id = arxiv_id.split('astro-ph')[1].split('_arXiv')[0] | |
try: | |
year = int("20" + arxiv_id[:2]) | |
month = int(arxiv_id[2:4]) | |
except: | |
year = 2023 | |
month = 1 | |
return date(year, month, 1) | |
def weight(self, time, shift, scale): | |
return 1 / (1 + np.exp((time - shift) / scale)) | |
def evaluate_filter(self, year, filter_string): | |
try: | |
# Use ast.literal_eval to safely evaluate the expression | |
result = eval(filter_string, {"__builtins__": None}, {"year": year}) | |
return result | |
except Exception as e: | |
print(f"Error evaluating filter: {e}") | |
return False | |
def filter(self, docs, boolean_date = None, min_date = None, max_date = None, time_score = 0): | |
filtered = [] | |
if boolean_date is not None: | |
boolean_date = boolean_date.replace("AND", "and").replace("OR", "or") | |
for doc in docs: | |
if self.evaluate_filter(self.document_dates[doc[0]].year, boolean_date): | |
filtered.append(doc) | |
else: | |
if min_date == None: min_date = date(1990, 1, 1) | |
if max_date == None: max_date = date(2024, 7, 3) | |
for doc in docs: | |
if self.document_dates[doc[0]] >= min_date and self.document_dates[doc[0]] <= max_date: | |
filtered.append(doc) | |
if time_score is not None: # apply time weighting | |
for i, item in enumerate(filtered): | |
time_diff = (max_date - self.document_dates[filtered[i][0]]).days / 365 | |
filtered[i][2] += time_score * 0.1 * self.weight(time_diff, 5, 5) | |
return filtered | |
class KeywordFilter(Filter): | |
def __init__(self, corpus, | |
remove_capitals: bool = True, metadata = None, ne_only = True, verbose = False): | |
self.index_path = 'keyword_index.json' | |
# self.metadata = metadata | |
self.remove_capitals = remove_capitals | |
self.ne_only = ne_only | |
self.stopwords = set(stopwords.words('english')) | |
self.verbose = verbose | |
self.index = None | |
self.kws = st.session_state.kws | |
self.ids = st.session_state.ids | |
self.titles = st.session_state.titles | |
self.load_or_build_index() | |
def preprocess_text(self, text: str) -> str: | |
text = ''.join(char for char in text if char.isalnum() or char.isspace()) | |
if self.remove_capitals: text = text.lower() | |
return ' '.join(word for word in text.split() if word.lower() not in self.stopwords) | |
def build_index(self): # include the title in the index | |
print("Building index...") | |
self.index = {} | |
for i in range(len(self.kws)): | |
paper = self.ids[i] | |
title = self.titles[i] | |
title_keywords = set() | |
for keyword in set(self.kws[i]) | title_keywords: | |
term = ' '.join(word for word in keyword.lower().split() if word.lower() not in self.stopwords) | |
if term not in self.index: | |
self.index[term] = [] | |
self.index[term].append(self.ids[i]) | |
with open(self.index_path, 'w') as f: | |
json.dump(self.index, f) | |
def load_index(self): | |
print("Loading existing index...") | |
with open(self.index_path, 'rb') as f: | |
self.index = json.load(f) | |
print("Index loaded successfully.") | |
def load_or_build_index(self): | |
if os.path.exists(self.index_path): | |
self.load_index() | |
else: | |
self.build_index() | |
def parse_doc(self, doc): | |
local_kws = [] | |
for phrase in doc._.phrases: | |
local_kws.append(phrase.text.lower()) | |
return [self.preprocess_text(word) for word in local_kws] | |
def get_propn(self, doc): | |
result = [] | |
working_str = '' | |
for token in doc: | |
if(token.text in nlp.Defaults.stop_words or token.text in punctuation): | |
if working_str != '': | |
result.append(working_str.strip()) | |
working_str = '' | |
if(token.pos_ == "PROPN"): | |
working_str += token.text + ' ' | |
if working_str != '': result.append(working_str.strip()) | |
return [self.preprocess_text(word) for word in result] | |
def filter(self, query: str, doc_ids = None): | |
doc = nlp(query) | |
query_keywords = self.parse_doc(doc) | |
nouns = self.get_propn(doc) | |
if self.verbose: print('keywords:', query_keywords) | |
if self.verbose: print('proper nouns:', nouns) | |
filtered = set() | |
if len(query_keywords) > 0 and not self.ne_only: | |
for keyword in query_keywords: | |
if keyword != '' and keyword in self.index.keys(): filtered |= set(self.index[keyword]) | |
if len(nouns) > 0: | |
ne_results = set() | |
for noun in nouns: | |
if noun in self.index.keys(): ne_results |= set(self.index[noun]) | |
if self.ne_only: filtered = ne_results # keep only named entity results | |
else: filtered &= ne_results # take the intersection | |
if doc_ids is not None: filtered &= doc_ids # apply filter to results | |
return filtered | |
class EmbeddingRetrievalSystem(): | |
def __init__(self, weight_citation = False, weight_date = False, weight_keywords = False): | |
self.ids = st.session_state.ids | |
self.years = st.session_state.years | |
self.abstract = st.session_state.abstracts | |
self.client = OpenAI(api_key = openai_key) | |
self.embed_model = "text-embedding-3-small" | |
self.dataset = arxiv_corpus | |
self.kws = st.session_state.kws | |
self.weight_citation = weight_citation | |
self.weight_date = weight_date | |
self.weight_keywords = weight_keywords | |
self.id_to_index = {self.ids[i]: i for i in range(len(self.ids))} | |
# self.citation_filter = CitationFilter(self.dataset) | |
# self.date_filter = DateFilter(self.dataset['date']) | |
self.keyword_filter = KeywordFilter(corpus=self.dataset, remove_capitals=True) | |
def parse_date(self, id): | |
# indexval = np.where(self.ids == id)[0][0] | |
indexval = id | |
return self.years[indexval] | |
def make_embedding(self, text): | |
str_embed = self.client.embeddings.create(input = [text], model = self.embed_model).data[0].embedding | |
return str_embed | |
def embed_batch(self, texts: List[str]) -> List[np.ndarray]: | |
embeddings = self.client.embeddings.create(input=texts, model=self.embed_model).data | |
return [np.array(embedding.embedding, dtype=np.float32) for embedding in embeddings] | |
def init_filters(self): | |
self.citation_filter = [] | |
self.date_filter = [] | |
self.keyword_filter = [] | |
def get_query_embedding(self, query): | |
return self.make_embedding(query) | |
def analyze_temporal_query(self, query): | |
return | |
def calc_faiss(self, query_embedding, top_k = 100): | |
# xq = query_embedding.reshape(-1,1).T.astype('float32') | |
# D, I = self.index.search(xq, top_k) | |
# return I[0], D[0] | |
tmp = self.dataset.search('embed',query_embedding, k=top_k) | |
return [tmp.indices, tmp.scores] | |
def rank_and_filter(self, query, query_embedding, query_date, top_k = 10, return_scores=False, time_result=None): | |
topk_indices, similarities = self.calc_faiss(np.array(query_embedding), top_k = 300) | |
if self.weight_keywords: | |
keyword_matches = self.keyword_filter.filter(query) | |
kw_indices = np.zeros_like(similarities) | |
for s in keyword_matches: | |
if self.id_to_index[s] in topk_indices: | |
# print('yes', self.id_to_index[s], topk_indices[np.where(topk_indices == self.id_to_index[s])[0]]) | |
similarities[np.where(topk_indices == self.id_to_index[s])[0]] = similarities[np.where(topk_indices == self.id_to_index[s])[0]] * 10. | |
similarities = similarities / 10. | |
filtered_results = [[topk_indices[i], similarities[i]] for i in range(len(similarities))] | |
top_results = sorted(filtered_results, key=lambda x: x[1], reverse=True)[:top_k] | |
if return_scores: | |
return {doc[0]: doc[1] for doc in top_results} | |
# Only keep the document IDs | |
top_results = [doc[0] for doc in top_results] | |
return top_results | |
def retrieve(self, query, top_k, time_result=None, query_date = None, return_scores = False): | |
query_embedding = self.get_query_embedding(query) | |
# Judge time relevance | |
if time_result is None: | |
if self.weight_date: | |
time_result, time_taken = self.analyze_temporal_query(query, self.anthropic_client) | |
else: | |
time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None} | |
top_results = self.rank_and_filter(query, | |
query_embedding, | |
query_date, | |
top_k, | |
return_scores = return_scores, | |
time_result = time_result) | |
return top_results | |
class HydeRetrievalSystem(EmbeddingRetrievalSystem): | |
def __init__(self, generation_model: str = "claude-3-haiku-20240307", | |
embedding_model: str = "text-embedding-3-small", | |
temperature: float = 0.5, | |
max_doclen: int = 500, | |
generate_n: int = 1, | |
embed_query = True, | |
conclusion = False, **kwargs): | |
# Handle the kwargs for the superclass init -- filters/citation weighting | |
super().__init__(**kwargs) | |
if max_doclen * generate_n > 8191: | |
raise ValueError("Too many tokens. Please reduce max_doclen or generate_n.") | |
self.embedding_model = embedding_model | |
self.generation_model = generation_model | |
# HYPERPARAMETERS | |
self.temperature = temperature # generation temperature | |
self.max_doclen = max_doclen # max tokens for generation | |
self.generate_n = generate_n # how many documents | |
self.embed_query = embed_query # embed the query vector? | |
self.conclusion = conclusion # generate conclusion as well? | |
self.anthropic_key = anthropic_key | |
self.generation_client = anthropic.Anthropic(api_key = self.anthropic_key) | |
def retrieve(self, query: str, top_k: int = 10, return_scores = False, time_result = None) -> List[Tuple[str, str, float]]: | |
if time_result is None: | |
if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client) | |
else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None} | |
docs = self.generate_docs(query) | |
doc_embeddings = self.embed_docs(docs) | |
if self.embed_query: | |
query_emb = self.embed_docs([query])[0] | |
doc_embeddings.append(query_emb) | |
embedding = np.mean(np.array(doc_embeddings), axis = 0) | |
top_results = self.rank_and_filter(query, embedding, query_date=None, top_k = top_k, return_scores = return_scores, time_result = time_result) | |
return top_results | |
def generate_doc(self, query: str): | |
prompt = """You are an expert astronomer. Given a scientific query, generate the abstract""" | |
if self.conclusion: | |
prompt += " and conclusion" | |
prompt += """ of an expert-level research paper | |
that answers the question. Stick to a maximum length of {} tokens and return just the text of the abstract and conclusion. | |
Do not include labels for any section. Use research-specific jargon.""".format(self.max_doclen) | |
message = self.generation_client.messages.create( | |
model = self.generation_model, | |
max_tokens = self.max_doclen, | |
temperature = self.temperature, | |
system = prompt, | |
messages=[{ "role": "user", | |
"content": [{"type": "text", "text": query,}] }] | |
) | |
return message.content[0].text | |
def generate_docs(self, query: str): | |
docs = [] | |
with concurrent.futures.ThreadPoolExecutor() as executor: | |
future_to_query = {executor.submit(self.generate_doc, query): query for i in range(self.generate_n)} | |
for future in concurrent.futures.as_completed(future_to_query): | |
query = future_to_query[future] | |
try: | |
data = future.result() | |
docs.append(data) | |
except Exception as exc: | |
pass | |
return docs | |
def embed_docs(self, docs: List[str]): | |
return self.embed_batch(docs) | |
class HydeCohereRetrievalSystem(HydeRetrievalSystem): | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.cohere_key = "Of1MjzFjGmvzBAqdvNHTQLkAjecPcOKpiIPAnFMn" | |
self.cohere_client = cohere.Client(self.cohere_key) | |
def retrieve(self, query: str, | |
top_k: int = 10, | |
rerank_top_k: int = 250, | |
return_scores = False, time_result = None, | |
reweight = False) -> List[Tuple[str, str, float]]: | |
if time_result is None: | |
if self.weight_date: time_result, time_taken = analyze_temporal_query(query, self.anthropic_client) | |
else: time_result = {'has_temporal_aspect': False, 'expected_year_filter': None, 'expected_recency_weight': None} | |
top_results = super().retrieve(query, top_k = rerank_top_k, time_result = time_result) | |
# doc_texts = self.get_document_texts(top_results) | |
# docs_for_rerank = [f"Abstract: {doc['abstract']}\nConclusions: {doc['conclusions']}" for doc in doc_texts] | |
docs_for_rerank = [self.abstract[i] for i in top_results] | |
if len(docs_for_rerank) == 0: | |
return [] | |
reranked_results = self.cohere_client.rerank( | |
query=query, | |
documents=docs_for_rerank, | |
model='rerank-english-v3.0', | |
top_n=top_k | |
) | |
final_results = [] | |
for result in reranked_results.results: | |
doc_id = top_results[result.index] | |
doc_text = docs_for_rerank[result.index] | |
score = float(result.relevance_score) | |
final_results.append([doc_id, "", score]) | |
if reweight: | |
if time_result['has_temporal_aspect']: | |
final_results = self.date_filter.filter(final_results, time_score = time_result['expected_recency_weight']) | |
if self.weight_citation: self.citation_filter.filter(final_results) | |
if return_scores: | |
return {result[0]: result[2] for result in final_results} | |
return [doc[0] for doc in final_results] | |
def embed_docs(self, docs: List[str]): | |
return self.embed_batch(docs) | |
# ---------------------------------------------------------------- | |
if 'ec' not in st.session_state: | |
ec = EmbeddingRetrievalSystem(weight_keywords=True) | |
st.session_state.ec = ec | |
st.toast('loaded retrieval system') | |
else: | |
ec = st.session_state.ec | |
# Function to simulate question answering (replace with actual implementation) | |
def answer_question(question, keywords, toggles, method, question_type): | |
# Simulated answer (replace with actual logic) | |
# return f"Answer to '{question}' using method {method} for {question_type} question." | |
return run_ret(question, 10) | |
def get_papers(ids): | |
papers, scores, links = [], [], [] | |
for i in ids: | |
papers.append(st.session_state.titles[i]) | |
scores.append(ids[i]) | |
links.append('https://ui.adsabs.harvard.edu/abs/'+st.session_state.arxiv_corpus['bibcode'][i]+'/abstract') | |
return pd.DataFrame({ | |
'Title': papers, | |
'Relevance': scores, | |
'Link': links | |
}) | |
# Function to create embedding plot (replace with actual implementation) | |
def create_embedding_plot(): | |
# Simulated embedding data (replace with actual embedding calculation) | |
source = ColumnDataSource(data=dict( | |
x=[1, 2, 3, 4, 5], | |
y=[6, 7, 2, 4, 5], | |
colors=Spectral10[0:5], | |
labels=['A', 'B', 'C', 'D', 'E'] | |
)) | |
p = figure(width=400, height=400, title="Embedding Map") | |
p.circle('x', 'y', size=20, source=source, color='colors', alpha=0.6) | |
return p | |
# Function to simulate keyword extraction (replace with actual implementation) | |
def extract_keywords(question): | |
# Simulated keyword extraction (replace with actual logic) | |
return ['keyword1', 'keyword2', 'keyword3'] | |
# Function to estimate consensus (replace with actual implementation) | |
def estimate_consensus(): | |
# Simulated consensus estimation (replace with actual calculation) | |
return 0.75 | |
def run_ret(query, top_k): | |
rs = ec.retrieve(query, top_k, return_scores=True) | |
output_str = '' | |
for i in rs: | |
if rs[i] > 0.5: | |
output_str = output_str + '---> ' + st.session_state.titles[i] + '(score: %.2f) \n' %rs[i] | |
else: | |
output_str = output_str + '---> ' + st.session_state.titles[i] + '(score: %.2f) \n' %rs[i] | |
return output_str, rs | |
# Streamlit app | |
def main(): | |
# st.title("Question Answering App") | |
# Sidebar (Inputs) | |
st.sidebar.header("Inputs") | |
extra_keywords = st.sidebar.text_input("Enter extra keywords (comma-separated):") | |
st.sidebar.subheader("Toggles") | |
toggle_a = st.sidebar.checkbox("Toggle A") | |
toggle_b = st.sidebar.checkbox("Toggle B") | |
toggle_c = st.sidebar.checkbox("Toggle C") | |
method = st.sidebar.radio("Choose a method:", ["h1", "h2", "h3"]) | |
question_type = st.sidebar.selectbox("Select question type:", ["Type 1", "Type 2", "Type 3"]) | |
# store_output = st.sidebar.checkbox("Store the output") | |
store_output = st.sidebar.button("Save output") | |
# Main page (Outputs) | |
question = st.text_input("Ask me anything:") | |
submit_button = st.button("Submit") | |
if submit_button: | |
# Process inputs | |
keywords = [kw.strip() for kw in extra_keywords.split(',')] if extra_keywords else [] | |
toggles = {'A': toggle_a, 'B': toggle_b, 'C': toggle_c} | |
# Generate outputs | |
answer, rs = answer_question(question, keywords, toggles, method, question_type) | |
papers_df = get_papers(rs) | |
embedding_plot = create_embedding_plot() | |
triggered_keywords = extract_keywords(question) | |
consensus = estimate_consensus() | |
# Display outputs | |
st.subheader("Answer") | |
st.write(answer) | |
with st.expander("Papers used", expanded=True): | |
st.dataframe(papers_df) | |
col1, col2 = st.columns(2) | |
with col1: | |
st.subheader("Embedding Map") | |
st.bokeh_chart(embedding_plot) | |
st.subheader("Triggered Keywords") | |
st.write(", ".join(triggered_keywords)) | |
with col2: | |
st.subheader("Question Type") | |
st.write(question_type) | |
st.subheader("Consensus Estimate") | |
st.write(f"{consensus:.2%}") | |
# st.subheader("Papers Used") | |
# st.dataframe(papers_df) | |
else: | |
st.info("Use the sidebar to input parameters and submit to see results.") | |
if store_output: | |
st.toast("Output stored successfully!") | |
if __name__ == "__main__": | |
main() |