import time s = time.time() import os import datetime import faiss import streamlit as st import feedparser import urllib import cloudpickle as cp import pickle from urllib.request import urlopen from summa import summarizer import numpy as np import matplotlib.pyplot as plt import requests import json from scipy import ndimage from langchain_openai import AzureOpenAIEmbeddings # from langchain.llms import OpenAI from langchain_community.llms import OpenAI from langchain_openai import AzureChatOpenAI from fns import * st.image('local_files/synth_logo.png') st.markdown("") query = st.text_input('Ask me anything:', value="What causes galaxy quenching at high redshifts?") arxiv_id = None top_k = st.slider('How many papers should I show?', 1, 30, 6) retrieval_system = st.session_state.retrieval_system results = retrieval_system.retrieve(query, arxiv_id, top_k) aids = st.session_state.dataset['id'] titles = st.session_state.dataset['title'] auths = st.session_state.dataset['author'] bibcodes = st.session_state.dataset['bibcode'] all_keywords = st.session_state.dataset['keyword_search'] allyrs = st.session_state.dataset['year'] ret_indices = np.array([aids.index(results[i]) for i in range(top_k)]) yrs = [] for i in range(len(ret_indices)): yr = allyrs[ret_indices[i]] if yr < 50: yr = yr + 2000 else: yr = yr + 1900 yrs.append(yr) print_titles = [titles[ret_indices[i]][0] for i in range(len(ret_indices))] print_auths = [auths[ret_indices[i]][0]+' et al. '+str(yrs[i]) for i in range(len(ret_indices))] print_links = ['['+bibcodes[ret_indices[i]]+'](https://ui.adsabs.harvard.edu/abs/'+bibcodes[ret_indices[i]]+'/abstract)' for i in range(len(ret_indices))] st.divider() st.header('top-k papers:') for i in range(len(ret_indices)): st.subheader(str(i+1)+'. '+print_titles[i]) st.write(print_auths[i]+' '+print_links[i]) st.divider() st.header('top-k papers in context:') gtkws = get_keywords(query, ret_indices, all_keywords) umap, clbls, all_kws = load_umapcoords('local_files/arxiv_ads_corpus_coordsonly_v3.pkl') fig = plt.figure(figsize=(12*1.8*1.2,9*2.*1.2)) im = plt.imread('local_files/astro_worldmap.png') implot = plt.imshow(im,) xax = (umap[0:,1]-np.amin(umap[0:,1]))+.0 xax = xax / np.amax(xax) xax = xax * 1580 + 170 yax = (umap[0:,0]-np.amin(umap[0:,0]))+.0 yax = yax / np.amax(yax) yax = (np.amax(yax)-yax) * 1700 + 30 # plt.scatter(xax, yax,s=2,alpha=0.7,c='k') for i in range(np.amax(clbls)): clust_ids = np.arange(len(clbls))[clbls == i] clust_centroid = (np.median(xax[clust_ids]),np.median(yax[clust_ids])) # plt.text(clust_centroid[1], clust_centroid[0], all_kws[i],fontsize=9,ha="center", va="center", # bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.3',alpha=0.3)) plt.text(clust_centroid[0], clust_centroid[1], all_kws[i],fontsize=9,ha="center", va="center", fontfamily='serif',color='w', bbox=dict(facecolor='k', edgecolor='none', boxstyle='round,pad=0.1',alpha=0.3)) plt.scatter(xax[ret_indices], yax[ret_indices], c='k',s=300,zorder=100) plt.scatter(xax[ret_indices], yax[ret_indices], c='firebrick',s=100,zorder=101) plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='k',s=300,zorder=101) plt.scatter(xax[ret_indices[0]], yax[ret_indices[0]], c='w',s=100,zorder=101) tempx = plt.xlim(); tempy = plt.ylim() plt.text(0.012*tempx[1], (0.012+0.03)*tempy[0], 'The world of astronomy literature',fontsize=36, fontfamily='serif') plt.text(0.012*tempx[1], (0.012+0.06)*tempy[0], 'Query: '+query,fontsize=18, fontfamily='serif') plt.text(0.012*tempx[1], (0.012+0.08)*tempy[0], gtkws,fontsize=18, fontfamily='serif', va='top') plt.axis('off') st.pyplot(fig, transparent = True, bbox_inches='tight')