Spaces:
Running
Running
File size: 3,839 Bytes
6931cbb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import time
s = time.time()
import os
import datetime
import faiss
import streamlit as st
import feedparser
import urllib
import cloudpickle as cp
import pickle
from urllib.request import urlopen
from summa import summarizer
import numpy as np
import matplotlib.pyplot as plt
import requests
import json
from scipy import ndimage
from langchain_openai import AzureOpenAIEmbeddings
# from langchain.llms import OpenAI
from langchain_community.llms import OpenAI
from langchain_openai import AzureChatOpenAI
from fns import *
st.image('local_files/synth_logo.png')
st.markdown("")
query = st.text_input('Ask me anything:',
value="What causes galaxy quenching at high redshifts?")
arxiv_id = None
top_k = st.slider('How many papers should I show?', 1, 30, 6)
retrieval_system = st.session_state.retrieval_system
results = retrieval_system.retrieve(query, arxiv_id, top_k)
aids = st.session_state.dataset['id']
titles = st.session_state.dataset['title']
auths = st.session_state.dataset['author']
bibcodes = st.session_state.dataset['bibcode']
all_keywords = st.session_state.dataset['keyword_search']
allyrs = st.session_state.dataset['year']
ret_indices = np.array([aids.index(results[i]) for i in range(top_k)])
yrs = []
for i in range(len(ret_indices)):
yr = allyrs[ret_indices[i]]
if yr < 50:
yr = yr + 2000
else:
yr = yr + 1900
yrs.append(yr)
print_titles = [titles[ret_indices[i]][0] for i in range(len(ret_indices))]
print_auths = [auths[ret_indices[i]][0]+' et al. '+str(yrs[i]) for i in range(len(ret_indices))]
print_links = ['['+bibcodes[ret_indices[i]]+'](https://ui.adsabs.harvard.edu/abs/'+bibcodes[ret_indices[i]]+'/abstract)' for i in range(len(ret_indices))]
st.divider()
st.header('top-k papers:')
for i in range(len(ret_indices)):
st.subheader(str(i+1)+'. '+print_titles[i])
st.write(print_auths[i]+' '+print_links[i])
st.divider()
st.header('top-k papers in context:')
gtkws = get_keywords(query, ret_indices, all_keywords)
umap, clbls, all_kws = load_umapcoords('local_files/arxiv_ads_corpus_coordsonly_v3.pkl')
fig = plt.figure(figsize=(12*1.8*1.2,9*2.*1.2))
im = plt.imread('local_files/astro_worldmap.png')
implot = plt.imshow(im,)
xax = (umap[0:,1]-np.amin(umap[0:,1]))+.0
xax = xax / np.amax(xax)
xax = xax * 1580 + 170
yax = (umap[0:,0]-np.amin(umap[0:,0]))+.0
yax = yax / np.amax(yax)
yax = (np.amax(yax)-yax) * 1700 + 30
# plt.scatter(xax, yax,s=2,alpha=0.7,c='k')
for i in range(np.amax(clbls)):
clust_ids = np.arange(len(clbls))[clbls == i]
clust_centroid = (np.median(xax[clust_ids]),np.median(yax[clust_ids]))
# plt.text(clust_centroid[1], clust_centroid[0], all_kws[i],fontsize=9,ha="center", va="center",
# bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3',alpha=0.3))
plt.text(clust_centroid[0], clust_centroid[1], all_kws[i],fontsize=9,ha="center", va="center",
fontfamily='serif',color='w',
bbox=dict(facecolor='k', edgecolor='none', boxstyle='round,pad=0.1',alpha=0.3))
plt.scatter(xax[ret_indices], yax[ret_indices], c='k',s=300,zorder=100)
plt.scatter(xax[ret_indices], yax[ret_indices], c='firebrick',s=100,zorder=101)
plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='k',s=300,zorder=101)
plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='w',s=100,zorder=101)
tempx = plt.xlim(); tempy = plt.ylim()
plt.text(0.012*tempx[1], (0.012+0.03)*tempy[0], 'The world of astronomy literature',fontsize=36, fontfamily='serif')
plt.text(0.012*tempx[1], (0.012+0.06)*tempy[0], 'Query: '+query,fontsize=18, fontfamily='serif')
plt.text(0.012*tempx[1], (0.012+0.08)*tempy[0], gtkws,fontsize=18, fontfamily='serif', va='top')
plt.axis('off')
st.pyplot(fig, transparent = True, bbox_inches='tight')
|