mpsk commited on
Commit
06665fc
β€’
1 Parent(s): 9dd6716

improve chat experience

Browse files
Files changed (4) hide show
  1. app.py +1 -1
  2. callbacks/arxiv_callbacks.py +32 -3
  3. chat.py +23 -4
  4. helper.py +66 -8
app.py CHANGED
@@ -28,7 +28,7 @@ st.markdown(
28
  )
29
  st.header("ChatData")
30
 
31
- if 'retriever' not in st.session_state:
32
  st.session_state["sel_map_obj"] = build_all()
33
  st.session_state["tools"] = build_tools()
34
 
 
28
  )
29
  st.header("ChatData")
30
 
31
+ if 'sel_map_obj' not in st.session_state:
32
  st.session_state["sel_map_obj"] = build_all()
33
  st.session_state["tools"] = build_tools()
34
 
callbacks/arxiv_callbacks.py CHANGED
@@ -1,8 +1,11 @@
1
  import streamlit as st
2
- from typing import Dict, Any
 
 
3
  from sql_formatter.core import format_sql
4
- from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
5
  from langchain.schema.output import LLMResult
 
6
 
7
  class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
8
  def __init__(self) -> None:
@@ -91,4 +94,30 @@ class ChatDataSQLAskCallBackHandler(ChatDataSQLSearchCallBackHandler):
91
  self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
92
  self.status_bar = st.empty()
93
  self.prog_value = 0
94
- self.prog_interval = 0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import json
3
+ import textwrap
4
+ from typing import Dict, Any, List
5
  from sql_formatter.core import format_sql
6
+ from langchain.callbacks.streamlit.streamlit_callback_handler import LLMThought, StreamlitCallbackHandler
7
  from langchain.schema.output import LLMResult
8
+ from streamlit.delta_generator import DeltaGenerator
9
 
10
  class ChatDataSelfSearchCallBackHandler(StreamlitCallbackHandler):
11
  def __init__(self) -> None:
 
94
  self.progress_bar = st.progress(value=0.0, text='Writing SQL...')
95
  self.status_bar = st.empty()
96
  self.prog_value = 0
97
+ self.prog_interval = 0.1
98
+
99
+
100
+ class LLMThoughtWithKB(LLMThought):
101
+ def on_tool_end(self, output: str, color: str | None = None, observation_prefix: str | None = None, llm_prefix: str | None = None, **kwargs: Any) -> None:
102
+ try:
103
+ self._container.markdown("\n\n".join(["### Retrieved Documents:"] + \
104
+ [f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
105
+ for i, r in enumerate(json.loads(output))]))
106
+ except Exception as e:
107
+ super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
108
+
109
+
110
+ class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
111
+
112
+ def on_llm_start(
113
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
114
+ ) -> None:
115
+ if self._current_thought is None:
116
+ self._current_thought = LLMThoughtWithKB(
117
+ parent_container=self._parent_container,
118
+ expanded=self._expand_new_thoughts,
119
+ collapse_on_complete=self._collapse_completed_thoughts,
120
+ labeler=self._thought_labeler,
121
+ )
122
+
123
+ self._current_thought.on_llm_start(serialized, prompts)
chat.py CHANGED
@@ -5,6 +5,8 @@ import datetime
5
  import streamlit as st
6
  from lib.sessions import SessionManager
7
  from langchain.schema import HumanMessage, FunctionMessage
 
 
8
 
