File size: 3,013 Bytes
a8a9ff0
 
 
 
 
 
 
 
 
 
 
b5792ea
a8a9ff0
 
 
 
b5792ea
 
a8a9ff0
 
 
 
 
 
 
 
b5792ea
 
 
 
a8a9ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5792ea
a8a9ff0
 
 
 
 
b5792ea
 
a8a9ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5792ea
 
 
a8a9ff0
 
b5792ea
 
 
 
 
a8a9ff0
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os

from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.memory import ConversationBufferWindowMemory
from langchain.memory.chat_message_histories import StreamlitChatMessageHistory
from langchain.schema import AIMessage, HumanMessage
import streamlit as st




@st.cache_resource
def set_api_key(api_key):
    os.environ["OPENAI_API_KEY"] = api_key


@st.cache_resource
def get_chain(model_name, temperature):
    llm = ChatOpenAI(model_name=model_name, temperature=temperature)
    msgs = StreamlitChatMessageHistory(key="basic_chat_app")
    memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history", 
                                            chat_memory=msgs,
                                            return_messages=True)
    prompt = ChatPromptTemplate(
    messages=[
        SystemMessagePromptTemplate.from_template(
            "You are a nice chatbot having a conversation with a human."
        ),
        MessagesPlaceholder(variable_name="chat_history"),
        HumanMessagePromptTemplate.from_template("{input}")
    ]
    )
    conversation = LLMChain(
                    llm=llm,
                    prompt=prompt,
                    verbose=True,
                    memory=memory
                )
    return conversation



            
if __name__ == "__main__":
    st.header("Basic chatbot")
    with st.expander("How conversation history works"):
        st.write("To keep input lengths down and costs reasonable,"
                 " this bot only 'remembers' the past three turns of conversation.")
        st.write("To clear all memory and start fresh, click 'Clear history'" )
        
    API_KEY = st.sidebar.text_input(
    'API Key',
    type='password',
    help="Enter your OpenAI API key to use this app",
    value=None)

    model_name = st.sidebar.selectbox(
        label = "Choose a model",
        options = ["gpt-3.5-turbo", "gpt-4"],
        help="Which LLM to use",
        )
    
    temperature = st.sidebar.slider(
        label="Temperature",
        min_value=float(0),
        max_value=1.0,
        step=0.1,
        value=0.9,
        help="Set the decoding temperature. Lower temperatures give more predictable outputs."
        )
    
    if API_KEY is not None:
        set_api_key(API_KEY)
        chain = get_chain(model_name, temperature)
        if st.button("Clear history"):
            chain.memory.clear()
            st.cache_resource.clear()
        for message in chain.memory.buffer:
            st.chat_message(message.type).write(message.content)
        text = st.chat_input()
        if text:
            with st.chat_message("user"):
                st.write(text)
            result = chain.predict(input=text)
            with st.chat_message("assistant"):
                st.write(result)