John Graham Reynolds commited on
Commit
f45b463
·
1 Parent(s): 7edfd1a

add build out of app using streamlit from DBRX template

Browse files
Files changed (1) hide show
  1. app.py +200 -39
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import os
 
2
  import streamlit as st
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_databricks.vectorstores import DatabricksVectorSearch
 
5
 
6
  DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST")
7
  DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN")
@@ -13,19 +15,45 @@ if DATABRICKS_HOST is None:
13
  if DATABRICKS_TOKEN is None:
14
  raise ValueError("DATABRICKS_API_TOKEN environment variable must be set")
15
 
16
- TITLE = "VUMC Chatbot"
17
- DESCRIPTION="The first generation VUMC chatbot with knowledge of Vanderbilt specific terms."
 
 
 
18
  EXAMPLE_PROMPTS = [
19
- "Write a short story about a robot that has a nice day.",
20
- "In a table, what are some of the most common misconceptions about birds?",
21
- "Give me a recipe for vegan banana bread.",
22
- "Code a python function that can run merge sort on a list.",
23
- "Give me the character profile of a gumdrop obsessed knight in JSON.",
24
- "Write a rap battle between Alan Turing and Claude Shannon.",
 
25
  ]
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  st.set_page_config(layout="wide")
 
 
 
 
 
 
28
  st.title(TITLE)
 
29
  st.markdown(DESCRIPTION)
30
  st.markdown("\n")
31
 
@@ -33,10 +61,34 @@ st.markdown("\n")
33
  with open("style.css") as css:
34
  st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  # Same embedding model we used to create embeddings of terms
37
  # make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
38
- embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en", cache_folder="./langchain_cache/")
 
 
 
 
 
39
 
 
 
 
 
40
  vector_store = DatabricksVectorSearch(
41
  endpoint=VS_ENDPOINT_NAME,
42
  index_name=VS_INDEX_NAME,
@@ -45,33 +97,142 @@ vector_store = DatabricksVectorSearch(
45
  columns=["name", "description"],
46
  )
47
 
48
- results = vector_store.similarity_search(query="Tell me about what a data lake is.", k=5)
49
- st.write(results)
50
-
51
- # DBRX mainbody minus functions
52
-
53
- # main = st.container()
54
- # with main:
55
- # history = st.container(height=400)
56
- # with history:
57
- # for message in st.session_state["messages"]:
58
- # avatar = None
59
- # if message["role"] == "assistant":
60
- # avatar = MODEL_AVATAR_URL
61
- # with st.chat_message(message["role"],avatar=avatar):
62
- # if message["content"] is not None:
63
- # st.markdown(message["content"])
64
- # if message["error"] is not None:
65
- # st.error(message["error"],icon="🚨")
66
- # if message["warning"] is not None:
67
- # st.warning(message["warning"],icon="⚠️")
68
-
69
- # if prompt := st.chat_input("Type a message!", max_chars=1000):
70
- # handle_user_input(prompt)
71
- # st.markdown("\n") #add some space for iphone users
72
-
73
- # with st.sidebar:
74
- # with st.container():
75
- # st.title("Examples")
76
- # for prompt in EXAMPLE_PROMPTS:
77
- # st.button(prompt, args=(prompt,), on_click=handle_user_input)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import threading
3
  import streamlit as st
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
  from langchain_databricks.vectorstores import DatabricksVectorSearch
6
+ from itertools import tee
7
 
8
  DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST")
9
  DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN")
 
15
  if DATABRICKS_TOKEN is None:
16
  raise ValueError("DATABRICKS_API_TOKEN environment variable must be set")
17
 
18
+ MODEL_AVATAR_URL= "./VU.jpeg"
19
+
20
+ # MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns. Click the 'Clear Chat' button or refresh the page to start a new conversation."
21
+ # MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"
22
+
23
  EXAMPLE_PROMPTS = [
24
+ "Tell me about maximum out-of-pocket costs in healthcare."
25
+ "Write a haiku about Nashville, Tennessee."
26
+ "How is a data lake used at Vanderbilt University Medical Center?",
27
+ "In a table, what are some of the greatest hurdles to healthcare in the United States?",
28
+ "What does EDW stand for in the context of Vanderbilt University Medical Center?",
29
+ "Code a sql statement that can query a database named 'VUMC'.",
30
+ "Write a short story about a country concert in Nashville, Tennessee.",
31
  ]
32
 
33
+ TITLE = "VUMC Chatbot"
34
+ DESCRIPTION="""Welcome to the first generation Vanderbilt AI assistant! This AI assistant is built atop the Databricks DBRX large language model
35
+ and is augmented with additional organization-specific knowledge. Specifically, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center
36
+ terms like **Data Lake**, **EDW** (Enterprise Data Warehouse), **HCERA** (Health Care and Education Reconciliation Act), and **thousands more!** The model has **no access to PHI**.
37
+ Try querying the model with any of the examples prompts below for a simple introduction to both Vanderbilt-specific and general knowledge queries. The purpose of this
38
+ model is to allow VUMC employees access to an intelligent assistant that improves and expedites VUMC work. Please provide any feedback, ideas, or issues to the email: **[email protected]**.
39
+ Feedback and ideas are very welcome! We hope to gradually improve this AI assistant to create a large-scale, all-inclusive tool to compliment the work of all VUMC staff."""
40
+
41
+ GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."
42
+
43
+ # @st.cache_resource
44
+ # def get_global_semaphore():
45
+ # return threading.BoundedSemaphore(QUEUE_SIZE)
46
+ # global_semaphore = get_global_semaphore()
47
+
48
  st.set_page_config(layout="wide")
49
+
50
+ # # To prevent streaming to fast, chunk the output into TOKEN_CHUNK_SIZE chunks
51
+ TOKEN_CHUNK_SIZE = 1
52
+ # if TOKEN_CHUNK_SIZE_ENV is not None:
53
+ # TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)
54
+
55
  st.title(TITLE)
56
+ # st.image("sunrise.jpg", caption="Sunrise by the mountains") # add a Vanderbilt related picture to the head of our Space!
57
  st.markdown(DESCRIPTION)
58
  st.markdown("\n")
59
 
 
61
  with open("style.css") as css:
62
  st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)
