File size: 10,838 Bytes
d43b410
f45b463
8bb66b9
 
 
f45b463
a2a0721
d43b410
0e29746
d43b410
 
a2a0721
d43b410
 
0e29746
d43b410
a2a0721
f45b463
 
 
 
 
d43b410
f15be1a
 
f45b463
 
 
 
 
d43b410
a2a0721
0219321
fb9ed3a
 
f45b463
 
1543f99
4abddf8
 
 
f45b463
 
 
 
 
 
 
 
d43b410
f45b463
 
 
 
 
 
d43b410
f45b463
25b1dfe
 
a2a0721
d43b410
4abddf8
6e06674
d43b410
f45b463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bb66b9
 
f45b463
 
6b9eb29
f45b463
 
 
8bb66b9
f45b463
 
 
 
4abddf8
8bb66b9
 
 
 
 
 
 
 
f45b463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2828740
f45b463
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import os
import threading
import streamlit as st
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_databricks.vectorstores import DatabricksVectorSearch
from itertools import tee

DATABRICKS_HOST = os.environ.get("DATABRICKS_HOST")
DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN")
VS_ENDPOINT_NAME = os.environ.get("VS_ENDPOINT_NAME")
VS_INDEX_NAME = os.environ.get("VS_INDEX_NAME")

if DATABRICKS_HOST is None:
    raise ValueError("DATABRICKS_HOST environment variable must be set")
if DATABRICKS_TOKEN is None:
    raise ValueError("DATABRICKS_API_TOKEN environment variable must be set")

MODEL_AVATAR_URL= "./VU.jpeg"

# MSG_MAX_TURNS_EXCEEDED = f"Sorry! The Vanderbilt AI assistant playground is limited to {MAX_CHAT_TURNS} turns. Click the 'Clear Chat' button or refresh the page to start a new conversation."
# MSG_CLIPPED_AT_MAX_OUT_TOKENS = "Reached maximum output tokens for DBRX Playground"

EXAMPLE_PROMPTS = [
    "Tell me about maximum out-of-pocket costs in healthcare.",
    "Write a haiku about Nashville, Tennessee.",
    "How is a data lake used at Vanderbilt University Medical Center?",
    "In a table, what are some of the greatest hurdles to healthcare in the United States?",
    "What does EDW stand for in the context of Vanderbilt University Medical Center?",
    "Code a sql statement that can query a database named 'VUMC'.",
    "Write a short story about a country concert in Nashville, Tennessee.",
]

TITLE = "Vanderbilt AI Assistant"
DESCRIPTION="""Welcome to the first generation Vanderbilt AI assistant! \n 
This AI assistant is built atop the Databricks DBRX large language model 
and is augmented with additional organization-specific knowledge. Specifically, it has been preliminarily augmented with knowledge of Vanderbilt University Medical Center
terms like **Data Lake**, **EDW** (Enterprise Data Warehouse), **HCERA** (Health Care and Education Reconciliation Act), and **thousands more!** The model has **no access to PHI**. 
Try querying the model with any of the example prompts below for a simple introduction to both Vanderbilt-specific and general knowledge queries. The purpose of this 
model is to allow VUMC employees access to an intelligent assistant that improves and expedites VUMC work. \n
Feedback and ideas are very welcome! Please provide any feedback, ideas, or issues to the email: **[email protected]**. 
We hope to gradually improve this AI assistant to create a large-scale, all-inclusive tool to compliment the work of all VUMC staff."""

GENERAL_ERROR_MSG = "An error occurred. Please refresh the page to start a new conversation."

# @st.cache_resource
# def get_global_semaphore():
#     return threading.BoundedSemaphore(QUEUE_SIZE)
# global_semaphore = get_global_semaphore()

st.set_page_config(layout="wide")

# # To prevent streaming to fast, chunk the output into TOKEN_CHUNK_SIZE chunks
TOKEN_CHUNK_SIZE = 1
# if TOKEN_CHUNK_SIZE_ENV is not None:
#     TOKEN_CHUNK_SIZE = int(TOKEN_CHUNK_SIZE_ENV)

st.title(TITLE)
# st.image("sunrise.jpg", caption="Sunrise by the mountains") # add a Vanderbilt related picture to the head of our Space!
st.markdown(DESCRIPTION)
st.markdown("\n")

# use this to format later
with open("./style.css") as css:
    st.markdown( f'<style>{css.read()}</style>' , unsafe_allow_html= True)

if "messages" not in st.session_state:
    st.session_state["messages"] = []

def clear_chat_history():
    st.session_state["messages"] = []

st.button('Clear Chat', on_click=clear_chat_history)

def last_role_is_user():
    return len(st.session_state["messages"]) > 0 and st.session_state["messages"][-1]["role"] == "user"

def get_system_prompt():
    return ""

# ** working logic for querying glossary embeddings
# Same embedding model we used to create embeddings of terms
# make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
# try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
# does this cache to the given folder though? It does appear to populate the folder as expected after being run
@st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching
def load_embedding_model():
    embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-large-en", cache_folder="./langchain_cache/")
    return embeddings

embeddings = load_embedding_model()
# instantiate the vector store for similarity search in our chain
# need to make this a function and decorate it with @st.experimental_memo as above?
# We are only calling this initially when the Space starts. Can we expedite this process for users when opening up this Space?
# @st.cache_data # TODO add this in
vector_store = DatabricksVectorSearch(
    endpoint=VS_ENDPOINT_NAME,
    index_name=VS_INDEX_NAME,
    embedding=embeddings,
    text_column="name",
    columns=["name", "description"],
)