9
  from helper import (
10
  build_agents,
@@ -25,8 +27,14 @@ TOOL_NAMES = {
25
 
26
 
27
  def on_chat_submit():
28
- ret = st.session_state.agent({"input": st.session_state.chat_input})
29
- print(ret)
 
 
 
 
 
 
30
 
31
 
32
  def clear_history():
@@ -136,6 +144,12 @@ def chat_page():
136
  with st.sidebar:
137
  with st.expander("Session Management"):
138
  refresh_sessions()
 
 
 
 
 
 
139
  st.data_editor(
140
  st.session_state.current_sessions,
141
  num_rows="dynamic",
@@ -144,6 +158,8 @@ def chat_page():
144
  )
145
  st.button("Submit Change!", on_click=on_session_change_submit)
146
  with st.expander("Session Selection", expanded=True):
 
 
147
  try:
148
  dfl_indx = [
149
  x["session_id"] for x in st.session_state.current_sessions
@@ -152,7 +168,7 @@ def chat_page():
152
  print("*** ", str(e))
153
  dfl_indx = 0
154
  st.selectbox(
155
- "Choose a session be chat:",
156
  options=st.session_state.current_sessions,
157
  index=dfl_indx,
158
  key="sel_sess",
@@ -161,10 +177,12 @@ def chat_page():
161
  )
162
  print(st.session_state.sel_sess)
163
  with st.expander("Tool Settings", expanded=True):
 
 
164
  st.multiselect(
165
  "Knowledge Base",
166
  st.session_state.tools.keys(),
167
- default=["LangChain Self Query Retriever For Wikipedia"],
168
  key="selected_tools",
169
  on_change=refresh_agent,
170
  )
@@ -195,4 +213,5 @@ def chat_page():
195
  f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
196
  )
197
  st.write(f"{msg.content}")
 
198
  st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
 
5
  import streamlit as st
6
  from lib.sessions import SessionManager
7
  from langchain.schema import HumanMessage, FunctionMessage
8
+ from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
9
+ from langchain.callbacks.streamlit.streamlit_callback_handler import StreamlitCallbackHandler
10
 
11
  from helper import (
12
  build_agents,
 
27
 
28
 
29
  def on_chat_submit():
30
+ with st.session_state.next_round.container():
31
+ with st.chat_message('user'):
32
+ st.write(st.session_state.chat_input)
33
+ with st.chat_message('assistant'):
34
+ container = st.container()
35
+ st_callback = ChatDataAgentCallBackHandler(container, collapse_completed_thoughts=False)
36
+ ret = st.session_state.agent({"input": st.session_state.chat_input}, callbacks=[st_callback])
37
+ print(ret)
38
 
39
 
40
  def clear_history():
 
144
  with st.sidebar:
145
  with st.expander("Session Management"):
146
  refresh_sessions()
147
+ st.info("Here you can set up your session! \n\nYou can **change your prompt** here!",
148
+ icon="πŸ€–")
149
+ st.info(("**Add columns by clicking the empty row**.\n"
150
+ "And **delete columns by selecting rows with a press on `DEL` Key**"),
151
+ icon="πŸ’‘")
152
+ st.info("Don't forget to **click `Submit Change` to save your change**!", icon="πŸ“’")
153
  st.data_editor(
154
  st.session_state.current_sessions,
155
  num_rows="dynamic",
 
158
  )
159
  st.button("Submit Change!", on_click=on_session_change_submit)
160
  with st.expander("Session Selection", expanded=True):
161
+ st.info("Here you can select your session!", icon="πŸ€–")
162
+ st.info("If no session is attach to your account, then we will add a default session to you!", icon="❀️")
163
  try:
164
  dfl_indx = [
165
  x["session_id"] for x in st.session_state.current_sessions
 
168
  print("*** ", str(e))
169
  dfl_indx = 0
170
  st.selectbox(
171
+ "Choose a session to chat:",
172
  options=st.session_state.current_sessions,
173
  index=dfl_indx,
174
  key="sel_sess",
 
177
  )
178
  print(st.session_state.sel_sess)
179
  with st.expander("Tool Settings", expanded=True):
180
+ st.info("Here you can select your tools.", icon="πŸ”§")
181
+ st.info("We provides you several knowledge base tools for you. We are building more tools!", icon="πŸ‘·β€β™‚οΈ")
182
  st.multiselect(
183
  "Knowledge Base",
184
  st.session_state.tools.keys(),
185
+ default=["Wikipedia + Self Querying"],
186
  key="selected_tools",
187
  on_change=refresh_agent,
188
  )
 
213
  f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
214
  )
215
  st.write(f"{msg.content}")
216
+ st.session_state["next_round"] = st.empty()
217
  st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
helper.py CHANGED
@@ -2,12 +2,15 @@
2
  import json
3
  import time
4
  import hashlib
5
- from typing import Dict, Any
6
  import re
7
  import pandas as pd
8
  from os import environ
9
  import streamlit as st
10
  import datetime
 
 
 
11
 
12
  from sqlalchemy import Column, Text, create_engine, MetaData
13
  from langchain.agents import AgentExecutor
@@ -28,7 +31,7 @@ from langchain.prompts import PromptTemplate, ChatPromptTemplate, \
28
  SystemMessagePromptTemplate, HumanMessagePromptTemplate
29
  from langchain.prompts.prompt import PromptTemplate
30
  from langchain.chat_models import ChatOpenAI
31
- from langchain.schema import BaseRetriever
32
  from langchain import OpenAI
33
  from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
34
  from langchain.retrievers.self_query.base import SelfQueryRetriever
@@ -36,12 +39,12 @@ from langchain.retrievers.self_query.myscale import MyScaleTranslator
36
  from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
37
  from langchain.vectorstores import MyScaleSettings
38
  from chains.arxiv_chains import MyScaleWithoutMetadataJson
39
- from langchain.schema import Document
40
  from langchain.prompts.prompt import PromptTemplate
41
  from langchain.prompts.chat import MessagesPlaceholder
42
  from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
43
  from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
44
- from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage
 
45
  from langchain.memory import SQLChatMessageHistory
46
  from langchain.memory.chat_message_histories.sql import \
47
  BaseMessageConverter, DefaultMessageConverter
@@ -389,6 +392,26 @@ def create_message_model(table_name, DynamicBase): # type: ignore
389
 
390
  return Message
391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
392
  class DefaultClickhouseMessageConverter(DefaultMessageConverter):
393
  """The default message converter for SQLChatMessageHistory."""
394
 
@@ -411,9 +434,10 @@ class DefaultClickhouseMessageConverter(DefaultMessageConverter):
411
  "additional_kwargs": {"timestamp": tstamp},
412
  "data": message.dict()})
413
  )
 
