Spaces:
Sleeping
Sleeping
import time | |
s = time.time() | |
import os | |
import datetime | |
import faiss | |
import streamlit as st | |
import feedparser | |
import urllib | |
import cloudpickle as cp | |
import pickle | |
from urllib.request import urlopen | |
from summa import summarizer | |
import numpy as np | |
import matplotlib.pyplot as plt | |
import requests | |
import json | |
from scipy import ndimage | |
from langchain_openai import AzureOpenAIEmbeddings | |
# from langchain.llms import OpenAI | |
from langchain_community.llms import OpenAI | |
from langchain_openai import AzureChatOpenAI | |
from fns import * | |
st.image('local_files/synth_logo.png') | |
st.markdown("") | |
query = st.text_input('Ask me anything:', | |
value="What causes galaxy quenching at high redshifts?") | |
arxiv_id = None | |
top_k = st.slider('How many papers should I show?', 1, 30, 6) | |
retrieval_system = st.session_state.retrieval_system | |
results = retrieval_system.retrieve(query, arxiv_id, top_k) | |
aids = st.session_state.dataset['id'] | |
titles = st.session_state.dataset['title'] | |
auths = st.session_state.dataset['author'] | |
bibcodes = st.session_state.dataset['bibcode'] | |
all_keywords = st.session_state.dataset['keyword_search'] | |
allyrs = st.session_state.dataset['year'] | |
ret_indices = np.array([aids.index(results[i]) for i in range(top_k)]) | |
yrs = [] | |
for i in range(len(ret_indices)): | |
yr = allyrs[ret_indices[i]] | |
if yr < 50: | |
yr = yr + 2000 | |
else: | |
yr = yr + 1900 | |
yrs.append(yr) | |
print_titles = [titles[ret_indices[i]][0] for i in range(len(ret_indices))] | |
print_auths = [auths[ret_indices[i]][0]+' et al. '+str(yrs[i]) for i in range(len(ret_indices))] | |
print_links = ['['+bibcodes[ret_indices[i]]+'](https://ui.adsabs.harvard.edu/abs/'+bibcodes[ret_indices[i]]+'/abstract)' for i in range(len(ret_indices))] | |
st.divider() | |
st.header('top-k papers:') | |
for i in range(len(ret_indices)): | |
st.subheader(str(i+1)+'. '+print_titles[i]) | |
st.write(print_auths[i]+' '+print_links[i]) | |
st.divider() | |
st.header('top-k papers in context:') | |
gtkws = get_keywords(query, ret_indices, all_keywords) | |
umap, clbls, all_kws = load_umapcoords('local_files/arxiv_ads_corpus_coordsonly_v3.pkl') | |
fig = plt.figure(figsize=(12*1.8*1.2,9*2.*1.2)) | |
im = plt.imread('local_files/astro_worldmap.png') | |
implot = plt.imshow(im,) | |
xax = (umap[0:,1]-np.amin(umap[0:,1]))+.0 | |
xax = xax / np.amax(xax) | |
xax = xax * 1580 + 170 | |
yax = (umap[0:,0]-np.amin(umap[0:,0]))+.0 | |
yax = yax / np.amax(yax) | |
yax = (np.amax(yax)-yax) * 1700 + 30 | |
# plt.scatter(xax, yax,s=2,alpha=0.7,c='k') | |
for i in range(np.amax(clbls)): | |
clust_ids = np.arange(len(clbls))[clbls == i] | |
clust_centroid = (np.median(xax[clust_ids]),np.median(yax[clust_ids])) | |
# plt.text(clust_centroid[1], clust_centroid[0], all_kws[i],fontsize=9,ha="center", va="center", | |
# bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3',alpha=0.3)) | |
plt.text(clust_centroid[0], clust_centroid[1], all_kws[i],fontsize=9,ha="center", va="center", | |
fontfamily='serif',color='w', | |
bbox=dict(facecolor='k', edgecolor='none', boxstyle='round,pad=0.1',alpha=0.3)) | |
plt.scatter(xax[ret_indices], yax[ret_indices], c='k',s=300,zorder=100) | |
plt.scatter(xax[ret_indices], yax[ret_indices], c='firebrick',s=100,zorder=101) | |
plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='k',s=300,zorder=101) | |
plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='w',s=100,zorder=101) | |
tempx = plt.xlim(); tempy = plt.ylim() | |
plt.text(0.012*tempx[1], (0.012+0.03)*tempy[0], 'The world of astronomy literature',fontsize=36, fontfamily='serif') | |
plt.text(0.012*tempx[1], (0.012+0.06)*tempy[0], 'Query: '+query,fontsize=18, fontfamily='serif') | |
plt.text(0.012*tempx[1], (0.012+0.08)*tempy[0], gtkws,fontsize=18, fontfamily='serif', va='top') | |
plt.axis('off') | |
st.pyplot(fig, transparent = True, bbox_inches='tight') | |