Joshua Sundance Bailey commited on
Commit
4d225cc
·
unverified ·
2 Parent(s): b44a3fc c603886

Merge pull request #46 from joshuasundance-swca/cleanup

Browse files
.idea/langchain-streamlit-demo.iml CHANGED
@@ -1,7 +1,9 @@
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <module type="PYTHON_MODULE" version="4">
3
  <component name="NewModuleRootManager">
4
- <content url="file://$MODULE_DIR$" />
 
 
5
  <orderEntry type="jdk" jdkName="Remote Python 3.11.4 Docker (&lt;none&gt;:&lt;none&gt;) (5)" jdkType="Python SDK" />
6
  <orderEntry type="sourceFolder" forTests="false" />
7
  </component>
 
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <module type="PYTHON_MODULE" version="4">
3
  <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$">
5
+ <sourceFolder url="file://$MODULE_DIR$/langchain-streamlit-demo" isTestSource="false" />
6
+ </content>
7
  <orderEntry type="jdk" jdkName="Remote Python 3.11.4 Docker (&lt;none&gt;:&lt;none&gt;) (5)" jdkType="Python SDK" />
8
  <orderEntry type="sourceFolder" forTests="false" />
9
  </component>
kubernetes/resources.yaml CHANGED
@@ -39,6 +39,11 @@ spec:
39
  secretKeyRef:
40
  name: langchain-streamlit-demo-secret
41
  key: AZURE_OPENAI_DEPLOYMENT_NAME
 
 
 
 
 
42
  - name: AZURE_OPENAI_API_KEY
43
  valueFrom:
44
  secretKeyRef:
@@ -71,6 +76,10 @@ spec:
71
  key: LANGCHAIN_API_KEY
72
  - name: LANGCHAIN_PROJECT
73
  value: "langchain-streamlit-demo"
 
 
 
 
74
  securityContext:
75
  runAsNonRoot: true
76
  ---
 
39
  secretKeyRef:
40
  name: langchain-streamlit-demo-secret
41
  key: AZURE_OPENAI_DEPLOYMENT_NAME
42
+ - name: AZURE_OPENAI_EMB_DEPLOYMENT_NAME
43
+ valueFrom:
44
+ secretKeyRef:
45
+ name: langchain-streamlit-demo-secret
46
+ key: AZURE_OPENAI_EMB_DEPLOYMENT_NAME
47
  - name: AZURE_OPENAI_API_KEY
48
  valueFrom:
49
  secretKeyRef:
 
76
  key: LANGCHAIN_API_KEY
77
  - name: LANGCHAIN_PROJECT
78
  value: "langchain-streamlit-demo"
79
+ - name: SHOW_LANGCHAIN_OPTIONS
80
+ value: "False"
81
+ - name: SHOW_AZURE_OPTIONS
82
+ value: "False"
83
  securityContext:
84
  runAsNonRoot: true
85
  ---
langchain-streamlit-demo/app.py CHANGED
@@ -1,37 +1,22 @@
1
- import os
2
  from datetime import datetime
3
- from tempfile import NamedTemporaryFile
4
- from typing import Tuple, List, Dict, Any, Union
5
 
6
  import anthropic
7
  import langsmith.utils
8
  import openai
9
  import streamlit as st
10
- from langchain.callbacks.base import BaseCallbackHandler
11
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
12
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
13
- from langchain.chains import RetrievalQA
14
- from langchain.chains.llm import LLMChain
15
- from langchain.chat_models import (
16
- AzureChatOpenAI,
17
- ChatAnthropic,
18
- ChatAnyscale,
19
- ChatOpenAI,
20
- )
21
- from langchain.document_loaders import PyPDFLoader
22
- from langchain.embeddings import OpenAIEmbeddings
23
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
24
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
25
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
26
  from langchain.schema.document import Document
27
  from langchain.schema.retriever import BaseRetriever
28
- from langchain.text_splitter import RecursiveCharacterTextSplitter
29
- from langchain.vectorstores import FAISS
30
  from langsmith.client import Client
31
  from streamlit_feedback import streamlit_feedback
32
 
33
- from qagen import get_rag_qa_gen_chain
34
- from summarize import get_rag_summarization_chain
 
35
 
36
  __version__ = "0.0.13"
37
 
@@ -62,119 +47,72 @@ st_init_null(
62
  "trace_link",
63
  )
64
 
65
- # --- Memory ---
66
  STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
67
  MEMORY = ConversationBufferMemory(
68
  chat_memory=STMEMORY,
69
  return_messages=True,
70
  memory_key="chat_history",
71
  )
72
-
73
-
74
- # --- Callbacks ---
75
- class StreamHandler(BaseCallbackHandler):
76
- def __init__(self, container, initial_text=""):
77
- self.container = container
78
- self.text = initial_text
79
-
80
- def on_llm_new_token(self, token: str, **kwargs) -> None:
81
- self.text += token
82
- self.container.markdown(self.text)
83
-
84
-
85
  RUN_COLLECTOR = RunCollectorCallbackHandler()
86
 
87
-
88
- # --- Model Selection Helpers ---
89
- MODEL_DICT = {
90
- "gpt-3.5-turbo": "OpenAI",
91
- "gpt-4": "OpenAI",
92
- "claude-instant-v1": "Anthropic",
93
- "claude-2": "Anthropic",
94
- "meta-llama/Llama-2-7b-chat-hf": "Anyscale Endpoints",
95
- "meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
96
- "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
97
- "codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
98
- "Azure OpenAI": "Azure OpenAI",
99
- }
100
- SUPPORTED_MODELS = list(MODEL_DICT.keys())
101
-
102
-
103
- # --- Constants from Environment Variables ---
104
- DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-3.5-turbo")
105
- DEFAULT_SYSTEM_PROMPT = os.environ.get(
106
- "DEFAULT_SYSTEM_PROMPT",
107
- "You are a helpful chatbot.",
108
  )
