carolanderson commited on
Commit
9814d4c
·
1 Parent(s): 4481118

change api key handling

Browse files
Files changed (1) hide show
  1. app.py +112 -24
app.py CHANGED
@@ -1,7 +1,9 @@
 
1
  import os
2
 
3
  from langchain.chains import LLMChain
4
  from langchain.chat_models import ChatOpenAI
 
5
  from langchain.prompts.chat import (
6
  ChatPromptTemplate,
7
  MessagesPlaceholder,
@@ -15,9 +17,73 @@ from openai.error import AuthenticationError
15
  import streamlit as st
16
 
17
 
18
- def set_api_key(USER_API_KEY):
19
- os.environ["OPENAI_API_KEY"] = str(USER_API_KEY)
20
- st.cache_resource.clear() # clear existing chains with old api key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  @st.cache_resource
@@ -26,12 +92,18 @@ def setup_memory():
26
  memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
27
  chat_memory=msgs,
28
  return_messages=True)
 
29
  return memory
30
 
31
 
32
  @st.cache_resource
33
- def get_chain(model_name, _memory, temperature):
34
- llm = ChatOpenAI(model_name=model_name, temperature=temperature)
 
 
 
 
 
35
  prompt = ChatPromptTemplate(
36
  messages=[
37
  SystemMessagePromptTemplate.from_template(
@@ -50,31 +122,42 @@ def get_chain(model_name, _memory, temperature):
50
  return conversation
51
 
52
 
 
 
53
 
54
  if __name__ == "__main__":
55
-
 
56
  st.header("Basic chatbot")
57
-
58
  st.write("On small screens, click the `>` at top left to get started")
59
-
60
  with st.expander("How conversation history works"):
61
  st.write("To keep input lengths down and costs reasonable,"
62
  " this bot only 'remembers' the past three turns of conversation.")
63
  st.write("To clear all memory and start fresh, click 'Clear history'" )
64
-
65
  st.sidebar.title("Choose options and enter API key")
66
 
67
- USER_API_KEY = st.sidebar.text_input(
68
- 'API Key',
69
- type='password',
70
- help="Enter your OpenAI API key to use this app",
71
- value=None)
72
-
73
  model_name = st.sidebar.selectbox(
74
  label = "Choose a model",
75
- options = ["gpt-3.5-turbo", "gpt-4"],
 
 
76
  help="Which LLM to use",
77
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  temperature = st.sidebar.slider(
80
  label="Temperature",
@@ -82,17 +165,20 @@ if __name__ == "__main__":
82
  max_value=1.0,
83
  step=0.1,
84
  value=0.9,
85
- help="Set the decoding temperature. Lower temperatures give more predictable outputs."
86
  )
87
-
88
 
89
- if USER_API_KEY is not None:
90
- set_api_key(USER_API_KEY)
 
 
 
91
  memory = setup_memory()
92
- chain = get_chain(model_name, memory, temperature)
93
  if st.button("Clear history"):
94
  chain.memory.clear()
95
- st.cache_resource.clear()
96
  for message in chain.memory.buffer: # display chat history
97
  st.chat_message(message.type).write(message.content)
98
  text = st.chat_input()
@@ -103,11 +189,13 @@ if __name__ == "__main__":
103
  result = chain.predict(input=text)
104
  with st.chat_message("assistant"):
105
  st.write(result)
106
- except AuthenticationError:
107
  st.warning("Enter a valid API key", icon="⚠️")
108
 
109
 
110
 
 
 
111
 
112
 
113
 
 
1
+ import logging
2
  import os
3
 
4
  from langchain.chains import LLMChain
5
  from langchain.chat_models import ChatOpenAI
6
+ from langchain.llms import HuggingFaceHub
7
  from langchain.prompts.chat import (
8
  ChatPromptTemplate,
9
  MessagesPlaceholder,
 
17
  import streamlit as st
18
 
19
 
20
+ @st.cache_resource
21
+ class KeyManager():
22
+ """
23
+ Stores the original API keys from environment variables, which
24
+ can be overwritten if user supplies keys.
25
+ Also stores the currently active API key for each model provider and updates
26
+ these based on user input.
27
+ """
28
+ def __init__(self):
29
+ self.provider_names = {"OpenAI" : "OPENAI_API_KEY",
30
+ "HuggingFace" : "HUGGINGFACEHUB_API_TOKEN"}
31
+ self.original_keys = {k : os.environ.get(v) for k, v
32
+ in self.provider_names.items()}
33
+ self.current_keys = {k: os.environ.get(v) for k, v in self.provider_names.items()}
34
+ self.user_keys = {} # most recent key supplied by user for each provider
35
+
36
+ def set_key(self, api_key, model_provider, user_entered=False):
37
+ self.current_keys[model_provider] = api_key
38
+ os.environ[self.provider_names[model_provider]] = api_key
39
+ if user_entered:
40
+ self.user_keys[model_provider] = api_key
41
+ get_chain.clear()
42
+
43
+ def list_keys(self):
44
+ """
45
+ For debugging purposes only. Do not use in deployed app.
46
+ """
47
+ st.write("Active API keys:")
48
+ for k, v in self.provider_names.items():
49
+ st.write(k, " : ", os.environ.get(v))
50
+ st.write("Current API keys:")
51
+ for k, v in self.current_keys.items():
52
+ st.write(k, " : ", v)
53
+ st.write("User-supplied API keys:")
54
+ for k, v in self.user_keys.items():
55
+ st.write(k, " : ", v)
56
+ st.write("Original API keys:")
57
+ for k, v in self.original_keys.items():
58
+ st.write(k, " : ", v)
59
+
60
+ def configure_api_key(self, user_api_key, use_provided_key, model_provider):
61
+ """
62
+ Set the currently active API key(s) based on user input.
63
+ """
64
+ if user_api_key:
65
+ if use_provided_key:
66
+ st.warning("API key entered and 'use provided key' checked;"
67
+ " using the key you entered", icon="⚠️")
68
+ self.set_key(str(user_api_key), model_provider, user_entered=True)
69
+ return True
70
+
71
+ if use_provided_key:
72
+ self.set_key(self.original_keys[model_provider], model_provider)
73
+ return True
74
+
75
+ if not user_api_key and not use_provided_key:
76
+ # check if user previously supplied a key for this provider
77
+ if model_provider in self.user_keys:
78
+ self.set_key(self.user_keys[model_provider], model_provider)
79
+ st.warning("No key entered and 'use provided key' not checked;"
80
+ f" using previously entered {model_provider} key", icon="⚠️")
81
+ return True
82
+
83
+ else:
84
+ st.warning("Enter an API key or check 'use provided key'"
85
+ " to get started", icon="⚠️")
86
+ return False
87
 
88
 
89
  @st.cache_resource
 
92
  memory = ConversationBufferWindowMemory(k=3, memory_key="chat_history",
93
  chat_memory=msgs,
94
  return_messages=True)
95
+ logging.info("setting up new chat memory")
96
  return memory
97
 
98
 
99
  @st.cache_resource
100
+ def get_chain(model_name, model_provider, _memory, temperature):
101
+ logging.info(f"setting up new chain with params {model_name}, {model_provider}, {temperature}")
102
+ if model_provider == "OpenAI":
103
+ llm = ChatOpenAI(model_name=model_name, temperature=temperature)
104
+ elif model_provider == "HuggingFace":
105
+ llm = HuggingFaceHub(repo_id=model_name,
106
+ model_kwargs={"temperature": temperature, "max_length": 64})
107
  prompt = ChatPromptTemplate(
108
  messages=[
109
  SystemMessagePromptTemplate.from_template(
 
122
  return conversation
123
 
124
 
125
+
126
+
127
 
128
  if __name__ == "__main__":
129
+ logging.basicConfig(level=logging.INFO)
130
+
131
  st.header("Basic chatbot")
 
132
  st.write("On small screens, click the `>` at top left to get started")
 
133
  with st.expander("How conversation history works"):
134
  st.write("To keep input lengths down and costs reasonable,"
135
  " this bot only 'remembers' the past three turns of conversation.")
136
  st.write("To clear all memory and start fresh, click 'Clear history'" )
 
137
  st.sidebar.title("Choose options and enter API key")
138
 
139
+ #### USER INPUT ######
 
 
 
 
 
140
  model_name = st.sidebar.selectbox(
141
  label = "Choose a model",
142
+ options = ["gpt-3.5-turbo (OpenAI)",
143
+ "bigscience/bloom (HuggingFace)"
144
+ ],
145
  help="Which LLM to use",
146
+ )
147
+
148
+ user_api_key = st.sidebar.text_input(
149
+ 'Enter your API Key',
150
+ type='password',
151
+ help="Enter an API key for the appropriate model provider",
152
+ value="")
153
+
154
+ use_provided_key = st.sidebar.checkbox(
155
+ "Or use provided key",
156
+ help="If you don't have a key, you can use mine; usage limits apply.",
157
+ )
158
+
159
+ st.sidebar.write("Set the decoding temperature. Higher temperatures give "
160
+ "more unpredictable outputs.")
161
 
162
  temperature = st.sidebar.slider(
163
  label="Temperature",
 
165
  max_value=1.0,
166
  step=0.1,
167
  value=0.9,
168
+ help="Set the decoding temperature"
169
  )
170
+ ##########################
171
 
172
+ model = model_name.split("(")[0].rstrip() # remove name of model provider
173
+ model_provider = model_name.split("(")[-1].split(")")[0]
174
+ key_manager = KeyManager()
175
+ if key_manager.configure_api_key(user_api_key, use_provided_key, model_provider):
176
+ # key_manager.list_keys()
177
  memory = setup_memory()
178
+ chain = get_chain(model, model_provider, memory, temperature)
179
  if st.button("Clear history"):
180
  chain.memory.clear()
181
+ # st.cache_resource.clear()
182
  for message in chain.memory.buffer: # display chat history
183
  st.chat_message(message.type).write(message.content)
184
  text = st.chat_input()
 
189
  result = chain.predict(input=text)
190
  with st.chat_message("assistant"):
191
  st.write(result)
192
+ except (AuthenticationError, ValueError):
193
  st.warning("Enter a valid API key", icon="⚠️")
194
 
195
 
196
 
197
+
198
+
199
 
200
 
201