John Graham Reynolds commited on
Commit
81f74ed
·
1 Parent(s): b5c44b5

clean up comments and limit chat history

Browse files
Files changed (1) hide show
  1. chain.py +6 -10
chain.py CHANGED
@@ -57,11 +57,9 @@ class ChainBuilder:
57
  def load_embedding_model(self):
58
  model_name = self.retriever_config.get("embedding_model")
59
 
60
- # make sure we cache this so that it doesnt redownload each time, hindering Space start time if sleeping
61
- # try adding this st caching decorator to ensure the embeddings class gets cached after downloading the entirety of the model
62
  # cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
63
- # does this cache to the given folder though? It does appear to populate the folder as expected after being run
64
- @st.cache_resource # will this work here? https://docs.streamlit.io/develop/concepts/architecture/caching
65
  def load_and_cache_embedding_model(model_name):
66
  embeddings = HuggingFaceEmbeddings(model_name=model_name, cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
67
  # update this to read from a presaved cache of bge-large
@@ -78,7 +76,7 @@ class ChainBuilder:
78
  # you cannot directly use @st.cache_resource on a method (function within a class) that has a self argument.
79
  # This is because Streamlit's caching mechanism relies on hashing the function's code and input parameters, and the self argument represents the instance of the class, which is not hashable by default.
80
  # 'Cannot hash argument 'embeddings' (of type `langchain_huggingface.embeddings.huggingface.HuggingFaceEmbeddings`) in 'get_and_cache_retriever''
81
- # this is fine, we are caching the entire function above for embeddings, so recalling it entirely is fast. We _embeddings to not ignore hashing this argument
82
  @st.cache_resource # cache the Databricks vector store retriever
83
  def get_and_cache_retriever(endpoint, index_name, _embeddings, search_kwargs):
84
  vector_search_as_retriever = DatabricksVectorSearch(
@@ -120,8 +118,6 @@ class ChainBuilder:
120
  prompt = ChatPromptTemplate.from_messages(
121
  [
122
  ("system", self.get_system_prompt()),
123
- # *** Note: This chain does not compress the history, so very long converastions can overflow the context window. TODO
124
- # We need to at some point chop this history down to fixed amount of recent messages
125
  MessagesPlaceholder(variable_name="formatted_chat_history"), # placeholder for var named 'formatted_chat_history' with messages to be passed
126
  # User's most current question
127
  ("user", "{question}"),
@@ -130,12 +126,12 @@ class ChainBuilder:
130
  return prompt # return directly?
131
 
132
  # Format the converastion history to fit into the prompt template above.
133
- # **** TODO after only a few statements this will likely overflow the context window
134
  def format_chat_history_for_prompt(self, chat_messages_array):
135
  history = self.extract_chat_history(chat_messages_array)
136
  formatted_chat_history = []
137
  if len(history) > 0:
138
- for chat_message in history:
 
139
  if chat_message["role"] == "user":
140
  formatted_chat_history.append(HumanMessage(content=chat_message["content"]))
141
  elif chat_message["role"] == "assistant":
@@ -201,5 +197,5 @@ class ChainBuilder:
201
  )
202
  return chain
203
 
204
- # ## Tell MLflow logging where to find your chain.
205
  # mlflow.models.set_model(model=chain)
 
57
  def load_embedding_model(self):
58
  model_name = self.retriever_config.get("embedding_model")
59
 
60
+ # make sure we cache this so that it doesnt redownload each time
 
61
  # cannot directly use @st.cache_resource on a method (function within a class) that has a self argument
62
+ @st.cache_resource # https://docs.streamlit.io/develop/concepts/architecture/caching
 
63
  def load_and_cache_embedding_model(model_name):
64
  embeddings = HuggingFaceEmbeddings(model_name=model_name, cache_folder="./langchain_cache/") # this cache isnt working because were in the Docker container
65
  # update this to read from a presaved cache of bge-large
 
76
  # you cannot directly use @st.cache_resource on a method (function within a class) that has a self argument.
77
  # This is because Streamlit's caching mechanism relies on hashing the function's code and input parameters, and the self argument represents the instance of the class, which is not hashable by default.
78
  # 'Cannot hash argument 'embeddings' (of type `langchain_huggingface.embeddings.huggingface.HuggingFaceEmbeddings`) in 'get_and_cache_retriever''
79
+ # this is fine, we are caching the entire function above for 'embeddings', so recalling it entirely is fast. We _embeddings to not ignore hashing this argument
80
  @st.cache_resource # cache the Databricks vector store retriever
81
  def get_and_cache_retriever(endpoint, index_name, _embeddings, search_kwargs):
82
  vector_search_as_retriever = DatabricksVectorSearch(
 
118
  prompt = ChatPromptTemplate.from_messages(
119
  [
120
  ("system", self.get_system_prompt()),
 
 
121
  MessagesPlaceholder(variable_name="formatted_chat_history"), # placeholder for var named 'formatted_chat_history' with messages to be passed
122
  # User's most current question
123
  ("user", "{question}"),
 
126
  return prompt # return directly?
127
 
128
  # Format the converastion history to fit into the prompt template above.
 
129
  def format_chat_history_for_prompt(self, chat_messages_array):
130
  history = self.extract_chat_history(chat_messages_array)
131
  formatted_chat_history = []
132
  if len(history) > 0:
133
+ # grab at most just the last three sets of queries and respones as chat history for relevant context - limit history so as to not overflow 32k context window
134
+ for chat_message in history[-6:]:
135
  if chat_message["role"] == "user":
136
  formatted_chat_history.append(HumanMessage(content=chat_message["content"]))
137
  elif chat_message["role"] == "assistant":
 
197
  )
198
  return chain
199
 
200
+ # ## Tell MLflow logging where to find chain. # TODO can we implement this later for logging?
201
  # mlflow.models.set_model(model=chain)