109
- MIN_TEMP = float(os.environ.get("MIN_TEMPERATURE", 0.0))
110
- MAX_TEMP = float(os.environ.get("MAX_TEMPERATURE", 1.0))
111
- DEFAULT_TEMP = float(os.environ.get("DEFAULT_TEMPERATURE", 0.7))
112
- MIN_MAX_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
113
- MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
114
- DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
115
- DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
116
-
117
- AZURE_VARS = [
118
- "AZURE_OPENAI_BASE_URL",
119
- "AZURE_OPENAI_API_VERSION",
120
- "AZURE_OPENAI_DEPLOYMENT_NAME",
121
- "AZURE_OPENAI_API_KEY",
122
- "AZURE_OPENAI_MODEL_VERSION",
123
  ]
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- AZURE_DICT = {v: os.environ.get(v, "") for v in AZURE_VARS}
126
-
127
- PROVIDER_KEY_DICT = {
128
- "OpenAI": os.environ.get("OPENAI_API_KEY", ""),
129
- "Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
130
- "Anyscale Endpoints": os.environ.get("ANYSCALE_API_KEY", ""),
131
- "LANGSMITH": os.environ.get("LANGCHAIN_API_KEY", ""),
132
- }
133
- OPENAI_API_KEY = PROVIDER_KEY_DICT["OpenAI"]
134
-
135
- MIN_CHUNK_SIZE = 1
136
- MAX_CHUNK_SIZE = 10000
137
- DEFAULT_CHUNK_SIZE = 1000
138
-
139
- MIN_CHUNK_OVERLAP = 0
140
- MAX_CHUNK_OVERLAP = 10000
141
- DEFAULT_CHUNK_OVERLAP = 0
142
-
143
- DEFAULT_RETRIEVER_K = 4
144
 
145
 
146
  @st.cache_data
147
- def get_texts_and_retriever(
148
  uploaded_file_bytes: bytes,
149
- chunk_size: int = DEFAULT_CHUNK_SIZE,
150
- chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
151
- k: int = DEFAULT_RETRIEVER_K,
 
 
 
152
  ) -> Tuple[List[Document], BaseRetriever]:
153
- with NamedTemporaryFile() as temp_file:
154
- temp_file.write(uploaded_file_bytes)
155
- temp_file.seek(0)
156
-
157
- loader = PyPDFLoader(temp_file.name)
158
- documents = loader.load()
159
- text_splitter = RecursiveCharacterTextSplitter(
160
- chunk_size=chunk_size,
161
- chunk_overlap=chunk_overlap,
162
- )
163
- texts = text_splitter.split_documents(documents)
164
- embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
165
-
166
- bm25_retriever = BM25Retriever.from_documents(texts)
167
- bm25_retriever.k = k
168
-
169
- faiss_vectorstore = FAISS.from_documents(texts, embeddings)
170
- faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k})
171
-
172
- ensemble_retriever = EnsembleRetriever(
173
- retrievers=[bm25_retriever, faiss_retriever],
174
- weights=[0.5, 0.5],
175
- )
176
-
177
- return texts, ensemble_retriever
178
 
179
 
180
  # --- Sidebar ---
@@ -184,14 +122,14 @@ with sidebar:
184
 
185
  model = st.selectbox(
186
  label="Chat Model",
187
- options=SUPPORTED_MODELS,
188
- index=SUPPORTED_MODELS.index(DEFAULT_MODEL),
189
  )
190
 
191
- st.session_state.provider = MODEL_DICT[model]
192
 
