Joshua Sundance Bailey commited on
Commit
da76025
Β·
1 Parent(s): 38e6840

session_state

Browse files
Files changed (1) hide show
  1. langchain-streamlit-demo/app.py +94 -49
langchain-streamlit-demo/app.py CHANGED
@@ -30,15 +30,27 @@ def st_init_null(*variable_names) -> None:
30
 
31
 
32
  st_init_null(
33
- "trace_link",
34
- "run_id",
 
 
 
 
 
 
 
35
  "model",
 
36
  "provider",
37
- "system_prompt",
38
- "llm",
39
- "chain",
40
  "retriever",
41
- "client",
 
 
 
 
 
 
42
  )
43
 
44
  # --- Memory ---
@@ -118,6 +130,11 @@ with sidebar:
118
 
119
  # --- Advanced Options ---
120
  with st.expander("Advanced Options", expanded=False):
 
 
 
 
 
121
  st.session_state.system_prompt = (
122
  st.text_area(
123
  "Custom Instructions",
@@ -150,7 +167,7 @@ with sidebar:
150
  # --- API Keys ---
151
  st.session_state.provider = _MODEL_DICT[st.session_state.model]
152
 
153
- provider_api_key = st.text_input(
154
  f"{st.session_state.provider} API key",
155
  value=api_key_from_env(st.session_state.provider) or "",
156
  type="password",
@@ -170,31 +187,31 @@ with sidebar:
170
  os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key
171
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
172
  os.environ["LANGCHAIN_PROJECT"] = langsmith_project
173
- client = Client(api_key=langsmith_api_key)
174
 
175
 
176
  # --- LLM Instantiation ---
177
- if provider_api_key:
178
  if st.session_state.provider == "OpenAI":
179
- llm = ChatOpenAI(
180
  model=st.session_state.model,
181
- openai_api_key=provider_api_key,
182
  temperature=temperature,
183
  streaming=True,
184
  max_tokens=max_tokens,
185
  )
186
  elif st.session_state.provider == "Anthropic":
187
- llm = ChatAnthropic(
188
  model_name=st.session_state.model,
189
- anthropic_api_key=provider_api_key,
190
  temperature=temperature,
191
  streaming=True,
192
  max_tokens_to_sample=max_tokens,
193
  )
194
  elif st.session_state.provider == "Anyscale Endpoints":
195
- llm = ChatAnyscale(
196
  model=st.session_state.model,
197
- anyscale_api_key=provider_api_key,
198
  temperature=temperature,
199
  streaming=True,
200
  max_tokens=max_tokens,
@@ -217,13 +234,13 @@ if st.session_state.llm:
217
  # if isinstance(retriever, BaseRetriever):
218
  # # --- Document Chat ---
219
  # chain = ConversationalRetrievalChain.from_llm(
220
- # llm,
221
  # retriever,
222
  # memory=_MEMORY,
223
  # )
224
  # else:
225
  # --- Regular Chat ---
226
- prompt = ChatPromptTemplate.from_messages(
227
  [
228
  (
229
  "system",
@@ -234,30 +251,33 @@ if st.session_state.llm:
234
  ],
235
  ).partial(time=lambda: str(datetime.now()))
236
  st.session_state.chain = LLMChain(
237
- prompt=prompt,
238
  llm=st.session_state.llm,
239
  memory=_MEMORY,
240
  )
241
 
242
  # --- Chat Input ---
243
- prompt = st.chat_input(placeholder="Ask me a question!")
244
- if prompt:
245
- st.chat_message("user").write(prompt)
246
  st.session_state.feedback_update = None
247
  st.session_state.feedback = None
248
 
249
  # --- Chat Output ---
250
  with st.chat_message("assistant", avatar="🦜"):
251
  message_placeholder = st.empty()
252
- stream_handler = StreamHandler(message_placeholder)
253
- runnable_config = RunnableConfig(
254
- callbacks=[st.session_state.run_collector, stream_handler],
 
 
 
255
  tags=["Streamlit Chat"],
256
  )
257
  try:
258
- full_response = st.session_state.chain.invoke(
259
- {"input": prompt},
260
- config=runnable_config,
261
  )["text"]
262
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
263
  st.error(
@@ -265,16 +285,17 @@ if st.session_state.llm:
265
  icon="❌",
266
  )
267
  st.stop()
268
- message_placeholder.markdown(full_response)
269
 
270
  # --- Tracing ---
271
- if client:
272
- run = st.session_state.run_collector.traced_runs[0]
 
273
  st.session_state.run_collector.traced_runs = []
274
- st.session_state.run_id = run.id
275
  wait_for_all_tracers()
276
- url = client.read_run(run.id).url
277
- st.session_state.trace_link = url
 
278
  with sidebar:
279
  st.markdown(
280
  f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
@@ -282,26 +303,50 @@ if st.session_state.llm:
282
  )
283
 
284
  # --- Feedback ---
285
- if client and st.session_state.get("run_id"):
286
- scores = {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0}
287
- feedback = streamlit_feedback(
288
- feedback_type="faces",
289
  optional_text_label="[Optional] Please provide an explanation",
290
  key=f"feedback_{st.session_state.run_id}",
291
  )
292
- if feedback:
293
- score = scores[feedback["score"]]
294
- feedback = client.create_feedback(
295
- st.session_state.run_id,
296
- feedback["type"],
297
- score=score,
298
- comment=feedback.get("text", None),
 
 
 
 
 
 
 
299
  )
300
- st.session_state.feedback = {
301
- "feedback_id": str(feedback.id),
302
- "score": score,
303
- }
304
- st.toast("Feedback recorded!", icon="πŸ“")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  else:
307
  st.error(f"Please enter a valid {st.session_state.provider} API key.", icon="❌")
 
30
 
31
 
32
  st_init_null(
33
+ "chain",
34
+ "client",
35
+ "feedback",
36
+ "feedback_option",
37
+ "feedback_record",
38
+ "feedback_type_str",
39
+ "feedback_update",
40
+ "full_response",
41
+ "llm",
42
  "model",
43
+ "prompt",
44
  "provider",
45
+ "provider_api_key",
 
 
46
  "retriever",
47
+ "run_collector",
48
+ "run_id",
49
+ "runnable_config",
50
+ "score",
51
+ "stream_handler",
52
+ "system_prompt",
53
+ "trace_link",
54
  )
55
 
56
  # --- Memory ---
 
130
 
131
  # --- Advanced Options ---
132
  with st.expander("Advanced Options", expanded=False):
133
+ st.markdown("## Feedback Scale")
134
+ st.session_state.feedback_option = (
135
+ "faces" if st.toggle(label="`Faces` ⇄ `Thumbs`", value=False) else "thumbs"
136
+ )
137
+
138
  st.session_state.system_prompt = (
139
  st.text_area(
140
  "Custom Instructions",
 
167
  # --- API Keys ---
168
  st.session_state.provider = _MODEL_DICT[st.session_state.model]
169
 
170
+ st.session_state.provider_api_key = st.text_input(
171
  f"{st.session_state.provider} API key",
172
  value=api_key_from_env(st.session_state.provider) or "",
173
  type="password",
 
187
  os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key
188
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
189
  os.environ["LANGCHAIN_PROJECT"] = langsmith_project
190
+ st.session_state.client = Client(api_key=langsmith_api_key)
191
 
192
 
193
  # --- LLM Instantiation ---
194
+ if st.session_state.provider_api_key:
195
  if st.session_state.provider == "OpenAI":
196
+ st.session_state.llm = ChatOpenAI(
197
  model=st.session_state.model,
198
+ openai_api_key=st.session_state.provider_api_key,
199
  temperature=temperature,
200
  streaming=True,
201
  max_tokens=max_tokens,
202
  )
203
  elif st.session_state.provider == "Anthropic":
204
+ st.session_state.llm = ChatAnthropic(
205
  model_name=st.session_state.model,
206
+ anthropic_api_key=st.session_state.provider_api_key,
207
  temperature=temperature,
208
  streaming=True,
209
  max_tokens_to_sample=max_tokens,
210
  )
211
  elif st.session_state.provider == "Anyscale Endpoints":
212
+ st.session_state.llm = ChatAnyscale(
213
  model=st.session_state.model,
214
+ anyscale_api_key=st.session_state.provider_api_key,
215
  temperature=temperature,
216
  streaming=True,
217
  max_tokens=max_tokens,
 
234
  # if isinstance(retriever, BaseRetriever):
235
  # # --- Document Chat ---
236
  # chain = ConversationalRetrievalChain.from_llm(
237
+ # st.session_state.llm,
238
  # retriever,
239
  # memory=_MEMORY,
240
  # )
241
  # else:
242
  # --- Regular Chat ---
243
+ st.session_state.prompt = ChatPromptTemplate.from_messages(
244
  [
245
  (
246
  "system",
 
251
  ],
252
  ).partial(time=lambda: str(datetime.now()))
253
  st.session_state.chain = LLMChain(
254
+ prompt=st.session_state.prompt,
255
  llm=st.session_state.llm,
256
  memory=_MEMORY,
257
  )
258
 
259
  # --- Chat Input ---
260
+ st.session_state.prompt = st.chat_input(placeholder="Ask me a question!")
261
+ if st.session_state.prompt:
262
+ st.chat_message("user").write(st.session_state.prompt)
263
  st.session_state.feedback_update = None
264
  st.session_state.feedback = None
265
 
266
  # --- Chat Output ---
267
  with st.chat_message("assistant", avatar="🦜"):
268
  message_placeholder = st.empty()
269
+ st.session_state.stream_handler = StreamHandler(message_placeholder)
270
+ st.session_state.runnable_config = RunnableConfig(
271
+ callbacks=[
272
+ st.session_state.run_collector,
273
+ st.session_state.stream_handler,
274
+ ],
275
  tags=["Streamlit Chat"],
276
  )
277
  try:
278
+ st.session_state.full_response = st.session_state.chain.invoke(
279
+ {"input": st.session_state.prompt},
280
+ config=st.session_state.runnable_config,
281
  )["text"]
282
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
283
  st.error(
 
285
  icon="❌",
286
  )
287
  st.stop()
288
+ message_placeholder.markdown(st.session_state.full_response)
289
 
290
  # --- Tracing ---
291
+ if st.session_state.client:
292
+ st.session_state.run = st.session_state.run_collector.traced_runs[0]
293
+ st.session_state.run_id = st.session_state.run.id
294
  st.session_state.run_collector.traced_runs = []
 
295
  wait_for_all_tracers()
296
+ st.session_state.trace_link = st.session_state.client.read_run(
297
+ st.session_state.run_id,
298
+ ).url
299
  with sidebar:
300
  st.markdown(
301
  f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
 
303
  )
304
 
305
  # --- Feedback ---
306
+ if st.session_state.client and st.session_state.get("run_id"):
307
+ st.session_state.feedback = streamlit_feedback(
308
+ feedback_type=st.session_state.feedback_option,
 
309
  optional_text_label="[Optional] Please provide an explanation",
310
  key=f"feedback_{st.session_state.run_id}",
311
  )
312
+
313
+ # Define score mappings for both "thumbs" and "faces" feedback systems
314
+ score_mappings: dict[str, dict[str, Union[int, float]]] = {
315
+ "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
316
+ "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
317
+ }
318
+
319
+ # Get the score mapping based on the selected feedback option
320
+ st.session_state.scores = score_mappings[st.session_state.feedback_option]
321
+
322
+ if st.session_state.feedback:
323
+ # Get the score from the selected feedback option's score mapping
324
+ st.session_state.score = st.session_state.scores.get(
325
+ st.session_state.feedback["score"],
326
  )
327
+
328
+ if st.session_state.score is not None:
329
+ # Formulate feedback type string incorporating the feedback option
330
+ # and score value
331
+ st.session_state.feedback_type_str = f"{st.session_state.feedback_option} {st.session_state.feedback['score']}"
332
+
333
+ # Record the feedback with the formulated feedback type string
334
+ # and optional comment
335
+ st.session_state.feedback_record = (
336
+ st.session_state.client.create_feedback(
337
+ st.session_state.run_id,
338
+ st.session_state.feedback_type_str,
339
+ score=st.session_state.score,
340
+ comment=st.session_state.feedback.get("text"),
341
+ )
342
+ )
343
+ st.session_state.feedback = {
344
+ "feedback_id": str(st.session_state.feedback_record.id),
345
+ "score": st.session_state.score,
346
+ }
347
+ st.toast("Feedback recorded!", icon="πŸ“")
348
+ else:
349
+ st.warning("Invalid feedback score.")
350
 
351
  else:
352
  st.error(f"Please enter a valid {st.session_state.provider} API key.", icon="❌")