jedick commited on
Commit
7d21953
·
1 Parent(s): 556cc72

Handle tool calls with thinking enabled

Browse files
Files changed (4) hide show
  1. graph.py +15 -14
  2. main.py +2 -2
  3. prompts.py +16 -12
  4. requirements.txt +1 -1
graph.py CHANGED
@@ -9,7 +9,7 @@ import os
9
 
10
  # Local modules
11
  from retriever import BuildRetriever
12
- from prompts import retrieve_prompt, answer_prompt, gemma_tools_template
13
  from mods.tool_calling_llm import ToolCallingLLM
14
 
15
  # Local modules
@@ -49,13 +49,14 @@ def print_message_summaries(messages, header):
49
  def normalize_messages(messages):
50
  """Normalize messages to sequence of types expected by chat templates"""
51
  # Copy the most recent HumanMessage to the end
52
- # (avoids SmolLM3 ValueError: Last message must be a HumanMessage!)
53
  if not type(messages[-1]) is HumanMessage:
54
  for msg in reversed(messages):
55
  if type(msg) is HumanMessage:
56
  messages.append(msg)
 
57
  # Convert tool output (ToolMessage) to AIMessage
58
- # (avoids SmolLM3 ValueError: Unknown message type: <class 'langchain_core.messages.tool.ToolMessage'>)
59
  messages = [
60
  AIMessage(msg.content) if type(msg) is ToolMessage else msg for msg in messages
61
  ]
@@ -75,7 +76,7 @@ def ToolifyHF(chat_model, system_message, system_message_suffix="", think=False)
75
  Get a Hugging Face model ready for bind_tools().
76
  """
77
 
78
- ## Add /no_think flag to turn off thinking mode (SmolLM3)
79
  # if not think:
80
  # system_message = "/no_think\n" + system_message
81
 
@@ -203,14 +204,12 @@ def BuildGraph(
203
  # Add tools to the local or remote chat model
204
  is_local = hasattr(chat_model, "model_id")
205
  if is_local:
206
- # For local model (ChatHuggingFace)
207
  query_model = ToolifyHF(
208
- chat_model, retrieve_prompt(compute_mode), "", think_retrieve
209
  ).bind_tools([retrieve_emails])
210
- # Don't use answer_with_citations tool here because responses with Gemma are sometimes unparseable
211
- generate_model = ToolifyHF(
212
- chat_model, answer_prompt(with_tools=False), "", think_generate
213
- )
214
  else:
215
  # For remote model (OpenAI API)
216
  query_model = chat_model.bind_tools([retrieve_emails])
@@ -228,9 +227,7 @@ def BuildGraph(
228
  messages = normalize_messages(messages)
229
  print_message_summaries(messages, "--- query: after normalization ---")
230
  else:
231
- messages = [SystemMessage(retrieve_prompt(compute_mode))] + state[
232
- "messages"
233
- ]
234
  response = query_model.invoke(messages)
235
 
236
  return {"messages": response}
@@ -241,9 +238,13 @@ def BuildGraph(
241
  messages = state["messages"]
242
  print_message_summaries(messages, "--- generate: before normalization ---")
243
  messages = normalize_messages(messages)
 
 
 
 
244
  print_message_summaries(messages, "--- generate: after normalization ---")
245
  else:
246
- messages = [SystemMessage(answer_prompt())] + state["messages"]
247
  response = generate_model.invoke(messages)
248
 
249
  return {"messages": response}
 
9
 
10
  # Local modules
11
  from retriever import BuildRetriever
12
+ from prompts import query_prompt, generate_prompt, gemma_tools_template
13
  from mods.tool_calling_llm import ToolCallingLLM
14
 
15
  # Local modules
 
49
  def normalize_messages(messages):
50
  """Normalize messages to sequence of types expected by chat templates"""
51
  # Copy the most recent HumanMessage to the end
52
+ # (avoids SmolLM and Qwen ValueError: Last message must be a HumanMessage!)
53
  if not type(messages[-1]) is HumanMessage:
54
  for msg in reversed(messages):
55
  if type(msg) is HumanMessage:
56
  messages.append(msg)
57
+ break
58
  # Convert tool output (ToolMessage) to AIMessage
59
+ # (avoids SmolLM and Qwen ValueError: Unknown message type: <class 'langchain_core.messages.tool.ToolMessage'>)
60
  messages = [
61
  AIMessage(msg.content) if type(msg) is ToolMessage else msg for msg in messages
62
  ]
 
76
  Get a Hugging Face model ready for bind_tools().
77
  """
78
 
79
+ ## Add /no_think flag to turn off thinking mode (SmolLM3 and Qwen)
80
  # if not think:
81
  # system_message = "/no_think\n" + system_message
82
 
 
204
  # Add tools to the local or remote chat model
