jedick commited on
Commit
b7a3bb3
·
1 Parent(s): 9ac80a4

Convert consecutive ToolMessages to AIMessage

Browse files
Files changed (4) hide show
  1. app.py +5 -3
  2. graph.py +32 -11
  3. main.py +2 -3
  4. requirements.txt +4 -2
app.py CHANGED
@@ -26,7 +26,9 @@ if torch.cuda.is_available():
26
  ckpt_dir = snapshot_download(model_id, local_dir_use_symlinks=False)
27
  print(f"Using checkpoints from {ckpt_dir}")
28
  print(f"Downloading checkpoints for {embedding_model_id}...")
29
- embedding_ckpt_dir = snapshot_download(embedding_model_id, local_dir_use_symlinks=False)
 
 
30
  print(f"Using embedding checkpoints from {embedding_ckpt_dir}")
31
  else:
32
  ckpt_dir = None
@@ -173,7 +175,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
173
  count += 1
174
  # Get the retrieved emails as a list
175
  email_list = message.content.replace(
176
- "### Retrieved Emails:\n\n\n\n", ""
177
  ).split("--- --- --- --- Next Email --- --- --- ---\n\n")
178
  # Get the list of source files (e.g. R-help/2024-December.txt) for retrieved emails
179
  month_list = [text.splitlines()[0] for text in email_list]
@@ -196,7 +198,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
196
  # Format the retrieved emails with Tool Call heading
197
  retrieved_emails.append(
198
  message.content.replace(
199
- "### Retrieved Emails:\n\n\n\n",
200
  f"### ### ### ### Tool Call {count} ### ### ### ###\n\n",
201
  )
202
  )
 
26
  ckpt_dir = snapshot_download(model_id, local_dir_use_symlinks=False)
27
  print(f"Using checkpoints from {ckpt_dir}")
28
  print(f"Downloading checkpoints for {embedding_model_id}...")
29
+ embedding_ckpt_dir = snapshot_download(
30
+ embedding_model_id, local_dir_use_symlinks=False
31
+ )
32
  print(f"Using embedding checkpoints from {embedding_ckpt_dir}")
33
  else:
34
  ckpt_dir = None
 
175
  count += 1
176
  # Get the retrieved emails as a list
177
  email_list = message.content.replace(
178
+ "### Retrieved Emails:\n\n", ""
179
  ).split("--- --- --- --- Next Email --- --- --- ---\n\n")
180
  # Get the list of source files (e.g. R-help/2024-December.txt) for retrieved emails
181
  month_list = [text.splitlines()[0] for text in email_list]
 
198
  # Format the retrieved emails with Tool Call heading
199
  retrieved_emails.append(
200
  message.content.replace(
201
+ "### Retrieved Emails:\n\n",
202
  f"### ### ### ### Tool Call {count} ### ### ### ###\n\n",
203
  )
204
  )
graph.py CHANGED
@@ -44,24 +44,45 @@ def print_message_summaries(messages, header):
44
 
45
 
46
  def normalize_messages(messages):
47
- """Normalize messages to sequence of types expected by chat templates"""
48
  # Copy the most recent HumanMessage to the end
49
- # (avoids SmolLM and Qwen ValueError: Last message must be a HumanMessage!)
50
  if not type(messages[-1]) is HumanMessage:
51
  for msg in reversed(messages):
52
  if type(msg) is HumanMessage:
53
  messages.append(msg)
54
  break