63
 
64
+ if "messages" not in st.session_state:
65
+ st.session_state["messages"] = []
66
+
67
+ def clear_chat_history():
68
+ st.session_state["messages"] = []
69
+
70
+ st.button('Clear Chat', on_click=clear_chat_history)
71
+
72
+ def last_role_is_user():
73
+ return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user"
74
+
75
+ def get_system_prompt():
76
+ return ""
77
+
78
+ # ** working logic for querying glossary embeddings
79
  # Same embedding model we used to create embeddings of terms
80
  # make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
81
+ # try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
82
+ # does this cache to the given folder though? It does appear to populate the folder as expected after being run
83
+ @st.experimental_memo
84
+ def load_embedding_model():
85
+ embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en", cache_folder="./langchain_cache/")
86
+ return embeddings
87
 
88
+ embeddings = load_embedding_model()
89
+ # instantiate the vector store for similarity search in our chain
90
+ # need to make this a function and decorate it with @st.experimental_memo as above?
91
+ # We are only calling this initially when the Space starts. Can we expedite this process for users when opening up this Space?
92
  vector_store = DatabricksVectorSearch(
93
  endpoint=VS_ENDPOINT_NAME,
94
  index_name=VS_INDEX_NAME,
 
97
  columns=["name", "description"],
98
  )
99
 
