File size: 10,866 Bytes
d43b410
f45b463
8bb66b9
f45b463
31a5031
a2a0721
d43b410
0e29746
31a5031
 
 
a2a0721
d43b410
 
0e29746
31a5031
a2a0721
f45b463
312a2f7
676f0fa
f45b463
 
d43b410
f45b463
 
 
 
 
31a5031
d43b410
a2a0721
0219321
729681a
 
 
 
 
 
 
7100936
31a5031
7100936
 
 
 
312a2f7
7100936
2946731
7100936
 
 
 
f45b463
 
 
676f0fa
f4cddb7
676f0fa
 
 
 
 
 
 
f45b463
 
 
 
 
d43b410
f45b463
d43b410
676f0fa
25b1dfe
 
a2a0721
d43b410
4abddf8
6e06674
d43b410
f45b463
 
 
110e7e0
 
 
f45b463
 
 
 
 
31a5031
a1495e2
31a5031
f45b463
 
 
 
 
 
 
 
 
 
 
4264c5b
 
 
 
 
f45b463
 
 
676f0fa
ee63f07
7798405
f45b463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676f0fa
 
 
f45b463
 
 
676f0fa
f45b463
 
 
 
 
 
31a5031
676f0fa
f45b463
 
 
 
 
 
 
 
 
31a5031
f45b463
31a5031
f45b463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8958054
f45b463
 
 
 
 
 
 
6467922
52c3d27
d8f8caa
676f0fa
d8f8caa
 
676f0fa
d8f8caa
6467922
f45b463
 
70e033b
676f0fa
f45b463
 
 
3c6915f
f45b463
 
8958054
f45b463
 
4264c5b
 
 
 
f45b463
af3c6e2
f45b463
 
2ee18f4
70e033b
 
 
 
2ee18f4
f45b463
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os
import threading
import streamlit as st
from itertools import tee
from chain import ChainBuilder

DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST")
DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN")
# remove these secrets from the container
# VS_ENDPOINT_NAME = os.environ.get("VS_ENDPOINT_NAME")
# VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME")

if DATABRICKS_HOST is None:
    raise ValueError("DATABRICKS_HOST environment variable must be set")
if DATABRICKS_TOKEN is None:
    raise ValueError("DATABRICKS_TOKEN environment variable must be set")

MODEL_AVATAR_URL= "./VU.jpeg"
MAX_CHAT_TURNS = 10 # limit this for preliminary testing
MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns in a single history. Click the 'Clear Chat' button or refresh the page to start a new conversation."
# MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"

EXAMPLE_PROMPTS = [
    "How is a data lake used at Vanderbilt University Medical Center?",
    "In a table, what are some of the greatest hurdles to healthcare in the United States?",
    "What does EDW stand for in the context of Vanderbilt University Medical Center?",
    "Code a sql statement that can query a database named 'VUMC'.",
    "Write a short story about a country concert in Nashville, Tennessee.",
    "Tell me about maximum out-of-pocket costs in healthcare.",
]

TITLE = "Vanderbilt AI Assistant"
DESCRIPTION= """Welcome to the first generation Vanderbilt AI assistant! \n

**WARNING**: Unfortunately this space is currently deprecated. The serving endpoint used to serve the pay-per-token Databricks DBRX language model has been rate-limited by 
staff for security reasons to accept no queries. I am in the process of reworking this augmented model to hit an available endpoint for community use. Nonetheless, if you are interested 
in seeing this model's functionality in a 24 hour time window, send me an email at `[email protected]`, or the email below, and I will temporarily activate the serving endpoint 
for you to query the model. \n

**Overview and Usage**: This AI assistant is built atop the Databricks DBRX large language model 
and is augmented with additional organization-specific knowledge. Particularly, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center
terms like **EDW**, **HCERA**, **NRHA** and **thousands more**. (Ask the assistant if you don't know what any of these terms mean!) On the left is a sidebar of **Examples**; 
click any of these examples to issue the corresponding query to the AI.

**Feedback**: Feedback is welcomed, encouraged, and invaluable! To give feedback in regards to one of the model's responses, click the **Give Feedback on Last Response** button just below
the user input bar. This allows you to provide either positive or negative feedback in regards to the model's most recent response. A **Feedback Form** will appear above the model's title. 
Please be sure to select either πŸ‘ or πŸ‘Ž before adding additional notes about your choice. Be as brief or as detailed as you like! Note that you are making a difference; this 
feedback allows us to later improve this model for your usage through a training technique known as reinforcement learning through human feedback. \n

**Disclaimer**: The model has **no access to PHI**. \n

Please provide any additional, larger feedback, ideas, or issues to the email: **[email protected]**. Happy chatting!"""

GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."

# # To prevent streaming too fast, chunk the output into TOKEN_CHUNK_SIZE chunks
TOKEN_CHUNK_SIZE = 1 # test this number
# if TOKEN_CHUNK_SIZE_ENV is not None:
#     TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)

QUEUE_SIZE = 20 # maximize this value for adding enough places in the global queue?
# if QUEUE_SIZE_ENV is not None:
#     QUEUE_SIZE = int(QUEUE_SIZE_ENV)

# @st.cache_resource
# def get_global_semaphore():
#     return threading.BoundedSemaphore(QUEUE_SIZE)
# global_semaphore = get_global_semaphore()

