Joshua Sundance Bailey commited on
Commit
38e6840
Β·
1 Parent(s): ad0c8c1

experimental session_state

Browse files
Files changed (1) hide show
  1. langchain-streamlit-demo/app.py +209 -195
langchain-streamlit-demo/app.py CHANGED
@@ -10,21 +10,38 @@ from langchain.callbacks.base import BaseCallbackHandler
10
  from langchain.callbacks.tracers.langchain import wait_for_all_tracers
11
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
12
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
13
- from langchain.chat_models.base import BaseChatModel
14
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
15
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
16
  from langchain.schema.runnable import RunnableConfig
17
  from langsmith.client import Client
18
  from streamlit_feedback import streamlit_feedback
19
 
 
20
  st.set_page_config(
21
  page_title="langchain-streamlit-demo",
22
  page_icon="🦜",
23
  )
24
 
25
- st.sidebar.markdown("# Menu")
26
 
 
 
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  _STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
29
  _MEMORY = ConversationBufferMemory(
30
  chat_memory=_STMEMORY,
@@ -32,11 +49,22 @@ _MEMORY = ConversationBufferMemory(
32
  memory_key="chat_history",
33
  )
34
 
35
- _DEFAULT_SYSTEM_PROMPT = os.environ.get(
36
- "DEFAULT_SYSTEM_PROMPT",
37
- "You are a helpful chatbot.",
38
- )
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  _MODEL_DICT = {
41
  "gpt-3.5-turbo": "OpenAI",
42
  "gpt-4": "OpenAI",
@@ -47,106 +75,133 @@ _MODEL_DICT = {
47
  "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
48
  }
49
  _SUPPORTED_MODELS = list(_MODEL_DICT.keys())
50
- _DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-3.5-turbo")
51
-
52
- _DEFAULT_TEMPERATURE = float(os.environ.get("DEFAULT_TEMPERATURE", 0.7))
53
- _MIN_TEMPERATURE = float(os.environ.get("MIN_TEMPERATURE", 0.0))
54
- _MAX_TEMPERATURE = float(os.environ.get("MAX_TEMPERATURE", 1.0))
55
-
56
- _DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
57
- _MIN_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
58
- _MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
59
-
60
-
61
- def get_llm(
62
- model: str,
63
- provider_api_key: str,
64
- temperature: float,
65
- max_tokens: int = _DEFAULT_MAX_TOKENS,
66
- ) -> BaseChatModel:
67
- if _MODEL_DICT[model] == "OpenAI":
68
- return ChatOpenAI(
69
- model=model,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  openai_api_key=provider_api_key,
71
  temperature=temperature,
72
  streaming=True,
73
  max_tokens=max_tokens,
74
  )
75
- elif _MODEL_DICT[model] == "Anthropic":
76
- return ChatAnthropic(
77
- model_name=model,
78
  anthropic_api_key=provider_api_key,
79
  temperature=temperature,
80
  streaming=True,
81
  max_tokens_to_sample=max_tokens,
82
  )
83
- elif _MODEL_DICT[model] == "Anyscale Endpoints":
84
- return ChatAnyscale(
85
- model=model,
86
  anyscale_api_key=provider_api_key,
87
  temperature=temperature,
88
  streaming=True,
89
  max_tokens=max_tokens,
90
  )
91
- else:
92
- raise NotImplementedError(f"Unknown model {model}")
93
-
94
 
95
- def get_llm_chain(
96
- model: str,
97
- provider_api_key: str,
98
- system_prompt: str = _DEFAULT_SYSTEM_PROMPT,
99
- temperature: float = _DEFAULT_TEMPERATURE,
100
- max_tokens: int = _DEFAULT_MAX_TOKENS,
101
- ) -> LLMChain:
102
- """Return a basic LLMChain with memory."""
103
- prompt = ChatPromptTemplate.from_messages(
104
- [
105
- (
106
- "system",
107
- system_prompt + "\nIt's currently {time}.",
108
- ),
109
- MessagesPlaceholder(variable_name="chat_history"),
110
- ("human", "{input}"),
111
- ],
112
- ).partial(time=lambda: str(datetime.now()))
113
- llm = get_llm(model, provider_api_key, temperature, max_tokens)
114
- return LLMChain(prompt=prompt, llm=llm, memory=_MEMORY)
115
 
116
-
117
- class StreamHandler(BaseCallbackHandler):
118
- def __init__(self, container, initial_text=""):
119
- self.container = container
120
- self.text = initial_text
121
-
122
- def on_llm_new_token(self, token: str, **kwargs) -> None:
123
- self.text += token
124
- self.container.markdown(self.text)
125
-
126
-
127
- def feedback_component(client):
128
- scores = {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0}
129
- if feedback := streamlit_feedback(
130
- feedback_type="faces",
131
- optional_text_label="[Optional] Please provide an explanation",
132
- key=f"feedback_{st.session_state.run_id}",
133
- ):
134
- score = scores[feedback["score"]]
135
- feedback = client.create_feedback(
136
- st.session_state.run_id,
137
- feedback["type"],
138
- score=score,
139
- comment=feedback.get("text", None),
140
- )
141
- st.session_state.feedback = {"feedback_id": str(feedback.id), "score": score}
142
- st.toast("Feedback recorded!", icon="πŸ“")
143
-
144
-
145
- # Initialize State
146
- if "trace_link" not in st.session_state:
147
- st.session_state.trace_link = None
148
- if "run_id" not in st.session_state:
149
- st.session_state.run_id = None
150
  if len(_STMEMORY.messages) == 0:
151
  _STMEMORY.add_ai_message("Hello! I'm a helpful AI chatbot. Ask me a question!")
152
 
@@ -156,138 +211,97 @@ for msg in _STMEMORY.messages:
156
  avatar="🦜" if msg.type in ("ai", "assistant") else None,
157
  ).write(msg.content)
158
 
159
- model = st.sidebar.selectbox(
160
- label="Chat Model",
161
- options=_SUPPORTED_MODELS,
162
- index=_SUPPORTED_MODELS.index(_DEFAULT_MODEL),
163
- )
164
- provider = _MODEL_DICT[model]
165
-
166
-
167
- def api_key_from_env(_provider: str) -> Union[str, None]:
168
- if _provider == "OpenAI":
169
- return os.environ.get("OPENAI_API_KEY")
170
- elif _provider == "Anthropic":
171
- return os.environ.get("ANTHROPIC_API_KEY")
172
- elif _provider == "Anyscale Endpoints":
173
- return os.environ.get("ANYSCALE_API_KEY")
174
- elif _provider == "LANGSMITH":
175
- return os.environ.get("LANGCHAIN_API_KEY")
176
- else:
177
- return None
178
-
179
-
180
- provider_api_key = api_key_from_env(provider) or st.sidebar.text_input(
181
- f"{provider} API key",
182
- type="password",
183
- )
184
- langsmith_api_key = api_key_from_env("LANGSMITH") or st.sidebar.text_input(
185
- "LangSmith API Key (optional)",
186
- type="password",
187
- )
188
- if langsmith_api_key:
189
- langsmith_project = os.environ.get("LANGCHAIN_PROJECT") or st.sidebar.text_input(
190
- "LangSmith Project Name",
191
- value="langchain-streamlit-demo",
192
- )
193
- os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
194
- os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key
195
- os.environ["LANGCHAIN_TRACING_V2"] = "true"
196
- os.environ["LANGCHAIN_PROJECT"] = langsmith_project
197
-
198
- client = Client(api_key=langsmith_api_key)
199
- else:
200
- langsmith_project = None
201
- client = None
202
-
203
- system_prompt = (
204
- st.sidebar.text_area(
205
- "Custom Instructions",
206
- _DEFAULT_SYSTEM_PROMPT,
207
- help="Custom instructions to provide the language model to determine style, personality, etc.",
208
- )
209
- .strip()
210
- .replace("{", "{{")
211
- .replace("}", "}}")
212
- )
213
-
214
- if st.sidebar.button("Clear message history"):
215
- print("Clearing message history")
216
- _STMEMORY.clear()
217
- st.session_state.trace_link = None
218
- st.session_state.run_id = None
219
-
220
- temperature = st.sidebar.slider(
221
- "Temperature",
222
- min_value=_MIN_TEMPERATURE,
223
- max_value=_MAX_TEMPERATURE,
224
- value=_DEFAULT_TEMPERATURE,
225
- help="Higher values give more random results.",
226
- )
227
 
228
- max_tokens = st.sidebar.slider(
229
- "Max Tokens",
230
- min_value=_MIN_TOKENS,
231
- max_value=_MAX_TOKENS,
232
- value=_DEFAULT_MAX_TOKENS,
233
- help="Higher values give longer results.",
234
- )
235
- chain = None
236
- if provider_api_key:
237
- chain = get_llm_chain(
238
- model,
239
- provider_api_key,
240
- system_prompt,
241
- temperature,
242
- max_tokens,
 
 
 
 
 
 
 
 
 
 
243
  )
244
 
245
- run_collector = RunCollectorCallbackHandler()
246
-
247
-
248
- def _reset_feedback():
249
- st.session_state.feedback_update = None
250
- st.session_state.feedback = None
251
-
252
-
253
- if chain:
254
  prompt = st.chat_input(placeholder="Ask me a question!")
255
  if prompt:
256
  st.chat_message("user").write(prompt)
257
- _reset_feedback()
 
258
 
 
259
  with st.chat_message("assistant", avatar="🦜"):
260
  message_placeholder = st.empty()
261
  stream_handler = StreamHandler(message_placeholder)
262
  runnable_config = RunnableConfig(
263
- callbacks=[run_collector, stream_handler],
264
  tags=["Streamlit Chat"],
265
  )
266
  try:
267
- full_response = chain.invoke(
268
  {"input": prompt},
269
  config=runnable_config,
270
  )["text"]
271
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
272
- st.error(f"Please enter a valid {provider} API key.", icon="❌")
 
 
 
273
  st.stop()
274
  message_placeholder.markdown(full_response)
275
 
 
276
  if client:
277
- run = run_collector.traced_runs[0]
278
- run_collector.traced_runs = []
279
  st.session_state.run_id = run.id
280
  wait_for_all_tracers()
281
  url = client.read_run(run.id).url
282
  st.session_state.trace_link = url
 
 
 
 
 
 
 
283
  if client and st.session_state.get("run_id"):
284
- feedback_component(client)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
  else:
287
- st.error(f"Please enter a valid {provider} API key.", icon="❌")
288
-
289
- if client and st.session_state.get("trace_link"):
290
- st.sidebar.markdown(
291
- f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
292
- unsafe_allow_html=True,
293
- )
 
10
  from langchain.callbacks.tracers.langchain import wait_for_all_tracers
11
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
12
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
 
13
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
14
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
15
  from langchain.schema.runnable import RunnableConfig
16
  from langsmith.client import Client
17
  from streamlit_feedback import streamlit_feedback
18
 
19
+ # --- Initialization ---
20
  st.set_page_config(
21
  page_title="langchain-streamlit-demo",
22
  page_icon="🦜",
23
  )
24
 
 
25
 
26
+ def st_init_null(*variable_names) -> None:
27
+ for variable_name in variable_names:
28
+ if variable_name not in st.session_state:
29
+ st.session_state[variable_name] = None
30
 
31
+
32
+ st_init_null(
33
+ "trace_link",
34
+ "run_id",
35
+ "model",
36
+ "provider",
37
+ "system_prompt",
38
+ "llm",
39
+ "chain",
40
+ "retriever",
41
+ "client",
42
+ )
43
+
44
+ # --- Memory ---
45
  _STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
46
  _MEMORY = ConversationBufferMemory(
47
  chat_memory=_STMEMORY,
 
49
  memory_key="chat_history",
50
  )
51
 
 
 
 
 
52
 
53
+ # --- Callbacks ---
54
+ class StreamHandler(BaseCallbackHandler):
55
+ def __init__(self, container, initial_text=""):
56
+ self.container = container
57
+ self.text = initial_text
58
+
59
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
60
+ self.text += token
61
+ self.container.markdown(self.text)
62
+
63
+
64
+ st.session_state.run_collector = RunCollectorCallbackHandler()
65
+
66
+
67
+ # --- Model Selection Helpers ---
68
  _MODEL_DICT = {
69
  "gpt-3.5-turbo": "OpenAI",
70
  "gpt-4": "OpenAI",
 
75
  "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
76
  }
77
  _SUPPORTED_MODELS = list(_MODEL_DICT.keys())
78
+
79
+
80
+ def api_key_from_env(provider_name: str) -> Union[str, None]:
81
+ if provider_name == "OpenAI":
82
+ return os.environ.get("OPENAI_API_KEY")
83
+ elif provider_name == "Anthropic":
84
+ return os.environ.get("ANTHROPIC_API_KEY")
85
+ elif provider_name == "Anyscale Endpoints":
86
+ return os.environ.get("ANYSCALE_API_KEY")
87
+ elif provider_name == "LANGSMITH":
88
+ return os.environ.get("LANGCHAIN_API_KEY")
89
+ else:
90
+ return None
91
+
92
+
93
+ # --- Sidebar ---
94
+ sidebar = st.sidebar
95
+ with sidebar:
96
+ st.markdown("# Menu")
97
+
98
+ st.session_state.model = st.selectbox(
99
+ label="Chat Model",
100
+ options=_SUPPORTED_MODELS,
101
+ index=_SUPPORTED_MODELS.index(
102
+ st.session_state.model
103
+ or os.environ.get("DEFAULT_MODEL")
104
+ or "gpt-3.5-turbo",
105
+ ),
106
+ )
107
+
108
+ # document_chat = st.checkbox(
109
+ # "Document Chat",
110
+ # value=False,
111
+ # help="Upload a document",
112
+ # )
113
+
114
+ if st.button("Clear message history"):
115
+ _STMEMORY.clear()
116
+ st.session_state.trace_link = None
117
+ st.session_state.run_id = None
118
+
119
+ # --- Advanced Options ---
120
+ with st.expander("Advanced Options", expanded=False):
121
+ st.session_state.system_prompt = (
122
+ st.text_area(
123
+ "Custom Instructions",
124
+ st.session_state.system_prompt
125
+ or os.environ.get("DEFAULT_SYSTEM_PROMPT")
126
+ or "You are a helpful chatbot.",
127
+ help="Custom instructions to provide the language model to determine style, personality, etc.",
128
+ )
129
+ .strip()
130
+ .replace("{", "{{")
131
+ .replace("}", "}}")
132
+ )
133
+
134
+ temperature = st.slider(
135
+ "Temperature",
136
+ min_value=float(os.environ.get("MIN_TEMPERATURE", 0.0)),
137
+ max_value=float(os.environ.get("MIN_TEMPERATURE", 1.0)),
138
+ value=float(os.environ.get("DEFAULT_TEMPERATURE", 0.7)),
139
+ help="Higher values give more random results.",
140
+ )
141
+
142
+ max_tokens = st.slider(
143
+ "Max Tokens",
144
+ min_value=int(os.environ.get("MIN_MAX_TOKENS", 1)),
145
+ max_value=int(os.environ.get("MAX_MAX_TOKENS", 100000)),
146
+ value=int(os.environ.get("DEFAULT_MAX_TOKENS", 1000)),
147
+ help="Higher values give longer results.",
148
+ )
149
+
150
+ # --- API Keys ---
151
+ st.session_state.provider = _MODEL_DICT[st.session_state.model]
152
+
153
+ provider_api_key = st.text_input(
154
+ f"{st.session_state.provider} API key",
155
+ value=api_key_from_env(st.session_state.provider) or "",
156
+ type="password",
157
+ )
158
+
159
+ langsmith_api_key = st.text_input(
160
+ "LangSmith API Key (optional)",
161
+ value=api_key_from_env("LANGSMITH") or "",
162
+ type="password",
163
+ )
164
+ langsmith_project = st.text_input(
165
+ "LangSmith Project Name",
166
+ value=os.environ.get("LANGCHAIN_PROJECT") or "langchain-streamlit-demo",
167
+ )
168
+ if langsmith_api_key:
169
+ os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
170
+ os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key
171
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
172
+ os.environ["LANGCHAIN_PROJECT"] = langsmith_project
173
+ client = Client(api_key=langsmith_api_key)
174
+
175
+
176
+ # --- LLM Instantiation ---
177
+ if provider_api_key:
178
+ if st.session_state.provider == "OpenAI":
179
+ llm = ChatOpenAI(
180
+ model=st.session_state.model,
181
  openai_api_key=provider_api_key,
182
  temperature=temperature,
183
  streaming=True,
184
  max_tokens=max_tokens,
185
  )
186
+ elif st.session_state.provider == "Anthropic":
187
+ llm = ChatAnthropic(
188
+ model_name=st.session_state.model,
189
  anthropic_api_key=provider_api_key,
190
  temperature=temperature,
191
  streaming=True,
192
  max_tokens_to_sample=max_tokens,
193
  )
194
+ elif st.session_state.provider == "Anyscale Endpoints":
195
+ llm = ChatAnyscale(
196
+ model=st.session_state.model,
197
  anyscale_api_key=provider_api_key,
198
  temperature=temperature,
199
  streaming=True,
200
  max_tokens=max_tokens,
201
  )
 
 
 
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ # --- Chat History ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  if len(_STMEMORY.messages) == 0:
206
  _STMEMORY.add_ai_message("Hello! I'm a helpful AI chatbot. Ask me a question!")
207
 
 
211
  avatar="🦜" if msg.type in ("ai", "assistant") else None,
212
  ).write(msg.content)
