acloudfan commited on
Commit
df79503
·
verified ·
1 Parent(s): b378679

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Manages user & assistant messages in the session state.
2
+
3
+ ### 1. Import the libraries
4
+ import streamlit as st
5
+ import time
6
+ import os
7
+
8
+ from dataclasses import dataclass
9
+ from dotenv import load_dotenv
10
+ # https://api.python.langchain.com/en/latest/llms/langchain_community.llms.cohere.Cohere.html#langchain_community.llms.cohere.Cohere
11
+ from langchain_community.llms import Cohere
12
+
13
+ ### 2. Setup datastructure for holding the messages
14
+ # Define a Message class for holding the query/response
15
+ @dataclass
16
+ class Message:
17
+ role: str # identifies the actor (system, user or human, assistant or ai)
18
+ payload: str # instructions, query, response
19
+
20
+ # Streamlit knows about the common roles as a result, it is able to display the icons
21
+ USER = "user" # or human,
22
+ ASSISTANT = "assistant" # or ai,
23
+ SYSTEM = "system"
24
+
25
+ # This is to simplify local development
26
+ # Without this you will need to copy/paste the API key with every change
27
+ try:
28
+ # CHANGE the location of the file
29
+ load_dotenv('C:\\Users\\raj\\.jupyter\\.env')
30
+ # Add the API key to the session - use it for populating the interface
31
+ if os.getenv('COHERE_API_KEY'):
32
+ st.session_state['COHERE_API_KEY'] = os.getenv('COHERE_API_KEY')
33
+ except:
34
+ print("Environment file not found !! Copy & paste your Cohere API key.")
35
+
36
+
37
+ ### 3. Initialize the datastructure to hold the context
38
+ MESSAGES='messages'
39
+ if MESSAGES not in st.session_state:
40
+ system_message = Message(role=SYSTEM, payload='you are a polite assistant named "Ruby".')
41
+ st.session_state[MESSAGES] = [system_message]
42
+
43
+ ### 4. Setup the title & input text element for the Cohere API key
44
+ # Set the title
45
+ # Populate API key from session if it is available
46
+ st.title("Multi-Turn conversation interface !!!")
47
+
48
+ # If the key is already available, initialize its value on the UI
49
+ if 'COHERE_API_KEY' in st.session_state:
50
+ cohere_api_key = st.sidebar.text_input('Cohere API key',value=st.session_state['COHERE_API_KEY'])
51
+ else:
52
+ cohere_api_key = st.sidebar.text_input('Cohere API key',placeholder='copy & paste your API key')
53
+
54
+
55
+
56
+
57
+ ### 5. Define utility functions to invoke the LLM
58
+
59
+ # Create an instance of the LLM
60
+ @st.cache_resource
61
+ def get_llm():
62
+ return Cohere(model="command", cohere_api_key=cohere_api_key)
63
+
64
+ # Create the context by concatenating the messages
65
+ def get_chat_context():
66
+ context = ''
67
+ for msg in st.session_state[MESSAGES]:
68
+ context = context + '\n\n' + msg.role + ':' + msg.payload
69
+ return context
70
+
71
+ # Generate the response and return
72
+ def get_llm_response(prompt):
73
+ llm = get_llm()
74
+
75
+ # Show spinner, while we are waiting for the response
76
+ with st.spinner('Invoking LLM ... '):
77
+ # get the context
78
+ chat_context = get_chat_context()
79
+
80
+ # Prefix the query with context
81
+ query_payload = chat_context +'\n\n Question: ' + prompt
82
+
83
+ response = llm.invoke(query_payload)
84
+
85
+ return response
86
+
87
+ ### 6. Write the messages to chat_message container
88
+ # Write messages to the chat_message element
89
+ # This is needed as streamlit re-runs the entire script when user provides input in a widget
90
+ # https://docs.streamlit.io/develop/api-reference/chat/st.chat_message
91
+ for msg in st.session_state[MESSAGES]:
92
+ st.chat_message(msg.role).write(msg.payload)
93
+
94
+ ### 7. Create the *chat_input* element to get the user query
95
+ # Interface for user input
96
+ prompt = st.chat_input(placeholder='Your input here')
97
+
98
+ ### 8. Process the query received from user
99
+ if prompt:
100
+ # create user message and add to end of messages in the session
101
+ user_message = Message(role=USER, payload=prompt)
102
+ st.session_state[MESSAGES].append(user_message)
103
+
104
+ # Write the user prompt as chat message
105
+ st.chat_message(USER).write(prompt)
106
+
107
+ # Invoke the LLM
108
+ response = get_llm_response(prompt)
109
+
110
+ # Create message object representing the response
111
+ assistant_message = Message(role=ASSISTANT, payload=response)
112
+
113
+ # Add the response message to the mesages array in the session
114
+ st.session_state[MESSAGES].append(assistant_message)
115
+
116
+ # Write the response as chat_message
117
+ st.chat_message(ASSISTANT).write(response)
118
+
119
+ ### 9. Write out the current content of the context
120
+ st.divider()
121
+ st.subheader('st.session_state[MESSAGES] dump:')
122
+
123
+ # Print the state of the buffer
124
+ for msg in st.session_state[MESSAGES]:
125
+ st.text(msg.role + ' : ' + msg.payload)