Spaces:
Sleeping
Sleeping
File size: 10,866 Bytes
d43b410 f45b463 8bb66b9 f45b463 31a5031 a2a0721 d43b410 0e29746 31a5031 a2a0721 d43b410 0e29746 31a5031 a2a0721 f45b463 312a2f7 676f0fa f45b463 d43b410 f45b463 31a5031 d43b410 a2a0721 0219321 729681a 7100936 31a5031 7100936 312a2f7 7100936 2946731 7100936 f45b463 676f0fa f4cddb7 676f0fa f45b463 d43b410 f45b463 d43b410 676f0fa 25b1dfe a2a0721 d43b410 4abddf8 6e06674 d43b410 f45b463 110e7e0 f45b463 31a5031 a1495e2 31a5031 f45b463 4264c5b f45b463 676f0fa ee63f07 7798405 f45b463 676f0fa f45b463 676f0fa f45b463 31a5031 676f0fa f45b463 31a5031 f45b463 31a5031 f45b463 8958054 f45b463 6467922 52c3d27 d8f8caa 676f0fa d8f8caa 676f0fa d8f8caa 6467922 f45b463 70e033b 676f0fa f45b463 3c6915f f45b463 8958054 f45b463 4264c5b f45b463 af3c6e2 f45b463 2ee18f4 70e033b 2ee18f4 f45b463 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import os
import threading
import streamlit as st
from itertools import tee
from chain import ChainBuilder
DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST")
DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN")
# remove these secrets from the container
# VS_ENDPOINT_NAME = os.environ.get("VS_ENDPOINT_NAME")
# VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME")
if DATABRICKS_HOST is None:
raise ValueError("DATABRICKS_HOST environment variable must be set")
if DATABRICKS_TOKEN is None:
raise ValueError("DATABRICKS_TOKEN environment variable must be set")
MODEL_AVATAR_URL= "./VU.jpeg"
MAX_CHAT_TURNS = 10 # limit this for preliminary testing
MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns in a single history. Click the 'Clear Chat' button or refresh the page to start a new conversation."
# MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"
EXAMPLE_PROMPTS = [
"How is a data lake used at Vanderbilt University Medical Center?",
"In a table, what are some of the greatest hurdles to healthcare in the United States?",
"What does EDW stand for in the context of Vanderbilt University Medical Center?",
"Code a sql statement that can query a database named 'VUMC'.",
"Write a short story about a country concert in Nashville, Tennessee.",
"Tell me about maximum out-of-pocket costs in healthcare.",
]
TITLE = "Vanderbilt AI Assistant"
DESCRIPTION= """Welcome to the first generation Vanderbilt AI assistant! \n
**WARNING**: Unfortunately this space is currently deprecated. The serving endpoint used to serve the pay-per-token Databricks DBRX language model has been rate-limited by
staff for security reasons to accept no queries. I am in the process of reworking this augmented model to hit an available endpoint for community use. Nonetheless, if you are interested
in seeing this model's functionality in a 24 hour time window, send me an email at `[email protected]`, or the email below, and I will temporarily activate the serving endpoint
for you to query the model. \n
**Overview and Usage**: This AI assistant is built atop the Databricks DBRX large language model
and is augmented with additional organization-specific knowledge. Particularly, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center
terms like **EDW**, **HCERA**, **NRHA** and **thousands more**. (Ask the assistant if you don't know what any of these terms mean!) On the left is a sidebar of **Examples**;
click any of these examples to issue the corresponding query to the AI.
**Feedback**: Feedback is welcomed, encouraged, and invaluable! To give feedback in regards to one of the model's responses, click the **Give Feedback on Last Response** button just below
the user input bar. This allows you to provide either positive or negative feedback in regards to the model's most recent response. A **Feedback Form** will appear above the model's title.
Please be sure to select either π or π before adding additional notes about your choice. Be as brief or as detailed as you like! Note that you are making a difference; this
feedback allows us to later improve this model for your usage through a training technique known as reinforcement learning through human feedback. \n
**Disclaimer**: The model has **no access to PHI**. \n
Please provide any additional, larger feedback, ideas, or issues to the email: **[email protected]**. Happy chatting!"""
GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."
# # To prevent streaming too fast, chunk the output into TOKEN_CHUNK_SIZE chunks
TOKEN_CHUNK_SIZE = 1 # test this number
# if TOKEN_CHUNK_SIZE_ENV is not None:
# TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)
QUEUE_SIZE = 20 # maximize this value for adding enough places in the global queue?
# if QUEUE_SIZE_ENV is not None:
# QUEUE_SIZE = int(QUEUE_SIZE_ENV)
# @st.cache_resource
# def get_global_semaphore():
# return threading.BoundedSemaphore(QUEUE_SIZE)
# global_semaphore = get_global_semaphore()
st.set_page_config(layout="wide")
st.title(TITLE)
# st.image("sunrise.jpg", caption="Sunrise by the mountains") # TODO add a Vanderbilt related picture to the head of our Space!
st.markdown(DESCRIPTION)
st.markdown("\n")
# use this to format later
with open("./style.css") as css:
st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)
if "messages" not in st.session_state:
st.session_state["messages"] = []
if "feedback" not in st.session_state:
st.session_state["feedback"] = [None]
def clear_chat_history():
st.session_state["messages"] = []
st.button('Clear Chat', on_click=clear_chat_history)
# build our chain outside the working body so that its only instantiated once - simply pass it the chat history for chat completion
chain = ChainBuilder().build_chain()
def last_role_is_user():
return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user"
def text_stream(stream):
for chunk in stream:
if chunk["content"] is not None:
yield chunk["content"]
def get_stream_warning_error(stream):
error = None
warning = None
for chunk in stream:
if chunk["error"] is not None:
error = chunk["error"]
if chunk["warning"] is not None:
warning = chunk["warning"]
return warning, error
# @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
def chain_call(history):
input = {'messages': [{"role": m["role"], "content": m["content"]} for m in history]}
chat_completion = chain.stream(input)
return chat_completion
def write_response():
stream = chat_completion(st.session_state["messages"])
content_stream, error_stream = tee(stream)
response = st.write_stream(text_stream(content_stream))
stream_warning, stream_error = get_stream_warning_error(error_stream)
if stream_warning is not None:
st.warning(stream_warning,icon="β οΈ")
if stream_error is not None:
st.error(stream_error,icon="π¨")
# if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream
if isinstance(response, list):
response = None
return response, stream_warning, stream_error
def chat_completion(messages):
if (len(messages)-1)//2 >= MAX_CHAT_TURNS:
yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
return
chat_completion = None
error = None
# *** TODO add code for implementing a global queue with a bounded semaphore?
# wait to be in queue
# with global_semaphore:
# try:
# chat_completion = chat_api_call(history_dbrx_format)
# except Exception as e:
# error = e
# chat_completion = chain_call(history_dbrx_format)
chat_completion = chain_call(messages)
if error is not None:
yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
print(error)
return
max_token_warning = None
partial_message = ""
chunk_counter = 0
for chunk in chat_completion:
if chunk is not None:
chunk_counter += 1
partial_message += chunk
if chunk_counter % TOKEN_CHUNK_SIZE == 0:
chunk_counter = 0
yield {"content": partial_message, "error": None, "warning": None}
partial_message = ""
# if chunk.choices[0].finish_reason == "length":
# max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS
yield {"content": partial_message, "error": None, "warning": max_token_warning}
# if assistant is the last message, we need to prompt the user
# if user is the last message, we need to retry the assistant.
def handle_user_input(user_input):
with history:
response, stream_warning, stream_error = [None, None, None]
if last_role_is_user():
# retry the assistant if the user tries to send a new message
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
response, stream_warning, stream_error = write_response()
else:
st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None})
with st.chat_message("user", avatar="π§βπ»"):
st.markdown(user_input)
stream = chat_completion(st.session_state["messages"])
with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
response, stream_warning, stream_error = write_response()
st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error})
def feedback():
with st.form("feedback_form"):
st.title("Feedback Form")
st.markdown("Please select either π or π before providing a reason for your review of the most recent response. Dont forget to click submit!")
rating = st.feedback()
feedback = st.text_input("Please detail your feedback: ")
# implement a method for writing these responses to storage!
submitted = st.form_submit_button("Submit Feedback")
main = st.container()
with main:
if st.session_state["feedback"][-1] is not None: # TODO clean this up in a fn?
st.markdown("Thank you! Feedback received! Type a new message to continue your conversation.")
history = st.container(height=400)
with history:
for message in st.session_state["messages"]:
avatar = "π§βπ»"
if message["role"] == "assistant":
avatar = MODEL_AVATAR_URL
with st.chat_message(message["role"], avatar=avatar):
if message["content"] is not None:
st.markdown(message["content"])
if message["error"] is not None:
st.error(message["error"],icon="π¨")
if message["warning"] is not None:
st.warning(message["warning"],icon="β οΈ")
if prompt := st.chat_input("Type a message!", max_chars=5000):
handle_user_input(prompt)
st.markdown("\n") #add some space for iphone users
gave_feedback = st.button('Give Feedback on Last Response', on_click=feedback)
if gave_feedback: # TODO clean up the conditions here with a function
st.session_state["feedback"].append("given")
else:
st.session_state["feedback"].append(None)
with st.sidebar:
with st.container():
st.title("Examples")
for prompt in EXAMPLE_PROMPTS:
st.button(prompt, args=(prompt,), on_click=handle_user_input) |