213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ # --- Current Chat ---
216
+ if st.session_state.llm:
217
+ # if isinstance(retriever, BaseRetriever):
218
+ # # --- Document Chat ---
219
+ # chain = ConversationalRetrievalChain.from_llm(
220
+ # llm,
221
+ # retriever,
222
+ # memory=_MEMORY,
223
+ # )
224
+ # else:
225
+ # --- Regular Chat ---
226
+ prompt = ChatPromptTemplate.from_messages(
227
+ [
228
+ (
229
+ "system",
230
+ st.session_state.system_prompt + "\nIt's currently {time}.",
231
+ ),
232
+ MessagesPlaceholder(variable_name="chat_history"),
233
+ ("human", "{input}"),
234
+ ],
235
+ ).partial(time=lambda: str(datetime.now()))
236
+ st.session_state.chain = LLMChain(
237
+ prompt=prompt,
238
+ llm=st.session_state.llm,
239
+ memory=_MEMORY,
240
  )
241
 
242
+ # --- Chat Input ---
 
 
 
 
 
 
 
 
243
  prompt = st.chat_input(placeholder="Ask me a question!")
244
  if prompt:
245
  st.chat_message("user").write(prompt)
246
+ st.session_state.feedback_update = None
247
+ st.session_state.feedback = None
248
 