193
  provider_api_key = (
194
- PROVIDER_KEY_DICT.get(
195
  st.session_state.provider,
196
  )
197
  or st.text_input(
@@ -214,7 +152,7 @@ with sidebar:
214
  openai_api_key = (
215
  provider_api_key
216
  if st.session_state.provider == "OpenAI"
217
- else OPENAI_API_KEY
218
  or st.sidebar.text_input("OpenAI API Key: ", type="password")
219
  )
220
 
@@ -227,7 +165,7 @@ with sidebar:
227
  k = st.slider(
228
  label="Number of Chunks",
229
  help="How many document chunks will be used for context?",
230
- value=DEFAULT_RETRIEVER_K,
231
  min_value=1,
232
  max_value=10,
233
  )
@@ -235,21 +173,23 @@ with sidebar:
235
  chunk_size = st.slider(
236
  label="Number of Tokens per Chunk",
237
  help="Size of each chunk of text",
238
- min_value=MIN_CHUNK_SIZE,
239
- max_value=MAX_CHUNK_SIZE,
240
- value=DEFAULT_CHUNK_SIZE,
241
  )
 
242
  chunk_overlap = st.slider(
243
  label="Chunk Overlap",
244
  help="Number of characters to overlap between chunks",
245
- min_value=MIN_CHUNK_OVERLAP,
246
- max_value=MAX_CHUNK_OVERLAP,
247
- value=DEFAULT_CHUNK_OVERLAP,
248
  )
249
 
250
  chain_type_help_root = (
251
  "https://python.langchain.com/docs/modules/chains/document/"
252
  )
 
253
  chain_type_help = "\n".join(
254
  f"- [{chain_type_name}]({chain_type_help_root}/{chain_type_name})"
255
  for chain_type_name in (
@@ -259,6 +199,7 @@ with sidebar:
259
  "map_rerank",
260
  )
261
  )
 
262
  document_chat_chain_type = st.selectbox(
263
  label="Document Chat Chain Type",
264
  options=[
@@ -273,17 +214,28 @@ with sidebar:
273
  help=chain_type_help,
274
  disabled=not document_chat,
275
  )
 
 
 
 
 
 
 
 
276
 
277
  if uploaded_file:
278
- if openai_api_key:
279
  (
280
  st.session_state.texts,
281
  st.session_state.retriever,
282
- ) = get_texts_and_retriever(
283
  uploaded_file_bytes=uploaded_file.getvalue(),
 
284
  chunk_size=chunk_size,
285
  chunk_overlap=chunk_overlap,
286
  k=k,
 
 
287
  )
288
  else:
289
  st.error("Please enter a valid OpenAI API key.", icon="❌")
@@ -297,123 +249,100 @@ with sidebar:
297
  system_prompt = (
298
  st.text_area(
299
  "Custom Instructions",
300
- DEFAULT_SYSTEM_PROMPT,
301
  help="Custom instructions to provide the language model to determine style, personality, etc.",
302
  )
303
  .strip()
304
  .replace("{", "{{")
305
  .replace("}", "}}")
306
  )
 
307
  temperature = st.slider(
308
  "Temperature",
309
- min_value=MIN_TEMP,
310
- max_value=MAX_TEMP,
311
- value=DEFAULT_TEMP,
312
  help="Higher values give more random results.",
313
  )
314
 
315
  max_tokens = st.slider(
316
  "Max Tokens",
317
- min_value=MIN_MAX_TOKENS,
318
- max_value=MAX_MAX_TOKENS,
319
- value=DEFAULT_MAX_TOKENS,
320
  help="Higher values give longer results.",
321
  )
322
 
323
  # --- LangSmith Options ---
324
- with st.expander("LangSmith Options", expanded=False):
325
- LANGSMITH_API_KEY = st.text_input(
326
- "LangSmith API Key (optional)",
327
- type="password",
328
- value=PROVIDER_KEY_DICT.get("LANGSMITH"),
329
- )
330
- LANGSMITH_PROJECT = st.text_input(
331
- "LangSmith Project Name",
332
- value=DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo",
333
- )
334
- if st.session_state.client is None and LANGSMITH_API_KEY:
335
- st.session_state.client = Client(
336
- api_url="https://api.smith.langchain.com",
337
- api_key=LANGSMITH_API_KEY,
338
  )
339
- st.session_state.ls_tracer = LangChainTracer(
340
- project_name=LANGSMITH_PROJECT,
341
- client=st.session_state.client,
 
342
  )
343
 
344
- # --- Azure Options ---
345
- with st.expander("Azure Options", expanded=False):
346
- AZURE_OPENAI_BASE_URL = st.text_input(
347
- "AZURE_OPENAI_BASE_URL",
348
- value=AZURE_DICT["AZURE_OPENAI_BASE_URL"],
349
- )
350
- AZURE_OPENAI_API_VERSION = st.text_input(
351
- "AZURE_OPENAI_API_VERSION",
352
- value=AZURE_DICT["AZURE_OPENAI_API_VERSION"],
353
- )
354
- AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
355
- "AZURE_OPENAI_DEPLOYMENT_NAME",
356
- value=AZURE_DICT["AZURE_OPENAI_DEPLOYMENT_NAME"],
357
- )
358
- AZURE_OPENAI_API_KEY = st.text_input(
359
- "AZURE_OPENAI_API_KEY",
360
- value=AZURE_DICT["AZURE_OPENAI_API_KEY"],
361
- type="password",
362
  )
363
- AZURE_OPENAI_MODEL_VERSION = st.text_input(
364
- "AZURE_OPENAI_MODEL_VERSION",
365
- value=AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"],
366
  )
367
 
368
- AZURE_AVAILABLE = all(
369
- [
370
- AZURE_OPENAI_BASE_URL,
371
- AZURE_OPENAI_API_VERSION,
372
- AZURE_OPENAI_DEPLOYMENT_NAME,
373
- AZURE_OPENAI_API_KEY,
374
- AZURE_OPENAI_MODEL_VERSION,
375
- ],
376
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
 
379
  # --- LLM Instantiation ---
380
- if provider_api_key:
381
- if st.session_state.provider == "OpenAI":
382
- st.session_state.llm = ChatOpenAI(
383
- model_name=model,
384
- openai_api_key=provider_api_key,
385
- temperature=temperature,
386
- streaming=True,
387
- max_tokens=max_tokens,
388
- )
389
- elif st.session_state.provider == "Anthropic":
390
- st.session_state.llm = ChatAnthropic(
391
- model=model,
392
- anthropic_api_key=provider_api_key,
393
- temperature=temperature,
394
- streaming=True,
395
- max_tokens_to_sample=max_tokens,
396
- )
397
- elif st.session_state.provider == "Anyscale Endpoints":
398
- st.session_state.llm = ChatAnyscale(
399
- model_name=model,
400
- anyscale_api_key=provider_api_key,
401
- temperature=temperature,
402
- streaming=True,
403
- max_tokens=max_tokens,
404
- )
405
- elif AZURE_AVAILABLE and st.session_state.provider == "Azure OpenAI":
406
- st.session_state.llm = AzureChatOpenAI(
407
- openai_api_base=AZURE_OPENAI_BASE_URL,
408
- openai_api_version=AZURE_OPENAI_API_VERSION,
409
- deployment_name=AZURE_OPENAI_DEPLOYMENT_NAME,
410
- openai_api_key=AZURE_OPENAI_API_KEY,
411
- openai_api_type="azure",
412
- model_version=AZURE_OPENAI_MODEL_VERSION,
413
- temperature=temperature,
414
- streaming=True,
415
- max_tokens=max_tokens,
416
- )
417
 
418
  # --- Chat History ---
419
  if len(STMEMORY.messages) == 0:
@@ -474,38 +403,17 @@ if st.session_state.llm:
474
  stream_handler = StreamHandler(message_placeholder)
475
  callbacks.append(stream_handler)
476
 
477
- def get_rag_runnable():
478
- if document_chat_chain_type == "Q&A Generation":
479
- return get_rag_qa_gen_chain(
480
- st.session_state.retriever,
481
- st.session_state.llm,
482
- )
483
- elif document_chat_chain_type == "Summarization":
484
- return get_rag_summarization_chain(
485
- prompt,
486
- st.session_state.retriever,
487
- st.session_state.llm,
488
- )
489
- else:
490
- return RetrievalQA.from_chain_type(
491
- llm=st.session_state.llm,
492
- chain_type=document_chat_chain_type,
493
- retriever=st.session_state.retriever,
494
- memory=MEMORY,
495
- output_key="output_text",
496
- ) | (lambda output: output["output_text"])
497
-
498
- st.session_state.chain = (
499
- get_rag_runnable()
500
- if use_document_chat
501
- else LLMChain(
502
- prompt=chat_prompt,
503
- llm=st.session_state.llm,
504
- memory=MEMORY,
505
- )
506
- | (lambda output: output["text"])
507
  )
508
 
 
509
  try:
510
  full_response = st.session_state.chain.invoke(prompt, config)
511
 
@@ -515,6 +423,7 @@ if st.session_state.llm:
515
  icon="❌",
516
  )
517
 
 
518
  if full_response is not None:
519
  message_placeholder.markdown(full_response)
520
 
@@ -530,6 +439,8 @@ if st.session_state.llm:
530
  ).url
531
  except langsmith.utils.LangSmithError:
532
  st.session_state.trace_link = None
 
 
533
  if st.session_state.trace_link:
534
  with sidebar:
535
  st.markdown(
@@ -573,10 +484,6 @@ if st.session_state.llm:
573
  score=score,
574
  comment=feedback.get("text"),
575
  )
576
- # feedback = {
577
- # "feedback_id": str(feedback_record.id),
578
- # "score": score,
579
- # }
580
  st.toast("Feedback recorded!", icon="📝")
581
  else:
582
  st.warning("Invalid feedback score.")
 
 
1
  from datetime import datetime
2
+ from typing import Tuple, List, Dict, Any, Union, Optional
 
3
 
4
  import anthropic
5
  import langsmith.utils
6
  import openai
7
  import streamlit as st
 
8
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
9
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
 
 
 
 
 
 
 
 
 
 
10
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
11
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
 
12
  from langchain.schema.document import Document
13
  from langchain.schema.retriever import BaseRetriever
 
 
14
  from langsmith.client import Client
15
  from streamlit_feedback import streamlit_feedback
16
 
17
+ from defaults import default_values
18
+
19
+ from llm_resources import get_runnable, get_llm, get_texts_and_retriever, StreamHandler
20
 
21
  __version__ = "0.0.13"
22
 
 
47
  "trace_link",
48
  )
