synthesist / pages /1 retrieval.py
kiyer's picture
basic files and codebase
6931cbb verified
raw
history blame
3.84 kB
import time
s = time.time()
import os
import datetime
import faiss
import streamlit as st
import feedparser
import urllib
import cloudpickle as cp
import pickle
from urllib.request import urlopen
from summa import summarizer
import numpy as np
import matplotlib.pyplot as plt
import requests
import json
from scipy import ndimage
from langchain_openai import AzureOpenAIEmbeddings
# from langchain.llms import OpenAI
from langchain_community.llms import OpenAI
from langchain_openai import AzureChatOpenAI
from fns import *
st.image('local_files/synth_logo.png')
st.markdown("")
query = st.text_input('Ask me anything:',
value="What causes galaxy quenching at high redshifts?")
arxiv_id = None
top_k = st.slider('How many papers should I show?', 1, 30, 6)
retrieval_system = st.session_state.retrieval_system
results = retrieval_system.retrieve(query, arxiv_id, top_k)
aids = st.session_state.dataset['id']
titles = st.session_state.dataset['title']
auths = st.session_state.dataset['author']
bibcodes = st.session_state.dataset['bibcode']
all_keywords = st.session_state.dataset['keyword_search']
allyrs = st.session_state.dataset['year']
ret_indices = np.array([aids.index(results[i]) for i in range(top_k)])
yrs = []
for i in range(len(ret_indices)):
yr = allyrs[ret_indices[i]]
if yr < 50:
yr = yr + 2000
else:
yr = yr + 1900
yrs.append(yr)
print_titles = [titles[ret_indices[i]][0] for i in range(len(ret_indices))]
print_auths = [auths[ret_indices[i]][0]+' et al. '+str(yrs[i]) for i in range(len(ret_indices))]
print_links = ['['+bibcodes[ret_indices[i]]+'](https://ui.adsabs.harvard.edu/abs/'+bibcodes[ret_indices[i]]+'/abstract)' for i in range(len(ret_indices))]
st.divider()
st.header('top-k papers:')
for i in range(len(ret_indices)):
st.subheader(str(i+1)+'. '+print_titles[i])
st.write(print_auths[i]+' '+print_links[i])
st.divider()
st.header('top-k papers in context:')
gtkws = get_keywords(query, ret_indices, all_keywords)
umap, clbls, all_kws = load_umapcoords('local_files/arxiv_ads_corpus_coordsonly_v3.pkl')
fig = plt.figure(figsize=(12*1.8*1.2,9*2.*1.2))
im = plt.imread('local_files/astro_worldmap.png')
implot = plt.imshow(im,)
xax = (umap[0:,1]-np.amin(umap[0:,1]))+.0
xax = xax / np.amax(xax)
xax = xax * 1580 + 170
yax = (umap[0:,0]-np.amin(umap[0:,0]))+.0
yax = yax / np.amax(yax)
yax = (np.amax(yax)-yax) * 1700 + 30
# plt.scatter(xax, yax,s=2,alpha=0.7,c='k')
for i in range(np.amax(clbls)):
clust_ids = np.arange(len(clbls))[clbls == i]
clust_centroid = (np.median(xax[clust_ids]),np.median(yax[clust_ids]))
# plt.text(clust_centroid[1], clust_centroid[0], all_kws[i],fontsize=9,ha="center", va="center",
# bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3',alpha=0.3))
plt.text(clust_centroid[0], clust_centroid[1], all_kws[i],fontsize=9,ha="center", va="center",
fontfamily='serif',color='w',
bbox=dict(facecolor='k', edgecolor='none', boxstyle='round,pad=0.1',alpha=0.3))
plt.scatter(xax[ret_indices], yax[ret_indices], c='k',s=300,zorder=100)
plt.scatter(xax[ret_indices], yax[ret_indices], c='firebrick',s=100,zorder=101)
plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='k',s=300,zorder=101)
plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='w',s=100,zorder=101)
tempx = plt.xlim(); tempy = plt.ylim()
plt.text(0.012*tempx[1], (0.012+0.03)*tempy[0], 'The world of astronomy literature',fontsize=36, fontfamily='serif')
plt.text(0.012*tempx[1], (0.012+0.06)*tempy[0], 'Query: '+query,fontsize=18, fontfamily='serif')
plt.text(0.012*tempx[1], (0.012+0.08)*tempy[0], gtkws,fontsize=18, fontfamily='serif', va='top')
plt.axis('off')
st.pyplot(fig, transparent = True, bbox_inches='tight')