carolanderson commited on
Commit
b5792ea
·
1 Parent(s): cfb1b86

use StreamlitChatMessageHistory

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. app.py +18 -21
.gitignore CHANGED
@@ -1 +1,2 @@
1
  .ipynb_checkpoints
 
 
1
  .ipynb_checkpoints
2
+ .DS_Store
app.py CHANGED
@@ -9,10 +9,13 @@ SystemMessagePromptTemplate,
9
  HumanMessagePromptTemplate,
10
  )
11
  from langchain.memory import ConversationBufferWindowMemory
 
12
  from langchain.schema import AIMessage, HumanMessage
13
  import streamlit as st
14
 
15
 
 
 
16
  @st.cache_resource
17
  def set_api_key(api_key):
18
  os.environ["OPENAI_API_KEY"] = api_key
@@ -21,7 +24,10 @@ def set_api_key(api_key):
21
  @st.cache_resource
22
  def get_chain(model_name, temperature):
23
  llm = ChatOpenAI(model_name=model_name, temperature=temperature)
24
- memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history", return_messages=True)
 
 
 
25
  prompt = ChatPromptTemplate(
26
  messages=[
27
  SystemMessagePromptTemplate.from_template(
@@ -40,26 +46,15 @@ def get_chain(model_name, temperature):
40
  return conversation
41
 
42
 
43
- def display_messages(chain):
44
- """
45
- Show the messages in the conversation buffer.
46
- """
47
- for message in chain.memory.buffer:
48
- if isinstance(message, AIMessage):
49
- role = "assistant"
50
- elif isinstance(message, HumanMessage):
51
- role = "user"
52
- with st.chat_message(role):
53
- st.write(message.content)
54
-
55
 
 
56
  if __name__ == "__main__":
57
  st.header("Basic chatbot")
58
  with st.expander("How conversation history works"):
59
  st.write("To keep input lengths down and costs reasonable,"
60
  " this bot only 'remembers' the past three turns of conversation.")
61
- st.write("Each combination of model type and temperature has its own unique chat history.")
62
- st.write("To clear the current model's memory and start fresh, click 'Clear history'" )
63
  API_KEY = st.sidebar.text_input(
64
  'API Key',
65
  type='password',
@@ -86,14 +81,16 @@ if __name__ == "__main__":
86
  chain = get_chain(model_name, temperature)
87
  if st.button("Clear history"):
88
  chain.memory.clear()
89
- display_messages(chain)
 
 
90
  text = st.chat_input()
91
  if text:
92
- if text.lower() == "clear":
93
- chain.memory.clear()
94
- else:
95
- result = chain.predict(input=text)
96
- display_messages(chain)
97
 
98
 
99
 
 
9
  HumanMessagePromptTemplate,
10
  )
11
  from langchain.memory import ConversationBufferWindowMemory
12
+ from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
13
  from langchain.schema import AIMessage, HumanMessage
14
  import streamlit as st
15
 
16
 
17
+
18
+
19
  @st.cache_resource
20
  def set_api_key(api_key):
21
  os.environ["OPENAI_API_KEY"] = api_key
 
24
  @st.cache_resource
25
  def get_chain(model_name, temperature):
26
  llm = ChatOpenAI(model_name=model_name, temperature=temperature)
27
+ msgs = StreamlitChatMessageHistory(key="basic_chat_app")
28
+ memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
29
+ chat_memory=msgs,
30
+ return_messages=True)
31
  prompt = ChatPromptTemplate(
32
  messages=[
33
  SystemMessagePromptTemplate.from_template(
 
46
  return conversation
47
 
48
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+
51
  if __name__ == "__main__":
52
  st.header("Basic chatbot")
53
  with st.expander("How conversation history works"):
54
  st.write("To keep input lengths down and costs reasonable,"
55
  " this bot only 'remembers' the past three turns of conversation.")
56
+ st.write("To clear all memory and start fresh, click 'Clear history'" )
57
+
58
  API_KEY = st.sidebar.text_input(
59
  'API Key',
60
  type='password',
 
81
  chain = get_chain(model_name, temperature)
82
  if st.button("Clear history"):
83
  chain.memory.clear()
84
+ st.cache_resource.clear()
85
+ for message in chain.memory.buffer:
86
+ st.chat_message(message.type).write(message.content)
87
  text = st.chat_input()
88
  if text:
89
+ with st.chat_message("user"):
90
+ st.write(text)
91
+ result = chain.predict(input=text)
92
+ with st.chat_message("assistant"):
93
+ st.write(result)
94
 
95
 
96