49
 
50
+ # --- LLM globals ---
51
  STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
52
  MEMORY = ConversationBufferMemory(
53
  chat_memory=STMEMORY,
54
  return_messages=True,
55
  memory_key="chat_history",
56
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  RUN_COLLECTOR = RunCollectorCallbackHandler()
58
 
59
+ LANGSMITH_API_KEY = default_values.PROVIDER_KEY_DICT.get("LANGSMITH")
60
+ LANGSMITH_PROJECT = (
61
+ default_values.DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  )
63
+ AZURE_OPENAI_BASE_URL = default_values.AZURE_DICT["AZURE_OPENAI_BASE_URL"]
64
+ AZURE_OPENAI_API_VERSION = default_values.AZURE_DICT["AZURE_OPENAI_API_VERSION"]
65
+ AZURE_OPENAI_DEPLOYMENT_NAME = default_values.AZURE_DICT["AZURE_OPENAI_DEPLOYMENT_NAME"]
66
+ AZURE_OPENAI_EMB_DEPLOYMENT_NAME = default_values.AZURE_DICT[
67
+ "AZURE_OPENAI_EMB_DEPLOYMENT_NAME"
 
 
 
 
 
 
 
 
 
68
  ]
69
+ AZURE_OPENAI_API_KEY = default_values.AZURE_DICT["AZURE_OPENAI_API_KEY"]
70
+ AZURE_OPENAI_MODEL_VERSION = default_values.AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"]
71
+
72
+ AZURE_AVAILABLE = all(
73
+ [
74
+ AZURE_OPENAI_BASE_URL,
75
+ AZURE_OPENAI_API_VERSION,
76
+ AZURE_OPENAI_DEPLOYMENT_NAME,
77
+ AZURE_OPENAI_API_KEY,
78
+ AZURE_OPENAI_MODEL_VERSION,
79
+ ],
80
+ )
81
 
82
+ AZURE_EMB_AVAILABLE = AZURE_AVAILABLE and AZURE_OPENAI_EMB_DEPLOYMENT_NAME
83
+
84
+ AZURE_KWARGS = (
85
+ None
86
+ if not AZURE_EMB_AVAILABLE
87
+ else {
88
+ "openai_api_base": AZURE_OPENAI_BASE_URL,
89
+ "openai_api_version": AZURE_OPENAI_API_VERSION,
90
+ "deployment": AZURE_OPENAI_EMB_DEPLOYMENT_NAME,
91
+ "openai_api_key": AZURE_OPENAI_API_KEY,
92
+ "openai_api_type": "azure",
93
+ }
94
+ )
 
 
 
 
 
 
95
 
96
 
97
  @st.cache_data
98
+ def get_texts_and_retriever_cacheable_wrapper(
99
  uploaded_file_bytes: bytes,
100
+ openai_api_key: str,
101
+ chunk_size: int = default_values.DEFAULT_CHUNK_SIZE,
102
+ chunk_overlap: int = default_values.DEFAULT_CHUNK_OVERLAP,
103
+ k: int = default_values.DEFAULT_RETRIEVER_K,
104
+ azure_kwargs: Optional[Dict[str, str]] = None,
105
+ use_azure: bool = False,
106
  ) -> Tuple[List[Document], BaseRetriever]:
107
+ return get_texts_and_retriever(
108
+ uploaded_file_bytes=uploaded_file_bytes,
109
+ openai_api_key=openai_api_key,
110
+ chunk_size=chunk_size,
111
+ chunk_overlap=chunk_overlap,
112
+ k=k,
113
+ azure_kwargs=azure_kwargs,
114
+ use_azure=use_azure,
115
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
 
118
  # --- Sidebar ---
 
122
 
123
  model = st.selectbox(
124
  label="Chat Model",
125
+ options=default_values.SUPPORTED_MODELS,
126
+ index=default_values.SUPPORTED_MODELS.index(default_values.DEFAULT_MODEL),
127
  )
128
 
129
+ st.session_state.provider = default_values.MODEL_DICT[model]
130
 
131
  provider_api_key = (
132
+ default_values.PROVIDER_KEY_DICT.get(
133
  st.session_state.provider,
134
  )
135
  or st.text_input(
 
152
  openai_api_key = (
153
  provider_api_key
154
  if st.session_state.provider == "OpenAI"
155
+ else default_values.OPENAI_API_KEY
156
  or st.sidebar.text_input("OpenAI API Key: ", type="password")
157
  )
158
 
 
165
  k = st.slider(
166
  label="Number of Chunks",
167
  help="How many document chunks will be used for context?",
168
+ value=default_values.DEFAULT_RETRIEVER_K,
169
  min_value=1,
170
  max_value=10,
171
  )
 
173
  chunk_size = st.slider(
174
  label="Number of Tokens per Chunk",
175
  help="Size of each chunk of text",
176
+ min_value=default_values.MIN_CHUNK_SIZE,
177
+ max_value=default_values.MAX_CHUNK_SIZE,
178
+ value=default_values.DEFAULT_CHUNK_SIZE,
179
  )
180
+
181
  chunk_overlap = st.slider(
182
  label="Chunk Overlap",
183
  help="Number of characters to overlap between chunks",
184
+ min_value=default_values.MIN_CHUNK_OVERLAP,
185
+ max_value=default_values.MAX_CHUNK_OVERLAP,
186
+ value=default_values.DEFAULT_CHUNK_OVERLAP,
187
  )
188
 
189
  chain_type_help_root = (
190
  "https://python.langchain.com/docs/modules/chains/document/"
191
  )
192
+
193
  chain_type_help = "\n".join(
194
  f"- [{chain_type_name}]({chain_type_help_root}/{chain_type_name})"
195
  for chain_type_name in (
 
199
  "map_rerank",
200
  )
201
  )
202
+
203
  document_chat_chain_type = st.selectbox(
204
  label="Document Chat Chain Type",
205
  options=[
 
214
  help=chain_type_help,
215
  disabled=not document_chat,
216
  )
217
+ use_azure = False
218
+
219
+ if AZURE_EMB_AVAILABLE:
220
+ use_azure = st.toggle(
221
+ label="Use Azure OpenAI",
222
+ value=AZURE_EMB_AVAILABLE,
223
+ help="Use Azure for embeddings instead of using OpenAI directly.",
224
+ )
225
 
226
  if uploaded_file:
227
+ if AZURE_EMB_AVAILABLE or openai_api_key:
228
  (
229
  st.session_state.texts,
230
  st.session_state.retriever,
231
+ ) = get_texts_and_retriever_cacheable_wrapper(
232
  uploaded_file_bytes=uploaded_file.getvalue(),
233
+ openai_api_key=openai_api_key,
234
  chunk_size=chunk_size,
235
  chunk_overlap=chunk_overlap,
236
  k=k,
237
+ azure_kwargs=AZURE_KWARGS,
238
+ use_azure=use_azure,
239
  )
240
  else:
241
  st.error("Please enter a valid OpenAI API key.", icon="❌")
 
249
  system_prompt = (
250
  st.text_area(
251
  "Custom Instructions",
252
+ default_values.DEFAULT_SYSTEM_PROMPT,
253
  help="Custom instructions to provide the language model to determine style, personality, etc.",
254
  )
255
  .strip()
256
  .replace("{", "{{")
257
  .replace("}", "}}")
258
  )
259
+
260
  temperature = st.slider(
261
  "Temperature",
262
+ min_value=default_values.MIN_TEMP,
263
+ max_value=default_values.MAX_TEMP,
264
+ value=default_values.DEFAULT_TEMP,
265
  help="Higher values give more random results.",
266
  )
267
 
268
  max_tokens = st.slider(
269
  "Max Tokens",
270
+ min_value=default_values.MIN_MAX_TOKENS,
271
+ max_value=default_values.MAX_MAX_TOKENS,
272
+ value=default_values.DEFAULT_MAX_TOKENS,
273
  help="Higher values give longer results.",
274
  )
275
 
276
  # --- LangSmith Options ---
277
+ if default_values.SHOW_LANGSMITH_OPTIONS:
278
+ with st.expander("LangSmith Options", expanded=False):
279
+ LANGSMITH_API_KEY = st.text_input(
280
+ "LangSmith API Key (optional)",
281
+ value=LANGSMITH_API_KEY,
282
+ type="password",
 
 
 
 
 
 
 
 
283
  )
284
+
285
+ LANGSMITH_PROJECT = st.text_input(
286
+ "LangSmith Project Name",
287
+ value=LANGSMITH_PROJECT,
288
  )
289
 
290
+ if st.session_state.client is None and LANGSMITH_API_KEY:
291
+ st.session_state.client = Client(
292
+ api_url="https://api.smith.langchain.com",
293
+ api_key=LANGSMITH_API_KEY,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  )
295
+ st.session_state.ls_tracer = LangChainTracer(
296
+ project_name=LANGSMITH_PROJECT,
297
+ client=st.session_state.client,
298
  )
299
 
300
+ # --- Azure Options ---
301
+ if default_values.SHOW_AZURE_OPTIONS:
302
+ with st.expander("Azure Options", expanded=False):
303
+ AZURE_OPENAI_BASE_URL = st.text_input(
304
+ "AZURE_OPENAI_BASE_URL",
305
+ value=AZURE_OPENAI_BASE_URL,
306
+ )
307
+
308
+ AZURE_OPENAI_API_VERSION = st.text_input(
309
+ "AZURE_OPENAI_API_VERSION",
310
+ value=AZURE_OPENAI_API_VERSION,
311
+ )
312
+
313
+ AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
314
+ "AZURE_OPENAI_DEPLOYMENT_NAME",
315
+ value=AZURE_OPENAI_DEPLOYMENT_NAME,
316
+ )
317
+
318
+ AZURE_OPENAI_API_KEY = st.text_input(
319
+ "AZURE_OPENAI_API_KEY",
320
+ value=AZURE_OPENAI_API_KEY,
321
+ type="password",
322
+ )
323
+
324
+ AZURE_OPENAI_MODEL_VERSION = st.text_input(
325
+ "AZURE_OPENAI_MODEL_VERSION",
326
+ value=AZURE_OPENAI_MODEL_VERSION,
327
+ )
328
 
329
 
330
  # --- LLM Instantiation ---
331
+ st.session_state.llm = get_llm(
332
+ provider=st.session_state.provider,
333
+ model=model,
334
+ provider_api_key=provider_api_key,
335
+ temperature=temperature,
336
+ max_tokens=max_tokens,
337
+ azure_available=AZURE_AVAILABLE,
338
+ azure_dict={
339
+ "AZURE_OPENAI_BASE_URL": AZURE_OPENAI_BASE_URL,
340
+ "AZURE_OPENAI_API_VERSION": AZURE_OPENAI_API_VERSION,
341
+ "AZURE_OPENAI_DEPLOYMENT_NAME": AZURE_OPENAI_DEPLOYMENT_NAME,
342
+ "AZURE_OPENAI_API_KEY": AZURE_OPENAI_API_KEY,
343
+ "AZURE_OPENAI_MODEL_VERSION": AZURE_OPENAI_MODEL_VERSION,
344
+ },
345
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
347
  # --- Chat History ---
348
  if len(STMEMORY.messages) == 0:
 
403
  stream_handler = StreamHandler(message_placeholder)
404
  callbacks.append(stream_handler)
405
 
406
+ st.session_state.chain = get_runnable(
407
+ use_document_chat,
408
+ document_chat_chain_type,
409
+ st.session_state.llm,
410
+ st.session_state.retriever,
411
+ MEMORY,
412
+ chat_prompt,
413
+ prompt,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  )
415
 
416
+ # --- LLM call ---
417
  try:
418
  full_response = st.session_state.chain.invoke(prompt, config)
419
 
 
423
  icon="❌",
424
  )
425
 
426
+ # --- Display output ---
427
  if full_response is not None:
428
  message_placeholder.markdown(full_response)
429
 
 
439
  ).url
440
  except langsmith.utils.LangSmithError:
441
  st.session_state.trace_link = None
442
+
443
+ # --- LangSmith Trace Link ---
444
  if st.session_state.trace_link:
445
  with sidebar:
446
  st.markdown(
 
484
  score=score,
485
  comment=feedback.get("text"),
486
  )
 
 
 
 
487
  st.toast("Feedback recorded!", icon="📝")
488
  else:
489
  st.warning("Invalid feedback score.")
langchain-streamlit-demo/defaults.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import namedtuple
3
+
4
+
5
+ MODEL_DICT = {
6
+ "gpt-3.5-turbo": "OpenAI",
7
+ "gpt-4": "OpenAI",
8
+ "claude-instant-v1": "Anthropic",
9
+ "claude-2": "Anthropic",
10
+ "meta-llama/Llama-2-7b-chat-hf": "Anyscale Endpoints",
11
+ "meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
12
+ "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
13
+ "codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
14
+ "Azure OpenAI": "Azure OpenAI",
15
+ }
16
+
17
+ SUPPORTED_MODELS = list(MODEL_DICT.keys())
18
+
19
+ DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-3.5-turbo")
20
+
21
+ DEFAULT_SYSTEM_PROMPT = os.environ.get(
22
+ "DEFAULT_SYSTEM_PROMPT",
23
+ "You are a helpful chatbot.",
24
+ )
25
+
26
+ MIN_TEMP = float(os.environ.get("MIN_TEMPERATURE", 0.0))
27
+ MAX_TEMP = float(os.environ.get("MAX_TEMPERATURE", 1.0))
28
+ DEFAULT_TEMP = float(os.environ.get("DEFAULT_TEMPERATURE", 0.7))
29
+
30
+ MIN_MAX_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
31
+ MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
32
+ DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
33
+
34
+ DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
35
+
36
+ AZURE_VARS = [
37
+ "AZURE_OPENAI_BASE_URL",
38
+ "AZURE_OPENAI_API_VERSION",
39
+ "AZURE_OPENAI_DEPLOYMENT_NAME",
40
+ "AZURE_OPENAI_EMB_DEPLOYMENT_NAME",
41
+ "AZURE_OPENAI_API_KEY",
42
+ "AZURE_OPENAI_MODEL_VERSION",
43
+ ]
44
+
45
+ AZURE_DICT = {v: os.environ.get(v, "") for v in AZURE_VARS}
46
+
47
+
48
+ SHOW_LANGSMITH_OPTIONS = (
49
+ os.environ.get("SHOW_LANGSMITH_OPTIONS", "true").lower() == "true"
50
+ )
51
+ SHOW_AZURE_OPTIONS = os.environ.get("SHOW_AZURE_OPTIONS", "true").lower() == "true"
52
+
53
+ PROVIDER_KEY_DICT = {
54
+ "OpenAI": os.environ.get("OPENAI_API_KEY", ""),
55
+ "Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
56
+ "Anyscale Endpoints": os.environ.get("ANYSCALE_API_KEY", ""),
57
+ "LANGSMITH": os.environ.get("LANGCHAIN_API_KEY", ""),
58
+ }
59
+
60
+ OPENAI_API_KEY = PROVIDER_KEY_DICT["OpenAI"]
61
+
62
+
63
+ MIN_CHUNK_SIZE = 1
64
+ MAX_CHUNK_SIZE = 10000
65
+ DEFAULT_CHUNK_SIZE = 1000
66
+
67
+ MIN_CHUNK_OVERLAP = 0
68
+ MAX_CHUNK_OVERLAP = 10000
69
+ DEFAULT_CHUNK_OVERLAP = 0
70
+
71
+ DEFAULT_RETRIEVER_K = 4
72
+
73
+ DEFAULT_VALUES = namedtuple(
74
+ "DEFAULT_VALUES",
75
+ [
76
+ "MODEL_DICT",
77
+ "SUPPORTED_MODELS",
78
+ "DEFAULT_MODEL",
79
+ "DEFAULT_SYSTEM_PROMPT",
80
+ "MIN_TEMP",
81
+ "MAX_TEMP",
82
+ "DEFAULT_TEMP",
83
+ "MIN_MAX_TOKENS",
84
+ "MAX_MAX_TOKENS",
85
+ "DEFAULT_MAX_TOKENS",
86
+ "DEFAULT_LANGSMITH_PROJECT",
87
+ "AZURE_VARS",
88
+ "AZURE_DICT",
89
+ "PROVIDER_KEY_DICT",
90
+ "OPENAI_API_KEY",
91
+ "MIN_CHUNK_SIZE",
92
+ "MAX_CHUNK_SIZE",
93
+ "DEFAULT_CHUNK_SIZE",
94
+ "MIN_CHUNK_OVERLAP",
95
+ "MAX_CHUNK_OVERLAP",
96
+ "DEFAULT_CHUNK_OVERLAP",
97
+ "DEFAULT_RETRIEVER_K",
98
+ "SHOW_LANGSMITH_OPTIONS",
99
+ "SHOW_AZURE_OPTIONS",
100
+ ],
101
+ )
102
+
103
+
104
+ default_values = DEFAULT_VALUES(
105
+ MODEL_DICT,
106
+ SUPPORTED_MODELS,
107
+ DEFAULT_MODEL,
108
+ DEFAULT_SYSTEM_PROMPT,
109
+ MIN_TEMP,
110
+ MAX_TEMP,
111
+ DEFAULT_TEMP,
112
+ MIN_MAX_TOKENS,
113
+ MAX_MAX_TOKENS,
114
+ DEFAULT_MAX_TOKENS,
115
+ DEFAULT_LANGSMITH_PROJECT,
116
+ AZURE_VARS,
117
+ AZURE_DICT,
118
+ PROVIDER_KEY_DICT,
119
+ OPENAI_API_KEY,
120
+ MIN_CHUNK_SIZE,
121
+ MAX_CHUNK_SIZE,
122
+ DEFAULT_CHUNK_SIZE,
123
+ MIN_CHUNK_OVERLAP,
124
+ MAX_CHUNK_OVERLAP,
125
+ DEFAULT_CHUNK_OVERLAP,
126
+ DEFAULT_RETRIEVER_K,
127
+ SHOW_LANGSMITH_OPTIONS,
128
+ SHOW_AZURE_OPTIONS,
129
+ )
langchain-streamlit-demo/llm_resources.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tempfile import NamedTemporaryFile
2
+ from typing import Tuple, List, Optional, Dict
3
+
4
+ from langchain.callbacks.base import BaseCallbackHandler
5
+ from langchain.chains import RetrievalQA, LLMChain
6
+ from langchain.chat_models import (
7
+ AzureChatOpenAI,
8
+ ChatOpenAI,
9
+ ChatAnthropic,
10
+ ChatAnyscale,
11
+ )
12
+ from langchain.document_loaders import PyPDFLoader
13
+ from langchain.embeddings import OpenAIEmbeddings
14
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
15
+ from langchain.schema import Document, BaseRetriever
16
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
17
+ from langchain.vectorstores import FAISS
18
+
19
+ from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
20
+ from qagen import get_rag_qa_gen_chain
21
+ from summarize import get_rag_summarization_chain
22
+
23
+
24
+ def get_runnable(
25
+ use_document_chat: bool,
26
+ document_chat_chain_type: str,
27
+ llm,
28
+ retriever,
29
+ memory,
30
+ chat_prompt,
31
+ summarization_prompt,
32
+ ):
33
+ if not use_document_chat:
34
+ return LLMChain(
35
+ prompt=chat_prompt,
36
+ llm=llm,
37
+ memory=memory,
38
+ ) | (lambda output: output["text"])
39
+
40
+ if document_chat_chain_type == "Q&A Generation":
41
+ return get_rag_qa_gen_chain(
42
+ retriever,
43
+ llm,
44
+ )
45
+ elif document_chat_chain_type == "Summarization":
46
+ return get_rag_summarization_chain(
47
+ summarization_prompt,
48
+ retriever,
49
+ llm,
50
+ )
51
+ else:
52
+ return RetrievalQA.from_chain_type(
53
+ llm=llm,
54
+ chain_type=document_chat_chain_type,
55
+ retriever=retriever,
56
+ memory=memory,
57
+ output_key="output_text",
58
+ ) | (lambda output: output["output_text"])
59
+
60
+
61
+ def get_llm(
62
+ provider: str,
63
+ model: str,
64
+ provider_api_key: str,
65
+ temperature: float,
66
+ max_tokens: int,
67
+ azure_available: bool,
68
+ azure_dict: dict[str, str],
69
+ ):
70
+ if azure_available and provider == "Azure OpenAI":
71
+ return AzureChatOpenAI(
72
+ openai_api_base=azure_dict["AZURE_OPENAI_BASE_URL"],
73
+ openai_api_version=azure_dict["AZURE_OPENAI_API_VERSION"],
74
+ deployment_name=azure_dict["AZURE_OPENAI_DEPLOYMENT_NAME"],
75
+ openai_api_key=azure_dict["AZURE_OPENAI_API_KEY"],
76
+ openai_api_type="azure",
77
+ model_version=azure_dict["AZURE_OPENAI_MODEL_VERSION"],
78
+ temperature=temperature,
79
+ streaming=True,
80
+ max_tokens=max_tokens,
81
+ )
82
+
83
+ elif provider_api_key:
84
+ if provider == "OpenAI":
85
+ return ChatOpenAI(
86
+ model_name=model,
87
+ openai_api_key=provider_api_key,
88
+ temperature=temperature,
89
+ streaming=True,
90
+ max_tokens=max_tokens,
91
+ )
92
+
93
+ elif provider == "Anthropic":
94
+ return ChatAnthropic(
95
+ model=model,
96
+ anthropic_api_key=provider_api_key,
97
+ temperature=temperature,
98
+ streaming=True,
99
+ max_tokens_to_sample=max_tokens,
100
+ )
101
+
102
+ elif provider == "Anyscale Endpoints":
103
+ return ChatAnyscale(
104
+ model_name=model,
105
+ anyscale_api_key=provider_api_key,
106
+ temperature=temperature,
107
+ streaming=True,
108
+ max_tokens=max_tokens,
109
+ )
110
+
111
+ return None
112
+
113
+
114
+ def get_texts_and_retriever(
115
+ uploaded_file_bytes: bytes,
116
+ openai_api_key: str,
117
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
118
+ chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
119
+ k: int = DEFAULT_RETRIEVER_K,
120
+ azure_kwargs: Optional[Dict[str, str]] = None,
121
+ use_azure: bool = False,
122
+ ) -> Tuple[List[Document], BaseRetriever]:
123
+ with NamedTemporaryFile() as temp_file:
124
+ temp_file.write(uploaded_file_bytes)
125
+ temp_file.seek(0)
126
+
127
+ loader = PyPDFLoader(temp_file.name)
128
+ documents = loader.load()
129
+ text_splitter = RecursiveCharacterTextSplitter(
130
+ chunk_size=chunk_size,
131
+ chunk_overlap=chunk_overlap,
132
+ )
133
+ texts = text_splitter.split_documents(documents)
134
+ embeddings_kwargs = {"openai_api_key": openai_api_key}
135
+ if use_azure and azure_kwargs:
136
+ embeddings_kwargs.update(azure_kwargs)
137
+ embeddings = OpenAIEmbeddings(**embeddings_kwargs)
138
+
139
+ bm25_retriever = BM25Retriever.from_documents(texts)
140
+ bm25_retriever.k = k
141
+
142
+ faiss_vectorstore = FAISS.from_documents(texts, embeddings)
143
+ faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k})
144
+
145
+ ensemble_retriever = EnsembleRetriever(
146
+ retrievers=[bm25_retriever, faiss_retriever],
147
+ weights=[0.5, 0.5],
148
+ )
149
+
150
+ return texts, ensemble_retriever
151
+
152
+
153
+ class StreamHandler(BaseCallbackHandler):
154
+ def __init__(self, container, initial_text=""):
155
+ self.container = container
156
+ self.text = initial_text
157
+
158
+ def on_llm_new_token(self, token: str, **kwargs) -> None:
159
+ self.text += token
160
+ self.container.markdown(self.text)