ChatData / chat.py
mpsk's picture
improve chat experience
06665fc
raw
history blame
8.9 kB
import pandas as pd
from os import environ
from time import sleep
import datetime
import streamlit as st
from lib.sessions import SessionManager
from langchain.schema import HumanMessage, FunctionMessage
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
from helper import (
build_agents,
MYSCALE_HOST,
MYSCALE_PASSWORD,
MYSCALE_PORT,
MYSCALE_USER,
DEFAULT_SYSTEM_PROMPT,
)
from login import back_to_main
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
TOOL_NAMES = {
"langchain_retriever_tool": "Self-querying retriever",
"vecsql_retriever_tool": "Vector SQL",
}
def on_chat_submit():
with st.session_state.next_round.container():
with st.chat_message('user'):
st.write(st.session_state.chat_input)
with st.chat_message('assistant'):
container = st.container()
st_callback = ChatDataAgentCallBackHandler(container, collapse_completed_thoughts=False)
ret = st.session_state.agent({"input": st.session_state.chat_input}, callbacks=[st_callback])
print(ret)
def clear_history():
if "agent" in st.session_state:
st.session_state.agent.memory.clear()
def back_to_main():
if "user_info" in st.session_state:
del st.session_state.user_info
if "user_name" in st.session_state:
del st.session_state.user_name
if "jump_query_ask" in st.session_state:
del st.session_state.jump_query_ask
def on_session_change_submit():
if "session_manager" in st.session_state and "session_editor" in st.session_state:
print(st.session_state.session_editor)
try:
for elem in st.session_state.session_editor["added_rows"]:
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
if elem["session_id"] != "" and "?" not in elem["session_id"]:
st.session_state.session_manager.add_session(
user_id=st.session_state.user_name,
session_id=f"{st.session_state.user_name}?{elem['session_id']}",
system_prompt=elem["system_prompt"],
)
else:
raise KeyError(
"`session_id` should NOT be neither empty nor contain question marks."
)
else:
raise KeyError(
"You should fill both `session_id` and `system_prompt` to add a column!"
)
for elem in st.session_state.session_editor["deleted_rows"]:
st.session_state.session_manager.remove_session(
session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}",
)
refresh_sessions()
except Exception as e:
sleep(2)
st.error(f"{type(e)}: {str(e)}")
finally:
st.session_state.session_editor["added_rows"] = []
st.session_state.session_editor["deleted_rows"] = []
refresh_agent()
def build_session_manager():
return SessionManager(
host=MYSCALE_HOST,
port=MYSCALE_PORT,
username=MYSCALE_USER,
password=MYSCALE_PASSWORD,
)
def refresh_sessions():
st.session_state[
"current_sessions"
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
if type(st.session_state.current_sessions) is not dict and len(st.session_state.current_sessions) <= 0:
st.session_state.session_manager.add_session(
st.session_state.user_name,
f"{st.session_state.user_name}?default",
DEFAULT_SYSTEM_PROMPT,
)
st.session_state[
"current_sessions"
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
try:
dfl_indx = [
x["session_id"] for x in st.session_state.current_sessions
].index("default")
except ValueError:
dfl_indx = 0
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
def refresh_agent():
with st.spinner("Initializing session..."):
print(
f"??? Changed to ",
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
)
st.session_state["agent"] = build_agents(
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
["LangChain Self Query Retriever For Wikipedia"]
if "selected_tools" not in st.session_state
else st.session_state.selected_tools,
system_prompt=DEFAULT_SYSTEM_PROMPT
if "sel_sess" not in st.session_state
else st.session_state.sel_sess["system_prompt"],
)
st.session_state["session_manager"] = build_session_manager()
def chat_page():
if "sel_sess" not in st.session_state:
st.session_state["sel_sess"] = {
"session_id": "default",
"system_prompt": DEFAULT_SYSTEM_PROMPT,
}
st.session_state["session_manager"] = build_session_manager()
with st.sidebar:
with st.expander("Session Management"):
refresh_sessions()
st.info("Here you can set up your session! \n\nYou can **change your prompt** here!",
icon="πŸ€–")
st.info(("**Add columns by clicking the empty row**.\n"
"And **delete columns by selecting rows with a press on `DEL` Key**"),
icon="πŸ’‘")
st.info("Don't forget to **click `Submit Change` to save your change**!", icon="πŸ“’")
st.data_editor(
st.session_state.current_sessions,
num_rows="dynamic",
key="session_editor",
use_container_width=True,
)
st.button("Submit Change!", on_click=on_session_change_submit)
with st.expander("Session Selection", expanded=True):
st.info("Here you can select your session!", icon="πŸ€–")
st.info("If no session is attach to your account, then we will add a default session to you!", icon="❀️")
try:
dfl_indx = [
x["session_id"] for x in st.session_state.current_sessions
].index("default")
except Exception as e:
print("*** ", str(e))
dfl_indx = 0
st.selectbox(
"Choose a session to chat:",
options=st.session_state.current_sessions,
index=dfl_indx,
key="sel_sess",
format_func=lambda x: x["session_id"],
on_change=refresh_agent,
)
print(st.session_state.sel_sess)
with st.expander("Tool Settings", expanded=True):
st.info("Here you can select your tools.", icon="πŸ”§")
st.info("We provides you several knowledge base tools for you. We are building more tools!", icon="πŸ‘·β€β™‚οΈ")
st.multiselect(
"Knowledge Base",
st.session_state.tools.keys(),
default=["Wikipedia + Self Querying"],
key="selected_tools",
on_change=refresh_agent,
)
st.button("Clear Chat History", on_click=clear_history)
st.button("Logout", on_click=back_to_main)
if 'agent' not in st.session_state:
refresh_agent()
print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
for msg in st.session_state.agent.memory.chat_memory.messages:
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
if isinstance(msg, FunctionMessage):
with st.chat_message("Knowledge Base", avatar="πŸ“–"):
st.write(
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
)
st.write("Retrieved from knowledge base:")
try:
st.dataframe(
pd.DataFrame.from_records(map(dict, eval(msg.content)))
)
except:
st.write(msg.content)
else:
if len(msg.content) > 0:
with st.chat_message(speaker):
print(type(msg), msg.dict())
st.write(
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
)
st.write(f"{msg.content}")
st.session_state["next_round"] = st.empty()
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")