Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
·
b7a3bb3
1
Parent(s):
9ac80a4
Convert consecutive ToolMessages to AIMessage
Browse files
app.py
CHANGED
@@ -26,7 +26,9 @@ if torch.cuda.is_available():
|
|
26 |
ckpt_dir = snapshot_download(model_id, local_dir_use_symlinks=False)
|
27 |
print(f"Using checkpoints from {ckpt_dir}")
|
28 |
print(f"Downloading checkpoints for {embedding_model_id}...")
|
29 |
-
embedding_ckpt_dir = snapshot_download(
|
|
|
|
|
30 |
print(f"Using embedding checkpoints from {embedding_ckpt_dir}")
|
31 |
else:
|
32 |
ckpt_dir = None
|
@@ -173,7 +175,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
|
173 |
count += 1
|
174 |
# Get the retrieved emails as a list
|
175 |
email_list = message.content.replace(
|
176 |
-
"### Retrieved Emails:\n\n
|
177 |
).split("--- --- --- --- Next Email --- --- --- ---\n\n")
|
178 |
# Get the list of source files (e.g. R-help/2024-December.txt) for retrieved emails
|
179 |
month_list = [text.splitlines()[0] for text in email_list]
|
@@ -196,7 +198,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
|
196 |
# Format the retrieved emails with Tool Call heading
|
197 |
retrieved_emails.append(
|
198 |
message.content.replace(
|
199 |
-
"### Retrieved Emails:\n\n
|
200 |
f"### ### ### ### Tool Call {count} ### ### ### ###\n\n",
|
201 |
)
|
202 |
)
|
|
|
26 |
ckpt_dir = snapshot_download(model_id, local_dir_use_symlinks=False)
|
27 |
print(f"Using checkpoints from {ckpt_dir}")
|
28 |
print(f"Downloading checkpoints for {embedding_model_id}...")
|
29 |
+
embedding_ckpt_dir = snapshot_download(
|
30 |
+
embedding_model_id, local_dir_use_symlinks=False
|
31 |
+
)
|
32 |
print(f"Using embedding checkpoints from {embedding_ckpt_dir}")
|
33 |
else:
|
34 |
ckpt_dir = None
|
|
|
175 |
count += 1
|
176 |
# Get the retrieved emails as a list
|
177 |
email_list = message.content.replace(
|
178 |
+
"### Retrieved Emails:\n\n", ""
|
179 |
).split("--- --- --- --- Next Email --- --- --- ---\n\n")
|
180 |
# Get the list of source files (e.g. R-help/2024-December.txt) for retrieved emails
|
181 |
month_list = [text.splitlines()[0] for text in email_list]
|
|
|
198 |
# Format the retrieved emails with Tool Call heading
|
199 |
retrieved_emails.append(
|
200 |
message.content.replace(
|
201 |
+
"### Retrieved Emails:\n\n",
|
202 |
f"### ### ### ### Tool Call {count} ### ### ### ###\n\n",
|
203 |
)
|
204 |
)
|
graph.py
CHANGED
@@ -44,24 +44,45 @@ def print_message_summaries(messages, header):
|
|
44 |
|
45 |
|
46 |
def normalize_messages(messages):
|
47 |
-
"""Normalize messages to sequence of types expected by chat
|
48 |
# Copy the most recent HumanMessage to the end
|
49 |
-
#
|
50 |
if not type(messages[-1]) is HumanMessage:
|
51 |
for msg in reversed(messages):
|
52 |
if type(msg) is HumanMessage:
|
53 |
messages.append(msg)
|
54 |
break
|
55 |
-
|
56 |
-
#
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
# Delete tool call (AIMessage)
|
61 |
-
#
|
62 |
messages = [
|
63 |
msg
|
64 |
-
for msg in
|
65 |
if not hasattr(msg, "tool_calls")
|
66 |
or (hasattr(msg, "tool_calls") and not msg.tool_calls)
|
67 |
]
|
@@ -168,12 +189,12 @@ def BuildGraph(
|
|
168 |
search_query = " ".join([search_query, start_year, end_year])
|
169 |
retrieved_docs = retriever.invoke(search_query)
|
170 |
serialized = "\n\n--- --- --- --- Next Email --- --- --- ---".join(
|
171 |
-
#
|
172 |
"\n\n" + doc.metadata["source"] + doc.page_content
|
173 |
for doc in retrieved_docs
|
174 |
)
|
175 |
retrieved_emails = (
|
176 |
-
"### Retrieved Emails
|
177 |
if serialized
|
178 |
else "### No emails were retrieved"
|
179 |
)
|
|
|
44 |
|
45 |
|
46 |
def normalize_messages(messages):
|
47 |
+
"""Normalize messages to sequence of types expected by chat models"""
|
48 |
# Copy the most recent HumanMessage to the end
|
49 |
+
# - Avoids SmolLM and Qwen ValueError: Last message must be a HumanMessage!
|
50 |
if not type(messages[-1]) is HumanMessage:
|
51 |
for msg in reversed(messages):
|
52 |
if type(msg) is HumanMessage:
|
53 |
messages.append(msg)
|
54 |
break
|
55 |
+
|
56 |
+
# Convert tool output (one or more consecutive ToolMessages) to AIMessage
|
57 |
+
# - Avoids SmolLM and Qwen ValueError: Unknown message type: <class 'langchain_core.messages.tool.ToolMessage'>
|
58 |
+
processed_messages = []
|
59 |
+
i = 0
|
60 |
+
while i < len(messages):
|
61 |
+
msg = messages[i]
|
62 |
+
|
63 |
+
if type(msg) is ToolMessage:
|
64 |
+
# Collect consecutive ToolMessages
|
65 |
+
tool_messages = []
|
66 |
+
count = 1
|
67 |
+
while i < len(messages) and type(messages[i]) is ToolMessage:
|
68 |
+
tool_msg = messages[i]
|
69 |
+
formatted_content = f"## Tool Call {count}\n\n{tool_msg.content}"
|
70 |
+
tool_messages.append(formatted_content)
|
71 |
+
count += 1
|
72 |
+
i += 1
|
73 |
+
|
74 |
+
# Combine all tool message contents into a single AIMessage
|
75 |
+
combined_content = "\n\n".join(tool_messages)
|
76 |
+
processed_messages.append(AIMessage(combined_content))
|
77 |
+
else:
|
78 |
+
processed_messages.append(msg)
|
79 |
+
i += 1
|
80 |
+
|
81 |
# Delete tool call (AIMessage)
|
82 |
+
# - Avoids Gemma TemplateError: Conversation roles must alternate user/assistant/user/assistant/...
|
83 |
messages = [
|
84 |
msg
|
85 |
+
for msg in processed_messages
|
86 |
if not hasattr(msg, "tool_calls")
|
87 |
or (hasattr(msg, "tool_calls") and not msg.tool_calls)
|
88 |
]
|
|
|
189 |
search_query = " ".join([search_query, start_year, end_year])
|
190 |
retrieved_docs = retriever.invoke(search_query)
|
191 |
serialized = "\n\n--- --- --- --- Next Email --- --- --- ---".join(
|
192 |
+
# Add file name (e.g. R-help/2024-December.txt) from source key
|
193 |
"\n\n" + doc.metadata["source"] + doc.page_content
|
194 |
for doc in retrieved_docs
|
195 |
)
|
196 |
retrieved_emails = (
|
197 |
+
"### Retrieved Emails:" + serialized
|
198 |
if serialized
|
199 |
else "### No emails were retrieved"
|
200 |
)
|
main.py
CHANGED
@@ -207,15 +207,14 @@ def RunChain(
|
|
207 |
|
208 |
# Create a prompt template
|
209 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
|
|
210 |
human_template = ChatPromptTemplate.from_template(
|
211 |
""""
|
212 |
### Question:
|
213 |
|
214 |
{question}
|
215 |
|
216 |
-
### Retrieved Emails:
|
217 |
-
|
218 |
-
{context}
|
219 |
"""
|
220 |
)
|
221 |
prompt_template = system_template + human_template
|
|
|
207 |
|
208 |
# Create a prompt template
|
209 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
210 |
+
# NOTE: Each new email starts with \n\n\nFrom, so we don't need newlines after Retrieved Emails:
|
211 |
human_template = ChatPromptTemplate.from_template(
|
212 |
""""
|
213 |
### Question:
|
214 |
|
215 |
{question}
|
216 |
|
217 |
+
### Retrieved Emails:{context}
|
|
|
|
|
218 |
"""
|
219 |
)
|
220 |
prompt_template = system_template + human_template
|
requirements.txt
CHANGED
@@ -9,6 +9,10 @@ chromadb==1.0.13
|
|
9 |
# NOTE: Gemma 3 with transformers==4.54.0 gives:
|
10 |
# ValueError: Max cache length is not consistent across layers
|
11 |
transformers==4.51.3
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Langchain packages
|
14 |
langchain==0.3.26
|
@@ -23,8 +27,6 @@ langgraph-sdk==0.1.72
|
|
23 |
langgraph-prebuilt==0.5.2
|
24 |
langgraph-checkpoint==2.1.0
|
25 |
|
26 |
-
# Required by langchain-huggingface
|
27 |
-
sentence-transformers==5.0.0
|
28 |
# Required by Nomic embeddings
|
29 |
einops==0.8.1
|
30 |
|
|
|
9 |
# NOTE: Gemma 3 with transformers==4.54.0 gives:
|
10 |
# ValueError: Max cache length is not consistent across layers
|
11 |
transformers==4.51.3
|
12 |
+
# Required by langchain-huggingface
|
13 |
+
sentence-transformers==5.0.0
|
14 |
+
# For snapshot_download
|
15 |
+
huggingface-hub==0.34.3
|
16 |
|
17 |
# Langchain packages
|
18 |
langchain==0.3.26
|
|
|
27 |
langgraph-prebuilt==0.5.2
|
28 |
langgraph-checkpoint==2.1.0
|
29 |
|
|
|
|
|
30 |
# Required by Nomic embeddings
|
31 |
einops==0.8.1
|
32 |
|