Joshua Sundance Bailey commited on
Commit
72c3d8c
·
1 Parent(s): e4b72fe

update langsmith & add support for azure chat

Browse files
Files changed (1) hide show
  1. langchain-streamlit-demo/app.py +74 -7
langchain-streamlit-demo/app.py CHANGED
@@ -12,7 +12,12 @@ from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_
12
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
13
  from langchain.chains import RetrievalQA
14
  from langchain.chains.llm import LLMChain
15
- from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
 
 
 
 
 
16
  from langchain.document_loaders import PyPDFLoader
17
  from langchain.embeddings import OpenAIEmbeddings
18
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
@@ -90,6 +95,7 @@ MODEL_DICT = {
90
  "meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
91
  "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
92
  "codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
 
93
  }
94
  SUPPORTED_MODELS = list(MODEL_DICT.keys())
95
 
@@ -107,6 +113,17 @@ MIN_MAX_TOKENS = int(os.environ.get("MIN_MAX_TOKENS", 1))
107
  MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
108
  DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
109
  DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
 
 
 
 
 
 
 
 
 
 
 
110
  PROVIDER_KEY_DICT = {
111
  "OpenAI": os.environ.get("OPENAI_API_KEY", ""),
112
  "Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
@@ -173,11 +190,16 @@ with sidebar:
173
 
174
  st.session_state.provider = MODEL_DICT[model]
175
 
176
- provider_api_key = PROVIDER_KEY_DICT.get(
177
- st.session_state.provider,
178
- ) or st.text_input(
179
- f"{st.session_state.provider} API key",
180
- type="password",
 
 
 
 
 
181
  )
182
 
183
  if st.button("Clear message history"):
@@ -317,6 +339,40 @@ with sidebar:
317
  client=st.session_state.client,
318
  )
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  # --- LLM Instantiation ---
322
  if provider_api_key:
@@ -344,7 +400,18 @@ if provider_api_key:
344
  streaming=True,
345
  max_tokens=max_tokens,
346
  )
347
-
 
 
 
 
 
 
 
 
 
 
 
348
 
349
  # --- Chat History ---
350
  if len(STMEMORY.messages) == 0:
 
12
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
13
  from langchain.chains import RetrievalQA
14
  from langchain.chains.llm import LLMChain
15
+ from langchain.chat_models import (
16
+ AzureChatOpenAI,
17
+ ChatAnthropic,
18
+ ChatAnyscale,
19
+ ChatOpenAI,
20
+ )
21
  from langchain.document_loaders import PyPDFLoader
22
  from langchain.embeddings import OpenAIEmbeddings
23
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
 
95
  "meta-llama/Llama-2-13b-chat-hf": "Anyscale Endpoints",
96
  "meta-llama/Llama-2-70b-chat-hf": "Anyscale Endpoints",
97
  "codellama/CodeLlama-34b-Instruct-hf": "Anyscale Endpoints",
98
+ "AZURE": "AZURE",
99
  }
100
  SUPPORTED_MODELS = list(MODEL_DICT.keys())
101
 
 
113
  MAX_MAX_TOKENS = int(os.environ.get("MAX_MAX_TOKENS", 100000))
114
  DEFAULT_MAX_TOKENS = int(os.environ.get("DEFAULT_MAX_TOKENS", 1000))
115
  DEFAULT_LANGSMITH_PROJECT = os.environ.get("LANGCHAIN_PROJECT")
116
+
117
+ AZURE_VARS = [
118
+ "AZURE_OPENAI_BASE_URL",
119
+ "AZURE_OPENAI_API_VERSION",
120
+ "AZURE_OPENAI_DEPLOYMENT_NAME",
121
+ "AZURE_OPENAI_API_KEY",
122
+ "AZURE_OPENAI_MODEL_VERSION",
123
+ ]
124
+
125
+ AZURE_DICT = {v: os.environ.get(v, "") for v in AZURE_VARS}
126
+
127
  PROVIDER_KEY_DICT = {
128
  "OpenAI": os.environ.get("OPENAI_API_KEY", ""),
129
  "Anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
 
190
 
191
  st.session_state.provider = MODEL_DICT[model]
192
 
193
+ provider_api_key = (
194
+ PROVIDER_KEY_DICT.get(
195
+ st.session_state.provider,
196
+ )
197
+ or st.text_input(
198
+ f"{st.session_state.provider} API key",
199
+ type="password",
200
+ )
201
+ if st.session_state.provider != "AZURE"
202
+ else ""
203
  )
204
 
205
  if st.button("Clear message history"):
 
339
  client=st.session_state.client,
340
  )
341
 
342
+ # --- Azure Options ---
343
+ with st.expander("Azure Options", expanded=False):
344
+ AZURE_OPENAI_BASE_URL = st.text_input(
345
+ "AZURE_OPENAI_BASE_URL",
346
+ value=AZURE_DICT["AZURE_OPENAI_BASE_URL"],
347
+ )
348
+ AZURE_OPENAI_API_VERSION = st.text_input(
349
+ "AZURE_OPENAI_API_VERSION",
350
+ value=AZURE_DICT["AZURE_OPENAI_API_VERSION"],
351
+ )
352
+ AZURE_OPENAI_DEPLOYMENT_NAME = st.text_input(
353
+ "AZURE_OPENAI_DEPLOYMENT_NAME",
354
+ value=AZURE_DICT["AZURE_OPENAI_DEPLOYMENT_NAME"],
355
+ )
356
+ AZURE_OPENAI_API_KEY = st.text_input(
357
+ "AZURE_OPENAI_API_KEY",
358
+ value=AZURE_DICT["AZURE_OPENAI_API_KEY"],
359
+ type="password",
360
+ )
361
+ AZURE_OPENAI_MODEL_VERSION = st.text_input(
362
+ "AZURE_OPENAI_MODEL_VERSION",
363
+ value=AZURE_DICT["AZURE_OPENAI_MODEL_VERSION"],
364
+ )
365
+
366
+ AZURE_AVAILABLE = all(
367
+ [
368
+ AZURE_OPENAI_BASE_URL,
369
+ AZURE_OPENAI_API_VERSION,
370
+ AZURE_OPENAI_DEPLOYMENT_NAME,
371
+ AZURE_OPENAI_API_KEY,
372
+ AZURE_OPENAI_MODEL_VERSION,
373
+ ],
374
+ )
375
+
376
 
377
  # --- LLM Instantiation ---
378
  if provider_api_key:
 
400
  streaming=True,
401
  max_tokens=max_tokens,
402
  )
403
+ elif AZURE_AVAILABLE and st.session_state.provider == "AZURE":
404
+ st.session_state.llm = AzureChatOpenAI(
405
+ openai_api_base=AZURE_OPENAI_BASE_URL,
406
+ openai_api_version=AZURE_OPENAI_API_VERSION,
407
+ deployment_name=AZURE_OPENAI_DEPLOYMENT_NAME,
408
+ openai_api_key=AZURE_OPENAI_API_KEY,
409
+ openai_api_type="azure",
410
+ model_version=AZURE_OPENAI_MODEL_VERSION,
411
+ temperature=temperature,
412
+ streaming=True,
413
+ max_tokens=max_tokens,
414
+ )
415
 
416
  # --- Chat History ---
417
  if len(STMEMORY.messages) == 0: