File size: 3,531 Bytes
a8d4e3d
 
 
 
8387173
cea2a96
f310b8b
 
106ef8f
f310b8b
 
2ad51f4
a8d4e3d
f310b8b
 
adf54e4
f310b8b
adf54e4
 
 
f310b8b
 
a8d4e3d
f310b8b
 
 
 
 
 
 
 
 
a8d4e3d
f310b8b
 
 
 
 
 
a8d4e3d
 
f310b8b
 
a8d4e3d
cea2a96
a8d4e3d
f310b8b
a8d4e3d
f310b8b
a8d4e3d
106ef8f
f310b8b
cea2a96
 
f310b8b
cea2a96
f310b8b
 
106ef8f
 
 
 
a8d4e3d
106ef8f
 
 
 
 
f310b8b
 
a8d4e3d
 
 
5727aa4
da318a3
 
 
f310b8b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
import pandas as pd
import numpy as np

from src.Surveyor import Surveyor
from streamlit_tags import st_sidebar_tags


@st.experimental_singleton(show_spinner=True, suppress_st_warning=True)
def get_surveyor_instance(_print_fn, _survey_print_fn):
     with st.spinner('Loading The-Surveyor ...'):
        return Surveyor(print_fn=_print_fn, survey_print_fn=_survey_print_fn, high_gpu=True)


def run_survey(surveyor, download_placeholder, research_keywords=None, arxiv_ids=None, max_search=None, num_papers=None):
    zip_file_name, survey_file_name = surveyor.survey(research_keywords, 
                                                      arxiv_ids,
                                                      max_search=max_search, 
                                                      num_papers=num_papers
                                                     )
    show_survey_download(zip_file_name, survey_file_name, download_placeholder)


def show_survey_download(zip_file_name, survey_file_name, download_placeholder):
    download_placeholder.empty()
    with download_placeholder.container():
        with open(str(zip_file_name), "rb") as file:
            btn = st.download_button(
                label="Download extracted topic-clustered-highlights, images and tables as zip",
                data=file,
                file_name=str(zip_file_name)
            )

        with open(str(survey_file_name), "rb") as file:
            btn = st.download_button(
                label="Download detailed generated survey file",
                data=file,
                file_name=str(survey_file_name)
            )


def survey_space(surveyor, download_placeholder):

    form = st.sidebar.form(key='survey_form')
    research_keywords = form.text_input("Enter your research keywords:", key='research_keywords', value='')
    max_search = form.number_input("num_papers_to_search", help="maximium number of papers to glance through - defaults to 20", 
                             min_value=1, max_value=50, value=10, step=1, key='max_search')
    num_papers = form.number_input("num_papers_to_select", help="maximium number of papers to select and analyse - defaults to 8",
                             min_value=1, max_value=8, value=2, step=1, key='num_papers')

    form.write('or')

    arxiv_ids = st_sidebar_tags(
                label='Enter arxiv ids for your curated set of papers (1-by-1):',
                value=[],
                text='Press enter to add more (e.g. 2205.12755, 2205.10937, 1605.08386v1 ...)',
                maxtags = 6,
                key='arxiv_ids')
                
    submit = form.form_submit_button('Submit')
    
    run_kwargs = {'surveyor':surveyor, 'download_placeholder':download_placeholder}
    if submit:
        if research_keywords != '':
            run_kwargs.update({'research_keywords':research_keywords, 'max_search':max_search, 'num_papers':num_papers})
        elif len(arxiv_ids):
            run_kwargs.update({'arxiv_ids':arxiv_ids})
        run_survey(**run_kwargs)




if __name__ == '__main__':
    st.title('Auto-Research V0.1 - Automated Survey generation from research keywords')
    std_col, survey_col = st.columns(2)
    std_col.header('execution log:')
    survey_col.header('Generated_survey:')
    download_placeholder = survey_col.container()
    download_placeholder = st.empty()
    surveyor_obj = get_surveyor_instance(_print_fn=std_col.write, _survey_print_fn=survey_col.write)
    survey_space(surveyor_obj, survey_col)