Fangrui Liu commited on
Commit
042a946
β€’
1 Parent(s): e1383d0

update session model

Browse files
Files changed (5) hide show
  1. app.py +10 -1
  2. chat.py +158 -14
  3. helper.py +23 -37
  4. lib/schemas.py +52 -0
  5. lib/sessions.py +68 -0
app.py CHANGED
@@ -10,13 +10,22 @@ from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
10
 
11
  from chat import chat_page
12
  from login import login, back_to_main
 
13
 
14
 
15
- from helper import build_tools, build_agents, build_all, sel_map, display
16
 
17
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
18
 
19
  st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
 
 
 
 
 
 
 
 
 
20
  st.header("ChatData")
21
 
22
  if 'retriever' not in st.session_state:
 
10
 
11
  from chat import chat_page
12
  from login import login, back_to_main
13
+ from helper import build_tools, build_agents, build_all, sel_map, display
14
 
15
 
 
16
 
17
  environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
18
 
19
  st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
20
+ st.markdown(
21
+ f"""
22
+ <style>
23
+ .st-e4 {{
24
+ max-width: 500px
25
+ }}
26
+ </style>""",
27
+ unsafe_allow_html=True,
28
+ )
29
  st.header("ChatData")
30
 
31
  if 'retriever' not in st.session_state:
chat.py CHANGED
@@ -1,20 +1,37 @@
1
  import pandas as pd
2
  from os import environ
 
3
  import datetime
4
  import streamlit as st
 
5
  from langchain.schema import HumanMessage, FunctionMessage
6
 
7
- from helper import build_agents
 
 
 
 
 
 
 
8
  from login import back_to_main
9
 
10
- environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
 
 
 
 
 
 
11
 
12
  def on_chat_submit():
13
- ret = st.session_state.agents[st.session_state.sel][st.session_state.ret_type]({"input": st.session_state.chat_input})
14
  print(ret)
15
-
 
16
  def clear_history():
17
- st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.clear()
 
18
 
19
 
20
  def back_to_main():
@@ -25,29 +42,156 @@ def back_to_main():
25
  if "jump_query_ask" in st.session_state:
26
  del st.session_state.jump_query_ask
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def chat_page():
29
- st.session_state["agents"] = build_agents(f"{st.session_state.user_name}?default")
 
 
 
 
30
  with st.sidebar:
31
- st.radio("Retriever Type", ["Self-querying retriever", "Vector SQL"], key="ret_type")
32
- st.selectbox("Knowledge Base", ["ArXiv Papers", "Wikipedia", "ArXiv + Wikipedia"], key="sel")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  st.button("Clear Chat History", on_click=clear_history)
34
  st.button("Logout", on_click=back_to_main)
35
- for msg in st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.chat_memory.messages:
 
 
 
36
  speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
37
  if isinstance(msg, FunctionMessage):
38
  with st.chat_message("Knowledge Base", avatar="πŸ“–"):
39
- print(type(msg.content))
40
- st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
 
41
  st.write("Retrieved from knowledge base:")
42
  try:
43
- st.dataframe(pd.DataFrame.from_records(map(dict, eval(msg.content))))
 
 
44
  except:
45
  st.write(msg.content)
46
  else:
47
  if len(msg.content) > 0:
48
  with st.chat_message(speaker):
49
  print(type(msg), msg.dict())
50
- st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
 
 
51
  st.write(f"{msg.content}")
52
  st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
53
-
 
1
  import pandas as pd
2
  from os import environ
3
+ from time import sleep
4
  import datetime
5
  import streamlit as st
6
+ from lib.sessions import SessionManager
7
  from langchain.schema import HumanMessage, FunctionMessage
8
 
9
+ from helper import (
10
+ build_agents,
11
+ MYSCALE_HOST,
12
+ MYSCALE_PASSWORD,
13
+ MYSCALE_PORT,
14
+ MYSCALE_USER,
15
+ DEFAULT_SYSTEM_PROMPT,
16
+ )
17
  from login import back_to_main
