Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
·
7d21953
1
Parent(s):
556cc72
Handle tool calls with thinking enabled
Browse files- graph.py +15 -14
- main.py +2 -2
- prompts.py +16 -12
- requirements.txt +1 -1
graph.py
CHANGED
@@ -9,7 +9,7 @@ import os
|
|
9 |
|
10 |
# Local modules
|
11 |
from retriever import BuildRetriever
|
12 |
-
from prompts import
|
13 |
from mods.tool_calling_llm import ToolCallingLLM
|
14 |
|
15 |
# Local modules
|
@@ -49,13 +49,14 @@ def print_message_summaries(messages, header):
|
|
49 |
def normalize_messages(messages):
|
50 |
"""Normalize messages to sequence of types expected by chat templates"""
|
51 |
# Copy the most recent HumanMessage to the end
|
52 |
-
# (avoids
|
53 |
if not type(messages[-1]) is HumanMessage:
|
54 |
for msg in reversed(messages):
|
55 |
if type(msg) is HumanMessage:
|
56 |
messages.append(msg)
|
|
|
57 |
# Convert tool output (ToolMessage) to AIMessage
|
58 |
-
# (avoids
|
59 |
messages = [
|
60 |
AIMessage(msg.content) if type(msg) is ToolMessage else msg for msg in messages
|
61 |
]
|
@@ -75,7 +76,7 @@ def ToolifyHF(chat_model, system_message, system_message_suffix="", think=False)
|
|
75 |
Get a Hugging Face model ready for bind_tools().
|
76 |
"""
|
77 |
|
78 |
-
## Add /no_think flag to turn off thinking mode (SmolLM3)
|
79 |
# if not think:
|
80 |
# system_message = "/no_think\n" + system_message
|
81 |
|
@@ -203,14 +204,12 @@ def BuildGraph(
|
|
203 |
# Add tools to the local or remote chat model
|
204 |
is_local = hasattr(chat_model, "model_id")
|
205 |
if is_local:
|
206 |
-
# For local
|
207 |
query_model = ToolifyHF(
|
208 |
-
chat_model,
|
209 |
).bind_tools([retrieve_emails])
|
210 |
-
# Don't use answer_with_citations tool
|
211 |
-
generate_model =
|
212 |
-
chat_model, answer_prompt(with_tools=False), "", think_generate
|
213 |
-
)
|
214 |
else:
|
215 |
# For remote model (OpenAI API)
|
216 |
query_model = chat_model.bind_tools([retrieve_emails])
|
@@ -228,9 +227,7 @@ def BuildGraph(
|
|
228 |
messages = normalize_messages(messages)
|
229 |
print_message_summaries(messages, "--- query: after normalization ---")
|
230 |
else:
|
231 |
-
messages = [SystemMessage(
|
232 |
-
"messages"
|
233 |
-
]
|
234 |
response = query_model.invoke(messages)
|
235 |
|
236 |
return {"messages": response}
|
@@ -241,9 +238,13 @@ def BuildGraph(
|
|
241 |
messages = state["messages"]
|
242 |
print_message_summaries(messages, "--- generate: before normalization ---")
|
243 |
messages = normalize_messages(messages)
|
|
|
|
|
|
|
|
|
244 |
print_message_summaries(messages, "--- generate: after normalization ---")
|
245 |
else:
|
246 |
-
messages = [SystemMessage(
|
247 |
response = generate_model.invoke(messages)
|
248 |
|
249 |
return {"messages": response}
|
|
|
9 |
|
10 |
# Local modules
|
11 |
from retriever import BuildRetriever
|
12 |
+
from prompts import query_prompt, generate_prompt, gemma_tools_template
|
13 |
from mods.tool_calling_llm import ToolCallingLLM
|
14 |
|
15 |
# Local modules
|
|
|
49 |
def normalize_messages(messages):
|
50 |
"""Normalize messages to sequence of types expected by chat templates"""
|
51 |
# Copy the most recent HumanMessage to the end
|
52 |
+
# (avoids SmolLM and Qwen ValueError: Last message must be a HumanMessage!)
|
53 |
if not type(messages[-1]) is HumanMessage:
|
54 |
for msg in reversed(messages):
|
55 |
if type(msg) is HumanMessage:
|
56 |
messages.append(msg)
|
57 |
+
break
|
58 |
# Convert tool output (ToolMessage) to AIMessage
|
59 |
+
# (avoids SmolLM and Qwen ValueError: Unknown message type: <class 'langchain_core.messages.tool.ToolMessage'>)
|
60 |
messages = [
|
61 |
AIMessage(msg.content) if type(msg) is ToolMessage else msg for msg in messages
|
62 |
]
|
|
|
76 |
Get a Hugging Face model ready for bind_tools().
|
77 |
"""
|
78 |
|
79 |
+
## Add /no_think flag to turn off thinking mode (SmolLM3 and Qwen)
|
80 |
# if not think:
|
81 |
# system_message = "/no_think\n" + system_message
|
82 |
|
|
|
204 |
# Add tools to the local or remote chat model
|
205 |
is_local = hasattr(chat_model, "model_id")
|
206 |
if is_local:
|
207 |
+
# For local models (ChatHuggingFace with SmolLM, Gemma, or Qwen)
|
208 |
query_model = ToolifyHF(
|
209 |
+
chat_model, query_prompt(compute_mode), "", think_retrieve
|
210 |
).bind_tools([retrieve_emails])
|
211 |
+
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
212 |
+
generate_model = chat_model
|
|
|
|
|
213 |
else:
|
214 |
# For remote model (OpenAI API)
|
215 |
query_model = chat_model.bind_tools([retrieve_emails])
|
|
|
227 |
messages = normalize_messages(messages)
|
228 |
print_message_summaries(messages, "--- query: after normalization ---")
|
229 |
else:
|
230 |
+
messages = [SystemMessage(query_prompt(compute_mode))] + state["messages"]
|
|
|
|
|
231 |
response = query_model.invoke(messages)
|
232 |
|
233 |
return {"messages": response}
|
|
|
238 |
messages = state["messages"]
|
239 |
print_message_summaries(messages, "--- generate: before normalization ---")
|
240 |
messages = normalize_messages(messages)
|
241 |
+
# Add the system message here because we're not using tools
|
242 |
+
messages = [
|
243 |
+
SystemMessage(generate_prompt(with_tools=False, think=False))
|
244 |
+
] + messages
|
245 |
print_message_summaries(messages, "--- generate: after normalization ---")
|
246 |
else:
|
247 |
+
messages = [SystemMessage(generate_prompt())] + state["messages"]
|
248 |
response = generate_model.invoke(messages)
|
249 |
|
250 |
return {"messages": response}
|
main.py
CHANGED
@@ -23,7 +23,7 @@ from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
|
|
23 |
from index import ProcessFile
|
24 |
from retriever import BuildRetriever, db_dir
|
25 |
from graph import BuildGraph
|
26 |
-
from prompts import
|
27 |
|
28 |
# -----------
|
29 |
# R-help-chat
|
@@ -200,7 +200,7 @@ def RunChain(
|
|
200 |
chat_model = GetChatModel(compute_mode)
|
201 |
|
202 |
# Control thinking for SmolLM3
|
203 |
-
system_prompt =
|
204 |
if hasattr(chat_model, "model_id") and not think:
|
205 |
system_prompt = f"/no_think\n{system_prompt}"
|
206 |
|
|
|
23 |
from index import ProcessFile
|
24 |
from retriever import BuildRetriever, db_dir
|
25 |
from graph import BuildGraph
|
26 |
+
from prompts import generate_prompt
|
27 |
|
28 |
# -----------
|
29 |
# R-help-chat
|
|
|
200 |
chat_model = GetChatModel(compute_mode)
|
201 |
|
202 |
# Control thinking for SmolLM3
|
203 |
+
system_prompt = generate_prompt()
|
204 |
if hasattr(chat_model, "model_id") and not think:
|
205 |
system_prompt = f"/no_think\n{system_prompt}"
|
206 |
|
prompts.py
CHANGED
@@ -3,7 +3,7 @@ from util import get_sources, get_start_end_months
|
|
3 |
import re
|
4 |
|
5 |
|
6 |
-
def
|
7 |
"""Return system prompt for query step
|
8 |
|
9 |
Args:
|
@@ -13,11 +13,11 @@ def retrieve_prompt(compute_mode):
|
|
13 |
# Get start and end months from database
|
14 |
start, end = get_start_end_months(get_sources())
|
15 |
|
16 |
-
|
17 |
f"Today Date: {date.today()}."
|
18 |
"You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
|
19 |
"Do not ask the user for more information, but retrieve emails from the R-help mailing list archives."
|
20 |
-
# gpt-4o-mini
|
21 |
f"The emails available for retrieval are from {start} to {end}."
|
22 |
"Write a search query based on the user's question, but do not answer the question just yet."
|
23 |
"For questions about differences or comparison between X and Y, retrieve emails about X and Y."
|
@@ -25,19 +25,20 @@ def retrieve_prompt(compute_mode):
|
|
25 |
"For specific questions, use retrieve_emails(search_query=<specific topic>)."
|
26 |
"For questions about years, use retrieve_emails(search_query=, start_year=, end_year=) (this month is this year)."
|
27 |
"For questions about months, use 3-letter abbreviations (Jan..Dec) for the 'month' argument."
|
28 |
-
"
|
|
|
29 |
)
|
30 |
# A sanity check that we don't have unassigned variables
|
31 |
# (this causes KeyError in parsing by ToolCallingLLM)
|
32 |
-
matches = re.findall(r"\{.*?\}", " ".join(
|
33 |
if matches:
|
34 |
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
35 |
-
return
|
36 |
|
37 |
|
38 |
-
def
|
39 |
"""Return system prompt for generate step"""
|
40 |
-
|
41 |
f"Today Date: {date.today()}."
|
42 |
"You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
|
43 |
"Summarize the retrieved emails from the R-help mailing list archives to answer the user's question or query."
|
@@ -45,17 +46,20 @@ def answer_prompt(with_tools=True):
|
|
45 |
"Tell the user if there are no retrieved emails or if you are unable to answer the question based on the information in the emails."
|
46 |
"Do not give an answer based on your own knowledge or memory, and do not include examples that aren't based on the retrieved emails."
|
47 |
"Example: For a question about writing formulas for lm(), make your answer about formulas for lm() from the retrieved emails."
|
48 |
-
"Do not respond with packages that are only listed under sessionInfo, session info, or other attached packages."
|
|
|
49 |
"Include inline citations (email senders and dates) in your response."
|
50 |
"Only answer general questions about R if the answer is given in the retrieved emails."
|
51 |
"Respond with 300 words maximum and 30 lines of code maximum and include any relevant URLs from the retrieved emails."
|
52 |
)
|
53 |
if with_tools:
|
54 |
-
|
55 |
-
|
|
|
|
|
56 |
if matches:
|
57 |
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
58 |
-
return
|
59 |
|
60 |
|
61 |
# Prompt template for SmolLM3 with tools
|
|
|
3 |
import re
|
4 |
|
5 |
|
6 |
+
def query_prompt(compute_mode):
|
7 |
"""Return system prompt for query step
|
8 |
|
9 |
Args:
|
|
|
13 |
# Get start and end months from database
|
14 |
start, end = get_start_end_months(get_sources())
|
15 |
|
16 |
+
query_prompt = (
|
17 |
f"Today Date: {date.today()}."
|
18 |
"You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
|
19 |
"Do not ask the user for more information, but retrieve emails from the R-help mailing list archives."
|
20 |
+
# gpt-4o-mini thinks last two months aren't available with this: "Emails from from {start} to {end} are available for retrieval."
|
21 |
f"The emails available for retrieval are from {start} to {end}."
|
22 |
"Write a search query based on the user's question, but do not answer the question just yet."
|
23 |
"For questions about differences or comparison between X and Y, retrieve emails about X and Y."
|
|
|
25 |
"For specific questions, use retrieve_emails(search_query=<specific topic>)."
|
26 |
"For questions about years, use retrieve_emails(search_query=, start_year=, end_year=) (this month is this year)."
|
27 |
"For questions about months, use 3-letter abbreviations (Jan..Dec) for the 'month' argument."
|
28 |
+
"Even if retrieved emails are already available, you should retrieve *more* emails to answer the most recent question." # Qwen
|
29 |
+
# "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list."
|
30 |
)
|
31 |
# A sanity check that we don't have unassigned variables
|
32 |
# (this causes KeyError in parsing by ToolCallingLLM)
|
33 |
+
matches = re.findall(r"\{.*?\}", " ".join(query_prompt))
|
34 |
if matches:
|
35 |
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
36 |
+
return query_prompt
|
37 |
|
38 |
|
39 |
+
def generate_prompt(with_tools=True, think=True):
|
40 |
"""Return system prompt for generate step"""
|
41 |
+
generate_prompt = (
|
42 |
f"Today Date: {date.today()}."
|
43 |
"You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
|
44 |
"Summarize the retrieved emails from the R-help mailing list archives to answer the user's question or query."
|
|
|
46 |
"Tell the user if there are no retrieved emails or if you are unable to answer the question based on the information in the emails."
|
47 |
"Do not give an answer based on your own knowledge or memory, and do not include examples that aren't based on the retrieved emails."
|
48 |
"Example: For a question about writing formulas for lm(), make your answer about formulas for lm() from the retrieved emails."
|
49 |
+
# "Do not respond with packages that are only listed under sessionInfo, session info, or other attached packages."
|
50 |
+
"Summarize the content of the emails rather than copying the headers." # Qwen
|
51 |
"Include inline citations (email senders and dates) in your response."
|
52 |
"Only answer general questions about R if the answer is given in the retrieved emails."
|
53 |
"Respond with 300 words maximum and 30 lines of code maximum and include any relevant URLs from the retrieved emails."
|
54 |
)
|
55 |
if with_tools:
|
56 |
+
generate_prompt += "Use answer_with_citations to provide the complete answer and all citations used."
|
57 |
+
if not think:
|
58 |
+
generate_prompt += "/no_think"
|
59 |
+
matches = re.findall(r"\{.*?\}", " ".join(generate_prompt))
|
60 |
if matches:
|
61 |
raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
|
62 |
+
return generate_prompt
|
63 |
|
64 |
|
65 |
# Prompt template for SmolLM3 with tools
|
requirements.txt
CHANGED
@@ -13,7 +13,7 @@ torch==2.5.1
|
|
13 |
# Gemma 3: transformers>=4.50
|
14 |
# Gemma 3 with transformers==4.54.0 gives:
|
15 |
# ValueError: Max cache length is not consistent across layers
|
16 |
-
transformers==4.
|
17 |
# Commented because we have local modifications
|
18 |
#tool-calling-llm==0.1.2
|
19 |
bm25s==0.2.12
|
|
|
13 |
# Gemma 3: transformers>=4.50
|
14 |
# Gemma 3 with transformers==4.54.0 gives:
|
15 |
# ValueError: Max cache length is not consistent across layers
|
16 |
+
transformers==4.50.0
|
17 |
# Commented because we have local modifications
|
18 |
#tool-calling-llm==0.1.2
|
19 |
bm25s==0.2.12
|