205
  is_local = hasattr(chat_model, "model_id")
206
  if is_local:
207
+ # For local models (ChatHuggingFace with SmolLM, Gemma, or Qwen)
208
  query_model = ToolifyHF(
209
+ chat_model, query_prompt(compute_mode), "", think_retrieve
210
  ).bind_tools([retrieve_emails])
211
+ # Don't use answer_with_citations tool because responses with are sometimes unparseable
212
+ generate_model = chat_model
 
 
213
  else:
214
  # For remote model (OpenAI API)
215
  query_model = chat_model.bind_tools([retrieve_emails])
 
227
  messages = normalize_messages(messages)
228
  print_message_summaries(messages, "--- query: after normalization ---")
229
  else:
230
+ messages = [SystemMessage(query_prompt(compute_mode))] + state["messages"]
 
 
231
  response = query_model.invoke(messages)
232
 
233
  return {"messages": response}
 
238
  messages = state["messages"]
239
  print_message_summaries(messages, "--- generate: before normalization ---")
240
  messages = normalize_messages(messages)
241
+ # Add the system message here because we're not using tools
242
+ messages = [
243
+ SystemMessage(generate_prompt(with_tools=False, think=False))
244
+ ] + messages
245
  print_message_summaries(messages, "--- generate: after normalization ---")
246
  else:
247
+ messages = [SystemMessage(generate_prompt())] + state["messages"]
248
  response = generate_model.invoke(messages)
249
 
250
  return {"messages": response}
main.py CHANGED
@@ -23,7 +23,7 @@ from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
23
  from index import ProcessFile
24
  from retriever import BuildRetriever, db_dir
25
  from graph import BuildGraph
26
- from prompts import answer_prompt
27
 
28
  # -----------
29
  # R-help-chat