def text_stream(stream):
    for chunk in stream:
        if chunk["content"] is not None:
            yield chunk["content"]

def get_stream_warning_error(stream):
    error = None
    warning = None
    # for chunk in stream:
    #     if chunk["error"] is not None:
    #         error = chunk["error"]
    #     if chunk["warning"] is not None:
    #         warning = chunk["warning"]
    return warning, error

# @retry(wait=wait_random_exponential(min=0.5, max=2), stop=stop_after_attempt(3))
def chat_api_call(history):
    # *** original code for instantiating the DBRX model through the OpenAI client *** skip this and introduce our chain eventually
    # extra_body = {}
    # if SAFETY_FILTER:
    #     extra_body["enable_safety_filter"] = SAFETY_FILTER
    # chat_completion = client.chat.completions.create(
    #     messages=[
    #         {"role": m["role"], "content": m["content"]}
    #         for m in history
    #     ],
    #     model="databricks-dbrx-instruct",
    #     stream=True,
    #     max_tokens=MAX_TOKENS,
    #     temperature=0.7,
    #     extra_body= extra_body
    # )

    # ** TODO update this next to take and do similarity search on user input!
    st.write(history)
    search_result = vector_store.similarity_search(query="Tell me about what a data lake is.", k=5)
    chat_completion = search_result # TODO update this after we implement our chain
    return chat_completion

def write_response():
    stream = chat_completion(st.session_state["messages"])
    content_stream, error_stream = tee(stream)
    response = st.write_stream(text_stream(content_stream))
    stream_warning, stream_error = get_stream_warning_error(error_stream)
    if stream_warning is not None:
        st.warning(stream_warning,icon="⚠️")
    if stream_error is not None:
        st.error(stream_error,icon="🚨")
    # if there was an error, a list will be returned instead of a string: https://docs.streamlit.io/library/api-reference/write-magic/st.write_stream
    if isinstance(response, list):
        response = None 
    return response, stream_warning, stream_error

def chat_completion(messages):
    history_dbrx_format = [
        {"role": "system", "content": get_system_prompt()}
    ]
        
    history_dbrx_format = history_dbrx_format + messages
    # if (len(history_dbrx_format)-1)//2 >= MAX_CHAT_TURNS:
    #     yield {"content": None, "error": MSG_MAX_TURNS_EXCEEDED, "warning": None}
    #     return

    chat_completion = None
    error = None
    # *** original code for querying DBRX through the OpenAI cleint for chat completion
    # wait to be in queue
    # with global_semaphore:
    #     try: 
    #         chat_completion = chat_api_call(history_dbrx_format)
    #     except Exception as e:
    #         error = e    
    chat_completion = chat_api_call(history_dbrx_format)
    if error is not None:
        yield {"content": None, "error": GENERAL_ERROR_MSG, "warning": None}
        print(error)
        return
    
    max_token_warning = None
    partial_message = ""
    chunk_counter = 0
    for chunk in chat_completion:
        # if chunk.choices[0].delta.content is not None:
        if chunk.page_content is not None:
            chunk_counter += 1
            # partial_message += chunk.choices[0].delta.content
            partial_message += f"* {chunk.page_content} [{chunk.metadata}]"
            if chunk_counter % TOKEN_CHUNK_SIZE == 0:
                chunk_counter = 0
                yield {"content": partial_message, "error": None, "warning": None}
                partial_message = ""
        # if chunk.choices[0].finish_reason == "length":
        #     max_token_warning = MSG_CLIPPED_AT_MAX_OUT_TOKENS

    yield {"content": partial_message, "error": None, "warning": max_token_warning}

# if assistant is the last message, we need to prompt the user
# if user is the last message, we need to retry the assistant.
def handle_user_input(user_input):
    with history:
        response, stream_warning, stream_error = [None, None, None]
        if last_role_is_user():
            # retry the assistant if the user tries to send a new message
            with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
                response, stream_warning, stream_error = write_response()
        else:
            st.session_state["messages"].append({"role": "user", "content": user_input, "warning": None, "error": None})
            with st.chat_message("user"):
                st.markdown(user_input)
            stream = chat_completion(st.session_state["messages"])
            with st.chat_message("assistant", avatar=MODEL_AVATAR_URL):
                response, stream_warning, stream_error = write_response()
        
        st.session_state["messages"].append({"role": "assistant", "content": response, "warning": stream_warning, "error": stream_error})

main = st.container()
with main:
    history = st.container(height=400)
    with history:
        for message in st.session_state["messages"]:
            avatar = "πŸ§‘β€πŸ’»"
            if message["role"] == "assistant":
                avatar = MODEL_AVATAR_URL
            with st.chat_message(message["role"],avatar=avatar):
                if message["content"] is not None:
                    st.markdown(message["content"])
                # if message["error"] is not None:
                #     st.error(message["error"],icon="🚨")
                # if message["warning"] is not None:
                #     st.warning(message["warning"],icon="⚠️")

    if prompt := st.chat_input("Type a message!", max_chars=1000):
        handle_user_input(prompt)
    st.markdown("\n") #add some space for iphone users

with st.sidebar:
    with st.container():
        st.title("Examples")
        for prompt in EXAMPLE_PROMPTS:
            st.button(prompt, args=(prompt,), on_click=handle_user_input)