Spaces:
Build error
Build error
File size: 5,101 Bytes
2bf46ef a8d4e3d 2bf46ef a8d4e3d 8387173 8cfcf51 f310b8b 106ef8f f310b8b 2ad51f4 a8d4e3d f310b8b adf54e4 f310b8b adf54e4 f310b8b a8d4e3d f310b8b a8d4e3d f310b8b a8d4e3d 2bf46ef f310b8b 2bf46ef a8492e7 2bf46ef 8cfcf51 2bf46ef a8492e7 8cfcf51 a8d4e3d cea2a96 a8d4e3d f310b8b a8d4e3d f310b8b a8d4e3d 106ef8f f310b8b cea2a96 f310b8b cea2a96 f310b8b 106ef8f 8cfcf51 106ef8f a8d4e3d 106ef8f 8cfcf51 f310b8b a8d4e3d 5727aa4 da318a3 f310b8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from typing import List, Optional
import streamlit as st
import streamlit_pydantic as sp
from pydantic import BaseModel, Field
from src.Surveyor import Surveyor
@st.experimental_singleton(show_spinner=True, suppress_st_warning=True)
def get_surveyor_instance(_print_fn, _survey_print_fn):
with st.spinner('Loading The-Surveyor ...'):
return Surveyor(print_fn=_print_fn, survey_print_fn=_survey_print_fn, high_gpu=True)
def run_survey(surveyor, download_placeholder, research_keywords=None, arxiv_ids=None, max_search=None, num_papers=None):
zip_file_name, survey_file_name = surveyor.survey(research_keywords,
arxiv_ids,
max_search=max_search,
num_papers=num_papers
)
show_survey_download(zip_file_name, survey_file_name, download_placeholder)
def show_survey_download(zip_file_name, survey_file_name, download_placeholder):
download_placeholder.empty()
with download_placeholder.container():
with open(str(zip_file_name), "rb") as file:
btn = st.download_button(
label="Download extracted topic-clustered-highlights, images and tables as zip",
data=file,
file_name=str(zip_file_name)
)
with open(str(survey_file_name), "rb") as file:
btn = st.download_button(
label="Download detailed generated survey file",
data=file,
file_name=str(survey_file_name)
)
class KeywordsModel(BaseModel):
research_keywords: Optional[str] = Field(
'', description="Enter your research keywords:"
)
max_search: int = Field(
10, ge=1, le=50, multiple_of=1,
description="num_papers_to_search:"
)
num_papers: int = Field(
3, ge=1, le=8, multiple_of=1,
description="num_papers_to_select:"
)
class ArxivIDsModel(BaseModel):
arxiv_ids: Optional[str] = Field(
'', description="Enter comma_separated arxiv ids for your curated set of papers (e.g. 2205.12755, 2205.10937, ...):"
)
def survey_space(surveyor, download_placeholder):
with st.sidebar.form(key="survey_keywords_form"):
session_data = sp.pydantic_input(key="keywords_input_model", model=KeywordsModel)
st.write('or')
session_data.update(sp.pydantic_input(key="arxiv_ids_input_model", model=ArxivIDsModel))
submit = st.form_submit_button(label="Submit")
run_kwargs = {'surveyor':surveyor, 'download_placeholder':download_placeholder}
if submit:
if session_data['research_keywords'] != '':
run_kwargs.update({'research_keywords':session_data['research_keywords'],
'max_search':session_data['research_keywords'],
'num_papers':session_data['research_keywords']})
elif session_data['arxiv_ids'] != '':
run_kwargs.update({'arxiv_ids':[id.strip() for id in session_data['arxiv_ids'].split(',')]})
st.json(run_kwargs)
run_survey(**run_kwargs)
'''
form = st.sidebar.form(key='survey_form')
research_keywords = form.text_input("Enter your research keywords:", key='research_keywords', value='')
max_search = form.number_input("num_papers_to_search", help="maximium number of papers to glance through - defaults to 20",
min_value=1, max_value=50, value=10, step=1, key='max_search')
num_papers = form.number_input("num_papers_to_select", help="maximium number of papers to select and analyse - defaults to 8",
min_value=1, max_value=8, value=2, step=1, key='num_papers')
form.write('or')
arxiv_ids = st_sidebar_tags(
label='Enter arxiv ids for your curated set of papers (1-by-1):',
value=[],
text='Press enter to add more (e.g. 2205.12755, 2205.10937, 1605.08386v1 ...)',
maxtags = 6,
key='arxiv_ids')
submit = form.form_submit_button('Submit')
run_kwargs = {'surveyor':surveyor, 'download_placeholder':download_placeholder}
if submit:
if research_keywords != '':
run_kwargs.update({'research_keywords':research_keywords, 'max_search':max_search, 'num_papers':num_papers})
elif len(arxiv_ids):
run_kwargs.update({'arxiv_ids':arxiv_ids})
run_survey(**run_kwargs)
'''
if __name__ == '__main__':
st.title('Auto-Research V0.1 - Automated Survey generation from research keywords')
std_col, survey_col = st.columns(2)
std_col.header('execution log:')
survey_col.header('Generated_survey:')
download_placeholder = survey_col.container()
download_placeholder = st.empty()
surveyor_obj = get_surveyor_instance(_print_fn=std_col.write, _survey_print_fn=survey_col.write)
survey_space(surveyor_obj, survey_col)
|