18
 
19
+ environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
20
+
21
+ TOOL_NAMES = {
22
+ "langchain_retriever_tool": "Self-querying retriever",
23
+ "vecsql_retriever_tool": "Vector SQL",
24
+ }
25
+
26
 
27
  def on_chat_submit():
28
+ ret = st.session_state.agent({"input": st.session_state.chat_input})
29
  print(ret)
30
+
31
+
32
  def clear_history():
33
+ if "agent" in st.session_state:
34
+ st.session_state.agent.memory.clear()
35
 
36
 
37
  def back_to_main():
 
42
  if "jump_query_ask" in st.session_state:
43
  del st.session_state.jump_query_ask
44
 
45
+
46
+ def on_session_change_submit():
47
+ if "session_manager" in st.session_state and "session_editor" in st.session_state:
48
+ print(st.session_state.session_editor)
49
+ try:
50
+ for elem in st.session_state.session_editor["added_rows"]:
51
+ if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
52
+ if elem["session_id"] != "" and "?" not in elem["session_id"]:
53
+ st.session_state.session_manager.add_session(
54
+ user_id=st.session_state.user_name,
55
+ session_id=f"{st.session_state.user_name}?{elem['session_id']}",
56
+ system_prompt=elem["system_prompt"],
57
+ )
58
+ else:
59
+ raise KeyError(
60
+ "`session_id` should NOT be neither empty nor contain question marks."
61
+ )
62
+ else:
63
+ raise KeyError(
64
+ "You should fill both `session_id` and `system_prompt` to add a column!"
65
+ )
66
+ for elem in st.session_state.session_editor["deleted_rows"]:
67
+ st.session_state.session_manager.remove_session(
68
+ session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}",
69
+ )
70
+ refresh_sessions()
71
+ if len(st.session_state.session_editor["deleted_rows"]) > 0:
72
+ try:
73
+ dfl_indx = [
74
+ x["session_id"] for x in st.session_state.current_sessions
75
+ ].index("default")
76
+ except ValueError:
77
+ dfl_indx = 0
78
+ st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
79
+ except Exception as e:
80
+ sleep(2)
81
+ st.error(f"{type(e)}: {str(e)}")
82
+ finally:
83
+ st.session_state.session_editor["added_rows"] = []
84
+ st.session_state.session_editor["deleted_rows"] = []
85
+ refresh_agent()
86
+
87
+
88
+ def build_session_manager():
89
+ return SessionManager(
90
+ host=MYSCALE_HOST,
91
+ port=MYSCALE_PORT,
92
+ username=MYSCALE_USER,
93
+ password=MYSCALE_PASSWORD,
94
+ )
95
+
96
+
97
+ def refresh_sessions():
98
+ st.session_state[
99
+ "current_sessions"
100
+ ] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
101
+ if type(st.session_state.current_sessions) is not dict and len(st.session_state.current_sessions) <= 0:
102
+ st.session_state.session_manager.add_session(
103
+ st.session_state.user_name,
104
+ f"{st.session_state.user_name}?default",
105
+ DEFAULT_SYSTEM_PROMPT,
106
+ )
107
+ st.session_state[
108
+ "current_sessions"
109
+ ] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
110
+
111
+
112
+ def refresh_agent():
113
+ with st.spinner("Initializing session..."):
114
+ print(
115
+ f"??? Changed to ",
116
+ f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
117
+ )
118
+ st.session_state["agent"] = build_agents(
119
+ f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
120
+ ["LangChain Self Query Retriever For Wikipedia"]
121
+ if "selected_tools" not in st.session_state
122
+ else st.session_state.selected_tools,
123
+ system_prompt=DEFAULT_SYSTEM_PROMPT
124
+ if "sel_sess" not in st.session_state
125
+ else st.session_state.sel_sess["system_prompt"],
126
+ )
127
+ st.session_state["session_manager"] = build_session_manager()
128
+
129
+
130
  def chat_page():