55
- # Convert tool output (ToolMessage) to AIMessage
56
- # (avoids SmolLM and Qwen ValueError: Unknown message type: <class 'langchain_core.messages.tool.ToolMessage'>)
57
- messages = [
58
- AIMessage(msg.content) if type(msg) is ToolMessage else msg for msg in messages
59
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  # Delete tool call (AIMessage)
61
- # (avoids Gemma TemplateError: Conversation roles must alternate user/assistant/user/assistant/...)
62
  messages = [
63
  msg
64
- for msg in messages
65
  if not hasattr(msg, "tool_calls")
66
  or (hasattr(msg, "tool_calls") and not msg.tool_calls)
67
  ]
@@ -168,12 +189,12 @@ def BuildGraph(
168
  search_query = " ".join([search_query, start_year, end_year])
169
  retrieved_docs = retriever.invoke(search_query)
170
  serialized = "\n\n--- --- --- --- Next Email --- --- --- ---".join(
171
- # source key has file names (e.g. R-help/2024-December.txt), useful for retrieval and reporting
172
  "\n\n" + doc.metadata["source"] + doc.page_content
173
  for doc in retrieved_docs
174
  )
175
  retrieved_emails = (
176
- "### Retrieved Emails:\n\n" + serialized
177
  if serialized
178
  else "### No emails were retrieved"
179
  )
 
44
 
45
 
46
  def normalize_messages(messages):
47
+ """Normalize messages to sequence of types expected by chat models"""
48
  # Copy the most recent HumanMessage to the end
49
+ # - Avoids SmolLM and Qwen ValueError: Last message must be a HumanMessage!
50
  if not type(messages[-1]) is HumanMessage:
51
  for msg in reversed(messages):
52
  if type(msg) is HumanMessage:
53
  messages.append(msg)
54
  break
55
+
56
+ # Convert tool output (one or more consecutive ToolMessages) to AIMessage
57
+ # - Avoids SmolLM and Qwen ValueError: Unknown message type: <class 'langchain_core.messages.tool.ToolMessage'>
58
+ processed_messages = []
59
+ i = 0
60
+ while i < len(messages):
61
+ msg = messages[i]
62
+
63
+ if type(msg) is ToolMessage:
64
+ # Collect consecutive ToolMessages
65
+ tool_messages = []
66
+ count = 1
67
+ while i < len(messages) and type(messages[i]) is ToolMessage:
68
+ tool_msg = messages[i]
69
+ formatted_content = f"## Tool Call {count}\n\n{tool_msg.content}"
70
+ tool_messages.append(formatted_content)
71
+ count += 1
72
+ i += 1
73
+
74
+ # Combine all tool message contents into a single AIMessage
75
+ combined_content = "\n\n".join(tool_messages)
76
+ processed_messages.append(AIMessage(combined_content))
77
+ else:
78
+ processed_messages.append(msg)
79
+ i += 1
80
+
81
  # Delete tool call (AIMessage)
82
+ # - Avoids Gemma TemplateError: Conversation roles must alternate user/assistant/user/assistant/...
83
  messages = [
84
  msg
85
+ for msg in processed_messages
86
  if not hasattr(msg, "tool_calls")
87
  or (hasattr(msg, "tool_calls") and not msg.tool_calls)
88
  ]
 
189
  search_query = " ".join([search_query, start_year, end_year])
190
  retrieved_docs = retriever.invoke(search_query)
191
  serialized = "\n\n--- --- --- --- Next Email --- --- --- ---".join(
192
+ # Add file name (e.g. R-help/2024-December.txt) from source key
193
  "\n\n" + doc.metadata["source"] + doc.page_content
194
  for doc in retrieved_docs
195
  )
196
  retrieved_emails = (
197
+ "### Retrieved Emails:" + serialized
198
  if serialized
199
  else "### No emails were retrieved"
200
  )
main.py CHANGED
@@ -207,15 +207,14 @@ def RunChain(
207
 
208
  # Create a prompt template
209
  system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
 
210
  human_template = ChatPromptTemplate.from_template(
211
  """"
212
  ### Question:
213
 
214
  {question}
215
 
216
- ### Retrieved Emails:
217
-
218
- {context}
219
  """
220
  )
221
  prompt_template = system_template + human_template
 
207
 
208
  # Create a prompt template
209
  system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
210
+ # NOTE: Each new email starts with \n\n\nFrom, so we don't need newlines after Retrieved Emails:
211
  human_template = ChatPromptTemplate.from_template(
212
  """"
213
  ### Question:
214
 
215
  {question}
216
 
217
+ ### Retrieved Emails:{context}
 
 
218
  """
219
  )
220
  prompt_template = system_template + human_template
requirements.txt CHANGED
@@ -9,6 +9,10 @@ chromadb==1.0.13
9
  # NOTE: Gemma 3 with transformers==4.54.0 gives:
10
  # ValueError: Max cache length is not consistent across layers
11
  transformers==4.51.3
 
 
 
 
12
 
13
  # Langchain packages
14
  langchain==0.3.26
@@ -23,8 +27,6 @@ langgraph-sdk==0.1.72
23
  langgraph-prebuilt==0.5.2
24
  langgraph-checkpoint==2.1.0
25
 
26
- # Required by langchain-huggingface
27
- sentence-transformers==5.0.0
28
  # Required by Nomic embeddings
29
  einops==0.8.1
30
 
 
9
  # NOTE: Gemma 3 with transformers==4.54.0 gives:
10
  # ValueError: Max cache length is not consistent across layers
11
  transformers==4.51.3
12
+ # Required by langchain-huggingface
13
+ sentence-transformers==5.0.0
14
+ # For snapshot_download
15
+ huggingface-hub==0.34.3
16
 
17
  # Langchain packages
18
  langchain==0.3.26
 
27
  langgraph-prebuilt==0.5.2
28
  langgraph-checkpoint==2.1.0
29
 
 
 
30
  # Required by Nomic embeddings
31
  einops==0.8.1
32