249
+ # --- Chat Output ---
250
  with st.chat_message("assistant", avatar="🦜"):
251
  message_placeholder = st.empty()
252
  stream_handler = StreamHandler(message_placeholder)
253
  runnable_config = RunnableConfig(
254
+ callbacks=[st.session_state.run_collector, stream_handler],
255
  tags=["Streamlit Chat"],
256
  )
257
  try:
258
+ full_response = st.session_state.chain.invoke(
259
  {"input": prompt},
260
  config=runnable_config,
261
  )["text"]
262
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
263
+ st.error(
264
+ f"Please enter a valid {st.session_state.provider} API key.",
265
+ icon="❌",
266
+ )
267
  st.stop()
268
  message_placeholder.markdown(full_response)
269
 
270
+ # --- Tracing ---
271
  if client:
272
+ run = st.session_state.run_collector.traced_runs[0]
273
+ st.session_state.run_collector.traced_runs = []
274
  st.session_state.run_id = run.id
275
  wait_for_all_tracers()
276
  url = client.read_run(run.id).url
277
  st.session_state.trace_link = url
278
+ with sidebar:
279
+ st.markdown(
280
+ f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: πŸ› οΈ</button></a>',
281
+ unsafe_allow_html=True,
282
+ )
283
+
284
+ # --- Feedback ---
285
  if client and st.session_state.get("run_id"):
286
+ scores = {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0}
287
+ feedback = streamlit_feedback(
288
+ feedback_type="faces",
289
+ optional_text_label="[Optional] Please provide an explanation",
290
+ key=f"feedback_{st.session_state.run_id}",
291
+ )
292
+ if feedback:
293
+ score = scores[feedback["score"]]
294
+ feedback = client.create_feedback(
295
+ st.session_state.run_id,
296
+ feedback["type"],
297
+ score=score,
298
+ comment=feedback.get("text", None),
299
+ )
300
+ st.session_state.feedback = {
301
+ "feedback_id": str(feedback.id),
302
+ "score": score,
303
+ }
304
+ st.toast("Feedback recorded!", icon="πŸ“")
305
 
306
  else:
307
+ st.error(f"Please enter a valid {st.session_state.provider} API key.", icon="❌")