File size: 5,197 Bytes
2bf46ef
a8d4e3d
2bf46ef
 
a8d4e3d
8387173
7888d98
8cfcf51
f310b8b
 
106ef8f
f310b8b
 
2ad51f4
a8d4e3d
f310b8b
 
adf54e4
f310b8b
adf54e4
 
 
f310b8b
 
a8d4e3d
f310b8b
 
 
 
 
 
 
 
 
a8d4e3d
f310b8b
 
 
 
 
 
a8d4e3d
 
2bf46ef
 
 
 
 
 
 
 
 
 
 
 
f310b8b
2bf46ef
 
 
 
 
 
 
 
8cfcf51
 
 
 
 
 
 
 
 
2bf46ef
 
 
8cfcf51
 
 
 
 
a8d4e3d
cea2a96
a8d4e3d
f310b8b
a8d4e3d
f310b8b
a8d4e3d
106ef8f
f310b8b
cea2a96
 
f310b8b
cea2a96
f310b8b
 
106ef8f
 
 
8cfcf51
106ef8f
a8d4e3d
106ef8f
 
 
 
 
8cfcf51
f310b8b
 
a8d4e3d
 
 
5727aa4
7888d98
da318a3
 
 
f310b8b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from typing import List, Optional
import streamlit as st
import streamlit_pydantic as sp
from pydantic import BaseModel, Field

from src.Surveyor import Surveyor
from temp_showcase_model import ShowcaseModel



@st.experimental_singleton(show_spinner=True, suppress_st_warning=True)
def get_surveyor_instance(_print_fn, _survey_print_fn):
     with st.spinner('Loading The-Surveyor ...'):
        return Surveyor(print_fn=_print_fn, survey_print_fn=_survey_print_fn, high_gpu=True)


def run_survey(surveyor, download_placeholder, research_keywords=None, arxiv_ids=None, max_search=None, num_papers=None):
    zip_file_name, survey_file_name = surveyor.survey(research_keywords, 
                                                      arxiv_ids,
                                                      max_search=max_search, 
                                                      num_papers=num_papers
                                                     )
    show_survey_download(zip_file_name, survey_file_name, download_placeholder)


def show_survey_download(zip_file_name, survey_file_name, download_placeholder):
    download_placeholder.empty()
    with download_placeholder.container():
        with open(str(zip_file_name), "rb") as file:
            btn = st.download_button(
                label="Download extracted topic-clustered-highlights, images and tables as zip",
                data=file,
                file_name=str(zip_file_name)
            )

        with open(str(survey_file_name), "rb") as file:
            btn = st.download_button(
                label="Download detailed generated survey file",
                data=file,
                file_name=str(survey_file_name)
            )


class KeywordsModel(BaseModel):
    research_keywords: Optional[str] =  Field(
        '', description="Enter your research keywords:"
    )
    max_search: int = Field(
        10, ge=1, le=50, multiple_of=1,
        description="num_papers_to_search:"
    )
    num_papers: int = Field(
        3, ge=1, le=8, multiple_of=1, 
        description="num_papers_to_select:"
    )


class ArxivIDsModel(BaseModel):
    arxiv_ids: Optional[List[str]] =  Field(
        [], max_items=8, regex=r"^[0-9]+\.[0-9]+$", description="Enter arxiv ids for your curated set of papers (e.g. 2205.12755, 2205.10937, ...):"
    )


def survey_space(surveyor, download_placeholder):
    with st.sidebar.form(key="survey_keywords_form"):
        session_data = sp.pydantic_input(key="keywords_input_model", model=KeywordsModel)
        st.write('or')
        session_data.update(sp.pydantic_input(key="arxiv_ids_input_model", model=ArxivIDsModel))
        submit = st.form_submit_button(label="Submit")
        
    run_kwargs = {'surveyor':surveyor, 'download_placeholder':download_placeholder}
    if submit:
        if session_data['research_keywords'] != '':
            run_kwargs.update({'research_keywords':session_data['research_keywords'], 
                               'max_search':session_data['research_keywords'], 
                               'num_papers':session_data['research_keywords']})
        elif len(session_data['arxiv_ids']):
            run_kwargs.update({'arxiv_ids':session_data['arxiv_ids']})
        run_survey(**run_kwargs)
    
    '''
    form = st.sidebar.form(key='survey_form')
    research_keywords = form.text_input("Enter your research keywords:", key='research_keywords', value='')
    max_search = form.number_input("num_papers_to_search", help="maximium number of papers to glance through - defaults to 20", 
                             min_value=1, max_value=50, value=10, step=1, key='max_search')
    num_papers = form.number_input("num_papers_to_select", help="maximium number of papers to select and analyse - defaults to 8",
                             min_value=1, max_value=8, value=2, step=1, key='num_papers')

    form.write('or')

    arxiv_ids = st_sidebar_tags(
                label='Enter arxiv ids for your curated set of papers (1-by-1):',
                value=[],
                text='Press enter to add more (e.g. 2205.12755, 2205.10937, 1605.08386v1 ...)',
                maxtags = 6,
                key='arxiv_ids')
                
    submit = form.form_submit_button('Submit')
    
    
    run_kwargs = {'surveyor':surveyor, 'download_placeholder':download_placeholder}
    if submit:
        if research_keywords != '':
            run_kwargs.update({'research_keywords':research_keywords, 'max_search':max_search, 'num_papers':num_papers})
        elif len(arxiv_ids):
            run_kwargs.update({'arxiv_ids':arxiv_ids})
        run_survey(**run_kwargs)
    '''




if __name__ == '__main__':
    st.title('Auto-Research V0.1 - Automated Survey generation from research keywords')
    dummy_session_data = sp.pydantic_input(key="dummy_model", model=ShowcaseModel)
    std_col, survey_col = st.columns(2)
    std_col.header('execution log:')
    survey_col.header('Generated_survey:')
    download_placeholder = survey_col.container()
    download_placeholder = st.empty()
    surveyor_obj = get_surveyor_instance(_print_fn=std_col.write, _survey_print_fn=survey_col.write)
    survey_space(surveyor_obj, survey_col)