|
import json |
|
import pandas as pd |
|
from os import environ |
|
from time import sleep |
|
import datetime |
|
import streamlit as st |
|
from lib.sessions import SessionManager |
|
from lib.private_kb import PrivateKnowledgeBase |
|
from langchain.schema import HumanMessage, FunctionMessage |
|
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler |
|
from lib.json_conv import CustomJSONDecoder |
|
|
|
from lib.helper import ( |
|
build_agents, |
|
MYSCALE_HOST, |
|
MYSCALE_PASSWORD, |
|
MYSCALE_PORT, |
|
MYSCALE_USER, |
|
DEFAULT_SYSTEM_PROMPT, |
|
UNSTRUCTURED_API, |
|
) |
|
from login import back_to_main |
|
|
|
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"] |
|
|
|
TOOL_NAMES = { |
|
"langchain_retriever_tool": "Self-querying retriever", |
|
"vecsql_retriever_tool": "Vector SQL", |
|
} |
|
|
|
|
|
def on_chat_submit(): |
|
with st.session_state.next_round.container(): |
|
with st.chat_message("user"): |
|
st.write(st.session_state.chat_input) |
|
with st.chat_message("assistant"): |
|
container = st.container() |
|
st_callback = ChatDataAgentCallBackHandler( |
|
container, collapse_completed_thoughts=False |
|
) |
|
ret = st.session_state.agent( |
|
{"input": st.session_state.chat_input}, callbacks=[st_callback] |
|
) |
|
print(ret) |
|
|
|
|
|
def clear_history(): |
|
if "agent" in st.session_state: |
|
st.session_state.agent.memory.clear() |
|
|
|
|
|
def back_to_main(): |
|
if "user_info" in st.session_state: |
|
del st.session_state.user_info |
|
if "user_name" in st.session_state: |
|
del st.session_state.user_name |
|
if "jump_query_ask" in st.session_state: |
|
del st.session_state.jump_query_ask |
|
if "sel_sess" in st.session_state: |
|
del st.session_state.sel_sess |
|
if "current_sessions" in st.session_state: |
|
del st.session_state.current_sessions |
|
|
|
|
|
def on_session_change_submit(): |
|
if "session_manager" in st.session_state and "session_editor" in st.session_state: |
|
print(st.session_state.session_editor) |
|
try: |
|
for elem in st.session_state.session_editor["added_rows"]: |
|
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem: |
|
if elem["session_id"] != "" and "?" not in elem["session_id"]: |
|
st.session_state.session_manager.add_session( |
|
user_id=st.session_state.user_name, |
|
session_id=f"{st.session_state.user_name}?{elem['session_id']}", |
|
system_prompt=elem["system_prompt"], |
|
) |
|
else: |
|
raise KeyError( |
|
"`session_id` should NOT be neither empty nor contain question marks." |
|
) |
|
else: |
|
raise KeyError( |
|
"You should fill both `session_id` and `system_prompt` to add a column!" |
|
) |
|
for elem in st.session_state.session_editor["deleted_rows"]: |
|
st.session_state.session_manager.remove_session( |
|
session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}", |
|
) |
|
refresh_sessions() |
|
except Exception as e: |
|
sleep(2) |
|
st.error(f"{type(e)}: {str(e)}") |
|
finally: |
|
st.session_state.session_editor["added_rows"] = [] |
|
st.session_state.session_editor["deleted_rows"] = [] |
|
refresh_agent() |
|
|
|
|
|
def build_session_manager(): |
|
return SessionManager( |
|
st.session_state, |
|
host=MYSCALE_HOST, |
|
port=MYSCALE_PORT, |
|
username=MYSCALE_USER, |
|
password=MYSCALE_PASSWORD, |
|
) |
|
|
|
|
|
def refresh_sessions(): |
|
st.session_state[ |
|
"current_sessions" |
|
] = st.session_state.session_manager.list_sessions(st.session_state.user_name) |
|
if ( |
|
type(st.session_state.current_sessions) is not dict |
|
and len(st.session_state.current_sessions) <= 0 |
|
): |
|
st.session_state.session_manager.add_session( |
|
st.session_state.user_name, |
|
f"{st.session_state.user_name}?default", |
|
DEFAULT_SYSTEM_PROMPT, |
|
) |
|
st.session_state[ |
|
"current_sessions" |
|
] = st.session_state.session_manager.list_sessions(st.session_state.user_name) |
|
st.session_state["user_files"] = st.session_state.private_kb.list_files( |
|
st.session_state.user_name |
|
) |
|
st.session_state["user_tools"] = st.session_state.private_kb.list_tools( |
|
st.session_state.user_name |
|
) |
|
st.session_state["tools_with_users"] = { |
|
**st.session_state.tools, |
|
**st.session_state.private_kb.as_tools(st.session_state.user_name), |
|
} |
|
try: |
|
dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index( |
|
"default" |
|
if "" not in st.session_state |
|
else st.session_state.sel_session["session_id"] |
|
) |
|
except ValueError: |
|
dfl_indx = 0 |
|
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx] |
|
|
|
|
|
def build_kb_as_tool(): |
|
if ( |
|
"b_tool_name" in st.session_state |
|
and "b_tool_desc" in st.session_state |
|
and "b_tool_files" in st.session_state |
|
and len(st.session_state.b_tool_name) > 0 |
|
and len(st.session_state.b_tool_desc) > 0 |
|
and len(st.session_state.b_tool_files) > 0 |
|
): |
|
st.session_state.private_kb.create_tool( |
|
st.session_state.user_name, |
|
st.session_state.b_tool_name, |
|
st.session_state.b_tool_desc, |
|
[f["file_name"] for f in st.session_state.b_tool_files], |
|
) |
|
refresh_sessions() |
|
else: |
|
st.session_state.tool_status.error( |
|
"You should fill all fields to build up a tool!" |
|
) |
|
sleep(2) |
|
|
|
|
|
def remove_kb(): |
|
if "r_tool_names" in st.session_state and len(st.session_state.r_tool_names) > 0: |
|
st.session_state.private_kb.remove_tools( |
|
st.session_state.user_name, |
|
[f["tool_name"] for f in st.session_state.r_tool_names], |
|
) |
|
refresh_sessions() |
|
else: |
|
st.session_state.tool_status.error( |
|
"You should specify at least one tool to delete!" |
|
) |
|
sleep(2) |
|
|
|
|
|
def refresh_agent(): |
|
with st.spinner("Initializing session..."): |
|
print( |
|
f"??? Changed to ", |
|
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}", |
|
) |
|
st.session_state["agent"] = build_agents( |
|
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}", |
|
["LangChain Self Query Retriever For Wikipedia"] |
|
if "selected_tools" not in st.session_state |
|
else st.session_state.selected_tools, |
|
system_prompt=DEFAULT_SYSTEM_PROMPT |
|
if "sel_sess" not in st.session_state |
|
else st.session_state.sel_sess["system_prompt"], |
|
) |
|
|
|
|
|
def add_file(): |
|
if ( |
|
"uploaded_files" not in st.session_state |
|
or len(st.session_state.uploaded_files) == 0 |
|
): |
|
st.session_state.tool_status.error("Please upload files!", icon="β οΈ") |
|
sleep(2) |
|
return |
|
try: |
|
st.session_state.tool_status.info("Uploading...") |
|
st.session_state.private_kb.add_by_file( |
|
st.session_state.user_name, st.session_state.uploaded_files |
|
) |
|
refresh_sessions() |
|
except ValueError as e: |
|
st.session_state.tool_status.error("Failed to upload! " + str(e)) |
|
sleep(2) |
|
|
|
|
|
def clear_files(): |
|
st.session_state.private_kb.clear(st.session_state.user_name) |
|
refresh_sessions() |
|
|
|
|
|
def chat_page(): |
|
if "sel_sess" not in st.session_state: |
|
st.session_state["sel_sess"] = { |
|
"session_id": "default", |
|
"system_prompt": DEFAULT_SYSTEM_PROMPT, |
|
} |
|
if "private_kb" not in st.session_state: |
|
st.session_state["private_kb"] = PrivateKnowledgeBase( |
|
host=MYSCALE_HOST, |
|
port=MYSCALE_PORT, |
|
username=MYSCALE_USER, |
|
password=MYSCALE_PASSWORD, |
|
embedding=st.session_state.embeddings["Wikipedia"], |
|
parser_api_key=UNSTRUCTURED_API, |
|
) |
|
if "session_manager" not in st.session_state: |
|
st.session_state["session_manager"] = build_session_manager() |
|
with st.sidebar: |
|
with st.expander("Session Management"): |
|
if "current_sessions" not in st.session_state: |
|
refresh_sessions() |
|
st.info( |
|
"Here you can set up your session! \n\nYou can **change your prompt** here!", |
|
icon="π€", |
|
) |
|
st.info( |
|
( |
|
"**Add columns by clicking the empty row**.\n" |
|
"And **delete columns by selecting rows with a press on `DEL` Key**" |
|
), |
|
icon="π‘", |
|
) |
|
st.info( |
|
"Don't forget to **click `Submit Change` to save your change**!", |
|
icon="π", |
|
) |
|
st.data_editor( |
|
st.session_state.current_sessions, |
|
num_rows="dynamic", |
|
key="session_editor", |
|
use_container_width=True, |
|
) |
|
st.button("Submit Change!", on_click=on_session_change_submit) |
|
with st.expander("Session Selection", expanded=True): |
|
st.info( |
|
"If no session is attach to your account, then we will add a default session to you!", |
|
icon="β€οΈ", |
|
) |
|
try: |
|
dfl_indx = [ |
|
x["session_id"] for x in st.session_state.current_sessions |
|
].index( |
|
"default" |
|
if "" not in st.session_state |
|
else st.session_state.sel_session["session_id"] |
|
) |
|
except Exception as e: |
|
print("*** ", str(e)) |
|
dfl_indx = 0 |
|
st.selectbox( |
|
"Choose a session to chat:", |
|
options=st.session_state.current_sessions, |
|
index=dfl_indx, |
|
key="sel_sess", |
|
format_func=lambda x: x["session_id"], |
|
on_change=refresh_agent, |
|
) |
|
print(st.session_state.sel_sess) |
|
with st.expander("Tool Settings", expanded=True): |
|
st.info( |
|
"We provides you several knowledge base tools for you. We are building more tools!", |
|
icon="π§", |
|
) |
|
st.session_state["tool_status"] = st.empty() |
|
tab_kb, tab_file = st.tabs( |
|
[ |
|
"Knowledge Bases", |
|
"File Upload", |
|
] |
|
) |
|
with tab_kb: |
|
st.markdown("#### Build You Own Knowledge") |
|
st.multiselect( |
|
"Select Files to Build up", |
|
st.session_state.user_files, |
|
placeholder="You should upload files first", |
|
key="b_tool_files", |
|
format_func=lambda x: x["file_name"], |
|
) |
|
st.text_input( |
|
"Tool Name", "get_relevant_documents", key="b_tool_name") |
|
st.text_input( |
|
"Tool Description", |
|
"Searches among user's private files and returns related documents", |
|
key="b_tool_desc", |
|
) |
|
st.button("Build!", on_click=build_kb_as_tool) |
|
st.markdown("### Knowledge Base Selection") |
|
if ( |
|
"user_tools" in st.session_state |
|
and len(st.session_state.user_tools) > 0 |
|
): |
|
st.markdown("***User Created Knowledge Bases***") |
|
st.dataframe(st.session_state.user_tools) |
|
st.multiselect( |
|
"Select a Knowledge Base Tool", |
|
st.session_state.tools.keys() |
|
if "tools_with_users" not in st.session_state |
|
else st.session_state.tools_with_users, |
|
default=["Wikipedia + Self Querying"], |
|
key="selected_tools", |
|
on_change=refresh_agent, |
|
) |
|
st.markdown("### Delete Knowledge Base") |
|
st.multiselect( |
|
"Choose Knowledge Base to Remove", |
|
st.session_state.user_tools, |
|
format_func=lambda x: x["tool_name"], |
|
key="r_tool_names", |
|
) |
|
st.button("Delete", on_click=remove_kb) |
|
with tab_file: |
|
st.info( |
|
( |
|
"We adopted [Unstructured API](https://unstructured.io/api-key) " |
|
"here and we only store the processed texts from your documents. " |
|
"For privacy concerns, please refer to " |
|
"[our policy issue](https://myscale.com/privacy/)." |
|
), |
|
icon="π", |
|
) |
|
st.file_uploader( |
|
"Upload files", key="uploaded_files", accept_multiple_files=True |
|
) |
|
st.markdown("### Uploaded Files") |
|
st.dataframe( |
|
st.session_state.private_kb.list_files( |
|
st.session_state.user_name), |
|
use_container_width=True, |
|
) |
|
col_1, col_2 = st.columns(2) |
|
with col_1: |
|
st.button("Add Files", on_click=add_file) |
|
with col_2: |
|
st.button("Clear Files and All Tools", |
|
on_click=clear_files) |
|
|
|
st.button("Clear Chat History", on_click=clear_history) |
|
st.button("Logout", on_click=back_to_main) |
|
if "agent" not in st.session_state: |
|
refresh_agent() |
|
print("!!! ", st.session_state.agent.memory.chat_memory.session_id) |
|
for msg in st.session_state.agent.memory.chat_memory.messages: |
|
speaker = "user" if isinstance(msg, HumanMessage) else "assistant" |
|
if isinstance(msg, FunctionMessage): |
|
with st.chat_message("Knowledge Base", avatar="π"): |
|
st.write( |
|
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" |
|
) |
|
st.write("Retrieved from knowledge base:") |
|
try: |
|
st.dataframe( |
|
pd.DataFrame.from_records( |
|
json.loads(msg.content, cls=CustomJSONDecoder) |
|
), |
|
use_container_width=True, |
|
) |
|
except: |
|
st.write(msg.content) |
|
else: |
|
if len(msg.content) > 0: |
|
with st.chat_message(speaker): |
|
print(type(msg), msg.dict()) |
|
st.write( |
|
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" |
|
) |
|
st.write(f"{msg.content}") |
|
st.session_state["next_round"] = st.empty() |
|
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input") |
|
|