Joshua Sundance Bailey commited on
Commit
1dd02fc
Β·
1 Parent(s): da76025
Files changed (1) hide show
  1. langchain-streamlit-demo/app.py +130 -124
langchain-streamlit-demo/app.py CHANGED
@@ -7,7 +7,7 @@ import openai
7
  import streamlit as st
8
  from langchain import LLMChain
9
  from langchain.callbacks.base import BaseCallbackHandler
10
- from langchain.callbacks.tracers.langchain import wait_for_all_tracers
11
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
12
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
13
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
@@ -39,6 +39,7 @@ st_init_null(
39
  "feedback_update",
40
  "full_response",
41
  "llm",
 
42
  "model",
43
  "prompt",
44
  "provider",
@@ -54,9 +55,9 @@ st_init_null(
54
  )
55
 
56
  # --- Memory ---
57
- _STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
58
- _MEMORY = ConversationBufferMemory(
59
- chat_memory=_STMEMORY,
60
  return_messages=True,
61
  memory_key="chat_history",
62
  )
@@ -73,11 +74,11 @@ class StreamHandler(BaseCallbackHandler):
73
  self.container.markdown(self.text)
74
 
75
 
76
- st.session_state.run_collector = RunCollectorCallbackHandler()
77
 
78
 
79
  # --- Model Selection Helpers ---
80
- _MODEL_DICT = {
81
  "gpt-3.5-turbo": "OpenAI",
82
  "gpt-4": "OpenAI",
83
  "claude-instant-v1": "Anthropic",
@@ -86,20 +87,31 @@ _MODEL_DICT = {
86
  "meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
87
  "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
88
  }
89
- _SUPPORTED_MODELS = list(_MODEL_DICT.keys())
90
 
91
 
92
- def api_key_from_env(provider_name: str) -> Union[str, None]:
93
- if provider_name == "OpenAI":
94
- return os.environ.get("OPENAI_API_KEY")
95
- elif provider_name == "Anthropic":
96
- return os.environ.get("ANTHROPIC_API_KEY")
97
- elif provider_name == "Anyscale Endpoints":
98
- return os.environ.get("ANYSCALE_API_KEY")
99
- elif provider_name == "LANGSMITH":
100
- return os.environ.get("LANGCHAIN_API_KEY")
101
- else:
102
- return None
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  # --- Sidebar ---
@@ -107,14 +119,10 @@ sidebar = st.sidebar
107
  with sidebar:
108
  st.markdown("# Menu")
109
 
110
- st.session_state.model = st.selectbox(
111
  label="Chat Model",
112
- options=_SUPPORTED_MODELS,
113
- index=_SUPPORTED_MODELS.index(
114
- st.session_state.model
115
- or os.environ.get("DEFAULT_MODEL")
116
- or "gpt-3.5-turbo",
117
- ),
118
  )
119
 
120
  # document_chat = st.checkbox(
@@ -124,94 +132,91 @@ with sidebar:
124
  # )
125
 
126
  if st.button("Clear message history"):
127
- _STMEMORY.clear()
128
  st.session_state.trace_link = None
129
  st.session_state.run_id = None
130
 
131
  # --- Advanced Options ---
132
  with st.expander("Advanced Options", expanded=False):
133
  st.markdown("## Feedback Scale")
134
- st.session_state.feedback_option = (
135
- "faces" if st.toggle(label="`Faces` ⇄ `Thumbs`", value=False) else "thumbs"
136
- )
137
 
138
- st.session_state.system_prompt = (
139
  st.text_area(
140
  "Custom Instructions",
141
- st.session_state.system_prompt
142
- or os.environ.get("DEFAULT_SYSTEM_PROMPT")
143
- or "You are a helpful chatbot.",
144
  help="Custom instructions to provide the language model to determine style, personality, etc.",
145
  )
146
  .strip()
147
  .replace("{", "{{")
148
  .replace("}", "}}")
149
  )
150
-
151
  temperature = st.slider(
152
  "Temperature",
153
- min_value=float(os.environ.get("MIN_TEMPERATURE", 0.0)),
154
- max_value=float(os.environ.get("MIN_TEMPERATURE", 1.0)),
155
- value=float(os.environ.get("DEFAULT_TEMPERATURE", 0.7)),
156
  help="Higher values give more random results.",
157
  )
158
 
159
  max_tokens = st.slider(
160
  "Max Tokens",
161
- min_value=int(os.environ.get("MIN_MAX_TOKENS", 1)),
162
- max_value=int(os.environ.get("MAX_MAX_TOKENS", 100000)),
163
- value=int(os.environ.get("DEFAULT_MAX_TOKENS", 1000)),
164
  help="Higher values give longer results.",
165
  )
166
 
167
  # --- API Keys ---
168
- st.session_state.provider = _MODEL_DICT[st.session_state.model]
169
 
170
- st.session_state.provider_api_key = st.text_input(
171
- f"{st.session_state.provider} API key",
172
- value=api_key_from_env(st.session_state.provider) or "",
173
  type="password",
174
  )
175
 
176
- langsmith_api_key = st.text_input(
177
  "LangSmith API Key (optional)",
178
- value=api_key_from_env("LANGSMITH") or "",
179
  type="password",
180
  )
181
- langsmith_project = st.text_input(
182
  "LangSmith Project Name",
183
- value=os.environ.get("LANGCHAIN_PROJECT") or "langchain-streamlit-demo",
184
  )
185
- if langsmith_api_key:
186
- os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
187
- os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key
188
- os.environ["LANGCHAIN_TRACING_V2"] = "true"
189
- os.environ["LANGCHAIN_PROJECT"] = langsmith_project
190
- st.session_state.client = Client(api_key=langsmith_api_key)
 
 
 
191
 
192
 
193
  # --- LLM Instantiation ---
194
- if st.session_state.provider_api_key:
195
- if st.session_state.provider == "OpenAI":
196
  st.session_state.llm = ChatOpenAI(
197
- model=st.session_state.model,
198
- openai_api_key=st.session_state.provider_api_key,
199
  temperature=temperature,
200
  streaming=True,
201
  max_tokens=max_tokens,
202
  )
203
- elif st.session_state.provider == "Anthropic":
204
  st.session_state.llm = ChatAnthropic(
205
- model_name=st.session_state.model,
206
- anthropic_api_key=st.session_state.provider_api_key,
207
  temperature=temperature,
208
  streaming=True,
209
  max_tokens_to_sample=max_tokens,
210
  )
211
- elif st.session_state.provider == "Anyscale Endpoints":
212
  st.session_state.llm = ChatAnyscale(
213
- model=st.session_state.model,
214
- anyscale_api_key=st.session_state.provider_api_key,
215
  temperature=temperature,
216
  streaming=True,
217
  max_tokens=max_tokens,
@@ -219,10 +224,10 @@ if st.session_state.provider_api_key:
219
 
220
 
221
  # --- Chat History ---
222
- if len(_STMEMORY.messages) == 0:
223
- _STMEMORY.add_ai_message("Hello! I'm a helpful AI chatbot. Ask me a question!")
224
 
225
- for msg in _STMEMORY.messages:
226
  st.chat_message(
227
  msg.type,
228
  avatar="🦜" if msg.type in ("ai", "assistant") else None,
@@ -240,72 +245,75 @@ if st.session_state.llm:
240
  # )
241
  # else:
242
  # --- Regular Chat ---
243
- st.session_state.prompt = ChatPromptTemplate.from_messages(
244
  [
245
  (
246
  "system",
247
- st.session_state.system_prompt + "\nIt's currently {time}.",
248
  ),
249
  MessagesPlaceholder(variable_name="chat_history"),
250
  ("human", "{input}"),
251
  ],
252
  ).partial(time=lambda: str(datetime.now()))
253
  st.session_state.chain = LLMChain(
254
- prompt=st.session_state.prompt,
255
  llm=st.session_state.llm,
256
- memory=_MEMORY,
257
  )
258
 
259
  # --- Chat Input ---
260
- st.session_state.prompt = st.chat_input(placeholder="Ask me a question!")
261
- if st.session_state.prompt:
262
- st.chat_message("user").write(st.session_state.prompt)
263
- st.session_state.feedback_update = None
264
- st.session_state.feedback = None
265
 
266
  # --- Chat Output ---
267
  with st.chat_message("assistant", avatar="🦜"):
268
  message_placeholder = st.empty()
269
- st.session_state.stream_handler = StreamHandler(message_placeholder)
270
- st.session_state.runnable_config = RunnableConfig(
271
- callbacks=[
272
- st.session_state.run_collector,
273
- st.session_state.stream_handler,
274
- ],
 
275
  tags=["Streamlit Chat"],
276
  )
277
  try:
278
- st.session_state.full_response = st.session_state.chain.invoke(
279
- {"input": st.session_state.prompt},
280
- config=st.session_state.runnable_config,
281
  )["text"]
282
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
283
  st.error(
284
- f"Please enter a valid {st.session_state.provider} API key.",
285
  icon="❌",
286
  )
287
- st.stop()
288
- message_placeholder.markdown(st.session_state.full_response)
289
-
290
- # --- Tracing ---
291
- if st.session_state.client:
292
- st.session_state.run = st.session_state.run_collector.traced_runs[0]
293
- st.session_state.run_id = st.session_state.run.id
294
- st.session_state.run_collector.traced_runs = []
295
- wait_for_all_tracers()
296
- st.session_state.trace_link = st.session_state.client.read_run(
297
- st.session_state.run_id,
298
- ).url
299
- with sidebar:
300
- st.markdown(
301
- f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
302
- unsafe_allow_html=True,
303
- )
 
 
304
 
305
  # --- Feedback ---
306
- if st.session_state.client and st.session_state.get("run_id"):
307
- st.session_state.feedback = streamlit_feedback(
308
- feedback_type=st.session_state.feedback_option,
309
  optional_text_label="[Optional] Please provide an explanation",
310
  key=f"feedback_{st.session_state.run_id}",
311
  )
@@ -317,36 +325,34 @@ if st.session_state.llm:
317
  }
318
 
319
  # Get the score mapping based on the selected feedback option
320
- st.session_state.scores = score_mappings[st.session_state.feedback_option]
321
 
322
- if st.session_state.feedback:
323
  # Get the score from the selected feedback option's score mapping
324
- st.session_state.score = st.session_state.scores.get(
325
- st.session_state.feedback["score"],
326
  )
327
 
328
- if st.session_state.score is not None:
329
  # Formulate feedback type string incorporating the feedback option
330
  # and score value
331
- st.session_state.feedback_type_str = f"{st.session_state.feedback_option} {st.session_state.feedback['score']}"
332
 
333
  # Record the feedback with the formulated feedback type string
334
  # and optional comment
335
- st.session_state.feedback_record = (
336
- st.session_state.client.create_feedback(
337
- st.session_state.run_id,
338
- st.session_state.feedback_type_str,
339
- score=st.session_state.score,
340
- comment=st.session_state.feedback.get("text"),
341
- )
342
  )
343
- st.session_state.feedback = {
344
- "feedback_id": str(st.session_state.feedback_record.id),
345
- "score": st.session_state.score,
346
- }
347
  st.toast("Feedback recorded!", icon="πŸ“")
348
  else:
349
  st.warning("Invalid feedback score.")
350
 
351
  else:
352
- st.error(f"Please enter a valid {st.session_state.provider} API key.", icon="❌")
 
7
  import streamlit as st
8
  from langchain import LLMChain
9
  from langchain.callbacks.base import BaseCallbackHandler
10
+ from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
11
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
12
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
13
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
 
39
  "feedback_update",
40
  "full_response",
41
  "llm",
42
+ "ls_tracer",
43
  "model",
44
  "prompt",
45
  "provider",
 
55
  )
56
 
57
  # --- Memory ---
58
+ STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
59
+ MEMORY = ConversationBufferMemory(
60
+ chat_memory=STMEMORY,
61
  return_messages=True,
62
  memory_key="chat_history",
63
  )
 
74
  self.container.markdown(self.text)
75
 
76
 
77
+ RUN_COLLECTOR = RunCollectorCallbackHandler()
78
 
79
 
80
  # --- Model Selection Helpers ---
81
+ MODEL_DICT = {
82
  "gpt-3.5-turbo": "OpenAI",
83
  "gpt-4": "OpenAI",
84
  "claude-instant-v1": "Anthropic",
 
87
  "meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
88
  "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
89
  }
90
+ SUPPORTED_MODELS = list(MODEL_DICT.keys())
91
 
92
 
93
+ # --- Constants from Environment Variables ---
94
+ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-3.5-turbo")
95
+ DEFAULT_SYSTEM_PROMPT = os.environ.get(
96
+ "DEFAULT_SYSTEM_PROMPT",
97
+ "You are a helpful chatbot.",
98
+ )
99
+ MIN_TEMP = float(os.environ.get("MIN_TEMPERATURE", 0.0))
100
+ MAX_TEMP = float(os.environ.get("MAX_TEMPERATURE", 1.0))
101
+ DEFAULT_TEMP = float(os.environ.get("DEFAULT_TEMPERATURE", 0.7))
102
+ MIN_MAX_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
103
+ MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
104
+ DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
105
+ DEFAULT_LANGSMITH_PROJECT = os.environ.get(
106
+ "LANGCHAIN_PROJECT",
107
+ "langchain-streamlit-demo",
108
+ )
109
+ PROVIDER_KEY_DICT = {
110
+ "OpenAI": os.environ.get("OPENAI_API_KEY", ""),
111
+ "Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
112
+ "Anyscale Endpoints": os.environ.get("ANYSCALE_API_KEY", ""),
113
+ "LANGSMITH": os.environ.get("LANGCHAIN_API_KEY", ""),
114
+ }
115
 
116
 
117
  # --- Sidebar ---
 
119
  with sidebar:
120
  st.markdown("# Menu")
121
 
122
+ model = st.selectbox(
123
  label="Chat Model",
124
+ options=SUPPORTED_MODELS,
125
+ index=SUPPORTED_MODELS.index(DEFAULT_MODEL),
 
 
 
 
126
  )
127
 
128
  # document_chat = st.checkbox(
 
132
  # )
133
 
134
  if st.button("Clear message history"):
135
+ STMEMORY.clear()
136
  st.session_state.trace_link = None
137
  st.session_state.run_id = None
138
 
139
  # --- Advanced Options ---
140
  with st.expander("Advanced Options", expanded=False):
141
  st.markdown("## Feedback Scale")
142
+ use_faces = st.toggle(label="`Thumbs` ⇄ `Faces`", value=False)
143
+ feedback_option = "faces" if use_faces else "thumbs"
 
144
 
145
+ system_prompt = (
146
  st.text_area(
147
  "Custom Instructions",
148
+ DEFAULT_SYSTEM_PROMPT,
 
 
149
  help="Custom instructions to provide the language model to determine style, personality, etc.",
150
  )
151
  .strip()
152
  .replace("{", "{{")
153
  .replace("}", "}}")
154
  )
 
155
  temperature = st.slider(
156
  "Temperature",
157
+ min_value=MIN_TEMP,
158
+ max_value=MAX_TEMP,
159
+ value=DEFAULT_TEMP,
160
  help="Higher values give more random results.",
161
  )
162
 
163
  max_tokens = st.slider(
164
  "Max Tokens",
165
+ min_value=MIN_MAX_TOKENS,
166
+ max_value=MAX_MAX_TOKENS,
167
+ value=DEFAULT_MAX_TOKENS,
168
  help="Higher values give longer results.",
169
  )
170
 
171
  # --- API Keys ---
172
+ provider = MODEL_DICT[model]
173
 
174
+ provider_api_key = PROVIDER_KEY_DICT.get(provider) or st.text_input(
175
+ f"{provider} API key",
 
176
  type="password",
177
  )
178
 
179
+ LANGSMITH_API_KEY = PROVIDER_KEY_DICT.get("LANGSMITH") or st.text_input(
180
  "LangSmith API Key (optional)",
 
181
  type="password",
182
  )
183
+ LANGSMITH_PROJECT = DEFAULT_LANGSMITH_PROJECT or st.text_input(
184
  "LangSmith Project Name",
185
+ value="langchain-streamlit-demo",
186
  )
187
+ if st.session_state.client is None and LANGSMITH_API_KEY:
188
+ st.session_state.client = Client(
189
+ api_url="https://api.smith.langchain.com",
190
+ api_key=LANGSMITH_API_KEY,
191
+ )
192
+ st.session_state.ls_tracer = LangChainTracer(
193
+ project_name=LANGSMITH_PROJECT,
194
+ client=st.session_state.client,
195
+ )
196
 
197
 
198
  # --- LLM Instantiation ---
199
+ if provider_api_key:
200
+ if provider == "OpenAI":
201
  st.session_state.llm = ChatOpenAI(
202
+ model=model,
203
+ openai_api_key=provider_api_key,
204
  temperature=temperature,
205
  streaming=True,
206
  max_tokens=max_tokens,
207
  )
208
+ elif provider == "Anthropic":
209
  st.session_state.llm = ChatAnthropic(
210
+ model_name=model,
211
+ anthropic_api_key=provider_api_key,
212
  temperature=temperature,
213
  streaming=True,
214
  max_tokens_to_sample=max_tokens,
215
  )
216
+ elif provider == "Anyscale Endpoints":
217
  st.session_state.llm = ChatAnyscale(
218
+ model=model,
219
+ anyscale_api_key=provider_api_key,
220
  temperature=temperature,
221
  streaming=True,
222
  max_tokens=max_tokens,
 
224
 
225
 
226
  # --- Chat History ---
227
+ if len(STMEMORY.messages) == 0:
228
+ STMEMORY.add_ai_message("Hello! I'm a helpful AI chatbot. Ask me a question!")
229
 
230
+ for msg in STMEMORY.messages:
231
  st.chat_message(
232
  msg.type,
233
  avatar="🦜" if msg.type in ("ai", "assistant") else None,
 
245
  # )
246
  # else:
247
  # --- Regular Chat ---
248
+ chat_prompt = ChatPromptTemplate.from_messages(
249
  [
250
  (
251
  "system",
252
+ system_prompt + "\nIt's currently {time}.",
253
  ),
254
  MessagesPlaceholder(variable_name="chat_history"),
255
  ("human", "{input}"),
256
  ],
257
  ).partial(time=lambda: str(datetime.now()))
258
  st.session_state.chain = LLMChain(
259
+ prompt=chat_prompt,
260
  llm=st.session_state.llm,
261
+ memory=MEMORY,
262
  )
263
 
264
  # --- Chat Input ---
265
+ prompt = st.chat_input(placeholder="Ask me a question!")
266
+ if prompt:
267
+ st.chat_message("user").write(prompt)
268
+ feedback_update = None
269
+ feedback = None
270
 
271
  # --- Chat Output ---
272
  with st.chat_message("assistant", avatar="🦜"):
273
  message_placeholder = st.empty()
274
+ stream_handler = StreamHandler(message_placeholder)
275
+ callbacks = [RUN_COLLECTOR, stream_handler]
276
+ if st.session_state.ls_tracer:
277
+ callbacks.append(st.session_state.ls_tracer)
278
+
279
+ runnable_config = RunnableConfig(
280
+ callbacks=callbacks,
281
  tags=["Streamlit Chat"],
282
  )
283
  try:
284
+ full_response = st.session_state.chain.invoke(
285
+ {"input": prompt},
286
+ config=runnable_config,
287
  )["text"]
288
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
289
  st.error(
290
+ f"Please enter a valid {provider} API key.",
291
  icon="❌",
292
  )
293
+ full_response = None
294
+ if full_response:
295
+ message_placeholder.markdown(full_response)
296
+
297
+ # --- Tracing ---
298
+ if st.session_state.client:
299
+ st.session_state.run = RUN_COLLECTOR.traced_runs[0]
300
+ st.session_state.run_id = st.session_state.run.id
301
+ RUN_COLLECTOR.traced_runs = []
302
+ wait_for_all_tracers()
303
+ st.session_state.trace_link = st.session_state.client.read_run(
304
+ st.session_state.run_id,
305
+ ).url
306
+ if st.session_state.trace_link:
307
+ with sidebar:
308
+ st.markdown(
309
+ f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
310
+ unsafe_allow_html=True,
311
+ )
312
 
313
  # --- Feedback ---
314
+ if st.session_state.client and st.session_state.run_id:
315
+ feedback = streamlit_feedback(
316
+ feedback_type=feedback_option,
317
  optional_text_label="[Optional] Please provide an explanation",
318
  key=f"feedback_{st.session_state.run_id}",
319
  )
 
325
  }
326
 
327
  # Get the score mapping based on the selected feedback option
328
+ scores = score_mappings[feedback_option]
329
 
330
+ if feedback:
331
  # Get the score from the selected feedback option's score mapping
332
+ score = scores.get(
333
+ feedback["score"],
334
  )
335
 
336
+ if score is not None:
337
  # Formulate feedback type string incorporating the feedback option
338
  # and score value
339
+ feedback_type_str = f"{feedback_option} {feedback['score']}"
340
 
341
  # Record the feedback with the formulated feedback type string
342
  # and optional comment
343
+ feedback_record = st.session_state.client.create_feedback(
344
+ st.session_state.run_id,
345
+ feedback_type_str,
346
+ score=score,
347
+ comment=feedback.get("text"),
 
 
348
  )
349
+ # feedback = {
350
+ # "feedback_id": str(feedback_record.id),
351
+ # "score": score,
352
+ # }
353
  st.toast("Feedback recorded!", icon="πŸ“")
354
  else:
355
  st.warning("Invalid feedback score.")
356
 
357
  else:
358
+ st.error(f"Please enter a valid {provider} API key.", icon="❌")