Merge pull request #46 from joshuasundance-swca/cleanup
Browse files- .idea/langchain-streamlit-demo.iml +3 -1
- kubernetes/resources.yaml +9 -0
- langchain-streamlit-demo/app.py +162 -255
- langchain-streamlit-demo/defaults.py +129 -0
- langchain-streamlit-demo/llm_resources.py +160 -0
.idea/langchain-streamlit-demo.iml
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
<module type="PYTHON_MODULE" version="4">
|
3 |
<component name="NewModuleRootManager">
|
4 |
-
<content url="file://$MODULE_DIR$"
|
|
|
|
|
5 |
<orderEntry type="jdk" jdkName="Remote Python 3.11.4 Docker (<none>:<none>) (5)" jdkType="Python SDK" />
|
6 |
<orderEntry type="sourceFolder" forTests="false" />
|
7 |
</component>
|
|
|
1 |
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
<module type="PYTHON_MODULE" version="4">
|
3 |
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$">
|
5 |
+
<sourceFolder url="file://$MODULE_DIR$/langchain-streamlit-demo" isTestSource="false" />
|
6 |
+
</content>
|
7 |
<orderEntry type="jdk" jdkName="Remote Python 3.11.4 Docker (<none>:<none>) (5)" jdkType="Python SDK" />
|
8 |
<orderEntry type="sourceFolder" forTests="false" />
|
9 |
</component>
|
kubernetes/resources.yaml
CHANGED
@@ -39,6 +39,11 @@ spec:
|
|
39 |
secretKeyRef:
|
40 |
name: langchain-streamlit-demo-secret
|
41 |
key: AZURE_OPENAI_DEPLOYMENT_NAME
|
|
|
|
|
|
|
|
|
|
|
42 |
- name: AZURE_OPENAI_API_KEY
|
43 |
valueFrom:
|
44 |
secretKeyRef:
|
@@ -71,6 +76,10 @@ spec:
|
|
71 |
key: LANGCHAIN_API_KEY
|
72 |
- name: LANGCHAIN_PROJECT
|
73 |
value: "langchain-streamlit-demo"
|
|
|
|
|
|
|
|
|
74 |
securityContext:
|
75 |
runAsNonRoot: true
|
76 |
---
|
|
|
39 |
secretKeyRef:
|
40 |
name: langchain-streamlit-demo-secret
|
41 |
key: AZURE_OPENAI_DEPLOYMENT_NAME
|
42 |
+
- name: AZURE_OPENAI_EMB_DEPLOYMENT_NAME
|
43 |
+
valueFrom:
|
44 |
+
secretKeyRef:
|
45 |
+
name: langchain-streamlit-demo-secret
|
46 |
+
key: AZURE_OPENAI_EMB_DEPLOYMENT_NAME
|
47 |
- name: AZURE_OPENAI_API_KEY
|
48 |
valueFrom:
|
49 |
secretKeyRef:
|
|
|
76 |
key: LANGCHAIN_API_KEY
|
77 |
- name: LANGCHAIN_PROJECT
|
78 |
value: "langchain-streamlit-demo"
|
79 |
+
- name: SHOW_LANGCHAIN_OPTIONS
|
80 |
+
value: "False"
|
81 |
+
- name: SHOW_AZURE_OPTIONS
|
82 |
+
value: "False"
|
83 |
securityContext:
|
84 |
runAsNonRoot: true
|
85 |
---
|
langchain-streamlit-demo/app.py
CHANGED
@@ -1,37 +1,22 @@
|
|
1 |
-
import os
|
2 |
from datetime import datetime
|
3 |
-
from
|
4 |
-
from typing import Tuple, List, Dict, Any, Union
|
5 |
|
6 |
import anthropic
|
7 |
import langsmith.utils
|
8 |
import openai
|
9 |
import streamlit as st
|
10 |
-
from langchain.callbacks.base import BaseCallbackHandler
|
11 |
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
12 |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
|
13 |
-
from langchain.chains import RetrievalQA
|
14 |
-
from langchain.chains.llm import LLMChain
|
15 |
-
from langchain.chat_models import (
|
16 |
-
AzureChatOpenAI,
|
17 |
-
ChatAnthropic,
|
18 |
-
ChatAnyscale,
|
19 |
-
ChatOpenAI,
|
20 |
-
)
|
21 |
-
from langchain.document_loaders import PyPDFLoader
|
22 |
-
from langchain.embeddings import OpenAIEmbeddings
|
23 |
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
|
24 |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
25 |
-
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
26 |
from langchain.schema.document import Document
|
27 |
from langchain.schema.retriever import BaseRetriever
|
28 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
29 |
-
from langchain.vectorstores import FAISS
|
30 |
from langsmith.client import Client
|
31 |
from streamlit_feedback import streamlit_feedback
|
32 |
|
33 |
-
from
|
34 |
-
|
|
|
35 |
|
36 |
__version__ = "0.0.13"
|
37 |
|
@@ -62,119 +47,72 @@ st_init_null(
|
|
62 |
"trace_link",
|
63 |
)
|
64 |
|
65 |
-
# ---
|
66 |
STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
|
67 |
MEMORY = ConversationBufferMemory(
|
68 |
chat_memory=STMEMORY,
|
69 |
return_messages=True,
|
70 |
memory_key="chat_history",
|
71 |
)
|
72 |
-
|
73 |
-
|
74 |
-
# --- Callbacks ---
|
75 |
-
class StreamHandler(BaseCallbackHandler):
|
76 |
-
def __init__(self, container, initial_text=""):
|
77 |
-
self.container = container
|
78 |
-
self.text = initial_text
|
79 |
-
|
80 |
-
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
81 |
-
self.text += token
|
82 |
-
self.container.markdown(self.text)
|
83 |
-
|
84 |
-
|
85 |
RUN_COLLECTOR = RunCollectorCallbackHandler()
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
"gpt-3.5-turbo": "OpenAI",
|
91 |
-
"gpt-4": "OpenAI",
|
92 |
-
"claude-instant-v1": "Anthropic",
|
93 |
-
"claude-2": "Anthropic",
|
94 |
-
"meta-llama/Llama-2-7b-chat-hf": "Anyscale Endpoints",
|
95 |
-
"meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
|
96 |
-
"meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
|
97 |
-
"codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
|
98 |
-
"Azure OpenAI": "Azure OpenAI",
|
99 |
-
}
|
100 |
-
SUPPORTED_MODELS = list(MODEL_DICT.keys())
|
101 |
-
|
102 |
-
|
103 |
-
# --- Constants from Environment Variables ---
|
104 |
-
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-3.5-turbo")
|
105 |
-
DEFAULT_SYSTEM_PROMPT = os.environ.get(
|
106 |
-
"DEFAULT_SYSTEM_PROMPT",
|
107 |
-
"You are a helpful chatbot.",
|
108 |
)
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
|
115 |
-
DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
|
116 |
-
|
117 |
-
AZURE_VARS = [
|
118 |
-
"AZURE_OPENAI_BASE_URL",
|
119 |
-
"AZURE_OPENAI_API_VERSION",
|
120 |
-
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
121 |
-
"AZURE_OPENAI_API_KEY",
|
122 |
-
"AZURE_OPENAI_MODEL_VERSION",
|
123 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
MIN_CHUNK_OVERLAP = 0
|
140 |
-
MAX_CHUNK_OVERLAP = 10000
|
141 |
-
DEFAULT_CHUNK_OVERLAP = 0
|
142 |
-
|
143 |
-
DEFAULT_RETRIEVER_K = 4
|
144 |
|
145 |
|
146 |
@st.cache_data
|
147 |
-
def
|
148 |
uploaded_file_bytes: bytes,
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
152 |
) -> Tuple[List[Document], BaseRetriever]:
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
)
|
163 |
-
texts = text_splitter.split_documents(documents)
|
164 |
-
embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
|
165 |
-
|
166 |
-
bm25_retriever = BM25Retriever.from_documents(texts)
|
167 |
-
bm25_retriever.k = k
|
168 |
-
|
169 |
-
faiss_vectorstore = FAISS.from_documents(texts, embeddings)
|
170 |
-
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k})
|
171 |
-
|
172 |
-
ensemble_retriever = EnsembleRetriever(
|
173 |
-
retrievers=[bm25_retriever, faiss_retriever],
|
174 |
-
weights=[0.5, 0.5],
|
175 |
-
)
|
176 |
-
|
177 |
-
return texts, ensemble_retriever
|
178 |
|
179 |
|
180 |
# --- Sidebar ---
|
@@ -184,14 +122,14 @@ with sidebar:
|
|
184 |
|
185 |
model = st.selectbox(
|
186 |
label="Chat Model",
|
187 |
-
options=SUPPORTED_MODELS,
|
188 |
-
index=SUPPORTED_MODELS.index(DEFAULT_MODEL),
|
189 |
)
|
190 |
|
191 |
-
st.session_state.provider = MODEL_DICT[model]
|
192 |
|
193 |
provider_api_key = (
|
194 |
-
PROVIDER_KEY_DICT.get(
|
195 |
st.session_state.provider,
|
196 |
)
|
197 |
or st.text_input(
|
@@ -214,7 +152,7 @@ with sidebar:
|
|
214 |
openai_api_key = (
|
215 |
provider_api_key
|
216 |
if st.session_state.provider == "OpenAI"
|
217 |
-
else OPENAI_API_KEY
|
218 |
or st.sidebar.text_input("OpenAI API Key: ", type="password")
|
219 |
)
|
220 |
|
@@ -227,7 +165,7 @@ with sidebar:
|
|
227 |
k = st.slider(
|
228 |
label="Number of Chunks",
|
229 |
help="How many document chunks will be used for context?",
|
230 |
-
value=DEFAULT_RETRIEVER_K,
|
231 |
min_value=1,
|
232 |
max_value=10,
|
233 |
)
|
@@ -235,21 +173,23 @@ with sidebar:
|
|
235 |
chunk_size = st.slider(
|
236 |
label="Number of Tokens per Chunk",
|
237 |
help="Size of each chunk of text",
|
238 |
-
min_value=MIN_CHUNK_SIZE,
|
239 |
-
max_value=MAX_CHUNK_SIZE,
|
240 |
-
value=DEFAULT_CHUNK_SIZE,
|
241 |
)
|
|
|
242 |
chunk_overlap = st.slider(
|
243 |
label="Chunk Overlap",
|
244 |
help="Number of characters to overlap between chunks",
|
245 |
-
min_value=MIN_CHUNK_OVERLAP,
|
246 |
-
max_value=MAX_CHUNK_OVERLAP,
|
247 |
-
value=DEFAULT_CHUNK_OVERLAP,
|
248 |
)
|
249 |
|
250 |
chain_type_help_root = (
|
251 |
"https://python.langchain.com/docs/modules/chains/document/"
|
252 |
)
|
|
|
253 |
chain_type_help = "\n".join(
|
254 |
f"- [{chain_type_name}]({chain_type_help_root}/{chain_type_name})"
|
255 |
for chain_type_name in (
|
@@ -259,6 +199,7 @@ with sidebar:
|
|
259 |
"map_rerank",
|
260 |
)
|
261 |
)
|
|
|
262 |
document_chat_chain_type = st.selectbox(
|
263 |
label="Document Chat Chain Type",
|
264 |
options=[
|
@@ -273,17 +214,28 @@ with sidebar:
|
|
273 |
help=chain_type_help,
|
274 |
disabled=not document_chat,
|
275 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
if uploaded_file:
|
278 |
-
if openai_api_key:
|
279 |
(
|
280 |
st.session_state.texts,
|
281 |
st.session_state.retriever,
|
282 |
-
) =
|
283 |
uploaded_file_bytes=uploaded_file.getvalue(),
|
|
|
284 |
chunk_size=chunk_size,
|
285 |
chunk_overlap=chunk_overlap,
|
286 |
k=k,
|
|
|
|
|
287 |
)
|
288 |
else:
|
289 |
st.error("Please enter a valid OpenAI API key.", icon="❌")
|
@@ -297,123 +249,100 @@ with sidebar:
|
|
297 |
system_prompt = (
|
298 |
st.text_area(
|
299 |
"Custom Instructions",
|
300 |
-
DEFAULT_SYSTEM_PROMPT,
|
301 |
help="Custom instructions to provide the language model to determine style, personality, etc.",
|
302 |
)
|
303 |
.strip()
|
304 |
.replace("{", "{{")
|
305 |
.replace("}", "}}")
|
306 |
)
|
|
|
307 |
temperature = st.slider(
|
308 |
"Temperature",
|
309 |
-
min_value=MIN_TEMP,
|
310 |
-
max_value=MAX_TEMP,
|
311 |
-
value=DEFAULT_TEMP,
|
312 |
help="Higher values give more random results.",
|
313 |
)
|
314 |
|
315 |
max_tokens = st.slider(
|
316 |
"Max Tokens",
|
317 |
-
min_value=MIN_MAX_TOKENS,
|
318 |
-
max_value=MAX_MAX_TOKENS,
|
319 |
-
value=DEFAULT_MAX_TOKENS,
|
320 |
help="Higher values give longer results.",
|
321 |
)
|
322 |
|
323 |
# --- LangSmith Options ---
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
LANGSMITH_PROJECT = st.text_input(
|
331 |
-
"LangSmith Project Name",
|
332 |
-
value=DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo",
|
333 |
-
)
|
334 |
-
if st.session_state.client is None and LANGSMITH_API_KEY:
|
335 |
-
st.session_state.client = Client(
|
336 |
-
api_url="https://api.smith.langchain.com",
|
337 |
-
api_key=LANGSMITH_API_KEY,
|
338 |
)
|
339 |
-
|
340 |
-
|
341 |
-
|
|
|
342 |
)
|
343 |
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
value=AZURE_DICT["AZURE_OPENAI_BASE_URL"],
|
349 |
-
)
|
350 |
-
AZURE_OPENAI_API_VERSION = st.text_input(
|
351 |
-
"AZURE_OPENAI_API_VERSION",
|
352 |
-
value=AZURE_DICT["AZURE_OPENAI_API_VERSION"],
|
353 |
-
)
|
354 |
-
AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
|
355 |
-
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
356 |
-
value=AZURE_DICT["AZURE_OPENAI_DEPLOYMENT_NAME"],
|
357 |
-
)
|
358 |
-
AZURE_OPENAI_API_KEY = st.text_input(
|
359 |
-
"AZURE_OPENAI_API_KEY",
|
360 |
-
value=AZURE_DICT["AZURE_OPENAI_API_KEY"],
|
361 |
-
type="password",
|
362 |
)
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
)
|
367 |
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
|
378 |
|
379 |
# --- LLM Instantiation ---
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
max_tokens_to_sample=max_tokens,
|
396 |
-
)
|
397 |
-
elif st.session_state.provider == "Anyscale Endpoints":
|
398 |
-
st.session_state.llm = ChatAnyscale(
|
399 |
-
model_name=model,
|
400 |
-
anyscale_api_key=provider_api_key,
|
401 |
-
temperature=temperature,
|
402 |
-
streaming=True,
|
403 |
-
max_tokens=max_tokens,
|
404 |
-
)
|
405 |
-
elif AZURE_AVAILABLE and st.session_state.provider == "Azure OpenAI":
|
406 |
-
st.session_state.llm = AzureChatOpenAI(
|
407 |
-
openai_api_base=AZURE_OPENAI_BASE_URL,
|
408 |
-
openai_api_version=AZURE_OPENAI_API_VERSION,
|
409 |
-
deployment_name=AZURE_OPENAI_DEPLOYMENT_NAME,
|
410 |
-
openai_api_key=AZURE_OPENAI_API_KEY,
|
411 |
-
openai_api_type="azure",
|
412 |
-
model_version=AZURE_OPENAI_MODEL_VERSION,
|
413 |
-
temperature=temperature,
|
414 |
-
streaming=True,
|
415 |
-
max_tokens=max_tokens,
|
416 |
-
)
|
417 |
|
418 |
# --- Chat History ---
|
419 |
if len(STMEMORY.messages) == 0:
|
@@ -474,38 +403,17 @@ if st.session_state.llm:
|
|
474 |
stream_handler = StreamHandler(message_placeholder)
|
475 |
callbacks.append(stream_handler)
|
476 |
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
prompt,
|
486 |
-
st.session_state.retriever,
|
487 |
-
st.session_state.llm,
|
488 |
-
)
|
489 |
-
else:
|
490 |
-
return RetrievalQA.from_chain_type(
|
491 |
-
llm=st.session_state.llm,
|
492 |
-
chain_type=document_chat_chain_type,
|
493 |
-
retriever=st.session_state.retriever,
|
494 |
-
memory=MEMORY,
|
495 |
-
output_key="output_text",
|
496 |
-
) | (lambda output: output["output_text"])
|
497 |
-
|
498 |
-
st.session_state.chain = (
|
499 |
-
get_rag_runnable()
|
500 |
-
if use_document_chat
|
501 |
-
else LLMChain(
|
502 |
-
prompt=chat_prompt,
|
503 |
-
llm=st.session_state.llm,
|
504 |
-
memory=MEMORY,
|
505 |
-
)
|
506 |
-
| (lambda output: output["text"])
|
507 |
)
|
508 |
|
|
|
509 |
try:
|
510 |
full_response = st.session_state.chain.invoke(prompt, config)
|
511 |
|
@@ -515,6 +423,7 @@ if st.session_state.llm:
|
|
515 |
icon="❌",
|
516 |
)
|
517 |
|
|
|
518 |
if full_response is not None:
|
519 |
message_placeholder.markdown(full_response)
|
520 |
|
@@ -530,6 +439,8 @@ if st.session_state.llm:
|
|
530 |
).url
|
531 |
except langsmith.utils.LangSmithError:
|
532 |
st.session_state.trace_link = None
|
|
|
|
|
533 |
if st.session_state.trace_link:
|
534 |
with sidebar:
|
535 |
st.markdown(
|
@@ -573,10 +484,6 @@ if st.session_state.llm:
|
|
573 |
score=score,
|
574 |
comment=feedback.get("text"),
|
575 |
)
|
576 |
-
# feedback = {
|
577 |
-
# "feedback_id": str(feedback_record.id),
|
578 |
-
# "score": score,
|
579 |
-
# }
|
580 |
st.toast("Feedback recorded!", icon="📝")
|
581 |
else:
|
582 |
st.warning("Invalid feedback score.")
|
|
|
|
|
1 |
from datetime import datetime
|
2 |
+
from typing import Tuple, List, Dict, Any, Union, Optional
|
|
|
3 |
|
4 |
import anthropic
|
5 |
import langsmith.utils
|
6 |
import openai
|
7 |
import streamlit as st
|
|
|
8 |
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
|
9 |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
|
11 |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
|
12 |
from langchain.schema.document import Document
|
13 |
from langchain.schema.retriever import BaseRetriever
|
|
|
|
|
14 |
from langsmith.client import Client
|
15 |
from streamlit_feedback import streamlit_feedback
|
16 |
|
17 |
+
from defaults import default_values
|
18 |
+
|
19 |
+
from llm_resources import get_runnable, get_llm, get_texts_and_retriever, StreamHandler
|
20 |
|
21 |
__version__ = "0.0.13"
|
22 |
|
|
|
47 |
"trace_link",
|
48 |
)
|
49 |
|
50 |
+
# --- LLM globals ---
|
51 |
STMEMORY = StreamlitChatMessageHistory(key="langchain_messages")
|
52 |
MEMORY = ConversationBufferMemory(
|
53 |
chat_memory=STMEMORY,
|
54 |
return_messages=True,
|
55 |
memory_key="chat_history",
|
56 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
RUN_COLLECTOR = RunCollectorCallbackHandler()
|
58 |
|
59 |
+
LANGSMITH_API_KEY = default_values.PROVIDER_KEY_DICT.get("LANGSMITH")
|
60 |
+
LANGSMITH_PROJECT = (
|
61 |
+
default_values.DEFAULT_LANGSMITH_PROJECT or "langchain-streamlit-demo"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
)
|
63 |
+
AZURE_OPENAI_BASE_URL = default_values.AZURE_DICT["AZURE_OPENAI_BASE_URL"]
|
64 |
+
AZURE_OPENAI_API_VERSION = default_values.AZURE_DICT["AZURE_OPENAI_API_VERSION"]
|
65 |
+
AZURE_OPENAI_DEPLOYMENT_NAME = default_values.AZURE_DICT["AZURE_OPENAI_DEPLOYMENT_NAME"]
|
66 |
+
AZURE_OPENAI_EMB_DEPLOYMENT_NAME = default_values.AZURE_DICT[
|
67 |
+
"AZURE_OPENAI_EMB_DEPLOYMENT_NAME"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
]
|
69 |
+
AZURE_OPENAI_API_KEY = default_values.AZURE_DICT["AZURE_OPENAI_API_KEY"]
|
70 |
+
AZURE_OPENAI_MODEL_VERSION = default_values.AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"]
|
71 |
+
|
72 |
+
AZURE_AVAILABLE = all(
|
73 |
+
[
|
74 |
+
AZURE_OPENAI_BASE_URL,
|
75 |
+
AZURE_OPENAI_API_VERSION,
|
76 |
+
AZURE_OPENAI_DEPLOYMENT_NAME,
|
77 |
+
AZURE_OPENAI_API_KEY,
|
78 |
+
AZURE_OPENAI_MODEL_VERSION,
|
79 |
+
],
|
80 |
+
)
|
81 |
|
82 |
+
AZURE_EMB_AVAILABLE = AZURE_AVAILABLE and AZURE_OPENAI_EMB_DEPLOYMENT_NAME
|
83 |
+
|
84 |
+
AZURE_KWARGS = (
|
85 |
+
None
|
86 |
+
if not AZURE_EMB_AVAILABLE
|
87 |
+
else {
|
88 |
+
"openai_api_base": AZURE_OPENAI_BASE_URL,
|
89 |
+
"openai_api_version": AZURE_OPENAI_API_VERSION,
|
90 |
+
"deployment": AZURE_OPENAI_EMB_DEPLOYMENT_NAME,
|
91 |
+
"openai_api_key": AZURE_OPENAI_API_KEY,
|
92 |
+
"openai_api_type": "azure",
|
93 |
+
}
|
94 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
|
97 |
@st.cache_data
|
98 |
+
def get_texts_and_retriever_cacheable_wrapper(
|
99 |
uploaded_file_bytes: bytes,
|
100 |
+
openai_api_key: str,
|
101 |
+
chunk_size: int = default_values.DEFAULT_CHUNK_SIZE,
|
102 |
+
chunk_overlap: int = default_values.DEFAULT_CHUNK_OVERLAP,
|
103 |
+
k: int = default_values.DEFAULT_RETRIEVER_K,
|
104 |
+
azure_kwargs: Optional[Dict[str, str]] = None,
|
105 |
+
use_azure: bool = False,
|
106 |
) -> Tuple[List[Document], BaseRetriever]:
|
107 |
+
return get_texts_and_retriever(
|
108 |
+
uploaded_file_bytes=uploaded_file_bytes,
|
109 |
+
openai_api_key=openai_api_key,
|
110 |
+
chunk_size=chunk_size,
|
111 |
+
chunk_overlap=chunk_overlap,
|
112 |
+
k=k,
|
113 |
+
azure_kwargs=azure_kwargs,
|
114 |
+
use_azure=use_azure,
|
115 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
|
118 |
# --- Sidebar ---
|
|
|
122 |
|
123 |
model = st.selectbox(
|
124 |
label="Chat Model",
|
125 |
+
options=default_values.SUPPORTED_MODELS,
|
126 |
+
index=default_values.SUPPORTED_MODELS.index(default_values.DEFAULT_MODEL),
|
127 |
)
|
128 |
|
129 |
+
st.session_state.provider = default_values.MODEL_DICT[model]
|
130 |
|
131 |
provider_api_key = (
|
132 |
+
default_values.PROVIDER_KEY_DICT.get(
|
133 |
st.session_state.provider,
|
134 |
)
|
135 |
or st.text_input(
|
|
|
152 |
openai_api_key = (
|
153 |
provider_api_key
|
154 |
if st.session_state.provider == "OpenAI"
|
155 |
+
else default_values.OPENAI_API_KEY
|
156 |
or st.sidebar.text_input("OpenAI API Key: ", type="password")
|
157 |
)
|
158 |
|
|
|
165 |
k = st.slider(
|
166 |
label="Number of Chunks",
|
167 |
help="How many document chunks will be used for context?",
|
168 |
+
value=default_values.DEFAULT_RETRIEVER_K,
|
169 |
min_value=1,
|
170 |
max_value=10,
|
171 |
)
|
|
|
173 |
chunk_size = st.slider(
|
174 |
label="Number of Tokens per Chunk",
|
175 |
help="Size of each chunk of text",
|
176 |
+
min_value=default_values.MIN_CHUNK_SIZE,
|
177 |
+
max_value=default_values.MAX_CHUNK_SIZE,
|
178 |
+
value=default_values.DEFAULT_CHUNK_SIZE,
|
179 |
)
|
180 |
+
|
181 |
chunk_overlap = st.slider(
|
182 |
label="Chunk Overlap",
|
183 |
help="Number of characters to overlap between chunks",
|
184 |
+
min_value=default_values.MIN_CHUNK_OVERLAP,
|
185 |
+
max_value=default_values.MAX_CHUNK_OVERLAP,
|
186 |
+
value=default_values.DEFAULT_CHUNK_OVERLAP,
|
187 |
)
|
188 |
|
189 |
chain_type_help_root = (
|
190 |
"https://python.langchain.com/docs/modules/chains/document/"
|
191 |
)
|
192 |
+
|
193 |
chain_type_help = "\n".join(
|
194 |
f"- [{chain_type_name}]({chain_type_help_root}/{chain_type_name})"
|
195 |
for chain_type_name in (
|
|
|
199 |
"map_rerank",
|
200 |
)
|
201 |
)
|
202 |
+
|
203 |
document_chat_chain_type = st.selectbox(
|
204 |
label="Document Chat Chain Type",
|
205 |
options=[
|
|
|
214 |
help=chain_type_help,
|
215 |
disabled=not document_chat,
|
216 |
)
|
217 |
+
use_azure = False
|
218 |
+
|
219 |
+
if AZURE_EMB_AVAILABLE:
|
220 |
+
use_azure = st.toggle(
|
221 |
+
label="Use Azure OpenAI",
|
222 |
+
value=AZURE_EMB_AVAILABLE,
|
223 |
+
help="Use Azure for embeddings instead of using OpenAI directly.",
|
224 |
+
)
|
225 |
|
226 |
if uploaded_file:
|
227 |
+
if AZURE_EMB_AVAILABLE or openai_api_key:
|
228 |
(
|
229 |
st.session_state.texts,
|
230 |
st.session_state.retriever,
|
231 |
+
) = get_texts_and_retriever_cacheable_wrapper(
|
232 |
uploaded_file_bytes=uploaded_file.getvalue(),
|
233 |
+
openai_api_key=openai_api_key,
|
234 |
chunk_size=chunk_size,
|
235 |
chunk_overlap=chunk_overlap,
|
236 |
k=k,
|
237 |
+
azure_kwargs=AZURE_KWARGS,
|
238 |
+
use_azure=use_azure,
|
239 |
)
|
240 |
else:
|
241 |
st.error("Please enter a valid OpenAI API key.", icon="❌")
|
|
|
249 |
system_prompt = (
|
250 |
st.text_area(
|
251 |
"Custom Instructions",
|
252 |
+
default_values.DEFAULT_SYSTEM_PROMPT,
|
253 |
help="Custom instructions to provide the language model to determine style, personality, etc.",
|
254 |
)
|
255 |
.strip()
|
256 |
.replace("{", "{{")
|
257 |
.replace("}", "}}")
|
258 |
)
|
259 |
+
|
260 |
temperature = st.slider(
|
261 |
"Temperature",
|
262 |
+
min_value=default_values.MIN_TEMP,
|
263 |
+
max_value=default_values.MAX_TEMP,
|
264 |
+
value=default_values.DEFAULT_TEMP,
|
265 |
help="Higher values give more random results.",
|
266 |
)
|
267 |
|
268 |
max_tokens = st.slider(
|
269 |
"Max Tokens",
|
270 |
+
min_value=default_values.MIN_MAX_TOKENS,
|
271 |
+
max_value=default_values.MAX_MAX_TOKENS,
|
272 |
+
value=default_values.DEFAULT_MAX_TOKENS,
|
273 |
help="Higher values give longer results.",
|
274 |
)
|
275 |
|
276 |
# --- LangSmith Options ---
|
277 |
+
if default_values.SHOW_LANGSMITH_OPTIONS:
|
278 |
+
with st.expander("LangSmith Options", expanded=False):
|
279 |
+
LANGSMITH_API_KEY = st.text_input(
|
280 |
+
"LangSmith API Key (optional)",
|
281 |
+
value=LANGSMITH_API_KEY,
|
282 |
+
type="password",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
283 |
)
|
284 |
+
|
285 |
+
LANGSMITH_PROJECT = st.text_input(
|
286 |
+
"LangSmith Project Name",
|
287 |
+
value=LANGSMITH_PROJECT,
|
288 |
)
|
289 |
|
290 |
+
if st.session_state.client is None and LANGSMITH_API_KEY:
|
291 |
+
st.session_state.client = Client(
|
292 |
+
api_url="https://api.smith.langchain.com",
|
293 |
+
api_key=LANGSMITH_API_KEY,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
)
|
295 |
+
st.session_state.ls_tracer = LangChainTracer(
|
296 |
+
project_name=LANGSMITH_PROJECT,
|
297 |
+
client=st.session_state.client,
|
298 |
)
|
299 |
|
300 |
+
# --- Azure Options ---
|
301 |
+
if default_values.SHOW_AZURE_OPTIONS:
|
302 |
+
with st.expander("Azure Options", expanded=False):
|
303 |
+
AZURE_OPENAI_BASE_URL = st.text_input(
|
304 |
+
"AZURE_OPENAI_BASE_URL",
|
305 |
+
value=AZURE_OPENAI_BASE_URL,
|
306 |
+
)
|
307 |
+
|
308 |
+
AZURE_OPENAI_API_VERSION = st.text_input(
|
309 |
+
"AZURE_OPENAI_API_VERSION",
|
310 |
+
value=AZURE_OPENAI_API_VERSION,
|
311 |
+
)
|
312 |
+
|
313 |
+
AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
|
314 |
+
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
315 |
+
value=AZURE_OPENAI_DEPLOYMENT_NAME,
|
316 |
+
)
|
317 |
+
|
318 |
+
AZURE_OPENAI_API_KEY = st.text_input(
|
319 |
+
"AZURE_OPENAI_API_KEY",
|
320 |
+
value=AZURE_OPENAI_API_KEY,
|
321 |
+
type="password",
|
322 |
+
)
|
323 |
+
|
324 |
+
AZURE_OPENAI_MODEL_VERSION = st.text_input(
|
325 |
+
"AZURE_OPENAI_MODEL_VERSION",
|
326 |
+
value=AZURE_OPENAI_MODEL_VERSION,
|
327 |
+
)
|
328 |
|
329 |
|
330 |
# --- LLM Instantiation ---
|
331 |
+
st.session_state.llm = get_llm(
|
332 |
+
provider=st.session_state.provider,
|
333 |
+
model=model,
|
334 |
+
provider_api_key=provider_api_key,
|
335 |
+
temperature=temperature,
|
336 |
+
max_tokens=max_tokens,
|
337 |
+
azure_available=AZURE_AVAILABLE,
|
338 |
+
azure_dict={
|
339 |
+
"AZURE_OPENAI_BASE_URL": AZURE_OPENAI_BASE_URL,
|
340 |
+
"AZURE_OPENAI_API_VERSION": AZURE_OPENAI_API_VERSION,
|
341 |
+
"AZURE_OPENAI_DEPLOYMENT_NAME": AZURE_OPENAI_DEPLOYMENT_NAME,
|
342 |
+
"AZURE_OPENAI_API_KEY": AZURE_OPENAI_API_KEY,
|
343 |
+
"AZURE_OPENAI_MODEL_VERSION": AZURE_OPENAI_MODEL_VERSION,
|
344 |
+
},
|
345 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
|
347 |
# --- Chat History ---
|
348 |
if len(STMEMORY.messages) == 0:
|
|
|
403 |
stream_handler = StreamHandler(message_placeholder)
|
404 |
callbacks.append(stream_handler)
|
405 |
|
406 |
+
st.session_state.chain = get_runnable(
|
407 |
+
use_document_chat,
|
408 |
+
document_chat_chain_type,
|
409 |
+
st.session_state.llm,
|
410 |
+
st.session_state.retriever,
|
411 |
+
MEMORY,
|
412 |
+
chat_prompt,
|
413 |
+
prompt,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
)
|
415 |
|
416 |
+
# --- LLM call ---
|
417 |
try:
|
418 |
full_response = st.session_state.chain.invoke(prompt, config)
|
419 |
|
|
|
423 |
icon="❌",
|
424 |
)
|
425 |
|
426 |
+
# --- Display output ---
|
427 |
if full_response is not None:
|
428 |
message_placeholder.markdown(full_response)
|
429 |
|
|
|
439 |
).url
|
440 |
except langsmith.utils.LangSmithError:
|
441 |
st.session_state.trace_link = None
|
442 |
+
|
443 |
+
# --- LangSmith Trace Link ---
|
444 |
if st.session_state.trace_link:
|
445 |
with sidebar:
|
446 |
st.markdown(
|
|
|
484 |
score=score,
|
485 |
comment=feedback.get("text"),
|
486 |
)
|
|
|
|
|
|
|
|
|
487 |
st.toast("Feedback recorded!", icon="📝")
|
488 |
else:
|
489 |
st.warning("Invalid feedback score.")
|
langchain-streamlit-demo/defaults.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from collections import namedtuple
|
3 |
+
|
4 |
+
|
5 |
+
MODEL_DICT = {
|
6 |
+
"gpt-3.5-turbo": "OpenAI",
|
7 |
+
"gpt-4": "OpenAI",
|
8 |
+
"claude-instant-v1": "Anthropic",
|
9 |
+
"claude-2": "Anthropic",
|
10 |
+
"meta-llama/Llama-2-7b-chat-hf": "Anyscale Endpoints",
|
11 |
+
"meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
|
12 |
+
"meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
|
13 |
+
"codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
|
14 |
+
"Azure OpenAI": "Azure OpenAI",
|
15 |
+
}
|
16 |
+
|
17 |
+
SUPPORTED_MODELS = list(MODEL_DICT.keys())
|
18 |
+
|
19 |
+
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "gpt-3.5-turbo")
|
20 |
+
|
21 |
+
DEFAULT_SYSTEM_PROMPT = os.environ.get(
|
22 |
+
"DEFAULT_SYSTEM_PROMPT",
|
23 |
+
"You are a helpful chatbot.",
|
24 |
+
)
|
25 |
+
|
26 |
+
MIN_TEMP = float(os.environ.get("MIN_TEMPERATURE", 0.0))
|
27 |
+
MAX_TEMP = float(os.environ.get("MAX_TEMPERATURE", 1.0))
|
28 |
+
DEFAULT_TEMP = float(os.environ.get("DEFAULT_TEMPERATURE", 0.7))
|
29 |
+
|
30 |
+
MIN_MAX_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
|
31 |
+
MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
|
32 |
+
DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
|
33 |
+
|
34 |
+
DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
|
35 |
+
|
36 |
+
AZURE_VARS = [
|
37 |
+
"AZURE_OPENAI_BASE_URL",
|
38 |
+
"AZURE_OPENAI_API_VERSION",
|
39 |
+
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
40 |
+
"AZURE_OPENAI_EMB_DEPLOYMENT_NAME",
|
41 |
+
"AZURE_OPENAI_API_KEY",
|
42 |
+
"AZURE_OPENAI_MODEL_VERSION",
|
43 |
+
]
|
44 |
+
|
45 |
+
AZURE_DICT = {v: os.environ.get(v, "") for v in AZURE_VARS}
|
46 |
+
|
47 |
+
|
48 |
+
SHOW_LANGSMITH_OPTIONS = (
|
49 |
+
os.environ.get("SHOW_LANGSMITH_OPTIONS", "true").lower() == "true"
|
50 |
+
)
|
51 |
+
SHOW_AZURE_OPTIONS = os.environ.get("SHOW_AZURE_OPTIONS", "true").lower() == "true"
|
52 |
+
|
53 |
+
PROVIDER_KEY_DICT = {
|
54 |
+
"OpenAI": os.environ.get("OPENAI_API_KEY", ""),
|
55 |
+
"Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
|
56 |
+
"Anyscale Endpoints": os.environ.get("ANYSCALE_API_KEY", ""),
|
57 |
+
"LANGSMITH": os.environ.get("LANGCHAIN_API_KEY", ""),
|
58 |
+
}
|
59 |
+
|
60 |
+
OPENAI_API_KEY = PROVIDER_KEY_DICT["OpenAI"]
|
61 |
+
|
62 |
+
|
63 |
+
MIN_CHUNK_SIZE = 1
|
64 |
+
MAX_CHUNK_SIZE = 10000
|
65 |
+
DEFAULT_CHUNK_SIZE = 1000
|
66 |
+
|
67 |
+
MIN_CHUNK_OVERLAP = 0
|
68 |
+
MAX_CHUNK_OVERLAP = 10000
|
69 |
+
DEFAULT_CHUNK_OVERLAP = 0
|
70 |
+
|
71 |
+
DEFAULT_RETRIEVER_K = 4
|
72 |
+
|
73 |
+
DEFAULT_VALUES = namedtuple(
|
74 |
+
"DEFAULT_VALUES",
|
75 |
+
[
|
76 |
+
"MODEL_DICT",
|
77 |
+
"SUPPORTED_MODELS",
|
78 |
+
"DEFAULT_MODEL",
|
79 |
+
"DEFAULT_SYSTEM_PROMPT",
|
80 |
+
"MIN_TEMP",
|
81 |
+
"MAX_TEMP",
|
82 |
+
"DEFAULT_TEMP",
|
83 |
+
"MIN_MAX_TOKENS",
|
84 |
+
"MAX_MAX_TOKENS",
|
85 |
+
"DEFAULT_MAX_TOKENS",
|
86 |
+
"DEFAULT_LANGSMITH_PROJECT",
|
87 |
+
"AZURE_VARS",
|
88 |
+
"AZURE_DICT",
|
89 |
+
"PROVIDER_KEY_DICT",
|
90 |
+
"OPENAI_API_KEY",
|
91 |
+
"MIN_CHUNK_SIZE",
|
92 |
+
"MAX_CHUNK_SIZE",
|
93 |
+
"DEFAULT_CHUNK_SIZE",
|
94 |
+
"MIN_CHUNK_OVERLAP",
|
95 |
+
"MAX_CHUNK_OVERLAP",
|
96 |
+
"DEFAULT_CHUNK_OVERLAP",
|
97 |
+
"DEFAULT_RETRIEVER_K",
|
98 |
+
"SHOW_LANGSMITH_OPTIONS",
|
99 |
+
"SHOW_AZURE_OPTIONS",
|
100 |
+
],
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
default_values = DEFAULT_VALUES(
|
105 |
+
MODEL_DICT,
|
106 |
+
SUPPORTED_MODELS,
|
107 |
+
DEFAULT_MODEL,
|
108 |
+
DEFAULT_SYSTEM_PROMPT,
|
109 |
+
MIN_TEMP,
|
110 |
+
MAX_TEMP,
|
111 |
+
DEFAULT_TEMP,
|
112 |
+
MIN_MAX_TOKENS,
|
113 |
+
MAX_MAX_TOKENS,
|
114 |
+
DEFAULT_MAX_TOKENS,
|
115 |
+
DEFAULT_LANGSMITH_PROJECT,
|
116 |
+
AZURE_VARS,
|
117 |
+
AZURE_DICT,
|
118 |
+
PROVIDER_KEY_DICT,
|
119 |
+
OPENAI_API_KEY,
|
120 |
+
MIN_CHUNK_SIZE,
|
121 |
+
MAX_CHUNK_SIZE,
|
122 |
+
DEFAULT_CHUNK_SIZE,
|
123 |
+
MIN_CHUNK_OVERLAP,
|
124 |
+
MAX_CHUNK_OVERLAP,
|
125 |
+
DEFAULT_CHUNK_OVERLAP,
|
126 |
+
DEFAULT_RETRIEVER_K,
|
127 |
+
SHOW_LANGSMITH_OPTIONS,
|
128 |
+
SHOW_AZURE_OPTIONS,
|
129 |
+
)
|
langchain-streamlit-demo/llm_resources.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tempfile import NamedTemporaryFile
|
2 |
+
from typing import Tuple, List, Optional, Dict
|
3 |
+
|
4 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
5 |
+
from langchain.chains import RetrievalQA, LLMChain
|
6 |
+
from langchain.chat_models import (
|
7 |
+
AzureChatOpenAI,
|
8 |
+
ChatOpenAI,
|
9 |
+
ChatAnthropic,
|
10 |
+
ChatAnyscale,
|
11 |
+
)
|
12 |
+
from langchain.document_loaders import PyPDFLoader
|
13 |
+
from langchain.embeddings import OpenAIEmbeddings
|
14 |
+
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
15 |
+
from langchain.schema import Document, BaseRetriever
|
16 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
+
from langchain.vectorstores import FAISS
|
18 |
+
|
19 |
+
from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
|
20 |
+
from qagen import get_rag_qa_gen_chain
|
21 |
+
from summarize import get_rag_summarization_chain
|
22 |
+
|
23 |
+
|
24 |
+
def get_runnable(
|
25 |
+
use_document_chat: bool,
|
26 |
+
document_chat_chain_type: str,
|
27 |
+
llm,
|
28 |
+
retriever,
|
29 |
+
memory,
|
30 |
+
chat_prompt,
|
31 |
+
summarization_prompt,
|
32 |
+
):
|
33 |
+
if not use_document_chat:
|
34 |
+
return LLMChain(
|
35 |
+
prompt=chat_prompt,
|
36 |
+
llm=llm,
|
37 |
+
memory=memory,
|
38 |
+
) | (lambda output: output["text"])
|
39 |
+
|
40 |
+
if document_chat_chain_type == "Q&A Generation":
|
41 |
+
return get_rag_qa_gen_chain(
|
42 |
+
retriever,
|
43 |
+
llm,
|
44 |
+
)
|
45 |
+
elif document_chat_chain_type == "Summarization":
|
46 |
+
return get_rag_summarization_chain(
|
47 |
+
summarization_prompt,
|
48 |
+
retriever,
|
49 |
+
llm,
|
50 |
+
)
|
51 |
+
else:
|
52 |
+
return RetrievalQA.from_chain_type(
|
53 |
+
llm=llm,
|
54 |
+
chain_type=document_chat_chain_type,
|
55 |
+
retriever=retriever,
|
56 |
+
memory=memory,
|
57 |
+
output_key="output_text",
|
58 |
+
) | (lambda output: output["output_text"])
|
59 |
+
|
60 |
+
|
61 |
+
def get_llm(
|
62 |
+
provider: str,
|
63 |
+
model: str,
|
64 |
+
provider_api_key: str,
|
65 |
+
temperature: float,
|
66 |
+
max_tokens: int,
|
67 |
+
azure_available: bool,
|
68 |
+
azure_dict: dict[str, str],
|
69 |
+
):
|
70 |
+
if azure_available and provider == "Azure OpenAI":
|
71 |
+
return AzureChatOpenAI(
|
72 |
+
openai_api_base=azure_dict["AZURE_OPENAI_BASE_URL"],
|
73 |
+
openai_api_version=azure_dict["AZURE_OPENAI_API_VERSION"],
|
74 |
+
deployment_name=azure_dict["AZURE_OPENAI_DEPLOYMENT_NAME"],
|
75 |
+
openai_api_key=azure_dict["AZURE_OPENAI_API_KEY"],
|
76 |
+
openai_api_type="azure",
|
77 |
+
model_version=azure_dict["AZURE_OPENAI_MODEL_VERSION"],
|
78 |
+
temperature=temperature,
|
79 |
+
streaming=True,
|
80 |
+
max_tokens=max_tokens,
|
81 |
+
)
|
82 |
+
|
83 |
+
elif provider_api_key:
|
84 |
+
if provider == "OpenAI":
|
85 |
+
return ChatOpenAI(
|
86 |
+
model_name=model,
|
87 |
+
openai_api_key=provider_api_key,
|
88 |
+
temperature=temperature,
|
89 |
+
streaming=True,
|
90 |
+
max_tokens=max_tokens,
|
91 |
+
)
|
92 |
+
|
93 |
+
elif provider == "Anthropic":
|
94 |
+
return ChatAnthropic(
|
95 |
+
model=model,
|
96 |
+
anthropic_api_key=provider_api_key,
|
97 |
+
temperature=temperature,
|
98 |
+
streaming=True,
|
99 |
+
max_tokens_to_sample=max_tokens,
|
100 |
+
)
|
101 |
+
|
102 |
+
elif provider == "Anyscale Endpoints":
|
103 |
+
return ChatAnyscale(
|
104 |
+
model_name=model,
|
105 |
+
anyscale_api_key=provider_api_key,
|
106 |
+
temperature=temperature,
|
107 |
+
streaming=True,
|
108 |
+
max_tokens=max_tokens,
|
109 |
+
)
|
110 |
+
|
111 |
+
return None
|
112 |
+
|
113 |
+
|
114 |
+
def get_texts_and_retriever(
|
115 |
+
uploaded_file_bytes: bytes,
|
116 |
+
openai_api_key: str,
|
117 |
+
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
118 |
+
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
119 |
+
k: int = DEFAULT_RETRIEVER_K,
|
120 |
+
azure_kwargs: Optional[Dict[str, str]] = None,
|
121 |
+
use_azure: bool = False,
|
122 |
+
) -> Tuple[List[Document], BaseRetriever]:
|
123 |
+
with NamedTemporaryFile() as temp_file:
|
124 |
+
temp_file.write(uploaded_file_bytes)
|
125 |
+
temp_file.seek(0)
|
126 |
+
|
127 |
+
loader = PyPDFLoader(temp_file.name)
|
128 |
+
documents = loader.load()
|
129 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
130 |
+
chunk_size=chunk_size,
|
131 |
+
chunk_overlap=chunk_overlap,
|
132 |
+
)
|
133 |
+
texts = text_splitter.split_documents(documents)
|
134 |
+
embeddings_kwargs = {"openai_api_key": openai_api_key}
|
135 |
+
if use_azure and azure_kwargs:
|
136 |
+
embeddings_kwargs.update(azure_kwargs)
|
137 |
+
embeddings = OpenAIEmbeddings(**embeddings_kwargs)
|
138 |
+
|
139 |
+
bm25_retriever = BM25Retriever.from_documents(texts)
|
140 |
+
bm25_retriever.k = k
|
141 |
+
|
142 |
+
faiss_vectorstore = FAISS.from_documents(texts, embeddings)
|
143 |
+
faiss_retriever = faiss_vectorstore.as_retriever(search_kwargs={"k": k})
|
144 |
+
|
145 |
+
ensemble_retriever = EnsembleRetriever(
|
146 |
+
retrievers=[bm25_retriever, faiss_retriever],
|
147 |
+
weights=[0.5, 0.5],
|
148 |
+
)
|
149 |
+
|
150 |
+
return texts, ensemble_retriever
|
151 |
+
|
152 |
+
|
153 |
+
class StreamHandler(BaseCallbackHandler):
|
154 |
+
def __init__(self, container, initial_text=""):
|
155 |
+
self.container = container
|
156 |
+
self.text = initial_text
|
157 |
+
|
158 |
+
def on_llm_new_token(self, token: str, **kwargs) -> None:
|
159 |
+
self.text += token
|
160 |
+
self.container.markdown(self.text)
|