st.set_page_config(layout="wide")

st.title(TITLE)
# st.image("sunrise.jpg", caption="Sunrise by the mountains") # TODO add a Vanderbilt related picture to the head of our Space!
st.markdown(DESCRIPTION)
st.markdown("\n")

# use this to format later
with open("./style.css") as css:
    st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)

if "messages" not in st.session_state:
    st.session_state["messages"] = []

if "feedback" not in st.session_state:
    st.session_state["feedback"] = [None]

def clear_chat_history():
    st.session_state["messages"] = []

st.button('Clear Chat', on_click=clear_chat_history)

# build our chain outside the working body so that its only instantiated once - simply pass it the chat history for chat completion
chain = ChainBuilder().build_chain()

def last_role_is_user():
    return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user"

def text_stream(stream):
    for chunk in stream:
        if chunk["content"] is not None:
            yield chunk["content"]

def get_stream_warning_error(stream):
    error = None
    warning = None
    for chunk in stream:
        if chunk["error"] is not None:
            error = chunk["error"]
        if chunk["warning"] is not None:
            warning = chunk["warning"]
    return warning, error

# @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
def chain_call(history):    
    input = {'messages': [{"role": m["role"], "content": m["content"]} for m in history]}
    chat_completion = chain.stream(input)
    return chat_completion

def write_response():
    stream = chat_completion(st.session_state["messages"])
    content_stream, error_stream = tee(stream)
    response = st.write_stream(text_stream(content_stream))
    stream_warning, stream_error = get_stream_warning_error(error_stream)
    if stream_warning is not None:
        st.warning(stream_warning,icon="⚠️")
    if stream_error is not None:
        st.error(stream_error,icon="🚨")
    # if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream
    if isinstance(response, list):
        response = None 
    return response, stream_warning, stream_error

def chat_completion(messages):
    if (len(messages)-1)//2 >= MAX_CHAT_TURNS:
        yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
        return

    chat_completion = None
    error = None
    # *** TODO add code for implementing a global queue with a bounded semaphore?  
    # wait to be in queue
    # with global_semaphore:
    #     try: 
    #         chat_completion = chat_api_call(history_dbrx_format)
    #     except Exception as e:
    #         error = e    
    # chat_completion = chain_call(history_dbrx_format)
    chat_completion = chain_call(messages)
    if error is not None:
        yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
        print(error)
        return
    
    max_token_warning = None
    partial_message = ""
    chunk_counter = 0
    for chunk in chat_completion:
        if chunk is not None:
            chunk_counter += 1
            partial_message += chunk
            if chunk_counter % TOKEN_CHUNK_SIZE == 0:
                chunk_counter = 0
                yield {"content": partial_message, "error": None, "warning": None}
                partial_message = ""
        # if chunk.choices[0].finish_reason == "length":
        #     max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS

    yield {"content": partial_message, "error": None, "warning": max_token_warning}

# if assistant is the last message, we need to prompt the user
# if user is the last message, we need to retry the assistant.
def handle_user_input(user_input):
    with history:
        response, stream_warning, stream_error = [None, None, None]
        if last_role_is_user():
            # retry the assistant if the user tries to send a new message
            with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
                response, stream_warning, stream_error = write_response()
        else:
            st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None})
            with st.chat_message("user", avatar="πŸ§‘β€πŸ’»"):
                st.markdown(user_input)
            stream = chat_completion(st.session_state["messages"])
            with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
                response, stream_warning, stream_error = write_response()
        
        st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error})

def feedback():
    with st.form("feedback_form"):
        st.title("Feedback Form")
        st.markdown("Please select either πŸ‘ or πŸ‘Ž before providing a reason for your review of the most recent response. Dont forget to click submit!")
        rating = st.feedback()
        feedback = st.text_input("Please detail your feedback: ")
        # implement a method for writing these responses to storage!
        submitted = st.form_submit_button("Submit Feedback")

main = st.container()
with main:
    if st.session_state["feedback"][-1] is not None: # TODO clean this up in a fn?
            st.markdown("Thank you! Feedback received! Type a new message to continue your conversation.")
    history = st.container(height=400)
    with history:
        for message in st.session_state["messages"]:
            avatar = "πŸ§‘β€πŸ’»"
            if message["role"] == "assistant":
                avatar = MODEL_AVATAR_URL
            with st.chat_message(message["role"], avatar=avatar):
                if message["content"] is not None:
                    st.markdown(message["content"])
                if message["error"] is not None:
                    st.error(message["error"],icon="🚨")
                if message["warning"] is not None:
                    st.warning(message["warning"],icon="⚠️")

    if prompt := st.chat_input("Type a message!", max_chars=5000):
        handle_user_input(prompt)
    st.markdown("\n") #add some space for iphone users
    gave_feedback = st.button('Give Feedback on Last Response', on_click=feedback)
    if gave_feedback: # TODO clean up the conditions here with a function
        st.session_state["feedback"].append("given") 
    else: 
        st.session_state["feedback"].append(None)
        

with st.sidebar:
    with st.container():
        st.title("Examples")
        for prompt in EXAMPLE_PROMPTS:
            st.button(prompt, args=(prompt,), on_click=handle_user_input)