File size: 7,444 Bytes
e17c9f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6a5155
 
 
 
 
e17c9f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6336ac
e17c9f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import streamlit as st

# st.set_page_config(layout="wide", page_title="πŸ¦œπŸ”— Generate Idea Step-by-step")

## Pipeline global state
# 1.0: Input background is in progress
# 2.0: Brainstorming is in progress
#  2.5 Brainstorming is finished
# 3.0: Extracting entities is in progress
#  3.5 Extracting entities is finished
# 4.0: Retrieving literature is in progress
#  4.5 Retrieving ideas is finished
# 5.0: Generating initial ideas is in progress
#  5.5 Generating initial ideas is finished
# 6.0: Generating final ideas is in progress
#  6.5 Generating final ideas is finished
if "global_state_one_click" not in st.session_state:
    st.session_state["global_state_one_click"] = 1.0

def generate_sidebar():
    st.sidebar.header("SciPIP", divider="rainbow")
    st.sidebar.markdown(
        ("SciPIP will generate ideas in one click. The generation pipeline is the same as "
        "step-by-step generation, but you are free from caring about intermediate outputs.")
    )

    pipeline_list = ["1. Input Background", "2. Brainstorming", "3. Extracting Entities", "4. Retrieving Related Works", 
                     "5. Generate Initial Ideas", "6. Generate Final Ideas"]
    st.sidebar.header("Pipeline", divider="red")
    for i in range(6):
        st.sidebar.markdown(f"<font color='black'>{pipeline_list[i]}</font>", unsafe_allow_html=True)

    st.sidebar.header("Supported Fields", divider="orange")
    st.sidebar.caption("The supported fields are temporarily limited because we only collect literature "
               "from ICML, ICLR, NeurIPS, ACL, and EMNLP. Support for other fields are in progress.")
    st.sidebar.checkbox("Natural Language Processing (NLP)", value=True, disabled=True)
    st.sidebar.checkbox("Computer Vision (CV)", value=False, disabled=True)
    st.sidebar.checkbox("[Partial] Multimodal", value=True, disabled=True)
    st.sidebar.checkbox("Incoming Other Fields", value=False, disabled=True)

    st.sidebar.header("Help Us To Improve", divider="green")
    st.sidebar.markdown("https://forms.gle/YpLUrhqs1ahyCAe99", unsafe_allow_html=True)


def genrate_mainpage(backend):
    st.title('πŸ’§ Generate Idea in One-click')
    # st.markdown("# 🐳 Background")
    # st.markdown("Available soon...")

    if "messages" not in st.session_state:
        st.session_state["messages"] = [{"role": "assistant", "content": "Please give me some key words or a background"}]
    if "intermediate_output" not in st.session_state:
        st.session_state["intermediate_output"] = {}

    for msg in st.session_state.messages:
        st.chat_message(msg["role"]).write(msg["content"])

    def disable_submit():
        st.session_state["enable_submmit"] = False

    if prompt := st.chat_input(disabled=not st.session_state.get("enable_submmit", True), on_submit=disable_submit):
        st.session_state.messages.append({"role": "user", "content": prompt})
        st.chat_message("user").write(prompt)
        generate_ideas(backend, prompt)
    elif st.session_state.get("use_demo_input", False):
        generate_ideas(backend, st.session_state.get("demo_input"))
        st.session_state["use_demo_input"] = False
        del(st.session_state["demo_input"])

    def get_demo_n(i):
        demo_input = backend.get_demo_i(i)
        st.session_state["enable_submmit"] = False
        st.session_state.messages.append({"role": "user", "content": demo_input})
        st.session_state["use_demo_input"] = True
        st.session_state["demo_input"] = demo_input

    cols = st.columns([1, 1, 1, 1])
    cols[0].button("Example 1", on_click=get_demo_n, args=(0,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
    cols[1].button("Example 2", on_click=get_demo_n, args=(1,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
    cols[2].button("Example 3", on_click=get_demo_n, args=(2,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))
    cols[3].button("Example 4", on_click=get_demo_n, args=(3,), use_container_width=True, disabled=not st.session_state.get("enable_submmit", True))

    def check_intermediate_outputs(id="brainstorms"):
        msg = st.session_state["intermediate_output"].get(id, None)
        if msg is not None:
            st.session_state.messages.append(msg)
        else:
            st.toast(f"No {id} now!")

    def reset():
        del(st.session_state["messages"])
        st.session_state["enable_submmit"] = True
        st.session_state["global_state_one_click"] = 1.0
        st.toast(f"The chat has been reset!")

    cols = st.columns([1, 1, 1, 1])
    cols[0].button("Check Brainstorms", on_click=check_intermediate_outputs, args=("brainstorms",), use_container_width=True)
    cols[1].button("Check Entities", on_click=check_intermediate_outputs, args=("entities",), use_container_width=True)
    cols[2].button("Check Retrieved Papers", on_click=check_intermediate_outputs, args=("related_works",), use_container_width=True)
    cols[3].button("Reset Chat", on_click=reset, use_container_width=True, type="primary")

def generate_ideas(backend, background):
    with st.spinner(text="Brainstorming..."):
        brainstorms = backend.background2brainstorm_callback(background)
        st.session_state["intermediate_output"]["brainstorms"] = {"role": "assistant", "content": brainstorms}
        # st.chat_message("assistant").write(brainstorms)
        st.session_state["global_state_one_click"] = 2.5

    with st.spinner(text="Extracting entities..."):
        entities = backend.brainstorm2entities_callback(background, brainstorms)
        st.session_state["intermediate_output"]["entities"] = {"role": "assistant", "content": entities}
        # st.chat_message("assistant").write(entities)
        st.session_state["global_state_one_click"] = 3.5

    with st.spinner(text="Retrieving related works..."):
        msg = "My initial ideas are:"
        related_works, related_works_intact = backend.entities2literature_callback(background, entities)
        st.session_state["intermediate_output"]["related_works"] = {"role": "assistant", "content": related_works}
        # st.chat_message("assistant").write(related_works)
        st.session_state["global_state_one_click"] = 4.5

    with st.spinner(text="Generating initial ideas..."):
        msg = "My initial ideas are:"
        initial_ideas, final_ideas = backend.literature2initial_ideas_callback(background, brainstorms, related_works_intact)
        st.session_state.messages.append({"role": "assistant", "content": msg})
        st.chat_message("assistant").write(msg)
        st.session_state.messages.append({"role": "assistant", "content": initial_ideas})
        st.chat_message("assistant").write(initial_ideas)
        st.session_state["global_state_one_click"] = 5.5

    with st.spinner(text="Generating final ideas..."):
        msg = "My final ideas after refinement are:"
        final_ideas = backend.initial2final_callback(initial_ideas, final_ideas)
        st.session_state.messages.append({"role": "assistant", "content": msg})
        st.chat_message("assistant").write(msg)
        st.session_state.messages.append({"role": "assistant", "content": final_ideas})
        st.chat_message("assistant").write(final_ideas)
        st.session_state["global_state_one_click"] = 6.5

def one_click_generation(backend):
    generate_sidebar()
    genrate_mainpage(backend)