John Graham Reynolds
commited on
Commit
·
81f74ed
1
Parent(s):
b5c44b5
clean up comments and limit chat history
Browse files
chain.py
CHANGED
@@ -57,11 +57,9 @@ class ChainBuilder:
|
|
57 |
def load_embedding_model(self):
|
58 |
model_name = self.retriever_config.get("embedding_model")
|
59 |
|
60 |
-
# make sure we cache this so that it doesnt redownload each time
|
61 |
-
# try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
|
62 |
# cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
|
63 |
-
#
|
64 |
-
@st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching
|
65 |
def load_and_cache_embedding_model(model_name):
|
66 |
embeddings = HuggingFaceEmbeddings(model_name=model_name, cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
|
67 |
# update this to read from a presaved cache of bge-large
|
@@ -78,7 +76,7 @@ class ChainBuilder:
|
|
78 |
# you cannot directly use @st.cache_resource on a method (function within a class) that has a self argument.
|
79 |
# This is because Streamlit's caching mechanism relies on hashing the function's code and input parameters, and the self argument represents the instance of the class, which is not hashable by default.
|
80 |
# 'Cannot hash argument 'embeddings' (of type `langchain_huggingface.embeddings.huggingface.HuggingFaceEmbeddings`) in 'get_and_cache_retriever''
|
81 |
-
# this is fine, we are caching the entire function above for embeddings, so recalling it entirely is fast. We _embeddings to not ignore hashing this argument
|
82 |
@st.cache_resource # cache the Databricks vector store retriever
|
83 |
def get_and_cache_retriever(endpoint, index_name, _embeddings, search_kwargs):
|
84 |
vector_search_as_retriever = DatabricksVectorSearch(
|
@@ -120,8 +118,6 @@ class ChainBuilder:
|
|
120 |
prompt = ChatPromptTemplate.from_messages(
|
121 |
[
|
122 |
("system", self.get_system_prompt()),
|
123 |
-
# *** Note: This chain does not compress the history, so very long converastions can overflow the context window. TODO
|
124 |
-
# We need to at some point chop this history down to fixed amount of recent messages
|
125 |
MessagesPlaceholder(variable_name="formatted_chat_history"), # placeholder for var named 'formatted_chat_history' with messages to be passed
|
126 |
# User's most current question
|
127 |
("user", "{question}"),
|
@@ -130,12 +126,12 @@ class ChainBuilder:
|
|
130 |
return prompt # return directly?
|
131 |
|
132 |
# Format the converastion history to fit into the prompt template above.
|
133 |
-
# **** TODO after only a few statements this will likely overflow the context window
|
134 |
def format_chat_history_for_prompt(self, chat_messages_array):
|
135 |
history = self.extract_chat_history(chat_messages_array)
|
136 |
formatted_chat_history = []
|
137 |
if len(history) > 0:
|
138 |
-
for
|
|
|
139 |
if chat_message["role"] == "user":
|
140 |
formatted_chat_history.append(HumanMessage(content=chat_message["content"]))
|
141 |
elif chat_message["role"] == "assistant":
|
@@ -201,5 +197,5 @@ class ChainBuilder:
|
|
201 |
)
|
202 |
return chain
|
203 |
|
204 |
-
# ## Tell MLflow logging where to find
|
205 |
# mlflow.models.set_model(model=chain)
|
|
|
57 |
def load_embedding_model(self):
|
58 |
model_name = self.retriever_config.get("embedding_model")
|
59 |
|
60 |
+
# make sure we cache this so that it doesnt redownload each time
|
|
|
61 |
# cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
|
62 |
+
@st.cache_resource # https://docs.streamlit.io/develop/concepts/architecture/caching
|
|
|
63 |
def load_and_cache_embedding_model(model_name):
|
64 |
embeddings = HuggingFaceEmbeddings(model_name=model_name, cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
|
65 |
# update this to read from a presaved cache of bge-large
|
|
|
76 |
# you cannot directly use @st.cache_resource on a method (function within a class) that has a self argument.
|
77 |
# This is because Streamlit's caching mechanism relies on hashing the function's code and input parameters, and the self argument represents the instance of the class, which is not hashable by default.
|
78 |
# 'Cannot hash argument 'embeddings' (of type `langchain_huggingface.embeddings.huggingface.HuggingFaceEmbeddings`) in 'get_and_cache_retriever''
|
79 |
+
# this is fine, we are caching the entire function above for 'embeddings', so recalling it entirely is fast. We _embeddings to not ignore hashing this argument
|
80 |
@st.cache_resource # cache the Databricks vector store retriever
|
81 |
def get_and_cache_retriever(endpoint, index_name, _embeddings, search_kwargs):
|
82 |
vector_search_as_retriever = DatabricksVectorSearch(
|
|
|
118 |
prompt = ChatPromptTemplate.from_messages(
|
119 |
[
|
120 |
("system", self.get_system_prompt()),
|
|
|
|
|
121 |
MessagesPlaceholder(variable_name="formatted_chat_history"), # placeholder for var named 'formatted_chat_history' with messages to be passed
|
122 |
# User's most current question
|
123 |
("user", "{question}"),
|
|
|
126 |
return prompt # return directly?
|
127 |
|
128 |
# Format the converastion history to fit into the prompt template above.
|
|
|
129 |
def format_chat_history_for_prompt(self, chat_messages_array):
|
130 |
history = self.extract_chat_history(chat_messages_array)
|
131 |
formatted_chat_history = []
|
132 |
if len(history) > 0:
|
133 |
+
# grab at most just the last three sets of queries and respones as chat history for relevant context - limit history so as to not overflow 32k context window
|
134 |
+
for chat_message in history[-6:]:
|
135 |
if chat_message["role"] == "user":
|
136 |
formatted_chat_history.append(HumanMessage(content=chat_message["content"]))
|
137 |
elif chat_message["role"] == "assistant":
|
|
|
197 |
)
|
198 |
return chain
|
199 |
|
200 |
+
# ## Tell MLflow logging where to find chain. # TODO can we implement this later for logging?
|
201 |
# mlflow.models.set_model(model=chain)
|