File size: 4,197 Bytes
4df8c22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Manages user & assistant messages in the session state.

### 1. Import the libraries
import streamlit as st
import time
import os
from dotenv import load_dotenv

from langchain.memory import ConversationSummaryMemory
from langchain.chains import ConversationChain
from langchain_openai import ChatOpenAI
from langchain.prompts import PromptTemplate
from langchain_core.messages import HumanMessage, AIMessage

# This is to simplify local development
# Without this you will need to copy/paste the API key with every change
try:
    # CHANGE the location of the file
    load_dotenv('C:\\Users\\raj\\.jupyter\\.env')
    # Add the API key to the session - use it for populating the interface
    if os.getenv('OPENAI_API_KEY'):
        st.session_state['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY')
except:
    print("Environment file not found !! Copy & paste your OpenAI API key.")


### 1. Setup the title & input text element for the OpenAI API key
#    Set the title
#    Populate API key from session if it is available
st.title("LangChain ConversationSummaryMemory  !!!")

# If the key is already available, initialize its value on the UI
if 'OPENAI_API_KEY' in st.session_state:
    openai_api_key = st.sidebar.text_input('OpenAI API key',value=st.session_state['OPENAI_API_KEY'])
else:
    openai_api_key = st.sidebar.text_input('OpenAI API key',placeholder='copy & paste your OpenAI API key')

### 2. Define utility functions to invoke the LLM

# Create an instance of the LLM for summarization
@st.cache_resource
def  get_summarization_llm():
     model = 'gpt-3.5-turbo-0125'
     return ChatOpenAI(model=model, openai_api_key=openai_api_key) 

# Create an instance of the LLM for chatbot responses
@st.cache_resource
def  get_llm():
     model = 'gpt-3.5-turbo-0125'
     return ChatOpenAI(model=model, openai_api_key=openai_api_key) 

@st.cache_resource
def get_llm_chain():
    memory = st.session_state['MEMORY']
    conversation = ConversationChain(
        llm=get_llm(),
        # prompt=prompt_template,
        # verbose=True,
        memory=memory
    )
    return conversation

# Create the context by concatenating the messages
def get_chat_context():
    memory = st.session_state['MEMORY']
    return memory.buffer

# Generate the response and return
def  get_llm_response(prompt):
    # llm = get_llm()
    llm = get_llm_chain()

    # Show spinner, while we are waiting for the response
    with st.spinner('Invoking LLM ... '):
        # get the context
        chat_context = get_chat_context()

        # Prefix the query with context
        query_payload = chat_context +'\n\n Question: ' + prompt

        response = llm.invoke(query_payload)

        return response

# Initialize the session state memory
if 'MEMORY' not in st.session_state:
    memory = ConversationSummaryMemory(
        llm = get_summarization_llm(),
        human_prefix='user',
        ai_prefix = 'assistant',
        return_messages=True
    )
    # add to the session
    st.session_state['MEMORY'] = memory

### 3. Write the messages to chat_message container
# Write messages to the chat_message element
# This is needed as streamlit re-runs the entire script when user provides input in a widget
# https://docs.streamlit.io/develop/api-reference/chat/st.chat_message

for msg in st.session_state['MEMORY'].chat_memory.messages:

    if (isinstance(msg, HumanMessage)):
        st.chat_message('user').write(msg.content)
    elif (isinstance(msg, AIMessage)):
        st.chat_message('ai').write(msg.content)
    else:
        print('System message: ', msg.content)
    

### 4. Create the *chat_input* element to get the user query
# Interface for user input
prompt = st.chat_input(placeholder='Your input here')

### 5. Process the query received from user
if prompt:

    # Write the user prompt as chat message
    st.chat_message('user').write(prompt)

    # Invoke the LLM
    response = get_llm_response(prompt)

    # Write the response as chat_message
    st.chat_message('ai').write(response['response'])

### 6. Write out the current content of the context
st.divider()
st.subheader('Context/Summary:')

# Print the state of the buffer
st.session_state['MEMORY'].buffer