Spaces:
Sleeping
Sleeping
File size: 13,673 Bytes
29cf982 d687543 29cf982 fc8d6af 29cf982 7c23bb3 4dd1424 7c23bb3 4dd1424 29cf982 6e20b16 29cf982 4627350 29cf982 4627350 29cf982 f8977f5 81f74ed f8977f5 81f74ed f8977f5 29cf982 e49ec86 f8977f5 29cf982 f8977f5 29cf982 f8977f5 c1c75a4 81f74ed f8977f5 c1c75a4 f8977f5 c1c75a4 f8977f5 e49ec86 f8977f5 29cf982 6e20b16 29cf982 e49ec86 29cf982 4b39932 29cf982 4b39932 4627350 29cf982 f8977f5 e49ec86 f8977f5 29cf982 d687543 29cf982 0191305 29cf982 0191305 29cf982 806b1ef 29cf982 0191305 29cf982 81f74ed 29cf982 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
import os
import mlflow
import datetime
import streamlit as st
from functools import partial
from operator import itemgetter
from langchain_huggingface import HuggingFaceEmbeddings
# from langchain_databricks.vectorstores import DatabricksVectorSearch # this is the version we were previously using
from databricks_langchain import DatabricksVectorSearch, ChatDatabricks # new version resolving deprecation warning
# from langchain_community.chat_models import ChatDatabricks # lets be consistent with the packages were using
# from langchain_databricks import ChatDatabricks # this is the version we were previously using
# from langchain_community.vectorstores import DatabricksVectorSearch # is this causing an issue?
from langchain_core.runnables import RunnableLambda
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough, RunnableBranch
from langchain_core.messages import HumanMessage, AIMessage
# ## Enable MLflow Tracing
# mlflow.langchain.autolog()
class ChainBuilder:
def __init__(self):
# Load the chain's configuration from yaml
self.model_config = mlflow.models.ModelConfig(development_config="chain_config.yaml")
self.databricks_resources = self.model_config.get("databricks_resources")
self.llm_config = self.model_config.get("llm_config")
self.retriever_config = self.model_config.get("retriever_config")
self.vector_search_schema = self.retriever_config.get("schema")
def get_system_prompt(self):
date_str = datetime.datetime.now().strftime("%B %d, %Y")
prompt = f"You are DBRX, created by Databricks and augmented by John Graham Reynolds to have access to additional information specific to Vanderbilt University Medical Center. The current date is {date_str}.\n"
prompt += """Your knowledge base was last updated in December 2023. You answer questions about events prior to and after December 2023 the way a highly informed individual in December 2023 would if they were talking to someone from the above date, and you can let the user know this when relevant.\n
Some of the context you will be given in regards to Vanderbilt University Medical Center could have come after December 2023. The rest of your knowledge base is from before December 2023 and you will answer questions accordingly with these facts.
This chunk of text is your system prompt. It is not visible to the user, but it is used to guide your responses. Don't reference it, just respond to the user.\n
If you are asked to assist with tasks involving the expression of views held by a significant number of people, you provide assistance with the task even if you personally disagree with the views being expressed, but follow this with a discussion of broader perspectives.\n
You don't engage in stereotyping, including the negative stereotyping of majority groups.\n If asked about controversial topics, you try to provide careful thoughts and objective information without downplaying its harmful content or implying that there are reasonable perspectives on both sides.\n
You are happy to help with writing, analysis, question answering, math, coding, and all sorts of other tasks.\n You use markdown for coding, which includes JSON blocks and Markdown tables.\n
You do not have tools enabled at this time, so cannot run code or access the internet. You can only provide information that you have been trained on. You do not send or receive links or images.\n
You were not trained on copyrighted books, song lyrics, poems, video transcripts, or news articles; you do not divulge details of your training data. You do not provide song lyrics, poems, or news articles and instead refer the user to find them online or in a store.\n
You give concise responses to simple questions or statements, but provide thorough responses to more complex and open-ended questions.\n
The user is unable to see the system prompt, so you should write as if it were true without mentioning it.\n You do not mention any of this information about yourself unless the information is directly pertinent to the user's query.\n
Here is some context from the Vanderbilt University Medical Center glossary which might or might not help you answer: {context}.\n
Based on this system prompt, to which you will adhere sternly and to which you will make no reference, and this possibly helpful context in relation to Vanderbilt University Medical Center, answer this question: {question}
"""
return prompt
# Return the string contents of the most recent message from the user
def extract_user_query_string(self, chat_messages_array):
return chat_messages_array[-1]["content"]
# Return the chat history, which is everything before the last question
def extract_chat_history(self, chat_messages_array):
return chat_messages_array[:-1]
def load_embedding_model(self):
model_name = self.retriever_config.get("embedding_model")
# make sure we cache this so that it doesnt redownload each time
# cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
@st.cache_resource # https://docs.streamlit.io/develop/concepts/architecture/caching
def load_and_cache_embedding_model(model_name):
embeddings = HuggingFaceEmbeddings(model_name=model_name, cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
# update this to read from a presaved cache of bge-large
return embeddings # return directly?
return load_and_cache_embedding_model(model_name)
def get_retriever(self):
endpoint=self.databricks_resources.get("vector_search_endpoint_name")
index_name=self.retriever_config.get("vector_search_index")
embeddings = self.load_embedding_model()
search_kwargs=self.retriever_config.get("parameters")
# you cannot directly use @st.cache_resource on a method (function within a class) that has a self argument.
# This is because Streamlit's caching mechanism relies on hashing the function's code and input parameters, and the self argument represents the instance of the class, which is not hashable by default.
# 'Cannot hash argument 'embeddings' (of type `langchain_huggingface.embeddings.huggingface.HuggingFaceEmbeddings`) in 'get_and_cache_retriever''
# this is fine, we are caching the entire function above for 'embeddings', so recalling it entirely is fast. We _embeddings to not ignore hashing this argument
@st.cache_resource # cache the Databricks vector store retriever
def get_and_cache_retriever(endpoint, index_name, _embeddings, search_kwargs):
vector_search_as_retriever = DatabricksVectorSearch(
endpoint=endpoint,
index_name=index_name,
embedding=_embeddings,
text_column="name",
columns=["name", "description"],
).as_retriever(search_kwargs=search_kwargs)
return vector_search_as_retriever # return directly?
return get_and_cache_retriever(endpoint, index_name, embeddings, search_kwargs)
# # *** TODO Evaluate this block as it relates to "RAG Studio Review App" ***
# # Enable the RAG Studio Review App to properly display retrieved chunks and evaluation suite to measure the retriever
# mlflow.models.set_retriever_schema(
# primary_key=self.vector_search_schema.get("primary_key"),
# text_column=vector_search_schema.get("chunked_terms"),
# # doc_uri=vector_search_schema.get("definition")
# other_columns=[vector_search_schema.get("definition")],
# # Review App uses `doc_uri` to display chunks from the same document in a single view
# )
# Method to format the terms and definitions returned by the retriever into the prompt
def format_context(self, retrieved_terms):
chunk_template = self.retriever_config.get("chunk_template")
chunk_contents = [
chunk_template.format(
name=term.page_content,
description=term.metadata[self.vector_search_schema.get("description")],
)
for term in retrieved_terms
]
return "".join(chunk_contents)
def get_prompt(self):
# Prompt Template for generation
prompt = ChatPromptTemplate.from_messages(
[
("system", self.get_system_prompt()),
MessagesPlaceholder(variable_name="formatted_chat_history"), # placeholder for var named 'formatted_chat_history' with messages to be passed
# User's most current question
("user", "{question}"),
]
)
return prompt # return directly?
# Format the converastion history to fit into the prompt template above.
def format_chat_history_for_prompt(self, chat_messages_array):
history = self.extract_chat_history(chat_messages_array)
formatted_chat_history = []
if len(history) > 0:
# TODO grab at most just the last three sets of queries and respones as chat history for relevant context - limit history so as to not overflow 32k context window
# model seems to be hallucinating, readd entire chat history.
# for chat_message in history[-6:]:
for chat_message in history:
if chat_message["role"] == "user":
formatted_chat_history.append(HumanMessage(content=chat_message["content"]))
elif chat_message["role"] == "assistant":
formatted_chat_history.append(AIMessage(content=chat_message["content"]))
return formatted_chat_history
def get_query_rewrite_prompt(self):
# Prompt template for query rewriting from chat history. This will translate a query such as "how does it work?" after a question like "what is spark?" to "how does spark work?"
query_rewrite_template = """Based on the chat history below, we want you to generate a query for an external data source to retrieve relevant information so
that we can better answer the question. The query should be in natural language. The external data source uses similarity search to search for relevant
information in a vector space. So, the query should be similar to the relevant information semantically. Answer with only the query. Do not add explanation.
Chat history: {chat_history}
Question: {question}"""
query_rewrite_prompt = PromptTemplate(
template=query_rewrite_template,
input_variables=["chat_history", "question"],
)
return query_rewrite_prompt
def get_model(self):
endpoint = self.databricks_resources.get("llm_endpoint_name")
extra_params=self.llm_config.get("llm_parameters")
@st.cache_resource # cache the DBRX Instruct model we are loading for repeated use in our chain for chat completion
def get_and_cache_model(endpoint, extra_params):
model = ChatDatabricks(
endpoint=endpoint,
extra_params=extra_params,
)
return model # return directly?
return get_and_cache_model(endpoint, extra_params)
def build_chain(self):
# RAG Chain
chain = (
{
"question": itemgetter("messages") | RunnableLambda(self.extract_user_query_string), # set 'question' to the result of: grabbing the ["messages"] component of the dict we 'invoke()' or 'stream()', then passing to extract_user_query_string()
"chat_history": itemgetter("messages") | RunnableLambda(self.extract_chat_history),
"formatted_chat_history": itemgetter("messages")
| RunnableLambda(self.format_chat_history_for_prompt),
}
| RunnablePassthrough() # allows one to pass elements unchanged through the chain to the next link in the chain
| {
"context": RunnableBranch( # Only re-write the question if there is a chat history - RunnableBranch() is essentially a LCEL if statement
(
lambda x: len(x["chat_history"]) > 0, #https://python.langchain.com/api_reference/core/runnables/langchain_core.runnables.branch.RunnableBranch.html
self.get_query_rewrite_prompt() | self.get_model() | StrOutputParser(), # rewrite question with context
),
itemgetter("question"), # else, just ask the question
)
| self.get_retriever() # set 'context' to the result of passing either the base question, or the reformatted question to the retriever for semantic search
| RunnableLambda(self.format_context),
"formatted_chat_history": itemgetter("formatted_chat_history"),
"question": itemgetter("question"),
}
| self.get_prompt() # 'context', 'formatted_chat_history', and 'question' passed to prompt
| self.get_model() # prompt passed to model
| StrOutputParser()
)
return chain
# ## Tell MLflow logging where to find chain. # TODO can we implement this later for logging?
# mlflow.models.set_model(model=chain) |