Spaces:
Runtime error
Runtime error
carolanderson
commited on
Commit
·
b5792ea
1
Parent(s):
cfb1b86
use StreamlitChatMessageHistory
Browse files- .gitignore +1 -0
- app.py +18 -21
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
.ipynb_checkpoints
|
|
|
|
1 |
.ipynb_checkpoints
|
2 |
+
.DS_Store
|
app.py
CHANGED
@@ -9,10 +9,13 @@ SystemMessagePromptTemplate,
|
|
9 |
HumanMessagePromptTemplate,
|
10 |
)
|
11 |
from langchain.memory import ConversationBufferWindowMemory
|
|
|
12 |
from langchain.schema import AIMessage, HumanMessage
|
13 |
import streamlit as st
|
14 |
|
15 |
|
|
|
|
|
16 |
@st.cache_resource
|
17 |
def set_api_key(api_key):
|
18 |
os.environ["OPENAI_API_KEY"] = api_key
|
@@ -21,7 +24,10 @@ def set_api_key(api_key):
|
|
21 |
@st.cache_resource
|
22 |
def get_chain(model_name, temperature):
|
23 |
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
|
24 |
-
|
|
|
|
|
|
|
25 |
prompt = ChatPromptTemplate(
|
26 |
messages=[
|
27 |
SystemMessagePromptTemplate.from_template(
|
@@ -40,26 +46,15 @@ def get_chain(model_name, temperature):
|
|
40 |
return conversation
|
41 |
|
42 |
|
43 |
-
def display_messages(chain):
|
44 |
-
"""
|
45 |
-
Show the messages in the conversation buffer.
|
46 |
-
"""
|
47 |
-
for message in chain.memory.buffer:
|
48 |
-
if isinstance(message, AIMessage):
|
49 |
-
role = "assistant"
|
50 |
-
elif isinstance(message, HumanMessage):
|
51 |
-
role = "user"
|
52 |
-
with st.chat_message(role):
|
53 |
-
st.write(message.content)
|
54 |
-
|
55 |
|
|
|
56 |
if __name__ == "__main__":
|
57 |
st.header("Basic chatbot")
|
58 |
with st.expander("How conversation history works"):
|
59 |
st.write("To keep input lengths down and costs reasonable,"
|
60 |
" this bot only 'remembers' the past three turns of conversation.")
|
61 |
-
st.write("
|
62 |
-
|
63 |
API_KEY = st.sidebar.text_input(
|
64 |
'API Key',
|
65 |
type='password',
|
@@ -86,14 +81,16 @@ if __name__ == "__main__":
|
|
86 |
chain = get_chain(model_name, temperature)
|
87 |
if st.button("Clear history"):
|
88 |
chain.memory.clear()
|
89 |
-
|
|
|
|
|
90 |
text = st.chat_input()
|
91 |
if text:
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
|
98 |
|
99 |
|
|
|
9 |
HumanMessagePromptTemplate,
|
10 |
)
|
11 |
from langchain.memory import ConversationBufferWindowMemory
|
12 |
+
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
|
13 |
from langchain.schema import AIMessage, HumanMessage
|
14 |
import streamlit as st
|
15 |
|
16 |
|
17 |
+
|
18 |
+
|
19 |
@st.cache_resource
|
20 |
def set_api_key(api_key):
|
21 |
os.environ["OPENAI_API_KEY"] = api_key
|
|
|
24 |
@st.cache_resource
|
25 |
def get_chain(model_name, temperature):
|
26 |
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
|
27 |
+
msgs = StreamlitChatMessageHistory(key="basic_chat_app")
|
28 |
+
memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
|
29 |
+
chat_memory=msgs,
|
30 |
+
return_messages=True)
|
31 |
prompt = ChatPromptTemplate(
|
32 |
messages=[
|
33 |
SystemMessagePromptTemplate.from_template(
|
|
|
46 |
return conversation
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
+
|
51 |
if __name__ == "__main__":
|
52 |
st.header("Basic chatbot")
|
53 |
with st.expander("How conversation history works"):
|
54 |
st.write("To keep input lengths down and costs reasonable,"
|
55 |
" this bot only 'remembers' the past three turns of conversation.")
|
56 |
+
st.write("To clear all memory and start fresh, click 'Clear history'" )
|
57 |
+
|
58 |
API_KEY = st.sidebar.text_input(
|
59 |
'API Key',
|
60 |
type='password',
|
|
|
81 |
chain = get_chain(model_name, temperature)
|
82 |
if st.button("Clear history"):
|
83 |
chain.memory.clear()
|
84 |
+
st.cache_resource.clear()
|
85 |
+
for message in chain.memory.buffer:
|
86 |
+
st.chat_message(message.type).write(message.content)
|
87 |
text = st.chat_input()
|
88 |
if text:
|
89 |
+
with st.chat_message("user"):
|
90 |
+
st.write(text)
|
91 |
+
result = chain.predict(input=text)
|
92 |
+
with st.chat_message("assistant"):
|
93 |
+
st.write(result)
|
94 |
|
95 |
|
96 |
|