100
+ def text_stream(stream):
101
+ for chunk in stream:
102
+ if chunk["content"] is not None:
103
+ yield chunk["content"]
104
+
105
+ def get_stream_warning_error(stream):
106
+ error = None
107
+ warning = None
108
+ # for chunk in stream:
109
+ # if chunk["error"] is not None:
110
+ # error = chunk["error"]
111
+ # if chunk["warning"] is not None:
112
+ # warning = chunk["warning"]
113
+ return warning, error
114
+
115
+ # @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
116
+ def chat_api_call(history):
117
+ # *** original code for instantiating the DBRX model through the OpenAI client *** skip this and introduce our chain eventually
118
+ # extra_body = {}
119
+ # if SAFETY_FILTER:
120
+ # extra_body["enable_safety_filter"] = SAFETY_FILTER
121
+ # chat_completion = client.chat.completions.create(
122
+ # messages=[
123
+ # {"role": m["role"], "content": m["content"]}
124
+ # for m in history
125
+ # ],
126
+ # model="databricks-dbrx-instruct",
127
+ # stream=True,
128
+ # max_tokens=MAX_TOKENS,
129
+ # temperature=0.7,
130
+ # extra_body= extra_body
131
+ # )
132
+
133
+ # ** TODO update this next to take and do similarity search on user input!
134
+ search_result = vector_store.similarity_search(query="Tell me about what a data lake is.", k=5)
135
+ chat_completion = search_result # TODO update this after we implement our chain
136
+ return chat_completion
137
+
138
+ def write_response():
139
+ stream = chat_completion(st.session_state["messages"])
140
+ content_stream, error_stream = tee(stream)
141
+ response = st.write_stream(text_stream(content_stream))
142
+ stream_warning, stream_error = get_stream_warning_error(error_stream)
143
+ if stream_warning is not None:
144
+ st.warning(stream_warning,icon="⚠️")
145
+ if stream_error is not None:
146
+ st.error(stream_error,icon="🚨")
147
+ # if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream
148
+ if isinstance(response, list):
149
+ response = None
150
+ return response, stream_warning, stream_error
151
+
152
+ def chat_completion(messages):
153
+ history_dbrx_format = [
154
+ {"role": "system", "content": get_system_prompt()}
155
+ ]
156
+
157
+ history_dbrx_format = history_dbrx_format + messages
158
+ # if (len(history_dbrx_format)-1)//2 >= MAX_CHAT_TURNS:
159
+ # yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
160
+ # return
161
+
162
+ chat_completion = None
163
+ error = None
164
+ # *** original code for querying DBRX through the OpenAI cleint for chat completion
165
+ # wait to be in queue
166
+ # with global_semaphore:
167
+ # try:
168
+ # chat_completion = chat_api_call(history_dbrx_format)
169
+ # except Exception as e:
170
+ # error = e
171
+ chat_completion = chat_api_call(history_dbrx_format)
172
+ if error is not None:
173
+ yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
174
+ print(error)
175
+ return
176
+
177
+ max_token_warning = None
178
+ partial_message = ""
179
+ chunk_counter = 0
180
+ for chunk in chat_completion:
181
+ # if chunk.choices[0].delta.content is not None:
182
+ if chunk.page_content is not None:
183
+ chunk_counter += 1
184
+ # partial_message += chunk.choices[0].delta.content
185
+ partial_message += f"* {chunk.page_content} [{chunk.metadata}]"
186
+ if chunk_counter % TOKEN_CHUNK_SIZE == 0:
187
+ chunk_counter = 0
188
+ yield {"content": partial_message, "error": None, "warning": None}
189
+ partial_message = ""
190
+ # if chunk.choices[0].finish_reason == "length":
191
+ # max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS
192
+
193
+ yield {"content": partial_message, "error": None, "warning": max_token_warning}
194
+
195
+ # if assistant is the last message, we need to prompt the user
196
+ # if user is the last message, we need to retry the assistant.
197
+ def handle_user_input(user_input):
198
+ with history:
199
+ response, stream_warning, stream_error = [None, None, None]
200
+ if last_role_is_user():
201
+ # retry the assistant if the user tries to send a new message
202
+ with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
203
+ response, stream_warning, stream_error = write_response()
204
+ else:
205
+ st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None})
206
+ with st.chat_message("user"):
207
+ st.markdown(user_input)
208
+ stream = chat_completion(st.session_state["messages"])
209
+ with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
210
+ response, stream_warning, stream_error = write_response()
211
+
212
+ st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error})
213
+
214
+ main = st.container()
215
+ with main:
216
+ history = st.container(height=400)
217
+ with history:
218
+ for message in st.session_state["messages"]:
219
+ avatar = "🧑‍💻"
220
+ if message["role"] == "assistant":
221
+ avatar = MODEL_AVATAR_URL
222
+ with st.chat_message(message["role"],avatar=avatar):
223
+ if message["content"] is not None:
224
+ st.markdown(message["content"])
225
+ # if message["error"] is not None:
226
+ # st.error(message["error"],icon="🚨")
227
+ # if message["warning"] is not None:
228
+ # st.warning(message["warning"],icon="⚠️")
229
+
230
+ if prompt := st.chat_input("Type a message!", max_chars=1000):
231
+ handle_user_input(prompt)
232
+ st.markdown("\n") #add some space for iphone users
233
+
234
+ with st.sidebar:
235
+ with st.container():
236
+ st.title("Examples")
237
+ for prompt in EXAMPLE_PROMPTS:
238
+ st.button(prompt, args=(prompt,), on_click=handle_user_input)