Joshua Sundance Bailey commited on
Commit
bbf7b6c
·
1 Parent(s): 8534ca8

add support for anthropic and anysclae

Browse files
docker-compose.yml CHANGED
@@ -1,8 +1,8 @@
1
  version: '3.8'
2
 
3
  services:
4
- streamlit-chat:
5
- image: streamlit-chat:latest
6
  build: .
7
  ports:
8
  - "${APP_PORT:-7860}:${APP_PORT:-7860}"
 
1
  version: '3.8'
2
 
3
  services:
4
+ langchain-streamlit-demo:
5
+ image: langchain-streamlit-demo:latest
6
  build: .
7
  ports:
8
  - "${APP_PORT:-7860}:${APP_PORT:-7860}"
langchain-streamlit-demo/app.py CHANGED
@@ -4,58 +4,76 @@ from langchain.callbacks.manager import tracing_v2_enabled
4
  from langchain.callbacks.tracers.langchain import wait_for_all_tracers
5
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
6
  from langchain.schema.runnable import RunnableConfig
7
- from openai.error import AuthenticationError
 
 
8
 
9
  from llm_stuff import (
10
  _DEFAULT_SYSTEM_PROMPT,
11
- get_memory,
12
  get_llm_chain,
13
  StreamHandler,
14
  feedback_component,
15
- get_langsmith_client,
16
  )
17
 
18
  st.set_page_config(
19
- page_title="Chat LangSmith",
20
  page_icon="🦜",
21
  )
22
 
23
- # "# Chat🦜🛠️"
24
  # Initialize State
25
  if "trace_link" not in st.session_state:
26
  st.session_state.trace_link = None
27
  if "run_id" not in st.session_state:
28
  st.session_state.run_id = None
29
- st.sidebar.markdown(
30
- """
31
- # Menu
32
- """,
33
- )
34
 
35
- openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
36
- st.session_state.openai_api_key = openai_api_key
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  langsmith_api_key = st.sidebar.text_input(
39
  "LangSmith API Key (optional)",
40
  type="password",
41
  )
42
- st.session_state.langsmith_api_key = langsmith_api_key
43
- if st.session_state.langsmith_api_key.startswith("ls__"):
44
  langsmith_project = st.sidebar.text_input(
45
  "LangSmith Project Name",
46
  value="langchain-streamlit-demo",
47
  )
48
  os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
49
- os.environ["LANGCHAIN_API_KEY"] = st.session_state.langsmith_api_key
50
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
51
  os.environ["LANGCHAIN_PROJECT"] = langsmith_project
52
 
53
- client = get_langsmith_client()
54
  else:
55
  langsmith_project = None
56
  client = None
57
 