131
+ if "sel_sess" not in st.session_state:
132
+ st.session_state["sel_sess"] = {
133
+ "session_id": "default",
134
+ "system_prompt": DEFAULT_SYSTEM_PROMPT,
135
+ }
136
  with st.sidebar:
137
+ with st.expander("Session Management"):
138
+ refresh_sessions()
139
+ st.data_editor(
140
+ st.session_state.current_sessions,
141
+ num_rows="dynamic",
142
+ key="session_editor",
143
+ use_container_width=True,
144
+ )
145
+ st.button("Submit Change!", on_click=on_session_change_submit)
146
+ with st.expander("Session Selection", expanded=True):
147
+ try:
148
+ dfl_indx = [
149
+ x["session_id"] for x in st.session_state.current_sessions
150
+ ].index("default")
151
+ except ValueError:
152
+ dfl_indx = 0
153
+ st.selectbox(
154
+ "Choose a session be chat:",
155
+ options=st.session_state.current_sessions,
156
+ index=dfl_indx,
157
+ key="sel_sess",
158
+ format_func=lambda x: x["session_id"],
159
+ on_change=refresh_agent,
160
+ )
161
+ print(st.session_state.sel_sess)
162
+ with st.expander("Tool Settings", expanded=True):
163
+ st.multiselect(
164
+ "Knowledge Base",
165
+ st.session_state.tools.keys(),
166
+ default=["LangChain Self Query Retriever For Wikipedia"],
167
+ key="selected_tools",
168
+ on_change=refresh_agent,
169
+ )
170
  st.button("Clear Chat History", on_click=clear_history)
171
  st.button("Logout", on_click=back_to_main)
172
+ if 'agent' not in st.session_state:
173
+ refresh_agent()
174
+ print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
175
+ for msg in st.session_state.agent.memory.chat_memory.messages:
176
  speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
177
  if isinstance(msg, FunctionMessage):
178
  with st.chat_message("Knowledge Base", avatar="πŸ“–"):
179
+ st.write(
180
+ f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
181
+ )
182
  st.write("Retrieved from knowledge base:")
183
  try:
184
+ st.dataframe(
185
+ pd.DataFrame.from_records(map(dict, eval(msg.content)))
186
+ )
187
  except:
188
  st.write(msg.content)
189
  else:
190
  if len(msg.content) > 0:
191
  with st.chat_message(speaker):
192
  print(type(msg), msg.dict())
193
+ st.write(
194
+ f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
195
+ )
196
  st.write(f"{msg.content}")
197
  st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
 
helper.py CHANGED
@@ -68,6 +68,12 @@ MYSCALE_PORT = st.secrets['MYSCALE_PORT']
68
  COMBINE_PROMPT = ChatPromptTemplate.from_strings(
69
  string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
70
  (HumanMessagePromptTemplate, '{question}')])
 
 
 
 
 
 
71
 
72
  def hint_arxiv():
73
  st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
@@ -415,7 +421,7 @@ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
415
  return self.model_class
416
 
417
 
418
- def create_agent_executor(name, session_id, llm, tools, **kwargs):
419
  name = name.replace(" ", "_")
420
  conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}'
