Joshua Sundance Bailey
commited on
Commit
·
72c3d8c
1
Parent(s):
e4b72fe
update langsmith & add support for azure chat
Browse files
langchain-streamlit-demo/app.py
CHANGED
@@ -12,7 +12,12 @@ from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_
|
|
12 |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
|
13 |
from langchain.chains import RetrievalQA
|
14 |
from langchain.chains.llm import LLMChain
|
15 |
-
from langchain.chat_models import
|
|
|
|
|
|
|
|
|
|
|
16 |
from langchain.document_loaders import PyPDFLoader
|
17 |
from langchain.embeddings import OpenAIEmbeddings
|
18 |
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
|
@@ -90,6 +95,7 @@ MODEL_DICT = {
|
|
90 |
"meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
|
91 |
"meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
|
92 |
"codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
|
|
|
93 |
}
|
94 |
SUPPORTED_MODELS = list(MODEL_DICT.keys())
|
95 |
|
@@ -107,6 +113,17 @@ MIN_MAX_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
|
|
107 |
MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
|
108 |
DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
|
109 |
DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
PROVIDER_KEY_DICT = {
|
111 |
"OpenAI": os.environ.get("OPENAI_API_KEY", ""),
|
112 |
"Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
|
@@ -173,11 +190,16 @@ with sidebar:
|
|
173 |
|
174 |
st.session_state.provider = MODEL_DICT[model]
|
175 |
|
176 |
-
provider_api_key =
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
181 |
)
|
182 |
|
183 |
if st.button("Clear message history"):
|
@@ -317,6 +339,40 @@ with sidebar:
|
|
317 |
client=st.session_state.client,
|
318 |
)
|
319 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
# --- LLM Instantiation ---
|
322 |
if provider_api_key:
|
@@ -344,7 +400,18 @@ if provider_api_key:
|
|
344 |
streaming=True,
|
345 |
max_tokens=max_tokens,
|
346 |
)
|
347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
348 |
|
349 |
# --- Chat History ---
|
350 |
if len(STMEMORY.messages) == 0:
|
|
|
12 |
from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
|
13 |
from langchain.chains import RetrievalQA
|
14 |
from langchain.chains.llm import LLMChain
|
15 |
+
from langchain.chat_models import (
|
16 |
+
AzureChatOpenAI,
|
17 |
+
ChatAnthropic,
|
18 |
+
ChatAnyscale,
|
19 |
+
ChatOpenAI,
|
20 |
+
)
|
21 |
from langchain.document_loaders import PyPDFLoader
|
22 |
from langchain.embeddings import OpenAIEmbeddings
|
23 |
from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
|
|
|
95 |
"meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
|
96 |
"meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
|
97 |
"codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
|
98 |
+
"AZURE": "AZURE",
|
99 |
}
|
100 |
SUPPORTED_MODELS = list(MODEL_DICT.keys())
|
101 |
|
|
|
113 |
MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
|
114 |
DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
|
115 |
DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
|
116 |
+
|
117 |
+
AZURE_VARS = [
|
118 |
+
"AZURE_OPENAI_BASE_URL",
|
119 |
+
"AZURE_OPENAI_API_VERSION",
|
120 |
+
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
121 |
+
"AZURE_OPENAI_API_KEY",
|
122 |
+
"AZURE_OPENAI_MODEL_VERSION",
|
123 |
+
]
|
124 |
+
|
125 |
+
AZURE_DICT = {v: os.environ.get(v, "") for v in AZURE_VARS}
|
126 |
+
|
127 |
PROVIDER_KEY_DICT = {
|
128 |
"OpenAI": os.environ.get("OPENAI_API_KEY", ""),
|
129 |
"Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
|
|
|
190 |
|
191 |
st.session_state.provider = MODEL_DICT[model]
|
192 |
|
193 |
+
provider_api_key = (
|
194 |
+
PROVIDER_KEY_DICT.get(
|
195 |
+
st.session_state.provider,
|
196 |
+
)
|
197 |
+
or st.text_input(
|
198 |
+
f"{st.session_state.provider} API key",
|
199 |
+
type="password",
|
200 |
+
)
|
201 |
+
if st.session_state.provider != "AZURE"
|
202 |
+
else ""
|
203 |
)
|
204 |
|
205 |
if st.button("Clear message history"):
|
|
|
339 |
client=st.session_state.client,
|
340 |
)
|
341 |
|
342 |
+
# --- Azure Options ---
|
343 |
+
with st.expander("Azure Options", expanded=False):
|
344 |
+
AZURE_OPENAI_BASE_URL = st.text_input(
|
345 |
+
"AZURE_OPENAI_BASE_URL",
|
346 |
+
value=AZURE_DICT["AZURE_OPENAI_BASE_URL"],
|
347 |
+
)
|
348 |
+
AZURE_OPENAI_API_VERSION = st.text_input(
|
349 |
+
"AZURE_OPENAI_API_VERSION",
|
350 |
+
value=AZURE_DICT["AZURE_OPENAI_API_VERSION"],
|
351 |
+
)
|
352 |
+
AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
|
353 |
+
"AZURE_OPENAI_DEPLOYMENT_NAME",
|
354 |
+
value=AZURE_DICT["AZURE_OPENAI_DEPLOYMENT_NAME"],
|
355 |
+
)
|
356 |
+
AZURE_OPENAI_API_KEY = st.text_input(
|
357 |
+
"AZURE_OPENAI_API_KEY",
|
358 |
+
value=AZURE_DICT["AZURE_OPENAI_API_KEY"],
|
359 |
+
type="password",
|
360 |
+
)
|
361 |
+
AZURE_OPENAI_MODEL_VERSION = st.text_input(
|
362 |
+
"AZURE_OPENAI_MODEL_VERSION",
|
363 |
+
value=AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"],
|
364 |
+
)
|
365 |
+
|
366 |
+
AZURE_AVAILABLE = all(
|
367 |
+
[
|
368 |
+
AZURE_OPENAI_BASE_URL,
|
369 |
+
AZURE_OPENAI_API_VERSION,
|
370 |
+
AZURE_OPENAI_DEPLOYMENT_NAME,
|
371 |
+
AZURE_OPENAI_API_KEY,
|
372 |
+
AZURE_OPENAI_MODEL_VERSION,
|
373 |
+
],
|
374 |
+
)
|
375 |
+
|
376 |
|
377 |
# --- LLM Instantiation ---
|
378 |
if provider_api_key:
|
|
|
400 |
streaming=True,
|
401 |
max_tokens=max_tokens,
|
402 |
)
|
403 |
+
elif AZURE_AVAILABLE and st.session_state.provider == "AZURE":
|
404 |
+
st.session_state.llm = AzureChatOpenAI(
|
405 |
+
openai_api_base=AZURE_OPENAI_BASE_URL,
|
406 |
+
openai_api_version=AZURE_OPENAI_API_VERSION,
|
407 |
+
deployment_name=AZURE_OPENAI_DEPLOYMENT_NAME,
|
408 |
+
openai_api_key=AZURE_OPENAI_API_KEY,
|
409 |
+
openai_api_type="azure",
|
410 |
+
model_version=AZURE_OPENAI_MODEL_VERSION,
|
411 |
+
temperature=temperature,
|
412 |
+
streaming=True,
|
413 |
+
max_tokens=max_tokens,
|
414 |
+
)
|
415 |
|
416 |
# --- Chat History ---
|
417 |
if len(STMEMORY.messages) == 0:
|