58
- if st.session_state.openai_api_key.startswith("sk-"):
59
  system_prompt = (
60
  st.sidebar.text_area(
61
  "Custom Instructions",
@@ -75,15 +93,13 @@ if st.session_state.openai_api_key.startswith("sk-"):
75
  help="Higher values give more random results.",
76
  )
77
 
78
- memory = get_memory()
79
-
80
- chain = get_llm_chain(memory, system_prompt, temperature)
81
 
82
  run_collector = RunCollectorCallbackHandler()
83
 
84
  if st.sidebar.button("Clear message history"):
85
  print("Clearing message history")
86
- memory.clear()
87
  st.session_state.trace_link = None
88
  st.session_state.run_id = None
89
 
@@ -103,7 +119,7 @@ if st.session_state.openai_api_key.startswith("sk-"):
103
  with st.chat_message(streamlit_type, avatar=avatar):
104
  st.markdown(msg.content)
105
 
106
- if st.session_state.trace_link:
107
  st.sidebar.markdown(
108
  f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: 🛠️</button></a>',
109
  unsafe_allow_html=True,
@@ -136,8 +152,8 @@ if st.session_state.openai_api_key.startswith("sk-"):
136
  {"input": prompt},
137
  config=runnable_config,
138
  )["text"]
139
- except AuthenticationError:
140
- st.error("Please enter a valid OpenAI API key.", icon="❌")
141
  st.stop()
142
  message_placeholder.markdown(full_response)
143
 
@@ -153,5 +169,5 @@ if st.session_state.openai_api_key.startswith("sk-"):
153
  feedback_component(client)
154
 
155
  else:
156
- st.error("Please enter a valid OpenAI API key.", icon="❌")
157
  st.stop()
 
4
  from langchain.callbacks.tracers.langchain import wait_for_all_tracers
5
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
6
  from langchain.schema.runnable import RunnableConfig
7
+ import openai
8
+ import anthropic
9
+ from langsmith.client import Client
10
 
11
  from llm_stuff import (
12
  _DEFAULT_SYSTEM_PROMPT,
 
13
  get_llm_chain,
14
  StreamHandler,
15
  feedback_component,
 
16
  )
17
 
18
  st.set_page_config(
19
+ page_title="langchain-streamlit-demo",
20
  page_icon="🦜",
21
  )
22
 
 
23
  # Initialize State
24
  if "trace_link" not in st.session_state:
25
  st.session_state.trace_link = None
26
  if "run_id" not in st.session_state:
27
  st.session_state.run_id = None
 
 
 
 
 
28
 
29
+ st.sidebar.markdown("# Menu")
30
+ models = [
31
+ "gpt-3.5-turbo",
32
+ "gpt-4",
33
+ "claude-instant-v1",
34
+ "claude-2",
35
+ "meta-llama/Llama-2-7b-chat-hf",
36
+ "meta-llama/Llama-2-13b-chat-hf",
37
+ "meta-llama/Llama-2-70b-chat-hf",
38
+ ]
39
+ model = st.sidebar.selectbox(label="Chat Model", options=models, index=0)
40
+
41
+ if model.startswith("gpt"):
42
+ provider = "OpenAI"
43
+ elif model.startswith("claude"):
44
+ provider = "Anthropic"
45
+ elif model.startswith("meta-llama"):
46
+ provider = "Anyscale"
47
+ else:
48
+ st.stop()
49
+
50
+ if not model:
51
+ st.error("Please select a model and provide an API key.", icon="❌")
52
+ st.stop()
53
+
54
+ provider_api_key = st.sidebar.text_input(f"{provider} API key", type="password")
55
 
56
  langsmith_api_key = st.sidebar.text_input(
57
  "LangSmith API Key (optional)",
58
  type="password",
59
  )
60
+
61
+ if langsmith_api_key.startswith("ls__"):
62
  langsmith_project = st.sidebar.text_input(
63
  "LangSmith Project Name",
64
  value="langchain-streamlit-demo",
65
  )
66
  os.environ["LANGCHAIN_ENDPOINT"] = "https://api.smith.langchain.com"
67
+ os.environ["LANGCHAIN_API_KEY"] = langsmith_api_key
68
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
69
  os.environ["LANGCHAIN_PROJECT"] = langsmith_project
70
 
71
+ client = Client(api_key=langsmith_api_key)
72
  else:
73
  langsmith_project = None
74
  client = None
75
 
76
+ if provider_api_key:
77
  system_prompt = (
78
  st.sidebar.text_area(
79
  "Custom Instructions",
 
93
  help="Higher values give more random results.",
94
  )
95
 
96
+ chain = get_llm_chain(model, provider_api_key, system_prompt, temperature)
 
 
97
 
98
  run_collector = RunCollectorCallbackHandler()
99
 
100
  if st.sidebar.button("Clear message history"):
101
  print("Clearing message history")
102
+ chain.memory.clear()
103
  st.session_state.trace_link = None
104
  st.session_state.run_id = None
105
 
 
119
  with st.chat_message(streamlit_type, avatar=avatar):
120
  st.markdown(msg.content)
121
 
122
+ if client and st.session_state.trace_link:
123
  st.sidebar.markdown(
124
  f'<a href="{st.session_state.trace_link}" target="_blank"><button>Latest Trace: 🛠️</button></a>',
125
  unsafe_allow_html=True,
 
152
  {"input": prompt},
153
  config=runnable_config,
154
  )["text"]
155
+ except (openai.error.AuthenticationError, anthropic.AuthenticationError):
156
+ st.error("Please enter a valid {provider} API key.", icon="❌")
157
  st.stop()
158
  message_placeholder.markdown(full_response)
159
 
 
169
  feedback_component(client)
170
 
171
  else:
172
+ st.error(f"Please enter a valid {provider} API key.", icon="❌")
173
  st.stop()
langchain-streamlit-demo/llm_stuff.py CHANGED
@@ -3,21 +3,14 @@ from datetime import datetime
3
  import streamlit as st
4
  from langchain import LLMChain
5
  from langchain.callbacks.base import BaseCallbackHandler
6
- from langchain.chat_models import ChatOpenAI
7
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
8
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
9
- from langsmith.client import Client
10
  from streamlit_feedback import streamlit_feedback
11
 
12
  _DEFAULT_SYSTEM_PROMPT = "You are a helpful chatbot."
13
 
14
 
15
- def get_langsmith_client():
16
- return Client(
17
- api_key=st.session_state.langsmith_api_key,
18
- )
19
-
20
-
21
  def get_memory() -> ConversationBufferMemory:
22
  return ConversationBufferMemory(
23
  chat_memory=StreamlitChatMessageHistory(key="langchain_messages"),
@@ -26,8 +19,39 @@ def get_memory() -> ConversationBufferMemory:
26
  )
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_llm_chain(
30
- memory: ConversationBufferMemory,
 
31
  system_prompt: str = _DEFAULT_SYSTEM_PROMPT,
32
  temperature: float = 0.7,
33
  ) -> LLMChain:
@@ -42,11 +66,8 @@ def get_llm_chain(
42
  ("human", "{input}"),
43
  ],
44
  ).partial(time=lambda: str(datetime.now()))
45
- llm = ChatOpenAI(
46
- temperature=temperature,
47
- streaming=True,
48
- openai_api_key=st.session_state.openai_api_key,
49
- )
50
  return LLMChain(prompt=prompt, llm=llm, memory=memory or get_memory())
51
 
52
 
 
3
  import streamlit as st
4
  from langchain import LLMChain
5
  from langchain.callbacks.base import BaseCallbackHandler
6
+ from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
7
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
8
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
 
9
  from streamlit_feedback import streamlit_feedback
10
 
11
  _DEFAULT_SYSTEM_PROMPT = "You are a helpful chatbot."
12
 
13
 
 
 
 
 
 
 
14
  def get_memory() -> ConversationBufferMemory:
15
  return ConversationBufferMemory(
16
  chat_memory=StreamlitChatMessageHistory(key="langchain_messages"),
 
19
  )
20
 
21
 
22
+ def get_llm(
23
+ model: str,
24
+ provider_api_key: str,
25
+ temperature,
26
+ ):
27
+ if model.startswith("gpt"):
28
+ return ChatOpenAI(
29
+ model=model,
30
+ openai_api_key=provider_api_key,
31
+ temperature=temperature,
32
+ streaming=True,
33
+ )
34
+ elif model.startswith("claude"):
35
+ return ChatAnthropic(
36
+ model_name=model,
37
+ anthropic_api_key=provider_api_key,
38
+ temperature=temperature,
39
+ streaming=True,
40
+ )
41
+ elif model.startswith("meta-llama"):
42
+ return ChatAnyscale(
43
+ model=model,
44
+ anyscale_api_key=provider_api_key,
45
+ temperature=temperature,
46
+ streaming=True,
47
+ )
48
+ else:
49
+ raise NotImplementedError(f"Unknown model {model}")
50
+
51
+
52
  def get_llm_chain(
53
+ model: str,
54
+ provider_api_key: str,
55
  system_prompt: str = _DEFAULT_SYSTEM_PROMPT,
56
  temperature: float = 0.7,
57
  ) -> LLMChain:
 
66
  ("human", "{input}"),
67
  ],
68
  ).partial(time=lambda: str(datetime.now()))
69
+ memory = get_memory()
70
+ llm = get_llm(model, provider_api_key, temperature)
 
 
 
71
  return LLMChain(prompt=prompt, llm=llm, memory=memory or get_memory())
72
 
73