Spaces:
Running
Running
File size: 10,674 Bytes
19bd5a9 c5b06e8 19bd5a9 c5b06e8 401cf68 19bd5a9 c5b06e8 19bd5a9 c5b06e8 19bd5a9 c5b06e8 19bd5a9 17c6622 19bd5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
import json
import time
import pandas as pd
from os import environ
import datetime
import streamlit as st
from langchain.schema import Document
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
ChatDataSQLAskCallBackHandler
from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage
from auth0_component import login_button
from helper import build_tools, build_agents, build_all, sel_map, display
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
st.header("ChatData")
if 'retriever' not in st.session_state:
st.session_state["sel_map_obj"] = build_all()
st.session_state["tools"] = build_tools()
def on_chat_submit():
ret = st.session_state.agents[st.session_state.sel][st.session_state.ret_type]({"input": st.session_state.chat_input})
print(ret)
def clear_history():
st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.clear()
AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
def login():
if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
return True
st.subheader("π€ Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! π€ ")
st.write("You can now chat with ArXiv and Wikipedia! π\n")
st.write("Built purely with streamlit π , LangChain π¦π and love β€οΈ for AI!")
st.write("Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
st.write("For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
st.divider()
col1, col2 = st.columns(2, gap='large')
with col1.container():
st.write("Try out MyScale's Self-query and Vector SQL retrievers!")
st.write("In this demo, you will be able to see how those retrievers "
"**digest** -> **translate** -> **retrieve** -> **answer** to your question!")
st.session_state["jump_query_ask"] = st.button("Query / Ask")
with col2.container():
# st.warning("To use chat, please jump to [https://myscale-chatdata.hf.space](https://myscale-chatdata.hf.space)")
st.write("Now with the power of LangChain's Conversantional Agents, we are able to build "
"an RAG-enabled chatbot within one MyScale instance! ")
st.write("Log in to Chat with RAG!")
login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0")
st.divider()
st.write("- [Privacy Policy](https://myscale.com/privacy/)\n"
"- [Terms of Sevice](https://myscale.com/terms/)")
if st.session_state.auth0 is not None:
st.session_state.user_info = dict(st.session_state.auth0)
if 'email' in st.session_state.user_info:
email = st.session_state.user_info["email"]
else:
email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}"
st.session_state["user_name"] = email
del st.session_state.auth0
st.experimental_rerun()
if st.session_state.jump_query_ask:
st.experimental_rerun()
def back_to_main():
if "user_info" in st.session_state:
del st.session_state.user_info
if "user_name" in st.session_state:
del st.session_state.user_name
if "jump_query_ask" in st.session_state:
del st.session_state.jump_query_ask
if login():
if "user_name" in st.session_state:
st.session_state["agents"] = build_agents(st.session_state.user_name)
with st.sidebar:
st.radio("Retriever Type", ["Self-querying retriever", "Vector SQL"], key="ret_type")
st.selectbox("Knowledge Base", ["ArXiv Papers", "Wikipedia", "ArXiv + Wikipedia"], key="sel")
st.button("Clear Chat History", on_click=clear_history)
st.button("Logout", on_click=back_to_main)
for msg in st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.chat_memory.messages:
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
if isinstance(msg, FunctionMessage):
with st.chat_message("Knowledge Base", avatar="π"):
print(type(msg.content))
st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
st.write("Retrieved from knowledge base:")
try:
st.dataframe(pd.DataFrame.from_records(map(dict, eval(msg.content))))
except:
st.write(msg.content)
else:
if len(msg.content) > 0:
with st.chat_message(speaker):
print(type(msg), msg.dict())
st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
st.write(f"{msg.content}")
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
sel = st.selectbox('Choose the knowledge base you want to ask with:',
options=['ArXiv Papers', 'Wikipedia'])
sel_map[sel]['hint']()
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
with tab_sql:
sel_map[sel]['hint_sql']()
st.text_input("Ask a question:", key='query_sql')
cols = st.columns([1, 1, 1, 4])
cols[0].button("Query", key='search_sql')
cols[1].button("Ask", key='ask_sql')
cols[2].button("Back", key='back_sql', on_click=back_to_main)
plc_hldr = st.empty()
if st.session_state.search_sql:
plc_hldr = st.empty()
print(st.session_state.query_sql)
with plc_hldr.expander('Query Log', expanded=True):
callback = ChatDataSQLSearchCallBackHandler()
try:
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
st.session_state.query_sql, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(docs)
except Exception as e:
st.write('Oops π΅ Something bad happened...')
raise e
if st.session_state.ask_sql:
plc_hldr = st.empty()
print(st.session_state.query_sql)
with plc_hldr.expander('Chat Log', expanded=True):
callback = ChatDataSQLAskCallBackHandler()
try:
ret = st.session_state.sel_map_obj[sel]["sql_chain"](
st.session_state.query_sql, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
st.markdown(
f"### Answer from LLM\n{ret['answer']}\n### References")
docs = ret['sources']
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
except Exception as e:
st.write('Oops π΅ Something bad happened...')
raise e
with tab_self_query:
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='π‘')
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
st.text_input("Ask a question:", key='query_self')
cols = st.columns([1, 1, 1, 4])
cols[0].button("Query", key='search_self')
cols[1].button("Ask", key='ask_self')
cols[2].button("Back", key='back_self', on_click=back_to_main)
plc_hldr = st.empty()
if st.session_state.search_self:
plc_hldr = st.empty()
print(st.session_state.query_self)
with plc_hldr.expander('Query Log', expanded=True):
call_back = None
callback = ChatDataSelfSearchCallBackHandler()
try:
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
st.session_state.query_self, callbacks=[callback])
print(docs)
callback.progress_bar.progress(value=1.0, text="Done!")
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(docs, sel_map[sel]["must_have_cols"])
except Exception as e:
st.write('Oops π΅ Something bad happened...')
raise e
if st.session_state.ask_self:
plc_hldr = st.empty()
print(st.session_state.query_self)
with plc_hldr.expander('Chat Log', expanded=True):
call_back = None
callback = ChatDataSelfAskCallBackHandler()
try:
ret = st.session_state.sel_map_obj[sel]["chain"](
st.session_state.query_self, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
st.markdown(
f"### Answer from LLM\n{ret['answer']}\n### References")
docs = ret['sources']
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
except Exception as e:
st.write('Oops π΅ Something bad happened...')
raise e |