llm-explorer / app.py
carolanderson's picture
use StreamlitChatMessageHistory
b5792ea
raw
history blame
3.01 kB
import os
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain.schema import AIMessage, HumanMessage
import streamlit as st
@st.cache_resource
def set_api_key(api_key):
os.environ["OPENAI_API_KEY"] = api_key
@st.cache_resource
def get_chain(model_name, temperature):
llm = ChatOpenAI(model_name=model_name, temperature=temperature)
msgs = StreamlitChatMessageHistory(key="basic_chat_app")
memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
chat_memory=msgs,
return_messages=True)
prompt = ChatPromptTemplate(
messages=[
SystemMessagePromptTemplate.from_template(
"You are a nice chatbot having a conversation with a human."
),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template("{input}")
]
)
conversation = LLMChain(
llm=llm,
prompt=prompt,
verbose=True,
memory=memory
)
return conversation
if __name__ == "__main__":
st.header("Basic chatbot")
with st.expander("How conversation history works"):
st.write("To keep input lengths down and costs reasonable,"
" this bot only 'remembers' the past three turns of conversation.")
st.write("To clear all memory and start fresh, click 'Clear history'" )
API_KEY = st.sidebar.text_input(
'API Key',
type='password',
help="Enter your OpenAI API key to use this app",
value=None)
model_name = st.sidebar.selectbox(
label = "Choose a model",
options = ["gpt-3.5-turbo", "gpt-4"],
help="Which LLM to use",
)
temperature = st.sidebar.slider(
label="Temperature",
min_value=float(0),
max_value=1.0,
step=0.1,
value=0.9,
help="Set the decoding temperature. Lower temperatures give more predictable outputs."
)
if API_KEY is not None:
set_api_key(API_KEY)
chain = get_chain(model_name, temperature)
if st.button("Clear history"):
chain.memory.clear()
st.cache_resource.clear()
for message in chain.memory.buffer:
st.chat_message(message.type).write(message.content)
text = st.chat_input()
if text:
with st.chat_message("user"):
st.write(text)
result = chain.predict(input=text)
with st.chat_message("assistant"):
st.write(result)