File size: 6,313 Bytes
fab8405
 
a796108
 
 
 
fab8405
 
 
 
 
 
042a946
fab8405
 
 
 
e4853cf
fab8405
042a946
 
 
 
 
 
 
 
 
45180a0
a796108
 
45180a0
fab8405
45180a0
fab8405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9061790
fab8405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9061790
 
fab8405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9061790
fab8405
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import json
import time
import pandas as pd
from os import environ
import streamlit as st

from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
    ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
    ChatDataSQLAskCallBackHandler

from chat import chat_page
from login import login, back_to_main
from helper import build_tools, build_agents, build_all, sel_map, display



environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']

st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
st.markdown(
    f"""
    <style>
        .st-e4 {{
            max-width: 500px
        }}
    </style>""",
    unsafe_allow_html=True,
)
st.header("ChatData")

if 'retriever' not in st.session_state:
    st.session_state["sel_map_obj"] = build_all()
    st.session_state["tools"] = build_tools()

if login():
    if "user_name" in st.session_state:
        chat_page()
    elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
        
        sel = st.selectbox('Choose the knowledge base you want to ask with:',
                        options=['ArXiv Papers', 'Wikipedia'])
        sel_map[sel]['hint']()
        tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
        with tab_sql:
            sel_map[sel]['hint_sql']()
            st.text_input("Ask a question:", key='query_sql')
            cols = st.columns([1, 1, 1, 4])
            cols[0].button("Query", key='search_sql')
            cols[1].button("Ask", key='ask_sql')
            cols[2].button("Back", key='back_sql', on_click=back_to_main)
            plc_hldr = st.empty()
            if st.session_state.search_sql:
                plc_hldr = st.empty()
                print(st.session_state.query_sql)
                with plc_hldr.expander('Query Log', expanded=True):
                    callback = ChatDataSQLSearchCallBackHandler()
                    try:
                        docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
                            st.session_state.query_sql, callbacks=[callback])
                        callback.progress_bar.progress(value=1.0, text="Done!")
                        docs = pd.DataFrame(
                            [{**d.metadata, 'abstract': d.page_content} for d in docs])
                        display(docs)
                    except Exception as e:
                        st.write('Oops 😡 Something bad happened...')
                        raise e

            if st.session_state.ask_sql:
                plc_hldr = st.empty()
                print(st.session_state.query_sql)
                with plc_hldr.expander('Chat Log', expanded=True):
                    callback = ChatDataSQLAskCallBackHandler()
                    try:
                        ret = st.session_state.sel_map_obj[sel]["sql_chain"](
                            st.session_state.query_sql, callbacks=[callback])
                        callback.progress_bar.progress(value=1.0, text="Done!")
                        st.markdown(
                            f"### Answer from LLM\n{ret['answer']}\n### References")
                        docs = ret['sources']
                        docs = pd.DataFrame(
                            [{**d.metadata, 'abstract': d.page_content} for d in docs])
                        display(
                            docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
                    except Exception as e:
                        st.write('Oops 😡 Something bad happened...')
                        raise e


        with tab_self_query:
            st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='πŸ’‘')
            st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
            st.text_input("Ask a question:", key='query_self')
            cols = st.columns([1, 1, 1, 4])
            cols[0].button("Query", key='search_self')
            cols[1].button("Ask", key='ask_self')
            cols[2].button("Back", key='back_self', on_click=back_to_main)
            plc_hldr = st.empty()
            if st.session_state.search_self:
                plc_hldr = st.empty()
                print(st.session_state.query_self)
                with plc_hldr.expander('Query Log', expanded=True):
                    call_back = None
                    callback = ChatDataSelfSearchCallBackHandler()
                    try:
                        docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
                            st.session_state.query_self, callbacks=[callback])
                        print(docs)
                        callback.progress_bar.progress(value=1.0, text="Done!")
                        docs = pd.DataFrame(
                            [{**d.metadata, 'abstract': d.page_content} for d in docs])
                        display(docs, sel_map[sel]["must_have_cols"])
                    except Exception as e:
                        st.write('Oops 😡 Something bad happened...')
                        raise e

            if st.session_state.ask_self:
                plc_hldr = st.empty()
                print(st.session_state.query_self)
                with plc_hldr.expander('Chat Log', expanded=True):
                    call_back = None
                    callback = ChatDataSelfAskCallBackHandler()
                    try:
                        ret = st.session_state.sel_map_obj[sel]["chain"](
                            st.session_state.query_self, callbacks=[callback])
                        callback.progress_bar.progress(value=1.0, text="Done!")
                        st.markdown(
                            f"### Answer from LLM\n{ret['answer']}\n### References")
                        docs = ret['sources']
                        docs = pd.DataFrame(
                            [{**d.metadata, 'abstract': d.page_content} for d in docs])
                        display(
                            docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
                    except Exception as e:
                        st.write('Oops 😡 Something bad happened...')
                        raise e