@@ -200,7 +200,7 @@ def RunChain(
200
  chat_model = GetChatModel(compute_mode)
201
 
202
  # Control thinking for SmolLM3
203
- system_prompt = answer_prompt()
204
  if hasattr(chat_model, "model_id") and not think:
205
  system_prompt = f"/no_think\n{system_prompt}"
206
 
 
23
  from index import ProcessFile
24
  from retriever import BuildRetriever, db_dir
25
  from graph import BuildGraph
26
+ from prompts import generate_prompt
27
 
28
  # -----------
29
  # R-help-chat
 
200
  chat_model = GetChatModel(compute_mode)
201
 
202
  # Control thinking for SmolLM3
203
+ system_prompt = generate_prompt()
204
  if hasattr(chat_model, "model_id") and not think:
205
  system_prompt = f"/no_think\n{system_prompt}"
206
 
prompts.py CHANGED
@@ -3,7 +3,7 @@ from util import get_sources, get_start_end_months
3
  import re
4
 
5
 
6
- def retrieve_prompt(compute_mode):
7
  """Return system prompt for query step
8
 
9
  Args:
@@ -13,11 +13,11 @@ def retrieve_prompt(compute_mode):
13
  # Get start and end months from database
14
  start, end = get_start_end_months(get_sources())
15
 
16
- retrieve_prompt = (
17
  f"Today Date: {date.today()}."
18
  "You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
19
  "Do not ask the user for more information, but retrieve emails from the R-help mailing list archives."
20
- # gpt-4o-mini says last two months aren't available with this: Emails from from {start} to {end} are available for retrieval.
21
  f"The emails available for retrieval are from {start} to {end}."
22
  "Write a search query based on the user's question, but do not answer the question just yet."
23
  "For questions about differences or comparison between X and Y, retrieve emails about X and Y."
@@ -25,19 +25,20 @@ def retrieve_prompt(compute_mode):
25
  "For specific questions, use retrieve_emails(search_query=<specific topic>)."
26
  "For questions about years, use retrieve_emails(search_query=, start_year=, end_year=) (this month is this year)."
27
  "For questions about months, use 3-letter abbreviations (Jan..Dec) for the 'month' argument."
28
- "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list."
 
29
  )
30
  # A sanity check that we don't have unassigned variables
31
  # (this causes KeyError in parsing by ToolCallingLLM)
32
- matches = re.findall(r"\{.*?\}", " ".join(retrieve_prompt))
33
  if matches:
34
  raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
35
- return retrieve_prompt
36
 
37
 
38
- def answer_prompt(with_tools=True):
39
  """Return system prompt for generate step"""
40
- answer_prompt = (
41
  f"Today Date: {date.today()}."
42
  "You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
43
  "Summarize the retrieved emails from the R-help mailing list archives to answer the user's question or query."
@@ -45,17 +46,20 @@ def answer_prompt(with_tools=True):
45
  "Tell the user if there are no retrieved emails or if you are unable to answer the question based on the information in the emails."
46
  "Do not give an answer based on your own knowledge or memory, and do not include examples that aren't based on the retrieved emails."
47
  "Example: For a question about writing formulas for lm(), make your answer about formulas for lm() from the retrieved emails."
48
- "Do not respond with packages that are only listed under sessionInfo, session info, or other attached packages."
 
49
  "Include inline citations (email senders and dates) in your response."
50
  "Only answer general questions about R if the answer is given in the retrieved emails."
51
  "Respond with 300 words maximum and 30 lines of code maximum and include any relevant URLs from the retrieved emails."
52
  )
53
  if with_tools:
54
- answer_prompt += "Use answer_with_citations to provide the complete answer and all citations used. "
55
- matches = re.findall(r"\{.*?\}", " ".join(answer_prompt))
 
 
56
  if matches:
57
  raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
58
- return answer_prompt
59
 
60
 
61
  # Prompt template for SmolLM3 with tools
 
3
  import re
4
 
5
 
6
+ def query_prompt(compute_mode):
7
  """Return system prompt for query step
8
 
9
  Args:
 
13
  # Get start and end months from database
14
  start, end = get_start_end_months(get_sources())
15
 
16
+ query_prompt = (
17
  f"Today Date: {date.today()}."
18
  "You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
19
  "Do not ask the user for more information, but retrieve emails from the R-help mailing list archives."
20
+ # gpt-4o-mini thinks last two months aren't available with this: "Emails from from {start} to {end} are available for retrieval."
21
  f"The emails available for retrieval are from {start} to {end}."
22
  "Write a search query based on the user's question, but do not answer the question just yet."
23
  "For questions about differences or comparison between X and Y, retrieve emails about X and Y."
 
25
  "For specific questions, use retrieve_emails(search_query=<specific topic>)."
26
  "For questions about years, use retrieve_emails(search_query=, start_year=, end_year=) (this month is this year)."
27
  "For questions about months, use 3-letter abbreviations (Jan..Dec) for the 'month' argument."
28
+ "Even if retrieved emails are already available, you should retrieve *more* emails to answer the most recent question." # Qwen
29
+ # "If you decide not to retrieve emails, tell the user why and suggest how to improve their question to chat with the R-help mailing list."
30
  )
31
  # A sanity check that we don't have unassigned variables
32
  # (this causes KeyError in parsing by ToolCallingLLM)
33
+ matches = re.findall(r"\{.*?\}", " ".join(query_prompt))
34
  if matches:
35
  raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
36
+ return query_prompt
37
 
38
 
39
+ def generate_prompt(with_tools=True, think=True):
40
  """Return system prompt for generate step"""
41
+ generate_prompt = (
42
  f"Today Date: {date.today()}."
43
  "You are a helpful RAG chatbot designed to answer questions about R programming based on the R-help mailing list."
44
  "Summarize the retrieved emails from the R-help mailing list archives to answer the user's question or query."
 
46
  "Tell the user if there are no retrieved emails or if you are unable to answer the question based on the information in the emails."
47
  "Do not give an answer based on your own knowledge or memory, and do not include examples that aren't based on the retrieved emails."
48
  "Example: For a question about writing formulas for lm(), make your answer about formulas for lm() from the retrieved emails."
49
+ # "Do not respond with packages that are only listed under sessionInfo, session info, or other attached packages."
50
+ "Summarize the content of the emails rather than copying the headers." # Qwen
51
  "Include inline citations (email senders and dates) in your response."
52
  "Only answer general questions about R if the answer is given in the retrieved emails."
53
  "Respond with 300 words maximum and 30 lines of code maximum and include any relevant URLs from the retrieved emails."
54
  )
55
  if with_tools:
56
+ generate_prompt += "Use answer_with_citations to provide the complete answer and all citations used."
57
+ if not think:
58
+ generate_prompt += "/no_think"
59
+ matches = re.findall(r"\{.*?\}", " ".join(generate_prompt))
60
  if matches:
61
  raise ValueError(f"Unassigned variables in prompt: {' '.join(matches)}")
62
+ return generate_prompt
63
 
64
 
65
  # Prompt template for SmolLM3 with tools
requirements.txt CHANGED
@@ -13,7 +13,7 @@ torch==2.5.1
13
  # Gemma 3: transformers>=4.50
14
  # Gemma 3 with transformers==4.54.0 gives:
15
  # ValueError: Max cache length is not consistent across layers
16
- transformers==4.51.3
17
  # Commented because we have local modifications
18
  #tool-calling-llm==0.1.2
19
  bm25s==0.2.12
 
13
  # Gemma 3: transformers>=4.50
14
  # Gemma 3 with transformers==4.54.0 gives:
15
  # ValueError: Max cache length is not consistent across layers
16
+ transformers==4.50.0
17
  # Commented because we have local modifications
18
  #tool-calling-llm==0.1.2
19
  bm25s==0.2.12