File size: 5,101 Bytes
2bf46ef
a8d4e3d
2bf46ef
 
a8d4e3d
8387173
8cfcf51
f310b8b
 
106ef8f
f310b8b
 
2ad51f4
a8d4e3d
f310b8b
 
adf54e4
f310b8b
adf54e4
 
 
f310b8b
 
a8d4e3d
f310b8b
 
 
 
 
 
 
 
 
a8d4e3d
f310b8b
 
 
 
 
 
a8d4e3d
 
2bf46ef
 
 
 
 
 
 
 
 
 
 
 
f310b8b
2bf46ef
 
a8492e7
 
2bf46ef
 
 
 
8cfcf51
 
 
 
 
 
 
 
 
2bf46ef
 
 
a8492e7
 
 
8cfcf51
 
 
a8d4e3d
cea2a96
a8d4e3d
f310b8b
a8d4e3d
f310b8b
a8d4e3d
106ef8f
f310b8b
cea2a96
 
f310b8b
cea2a96
f310b8b
 
106ef8f
 
 
8cfcf51
106ef8f
a8d4e3d
106ef8f
 
 
 
 
8cfcf51
f310b8b
 
a8d4e3d
 
 
5727aa4
da318a3
 
 
f310b8b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import List, Optional
import streamlit as st
import streamlit_pydantic as sp
from pydantic import BaseModel, Field

from src.Surveyor import Surveyor



@st.experimental_singleton(show_spinner=True, suppress_st_warning=True)
def get_surveyor_instance(_print_fn, _survey_print_fn):
     with st.spinner('Loading The-Surveyor ...'):
        return Surveyor(print_fn=_print_fn, survey_print_fn=_survey_print_fn, high_gpu=True)


def run_survey(surveyor, download_placeholder, research_keywords=None, arxiv_ids=None, max_search=None, num_papers=None):
    zip_file_name, survey_file_name = surveyor.survey(research_keywords, 
                                                      arxiv_ids,
                                                      max_search=max_search, 
                                                      num_papers=num_papers
                                                     )
    show_survey_download(zip_file_name, survey_file_name, download_placeholder)


def show_survey_download(zip_file_name, survey_file_name, download_placeholder):
    download_placeholder.empty()
    with download_placeholder.container():
        with open(str(zip_file_name), "rb") as file:
            btn = st.download_button(
                label="Download extracted topic-clustered-highlights, images and tables as zip",
                data=file,
                file_name=str(zip_file_name)
            )

        with open(str(survey_file_name), "rb") as file:
            btn = st.download_button(
                label="Download detailed generated survey file",
                data=file,
                file_name=str(survey_file_name)
            )


class KeywordsModel(BaseModel):
    research_keywords: Optional[str] =  Field(
        '', description="Enter your research keywords:"
    )
    max_search: int = Field(
        10, ge=1, le=50, multiple_of=1,
        description="num_papers_to_search:"
    )
    num_papers: int = Field(
        3, ge=1, le=8, multiple_of=1, 
        description="num_papers_to_select:"
    )


class ArxivIDsModel(BaseModel):
    arxiv_ids: Optional[str] =  Field(
        '', description="Enter comma_separated arxiv ids for your curated set of papers (e.g. 2205.12755, 2205.10937, ...):"
    )


def survey_space(surveyor, download_placeholder):
    with st.sidebar.form(key="survey_keywords_form"):
        session_data = sp.pydantic_input(key="keywords_input_model", model=KeywordsModel)
        st.write('or')
        session_data.update(sp.pydantic_input(key="arxiv_ids_input_model", model=ArxivIDsModel))
        submit = st.form_submit_button(label="Submit")
        
    run_kwargs = {'surveyor':surveyor, 'download_placeholder':download_placeholder}
    if submit:
        if session_data['research_keywords'] != '':
            run_kwargs.update({'research_keywords':session_data['research_keywords'], 
                               'max_search':session_data['research_keywords'], 
                               'num_papers':session_data['research_keywords']})
        elif session_data['arxiv_ids'] != '':
            run_kwargs.update({'arxiv_ids':[id.strip() for id in session_data['arxiv_ids'].split(',')]})
        st.json(run_kwargs)
        run_survey(**run_kwargs)
    
    '''
    form = st.sidebar.form(key='survey_form')
    research_keywords = form.text_input("Enter your research keywords:", key='research_keywords', value='')
    max_search = form.number_input("num_papers_to_search", help="maximium number of papers to glance through - defaults to 20", 
                             min_value=1, max_value=50, value=10, step=1, key='max_search')
    num_papers = form.number_input("num_papers_to_select", help="maximium number of papers to select and analyse - defaults to 8",
                             min_value=1, max_value=8, value=2, step=1, key='num_papers')

    form.write('or')

    arxiv_ids = st_sidebar_tags(
                label='Enter arxiv ids for your curated set of papers (1-by-1):',
                value=[],
                text='Press enter to add more (e.g. 2205.12755, 2205.10937, 1605.08386v1 ...)',
                maxtags = 6,
                key='arxiv_ids')
                
    submit = form.form_submit_button('Submit')
    
    
    run_kwargs = {'surveyor':surveyor, 'download_placeholder':download_placeholder}
    if submit:
        if research_keywords != '':
            run_kwargs.update({'research_keywords':research_keywords, 'max_search':max_search, 'num_papers':num_papers})
        elif len(arxiv_ids):
            run_kwargs.update({'arxiv_ids':arxiv_ids})
        run_survey(**run_kwargs)
    '''




if __name__ == '__main__':
    st.title('Auto-Research V0.1 - Automated Survey generation from research keywords')
    std_col, survey_col = st.columns(2)
    std_col.header('execution log:')
    survey_col.header('Generated_survey:')
    download_placeholder = survey_col.container()
    download_placeholder = st.empty()
    surveyor_obj = get_surveyor_instance(_print_fn=std_col.write, _survey_print_fn=survey_col.write)
    survey_space(surveyor_obj, survey_col)