414
  def from_sql_model(self, sql_message: Any) -> BaseMessage:
415
  msg_dump = json.loads(sql_message.message)
416
- msg = messages_from_dict([msg_dump])[0]
417
  msg.additional_kwargs = msg_dump["additional_kwargs"]
418
  return msg
419
 
@@ -447,6 +471,38 @@ def create_agent_executor(name, session_id, llm, tools, system_prompt, **kwargs)
447
  **kwargs
448
  )
449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  @st.cache_resource
451
  def build_tools():
452
  """build all resources
@@ -465,13 +521,15 @@ def build_tools():
465
  if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
466
  st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
467
  sel_map_obj.update({
468
- f"LangChain Self Query Retriever For {k}": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
469
- f"Vector SQL Retriever For {k}": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
470
  })
471
  return sel_map_obj
472
 
473
  def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
474
- chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature, openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY)
 
 
475
  tools = [st.session_state.tools[k] for k in tool_names]
476
  agent = create_agent_executor(
477
  "chat_memory",
 
2
  import json
3
  import time
4
  import hashlib
5
+ from typing import Dict, Any, List
6
  import re
7
  import pandas as pd
8
  from os import environ
9
  import streamlit as st
10
  import datetime
11
+ from langchain.schema import BaseRetriever
12
+ from langchain.tools import Tool
13
+ from langchain.pydantic_v1 import BaseModel, Field
14
 
15
  from sqlalchemy import Column, Text, create_engine, MetaData
16
  from langchain.agents import AgentExecutor
 
31
  SystemMessagePromptTemplate, HumanMessagePromptTemplate
32
  from langchain.prompts.prompt import PromptTemplate
33
  from langchain.chat_models import ChatOpenAI
34
+ from langchain.schema import BaseRetriever, Document
35
  from langchain import OpenAI
36
  from langchain.chains.query_constructor.base import AttributeInfo, VirtualColumnName
37
  from langchain.retrievers.self_query.base import SelfQueryRetriever
 
39
  from langchain.embeddings import HuggingFaceInstructEmbeddings, SentenceTransformerEmbeddings
40
  from langchain.vectorstores import MyScaleSettings
41
  from chains.arxiv_chains import MyScaleWithoutMetadataJson
 
42
  from langchain.prompts.prompt import PromptTemplate
43
  from langchain.prompts.chat import MessagesPlaceholder
44
  from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
45
  from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
46
+ from langchain.schema.messages import BaseMessage, HumanMessage, AIMessage, FunctionMessage,\
47
+ SystemMessage, ChatMessage, ToolMessage
48
  from langchain.memory import SQLChatMessageHistory
49
  from langchain.memory.chat_message_histories.sql import \
50
  BaseMessageConverter, DefaultMessageConverter
 
392
 
393
  return Message
394
 
395
+ def _message_from_dict(message: dict) -> BaseMessage:
396
+ _type = message["type"]
397
+ if _type == "human":
398
+ return HumanMessage(**message["data"])
399
+ elif _type == "ai":
400
+ return AIMessage(**message["data"])
401
+ elif _type == "system":
402
+ return SystemMessage(**message["data"])
403
+ elif _type == "chat":
404
+ return ChatMessage(**message["data"])
405
+ elif _type == "function":
406
+ return FunctionMessage(**message["data"])
407
+ elif _type == "tool":
408
+ return ToolMessage(**message["data"])
409
+ elif _type == "AIMessageChunk":
410
+ message["data"]["type"] = "ai"
411
+ return AIMessage(**message["data"])
412
+ else:
413
+ raise ValueError(f"Got unexpected message type: {_type}")
414
+
415
  class DefaultClickhouseMessageConverter(DefaultMessageConverter):
416
  """The default message converter for SQLChatMessageHistory."""
417
 
 
434
  "additional_kwargs": {"timestamp": tstamp},
435
  "data": message.dict()})
436
  )
437
+
438
  def from_sql_model(self, sql_message: Any) -> BaseMessage:
439
  msg_dump = json.loads(sql_message.message)
440
+ msg = _message_from_dict(msg_dump)
441
  msg.additional_kwargs = msg_dump["additional_kwargs"]
442
  return msg
443
 
 
471
  **kwargs
472
  )
473
 
474
+ class RetrieverInput(BaseModel):
475
+ query: str = Field(description="query to look up in retriever")
476
+
477
+ def create_retriever_tool(
478
+ retriever: BaseRetriever, name: str, description: str
479
+ ) -> Tool:
480
+ """Create a tool to do retrieval of documents.
481
+
482
+ Args:
483
+ retriever: The retriever to use for the retrieval
484
+ name: The name for the tool. This will be passed to the language model,
485
+ so should be unique and somewhat descriptive.
486
+ description: The description for the tool. This will be passed to the language
487
+ model, so should be descriptive.
488
+
489
+ Returns:
490
+ Tool class to pass to an agent
491
+ """
492
+ def wrap(func):
493
+ def wrapped_retrieve(*args, **kwargs):
494
+ docs: List[Document] = func(*args, **kwargs)
495
+ return json.dumps([d.dict() for d in docs])
496
+ return wrapped_retrieve
497
+
498
+ return Tool(
499
+ name=name,
500
+ description=description,
501
+ func=wrap(retriever.get_relevant_documents),
502
+ coroutine=retriever.aget_relevant_documents,
503
+ args_schema=RetrieverInput,
504
+ )
505
+
506
  @st.cache_resource
507
  def build_tools():
508
  """build all resources
 
521
  if "langchain_retriever" not in st.session_state.sel_map_obj[k] or "vecsql_retriever" not in st.session_state.sel_map_obj[k]:
522
  st.session_state.sel_map_obj[k].update(build_chains_retrievers(k))
523
  sel_map_obj.update({
524
+ f"{k} + Self Querying": create_retriever_tool(st.session_state.sel_map_obj[k]["retriever"], *sel_map[k]["tool_desc"],),
525
+ f"{k} + Vector SQL": create_retriever_tool(st.session_state.sel_map_obj[k]["sql_retriever"], *sel_map[k]["tool_desc"],),
526
  })
527
  return sel_map_obj
528
 
529
  def build_agents(session_id, tool_names, chat_model_name=chat_model_name, temperature=0.6, system_prompt=DEFAULT_SYSTEM_PROMPT):
530
+ chat_llm = ChatOpenAI(model_name=chat_model_name, temperature=temperature,
531
+ openai_api_base=OPENAI_API_BASE, openai_api_key=OPENAI_API_KEY, streaming=True,
532
+ )
533
  tools = [st.session_state.tools[k] for k in tool_names]
534
  agent = create_agent_executor(
535
  "chat_memory",