import pandas as pd from os import environ from time import sleep import datetime import streamlit as st from lib.sessions import SessionManager from langchain.schema import HumanMessage, FunctionMessage from helper import ( build_agents, MYSCALE_HOST, MYSCALE_PASSWORD, MYSCALE_PORT, MYSCALE_USER, DEFAULT_SYSTEM_PROMPT, ) from login import back_to_main environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"] TOOL_NAMES = { "langchain_retriever_tool": "Self-querying retriever", "vecsql_retriever_tool": "Vector SQL", } def on_chat_submit(): ret = st.session_state.agent({"input": st.session_state.chat_input}) print(ret) def clear_history(): if "agent" in st.session_state: st.session_state.agent.memory.clear() def back_to_main(): if "user_info" in st.session_state: del st.session_state.user_info if "user_name" in st.session_state: del st.session_state.user_name if "jump_query_ask" in st.session_state: del st.session_state.jump_query_ask def on_session_change_submit(): if "session_manager" in st.session_state and "session_editor" in st.session_state: print(st.session_state.session_editor) try: for elem in st.session_state.session_editor["added_rows"]: if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem: if elem["session_id"] != "" and "?" not in elem["session_id"]: st.session_state.session_manager.add_session( user_id=st.session_state.user_name, session_id=f"{st.session_state.user_name}?{elem['session_id']}", system_prompt=elem["system_prompt"], ) else: raise KeyError( "`session_id` should NOT be neither empty nor contain question marks." ) else: raise KeyError( "You should fill both `session_id` and `system_prompt` to add a column!" ) for elem in st.session_state.session_editor["deleted_rows"]: st.session_state.session_manager.remove_session( session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}", ) refresh_sessions() if len(st.session_state.session_editor["deleted_rows"]) > 0: try: dfl_indx = [ x["session_id"] for x in st.session_state.current_sessions ].index("default") except ValueError: dfl_indx = 0 st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx] except Exception as e: sleep(2) st.error(f"{type(e)}: {str(e)}") finally: st.session_state.session_editor["added_rows"] = [] st.session_state.session_editor["deleted_rows"] = [] refresh_agent() def build_session_manager(): return SessionManager( host=MYSCALE_HOST, port=MYSCALE_PORT, username=MYSCALE_USER, password=MYSCALE_PASSWORD, ) def refresh_sessions(): st.session_state[ "current_sessions" ] = st.session_state.session_manager.list_sessions(st.session_state.user_name) if type(st.session_state.current_sessions) is not dict and len(st.session_state.current_sessions) <= 0: st.session_state.session_manager.add_session( st.session_state.user_name, f"{st.session_state.user_name}?default", DEFAULT_SYSTEM_PROMPT, ) st.session_state[ "current_sessions" ] = st.session_state.session_manager.list_sessions(st.session_state.user_name) def refresh_agent(): with st.spinner("Initializing session..."): print( f"??? Changed to ", f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}", ) st.session_state["agent"] = build_agents( f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}", ["LangChain Self Query Retriever For Wikipedia"] if "selected_tools" not in st.session_state else st.session_state.selected_tools, system_prompt=DEFAULT_SYSTEM_PROMPT if "sel_sess" not in st.session_state else st.session_state.sel_sess["system_prompt"], ) st.session_state["session_manager"] = build_session_manager() def chat_page(): if "sel_sess" not in st.session_state: st.session_state["sel_sess"] = { "session_id": "default", "system_prompt": DEFAULT_SYSTEM_PROMPT, } with st.sidebar: with st.expander("Session Management"): refresh_sessions() st.data_editor( st.session_state.current_sessions, num_rows="dynamic", key="session_editor", use_container_width=True, ) st.button("Submit Change!", on_click=on_session_change_submit) with st.expander("Session Selection", expanded=True): try: dfl_indx = [ x["session_id"] for x in st.session_state.current_sessions ].index("default") except ValueError: dfl_indx = 0 st.selectbox( "Choose a session be chat:", options=st.session_state.current_sessions, index=dfl_indx, key="sel_sess", format_func=lambda x: x["session_id"], on_change=refresh_agent, ) print(st.session_state.sel_sess) with st.expander("Tool Settings", expanded=True): st.multiselect( "Knowledge Base", st.session_state.tools.keys(), default=["LangChain Self Query Retriever For Wikipedia"], key="selected_tools", on_change=refresh_agent, ) st.button("Clear Chat History", on_click=clear_history) st.button("Logout", on_click=back_to_main) if 'agent' not in st.session_state: refresh_agent() print("!!! ", st.session_state.agent.memory.chat_memory.session_id) for msg in st.session_state.agent.memory.chat_memory.messages: speaker = "user" if isinstance(msg, HumanMessage) else "assistant" if isinstance(msg, FunctionMessage): with st.chat_message("Knowledge Base", avatar="📖"): st.write( f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" ) st.write("Retrieved from knowledge base:") try: st.dataframe( pd.DataFrame.from_records(map(dict, eval(msg.content))) ) except: st.write(msg.content) else: if len(msg.content) > 0: with st.chat_message(speaker): print(type(msg), msg.dict()) st.write( f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*" ) st.write(f"{msg.content}") st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")