421
  chat_memory = SQLChatMessageHistory(
@@ -425,12 +431,7 @@ def create_agent_executor(name, session_id, llm, tools, **kwargs):
425
  memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
426
 
427
  _system_message = SystemMessage(
428
- content=(
429
- "Do your best to answer the questions. "
430
- "Feel free to use any tools available to look up "
431
- "relevant information. Please keep all details in query "
432
- "when calling search functions."
433
- )
434
  )
435
  prompt = OpenAIFunctionsAgent.create_prompt(
436
  system_message=_system_message,
@@ -463,38 +464,23 @@ def build_tools():
463
  st.session_state["sel_map_obj"][k] = {}
464
  if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
465
  st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
466
- sel_map_obj[k] = {
467
- "langchain_retriever_tool": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
468
- "vecsql_retriever_tool": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
469
- }
470
  return sel_map_obj
471
 
472
- @st.cache_resource(max_entries=1)
473
- def build_agents(session_id):
474
- chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=0.6, openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY)
475
- agents = {}
476
- cnt = 0
477
- p = st.progress(0.0, "Building agents with different knowledge base...")
478
- for k in [*sel_map.keys(), 'ArXiv + Wikipedia']:
479
- for m, n in [("langchain_retriever_tool", "Self-querying retriever"), ("vecsql_retriever_tool", "Vector SQL")]:
480
- if k == 'ArXiv + Wikipedia':
481
- tools = [st.session_state.tools[k][m] for k in sel_map.keys()]
482
- elif k == 'Null':
483
- tools = []
484
- else:
485
- tools = [st.session_state.tools[k][m]]
486
- if k not in agents:
487
- agents[k] = {}
488
- agents[k][n] = create_agent_executor(
489
- "chat_memory",
490
- session_id,
491
- chat_llm,
492
- tools=tools,
493
- )
494
- cnt += 1/6
495
- p.progress(cnt, f"Building with Knowledge Base {k} via Retriever {n}...")
496
- p.empty()
497
- return agents
498
 
499
 
500
  def display(dataframe, columns_=None, index=None):
 
68
  COMBINE_PROMPT = ChatPromptTemplate.from_strings(
69
  string_messages=[(SystemMessagePromptTemplate, combine_prompt_template),
70
  (HumanMessagePromptTemplate, '{question}')])
71
+ DEFAULT_SYSTEM_PROMPT = (
72
+ "Do your best to answer the questions. "
73
+ "Feel free to use any tools available to look up "
74
+ "relevant information. Please keep all details in query "
75
+ "when calling search functions."
76
+ )
77
 
78
  def hint_arxiv():
79
  st.info("We provides you metadata columns below for query. Please choose a natural expression to describe filters on those columns.\n\n"
 
421
  return self.model_class
422
 
423
 
424
+ def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs):
425
  name = name.replace(" ", "_")
426
  conn_str = f'clickhouse://{MYSCALE_USER}:{MYSCALE_PASSWORD}@{MYSCALE_HOST}:{MYSCALE_PORT}'
427
  chat_memory = SQLChatMessageHistory(
 
431
  memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
432
 
433
  _system_message = SystemMessage(
434
+ content=system_prompt
 
 
 
 
 
435
  )
436
  prompt = OpenAIFunctionsAgent.create_prompt(
437
  system_message=_system_message,
 
464
  st.session_state["sel_map_obj"][k] = {}
465
  if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
466
  st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
467
+ sel_map_obj.update({
468
+ f"LangChain Self Query Retriever For {k}": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
469
+ f"Vector SQL Retriever For {k}": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
470
+ })
471
  return sel_map_obj
472
 
473
+ def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
474
+ chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature, openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY)
475
+ tools = [st.session_state.tools[k] for k in tool_names]
476
+ agent = create_agent_executor(
477
+ "chat_memory",
478
+ session_id,
479
+ chat_llm,
480
+ tools=tools,
481
+ system_prompt=system_prompt
482
+ )
483
+ return agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
 
486
  def display(dataframe, columns_=None, index=None):
lib/schemas.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlalchemy import Column, Text
2
+ from clickhouse_sqlalchemy import types, engines
3
+
4
+
5
+ def create_message_model(table_name, DynamicBase): # type: ignore
6
+ """
7
+ Create a message model for a given table name.
8
+
9
+ Args:
10
+ table_name: The name of the table to use.
11
+ DynamicBase: The base class to use for the model.
12
+
13
+ Returns:
14
+ The model class.
15
+
16
+ """
17
+
18
+ # Model decleared inside a function to have a dynamic table name
19
+ class Message(DynamicBase):
20
+ __tablename__ = table_name
21
+ id = Column(types.Float64)
22
+ session_id = Column(Text)
23
+ user_id = Column(Text)
24
+ msg_id = Column(Text, primary_key=True)
25
+ type = Column(Text)
26
+ addtionals = Column(Text)
27
+ message = Column(Text)
28
+ __table_args__ = (
29
+ engines.ReplacingMergeTree(
30
+ partition_by='session_id',
31
+ order_by=('id', 'msg_id')),
32
+ {'comment': 'Store Chat History'}
33
+ )
34
+
35
+ return Message
36
+
37
+
38
+ def create_session_table(table_name, DynamicBase): # type: ignore
39
+ # Model decleared inside a function to have a dynamic table name
40
+ class Session(DynamicBase):
41
+ __tablename__ = table_name
42
+ user_id = Column(Text)
43
+ session_id = Column(Text, primary_key=True)
44
+ system_prompt = Column(Text)
45
+ create_by = Column(types.DateTime)
46
+ additionals = Column(Text)
47
+ __table_args__ = (
48
+ engines.ReplacingMergeTree(
49
+ order_by=('session_id')),
50
+ {'comment': 'Store Session and Prompts'}
51
+ )
52
+ return Session
lib/sessions.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ try:
3
+ from sqlalchemy.orm import declarative_base
4
+ except ImportError:
5
+ from sqlalchemy.ext.declarative import declarative_base
6
+ from datetime import datetime
7
+ from sqlalchemy import Column, Text, orm, create_engine
8
+ from clickhouse_sqlalchemy import types, engines
9
+ from .schemas import create_message_model, create_session_table
10
+
11
+ def get_sessions(engine, model_class, user_id):
12
+ with orm.sessionmaker(engine)() as session:
13
+ result = (
14
+ session.query(model_class)
15
+ .where(
16
+ model_class.session_id == user_id
17
+ )
18
+ .order_by(model_class.create_by.desc())
19
+ )
20
+ return json.loads(result)
21
+
22
+ class SessionManager:
23
+ def __init__(self, host, port, username, password, db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
24
+ conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
25
+ self.engine = create_engine(conn_str, echo=False)
26
+ self.sess_model_class = create_session_table(sess_table, declarative_base())
27
+ self.sess_model_class.metadata.create_all(self.engine)
28
+ self.msg_model_class = create_message_model(msg_table, declarative_base())
29
+ self.msg_model_class.metadata.create_all(self.engine)
30
+ self.Session = orm.sessionmaker(self.engine)
31
+
32
+ def list_sessions(self, user_id):
33
+ with self.Session() as session:
34
+ result = (
35
+ session.query(self.sess_model_class)
36
+ .where(
37
+ self.sess_model_class.user_id == user_id
38
+ )
39
+ .order_by(self.sess_model_class.create_by.desc())
40
+ )
41
+ sessions = []
42
+ for r in result:
43
+ sessions.append({
44
+ "session_id": r.session_id.split("?")[-1],
45
+ "system_prompt": r.system_prompt,
46
+ })
47
+ return sessions
48
+
49
+ def modify_system_prompt(self, session_id, sys_prompt):
50
+ with self.Session() as session:
51
+ session.update(self.sess_model_class).where(self.sess_model_class==session_id).value(system_prompt=sys_prompt)
52
+ session.commit()
53
+
54
+ def add_session(self, user_id, session_id, system_prompt, **kwargs):
55
+ with self.Session() as session:
56
+ elem = self.sess_model_class(
57
+ user_id=user_id, session_id=session_id, system_prompt=system_prompt,
58
+ create_by=datetime.now(), additionals=json.dumps(kwargs)
59
+ )
60
+ session.add(elem)
61
+ session.commit()
62
+
63
+ def remove_session(self, session_id):
64
+ with self.Session() as session:
65
+ session.query(self.sess_model_class).where(self.sess_model_class.session_id==session_id).delete()
66
+ session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
67
+
68
+