John Graham Reynolds commited on
Commit
676f0fa
Β·
1 Parent(s): 81f74ed

remove comments, add MAX_CHAT_TURNS, and other comments orienting user

Browse files
Files changed (1) hide show
  1. app.py +21 -87
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import time
3
  import threading
4
  import streamlit as st
5
  from itertools import tee
@@ -17,8 +16,8 @@ if DATABRICKS_TOKEN is None:
17
  raise ValueError("DATABRICKS_TOKEN environment variable must be set")
18
 
19
  MODEL_AVATAR_URL= "./VU.jpeg"
20
-
21
- # MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns. Click the 'Clear Chat' button or refresh the page to start a new conversation."
22
  # MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"
23
 
24
  EXAMPLE_PROMPTS = [
@@ -42,6 +41,15 @@ We hope to gradually improve this AI assistant to create a large-scale, all-incl
42
 
43
  GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."
44
 
 
 
 
 
 
 
 
 
 
45
  # @st.cache_resource
46
  # def get_global_semaphore():
47
  # return threading.BoundedSemaphore(QUEUE_SIZE)
@@ -49,13 +57,8 @@ GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new c
49
 
50
  st.set_page_config(layout="wide")
51
 
52
- # # To prevent streaming to fast, chunk the output into TOKEN_CHUNK_SIZE chunks
53
- TOKEN_CHUNK_SIZE = 1
54
- # if TOKEN_CHUNK_SIZE_ENV is not None:
55
- # TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)
56
-
57
  st.title(TITLE)
58
- # st.image("sunrise.jpg", caption="Sunrise by the mountains") # add a Vanderbilt related picture to the head of our Space!
59
  st.markdown(DESCRIPTION)
60
  st.markdown("\n")
61
 
@@ -96,42 +99,8 @@ def get_stream_warning_error(stream):
96
  return warning, error
97
 
98
  # @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
99
- def chain_call(history):
100
- # *** original code for instantiating the DBRX model through the OpenAI client *** skip this and introduce our chain eventually
101
- # extra_body = {}
102
- # if SAFETY_FILTER:
103
- # extra_body["enable_safety_filter"] = SAFETY_FILTER
104
- # chat_completion = client.chat.completions.create(
105
- # messages=[
106
- # {"role": m["role"], "content": m["content"]}
107
- # for m in history
108
- # ],
109
- # model="databricks-dbrx-instruct",
110
- # stream=True,
111
- # max_tokens=MAX_TOKENS,
112
- # temperature=0.7,
113
- # extra_body= extra_body
114
- # )
115
-
116
- # *** can we stream the chain's response by incorporating the above OpenAI streaming functionality?
117
- # *** Look back at the predict_stream function and see if we can incorporate that!
118
- # *** looks like we want to use either chain.stream() or chain.astream()
119
- # test first with invoke
120
-
121
- input_example = {'messages':
122
- [{'content': 'What does EDW stand for?', 'role': 'user'},
123
- {'content': 'Enterprise Data Warehouse.', 'role': 'assistant'},
124
- {'content': 'Thank you. What is the data lake?', 'role': 'user'},
125
- {'content': 'A data lake is a centralized repository of structured and unstructured data. It allows data to be stored in its native state, without the need for transformations, so that it can be consumed by other users later. It is not just a term for storage, but also covers functionalities required for a platform, including data analysis, machine learning, cataloging and data movement.', 'role': 'assistant'},
126
- {'content': 'Can you tell me more about how they are used?', 'role': 'user'},
127
- {'content': 'At Vanderbilt University Medical Center, a data lake is used as a centralized repository for storing and managing large amounts of data in its native format. This allows for the data to be easily accessed and analyzed by different teams and business units within the organization. The data lake also provides functionalities such as data analysis, machine learning, cataloging and data movement, making it a versatile tool for handling diverse data sets.\n\nAn Enterprise Data Warehouse (EDW) is used for executing analytic queries on structured data. It is optimized for this purpose, with data being stored in a way that allows for efficient querying and analysis. This makes it a useful tool for teams that need to perform complex analyses on large data sets.\n\nA data mart is a specific organizational structure or pattern used in the context of data warehouses. It is a layer that has specific subdivisions for each business unit or team, such as finance, marketing, and product. This allows users to consume data in a format that meets their specific needs.\n\nA data lakehouse is a term used to describe approaches that attempt to combine the data structure and management features of a data warehouse with the low cost of storage of a data lake. This includes a structured transactional layer, which allows for efficient querying and analysis of data. This approach aims to provide the benefits of both data lakes and data warehouses in a single platform.', 'role': 'assistant'},
128
- {'content': 'Nice answer. Can you tell me what the HCERA is?', 'role': 'user'}]}
129
-
130
  input = {'messages': [{"role": m["role"], "content": m["content"]} for m in history]}
131
-
132
- # search_result = vector_store.similarity_search(query=st.session_state["messages"][-1]["content"], k=5)
133
- # chat_completion = search_result # TODO update this after we implement our chain
134
- # chat_completion = chain.invoke(input_example) # *** TODO here we will pass only the chat history, the chain handles the system prompt
135
  chat_completion = chain.stream(input)
136
  return chat_completion
137
 
@@ -150,18 +119,13 @@ def write_response():
150
  return response, stream_warning, stream_error
151
 
152
  def chat_completion(messages):
153
- # history_dbrx_format = [
154
- # {"role": "system", "content": SYSTEM_PROMPT} # no longer need this because the chain handles all of this for us
155
- # ]
156
-
157
- # history_dbrx_format = history_dbrx_format + messages
158
- # if (len(history_dbrx_format)-1)//2 >= MAX_CHAT_TURNS:
159
- # yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
160
- # return
161
 
162
  chat_completion = None
163
  error = None
164
- # *** original code for querying DBRX through the OpenAI cleint for chat completion
165
  # wait to be in queue
166
  # with global_semaphore:
167
  # try:
@@ -169,7 +133,7 @@ def chat_completion(messages):
169
  # except Exception as e:
170
  # error = e
171
  # chat_completion = chain_call(history_dbrx_format)
172
- chat_completion = chain_call(messages) # simply pass the old messages, need not worry about the system prompt
173
  if error is not None:
174
  yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
175
  print(error)
@@ -179,13 +143,8 @@ def chat_completion(messages):
179
  partial_message = ""
180
  chunk_counter = 0
181
  for chunk in chat_completion:
182
- # if chunk.choices[0].delta.content is not None:
183
- # TODO *** we need to refactor this logic to match what happens with the response from our chain - it should be strings or an iterator of strings
184
- # if chunk.page_content is not None:
185
  if chunk is not None:
186
  chunk_counter += 1
187
- # partial_message += chunk.choices[0].delta.content
188
- # partial_message += f"* {chunk.page_content} [{chunk.metadata}]"
189
  partial_message += chunk
190
  if chunk_counter % TOKEN_CHUNK_SIZE == 0:
191
  chunk_counter = 0
@@ -218,34 +177,16 @@ def handle_user_input(user_input):
218
  def feedback():
219
  with st.form("feedback_form"):
220
  st.title("Feedback Form")
221
- # sentiment_mapping = [":material/thumb_down:", ":material/thumb_up:"]
222
- # rating = None
223
- # while not rating:
224
  rating = st.feedback()
225
- # feedback = st.text_input(f"Please detail your rationale for choosing {sentiment_mapping[rating]}: ", "")
226
  feedback = st.text_input("Please detail your feedback: ")
227
- # feedback = ""
228
- # review = {}
229
- # if rating is not None:
230
- # # st.markdown(f"You selected: {sentiment_mapping[rating]}")
231
- # # rating = st.radio("Rate your experience:", ["πŸ‘", "Neutral", "πŸ‘Ž"])
232
- # review = {"rating": {rating}, "feedback": {feedback}}
233
  submitted = st.form_submit_button("Submit Feedback")
234
- # if submitted:
235
- # st.write(f"The feedback was: {sentiment_mapping[rating]} : {feedback}")
236
-
237
- # st.markdown(review)
238
- # time.sleep(5)
239
- # # Save the feedback data
240
- # if st.button("Submit"):
241
- # with open("feedback.json", "a") as f:
242
- # f.write()
243
- # st.write("Thank you for your feedback!")
244
 
245
  main = st.container()
246
  with main:
247
  if st.session_state["feedback"][-1] is not None: # TODO clean this up in a fn?
248
- st.markdown("Thank you! Feedback received! Type a new message to continue.")
249
  history = st.container(height=400)
250
  with history:
251
  for message in st.session_state["messages"]:
@@ -255,9 +196,6 @@ with main:
255
  with st.chat_message(message["role"], avatar=avatar):
256
  if message["content"] is not None:
257
  st.markdown(message["content"])
258
- # receive feedback on AI outputs if the user feels inclined to give it
259
- # rating = st.radio("Rate your experience:", ["Very satisfied", "Somewhat satisfied", "Neutral", "Somewhat dissatisfied", "Very dissatisfied"])
260
- # st.button("Provide Feedback", on_click=feedback)
261
  if message["error"] is not None:
262
  st.error(message["error"],icon="🚨")
263
  if message["warning"] is not None:
@@ -266,15 +204,11 @@ with main:
266
  if prompt := st.chat_input("Type a message!", max_chars=5000):
267
  handle_user_input(prompt)
268
  st.markdown("\n") #add some space for iphone users
269
- # with st.container():
270
  gave_feedback = st.button('Give Feedback on Last Response', on_click=feedback)
271
  if gave_feedback: # TODO clean up the conditions here with a function
272
  st.session_state["feedback"].append("given")
273
  else:
274
  st.session_state["feedback"].append(None)
275
- # st.markdown("Feedback received! Thank you for your insight.")
276
- # time.sleep(3)
277
- # add st.session_state["feedback"] var for keeping track of when the user gives feedback!
278
 
279
 
280
  with st.sidebar:
 
1
  import os
 
2
  import threading
3
  import streamlit as st
4
  from itertools import tee
 
16
  raise ValueError("DATABRICKS_TOKEN environment variable must be set")
17
 
18
  MODEL_AVATAR_URL= "./VU.jpeg"
19
+ MAX_CHAT_TURNS = 5 # limit this for preliminary testing
20
+ MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns in a single history. Click the 'Clear Chat' button or refresh the page to start a new conversation."
21
  # MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"
22
 
23
  EXAMPLE_PROMPTS = [
 
41
 
42
  GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."
43
 
44
+ # # To prevent streaming too fast, chunk the output into TOKEN_CHUNK_SIZE chunks
45
+ TOKEN_CHUNK_SIZE = 2 # test this number
46
+ # if TOKEN_CHUNK_SIZE_ENV is not None:
47
+ # TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)
48
+
49
+ QUEUE_SIZE = 20 # maximize this value for adding enough places in the global queue?
50
+ # if QUEUE_SIZE_ENV is not None:
51
+ # QUEUE_SIZE = int(QUEUE_SIZE_ENV)
52
+
53
  # @st.cache_resource
54
  # def get_global_semaphore():
55
  # return threading.BoundedSemaphore(QUEUE_SIZE)
 
57
 
58
  st.set_page_config(layout="wide")
59
 
 
 
 
 
 
60
  st.title(TITLE)
61
+ # st.image("sunrise.jpg", caption="Sunrise by the mountains") # TODO add a Vanderbilt related picture to the head of our Space!
62
  st.markdown(DESCRIPTION)
63
  st.markdown("\n")
64
 
 
99
  return warning, error
100
 
101
  # @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
102
+ def chain_call(history):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  input = {'messages': [{"role": m["role"], "content": m["content"]} for m in history]}
 
 
 
 
104
  chat_completion = chain.stream(input)
105
  return chat_completion
106
 
 
119
  return response, stream_warning, stream_error
120
 
121
  def chat_completion(messages):
122
+ if (len(messages)-1)//2 >= MAX_CHAT_TURNS:
123
+ yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
124
+ return
 
 
 
 
 
125
 
126
  chat_completion = None
127
  error = None
128
+ # *** TODO add code for implementing a global queue with a bounded semaphore?
129
  # wait to be in queue
130
  # with global_semaphore:
131
  # try:
 
133
  # except Exception as e:
134
  # error = e
135
  # chat_completion = chain_call(history_dbrx_format)
136
+ chat_completion = chain_call(messages)
137
  if error is not None:
138
  yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
139
  print(error)
 
143
  partial_message = ""
144
  chunk_counter = 0
145
  for chunk in chat_completion:
 
 
 
146
  if chunk is not None:
147
  chunk_counter += 1
 
 
148
  partial_message += chunk
149
  if chunk_counter % TOKEN_CHUNK_SIZE == 0:
150
  chunk_counter = 0
 
177
  def feedback():
178
  with st.form("feedback_form"):
179
  st.title("Feedback Form")
180
+ st.markdown("Please select either πŸ‘ or πŸ‘Ž before providing a reason for your review of the most recent response. Dont forget to click submit!")
 
 
181
  rating = st.feedback()
 
182
  feedback = st.text_input("Please detail your feedback: ")
183
+ # implement a method for writing these responses to storage!
 
 
 
 
 
184
  submitted = st.form_submit_button("Submit Feedback")
 
 
 
 
 
 
 
 
 
 
185
 
186
  main = st.container()
187
  with main:
188
  if st.session_state["feedback"][-1] is not None: # TODO clean this up in a fn?
189
+ st.markdown("Thank you! Feedback received! Type a new message to continue your conversation.")
190
  history = st.container(height=400)
191
  with history:
192
  for message in st.session_state["messages"]:
 
196
  with st.chat_message(message["role"], avatar=avatar):
197
  if message["content"] is not None:
198
  st.markdown(message["content"])
 
 
 
199
  if message["error"] is not None:
200
  st.error(message["error"],icon="🚨")
201
  if message["warning"] is not None:
 
204
  if prompt := st.chat_input("Type a message!", max_chars=5000):
205
  handle_user_input(prompt)
206
  st.markdown("\n") #add some space for iphone users
 
207
  gave_feedback = st.button('Give Feedback on Last Response', on_click=feedback)
208
  if gave_feedback: # TODO clean up the conditions here with a function
209
  st.session_state["feedback"].append("given")
210
  else:
211
  st.session_state["feedback"].append(None)
 
 
 
212
 
213
 
214
  with st.sidebar: