jedick commited on
Commit
3575a77
Β·
1 Parent(s): ace4242

Enable thinking for answer

Browse files
Files changed (5) hide show
  1. app.py +9 -8
  2. graph.py +17 -13
  3. main.py +8 -8
  4. mods/tool_calling_llm.py +11 -3
  5. prompts.py +2 -2
app.py CHANGED
@@ -88,7 +88,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
88
  # Get the chat model and build the graph
89
  chat_model = GetChatModel(compute_mode)
90
  graph_builder = BuildGraph(
91
- chat_model, compute_mode, search_type, think_query=True
92
  )
93
  # Compile the graph with an in-memory checkpointer
94
  memory = MemorySaver()
@@ -184,7 +184,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
184
  retrieved_emails = "\n\n".join(retrieved_emails)
185
  yield history, retrieved_emails, []
186
 
187
- if node == "generate":
188
  # Append messages (thinking and non-thinking) to history
189
  chunk_messages = chunk["messages"]
190
  history = append_content(chunk_messages, history, thinking_about="answer")
@@ -383,8 +383,9 @@ with gr.Blocks(
383
  status_text = f"""
384
  πŸ“ Now in **local** mode, using ZeroGPU hardware<br>
385
  βŒ› Response time is about one minute<br>
386
- πŸ” Thinking is enabled for the query<br>
387
- &emsp;&nbsp; 🧠 Add **/think** to enable thinking for the answer</br>
 
388
  ✨ [nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) and [{model_id.split("/")[-1]}](https://huggingface.co/{model_id})<br>
389
  🏠 See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
390
  """
@@ -412,15 +413,15 @@ with gr.Blocks(
412
  """Get example questions based on compute mode"""
413
  questions = [
414
  # "What is today's date?",
415
- "Summarize emails from the last two months",
416
- "Advice on using plotmath /think",
417
  "When was has.HLC mentioned?",
418
  "Who reported installation problems in 2023-2024?",
419
  ]
420
 
421
  if compute_mode == "remote":
422
- # Remove "/think" from questions in remote mode
423
- questions = [q.replace(" /think", "") for q in questions]
424
 
425
  # cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
426
  return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
 
88
  # Get the chat model and build the graph
89
  chat_model = GetChatModel(compute_mode)
90
  graph_builder = BuildGraph(
91
+ chat_model, compute_mode, search_type, think_answer=True
92
  )
93
  # Compile the graph with an in-memory checkpointer
94
  memory = MemorySaver()
 
184
  retrieved_emails = "\n\n".join(retrieved_emails)
185
  yield history, retrieved_emails, []
186
 
187
+ if node == "answer":
188
  # Append messages (thinking and non-thinking) to history
189
  chunk_messages = chunk["messages"]
190
  history = append_content(chunk_messages, history, thinking_about="answer")
 
383
  status_text = f"""
384
  πŸ“ Now in **local** mode, using ZeroGPU hardware<br>
385
  βŒ› Response time is about one minute<br>
386
+ 🧠 Thinking is enabled for the answer<br>
387
+ &emsp;&nbsp; πŸ” Add **/think** to enable thinking for the query</br>
388
+ &emsp;&nbsp; 🚫 Add **/no_think** to disable all thinking</br>
389
  ✨ [nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) and [{model_id.split("/")[-1]}](https://huggingface.co/{model_id})<br>
390
  🏠 See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
391
  """
 
413
  """Get example questions based on compute mode"""
414
  questions = [
415
  # "What is today's date?",
416
+ "Summarize emails from the last two months /no_think",
417
+ "Show me code examples using plotmath",
418
  "When was has.HLC mentioned?",
419
  "Who reported installation problems in 2023-2024?",
420
  ]
421
 
422
  if compute_mode == "remote":
423
+ # Remove "/no_think" from questions in remote mode
424
+ questions = [q.replace(" /no_think", "") for q in questions]
425
 
426
  # cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
427
  return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
graph.py CHANGED
@@ -9,7 +9,7 @@ import os
9
 
10
  # Local modules
11
  from retriever import BuildRetriever
12
- from prompts import query_prompt, generate_prompt, generic_tools_template
13
  from mods.tool_calling_llm import ToolCallingLLM
14
 
15
  # For tracing (disabled)
@@ -94,6 +94,7 @@ def BuildGraph(
94
  search_type,
95
  top_k=6,
96
  think_query=False,
 
97
  ):
98
  """
99
  Build conversational RAG graph for email retrieval and answering with citations.
@@ -103,7 +104,8 @@ def BuildGraph(
103
  compute_mode: remote or local (for retriever)
104
  search_type: dense, sparse, or hybrid (for retriever)
105
  top_k: number of documents to retrieve
106
- think_query: Whether to use thinking mode for query
 
107
 
108
  Based on:
109
  https://python.langchain.com/docs/how_to/qa_sources
@@ -193,11 +195,11 @@ def BuildGraph(
193
  chat_model, query_prompt(chat_model, think=think_query)
194
  ).bind_tools([retrieve_emails])
195
  # Don't use answer_with_citations tool because responses with are sometimes unparseable
196
- generate_model = chat_model
197
  else:
198
  # For remote model (OpenAI API)
199
  query_model = chat_model.bind_tools([retrieve_emails])
200
- generate_model = chat_model.bind_tools([answer_with_citations])
201
 
202
  # Initialize the graph object
203
  graph = StateGraph(MessagesState)
@@ -216,27 +218,29 @@ def BuildGraph(
216
 
217
  return {"messages": response}
218
 
219
- def generate(state: MessagesState):
220
  """Generates an answer with the chat model"""
221
  if is_local:
222
  messages = state["messages"]
223
- # print_message_summaries(messages, "--- generate: before normalization ---")
224
  messages = normalize_messages(messages)
225
  # Add the system message here because we're not using tools
226
- messages = [SystemMessage(generate_prompt(chat_model))] + messages
227
- # print_message_summaries(messages, "--- generate: after normalization ---")
 
 
228
  else:
229
  messages = [
230
- SystemMessage(generate_prompt(chat_model, with_tools=True))
231
  ] + state["messages"]
232
- response = generate_model.invoke(messages)
233
 
234
  return {"messages": response}
235
 
236
  # Define model and tool nodes
237
  graph.add_node("query", query)
238
- graph.add_node("generate", generate)
239
  graph.add_node("retrieve_emails", ToolNode([retrieve_emails]))
 
240
  graph.add_node("answer_with_citations", ToolNode([answer_with_citations]))
241
 
242
  # Route the user's input to the query model
@@ -249,13 +253,13 @@ def BuildGraph(
249
  {END: END, "tools": "retrieve_emails"},
250
  )
251
  graph.add_conditional_edges(
252
- "generate",
253
  tools_condition,
254
  {END: END, "tools": "answer_with_citations"},
255
  )
256
 
257
  # Add edge from the retrieval tool to the generating model
258
- graph.add_edge("retrieve_emails", "generate")
259
 
260
  # Done!
261
  return graph
 
9
 
10
  # Local modules
11
  from retriever import BuildRetriever
12
+ from prompts import query_prompt, answer_prompt, generic_tools_template
13
  from mods.tool_calling_llm import ToolCallingLLM
14
 
15
  # For tracing (disabled)
 
94
  search_type,
95
  top_k=6,
96
  think_query=False,
97
+ think_answer=False,
98
  ):
99
  """
100
  Build conversational RAG graph for email retrieval and answering with citations.
 
104
  compute_mode: remote or local (for retriever)
105
  search_type: dense, sparse, or hybrid (for retriever)
106
  top_k: number of documents to retrieve
107
+ think_query: Whether to use thinking mode for the query
108
+ think_answer: Whether to use thinking mode for the answer
109
 
110
  Based on:
111
  https://python.langchain.com/docs/how_to/qa_sources
 
195
  chat_model, query_prompt(chat_model, think=think_query)
196
  ).bind_tools([retrieve_emails])
197
  # Don't use answer_with_citations tool because responses with are sometimes unparseable
198
+ answer_model = chat_model
199
  else:
200
  # For remote model (OpenAI API)
201
  query_model = chat_model.bind_tools([retrieve_emails])
202
+ answer_model = chat_model.bind_tools([answer_with_citations])
203
 
204
  # Initialize the graph object
205
  graph = StateGraph(MessagesState)
 
218
 
219
  return {"messages": response}
220
 
221
+ def answer(state: MessagesState):
222
  """Generates an answer with the chat model"""
223
  if is_local:
224
  messages = state["messages"]
225
+ # print_message_summaries(messages, "--- answer: before normalization ---")
226
  messages = normalize_messages(messages)
227
  # Add the system message here because we're not using tools
228
+ messages = [
229
+ SystemMessage(answer_prompt(chat_model, think=think_answer))
230
+ ] + messages
231
+ # print_message_summaries(messages, "--- answer: after normalization ---")
232
  else:
233
  messages = [
234
+ SystemMessage(answer_prompt(chat_model, with_tools=True))
235
  ] + state["messages"]
236
+ response = answer_model.invoke(messages)
237
 
238
  return {"messages": response}
239
 
240
  # Define model and tool nodes
241
  graph.add_node("query", query)
 
242
  graph.add_node("retrieve_emails", ToolNode([retrieve_emails]))
243
+ graph.add_node("answer", answer)
244
  graph.add_node("answer_with_citations", ToolNode([answer_with_citations]))
245
 
246
  # Route the user's input to the query model
 
253
  {END: END, "tools": "retrieve_emails"},
254
  )
255
  graph.add_conditional_edges(
256
+ "answer",
257
  tools_condition,
258
  {END: END, "tools": "answer_with_citations"},
259
  )
260
 
261
  # Add edge from the retrieval tool to the generating model
262
+ graph.add_edge("retrieve_emails", "answer")
263
 
264
  # Done!
265
  return graph
main.py CHANGED
@@ -23,7 +23,7 @@ from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
23
  from index import ProcessFile
24
  from retriever import BuildRetriever, db_dir
25
  from graph import BuildGraph
26
- from prompts import generate_prompt
27
 
28
  # -----------
29
  # R-help-chat
@@ -201,7 +201,7 @@ def RunChain(
201
  chat_model = GetChatModel(compute_mode)
202
 
203
  # Get prompt with /no_think for SmolLM3/Qwen
204
- system_prompt = generate_prompt(chat_model)
205
 
206
  # Create a prompt template
207
  system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
@@ -236,8 +236,8 @@ def RunGraph(
236
  compute_mode: str = "remote",
237
  search_type: str = "hybrid",
238
  top_k: int = 6,
239
- think_retrieve=False,
240
- think_generate=False,
241
  thread_id=None,
242
  ):
243
  """Run graph for conversational RAG app
@@ -247,8 +247,8 @@ def RunGraph(
247
  compute_mode: Compute mode for embedding and chat models (remote or local)
248
  search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
249
  top_k: Number of documents to retrieve
250
- think_retrieve: Whether to use thinking mode for retrieval (tool-calling)
251
- think_generate: Whether to use thinking mode for generation
252
  thread_id: Thread ID for memory (optional)
253
 
254
  Example:
@@ -263,8 +263,8 @@ def RunGraph(
263
  compute_mode,
264
  search_type,
265
  top_k,
266
- think_retrieve,
267
- think_generate,
268
  )
269
 
270
  # Compile the graph with an in-memory checkpointer
 
23
  from index import ProcessFile
24
  from retriever import BuildRetriever, db_dir
25
  from graph import BuildGraph
26
+ from prompts import answer_prompt
27
 
28
  # -----------
29
  # R-help-chat
 
201
  chat_model = GetChatModel(compute_mode)
202
 
203
  # Get prompt with /no_think for SmolLM3/Qwen
204
+ system_prompt = answer_prompt(chat_model)
205
 
206
  # Create a prompt template
207
  system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
 
236
  compute_mode: str = "remote",
237
  search_type: str = "hybrid",
238
  top_k: int = 6,
239
+ think_query=False,
240
+ think_answer=False,
241
  thread_id=None,
242
  ):
243
  """Run graph for conversational RAG app
 
247
  compute_mode: Compute mode for embedding and chat models (remote or local)
248
  search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
249
  top_k: Number of documents to retrieve
250
+ think_query: Whether to use thinking mode for the query
251
+ think_answer: Whether to use thinking mode for the answer
252
  thread_id: Thread ID for memory (optional)
253
 
254
  Example:
 
263
  compute_mode,
264
  search_type,
265
  top_k,
266
+ think_query,
267
+ think_answer,
268
  )
269
 
270
  # Compile the graph with an in-memory checkpointer
mods/tool_calling_llm.py CHANGED
@@ -183,10 +183,18 @@ class ToolCallingLLM(BaseChatModel, ABC):
183
 
184
  # Parse output for JSON (support multiple objects separated by commas)
185
  try:
 
186
  parsed_json_results = json.loads(f"[{post_think}]")
187
- except json.JSONDecodeError:
188
- # Return entire response if JSON wasn't parsed (or is missing)
189
- return AIMessage(content=response_message.content)
 
 
 
 
 
 
 
190
 
191
  tool_calls = []
192
  for parsed_json_result in parsed_json_results:
 
183
 
184
  # Parse output for JSON (support multiple objects separated by commas)
185
  try:
186
+ # Works for one or more JSON objects not enclosed in "[]"
187
  parsed_json_results = json.loads(f"[{post_think}]")
188
+ except:
189
+ try:
190
+ # Works for one or more JSON objects already enclosed in "[]"
191
+ parsed_json_results = json.loads(f"{post_think}")
192
+ except json.JSONDecodeError:
193
+ # Return entire response if JSON wasn't parsed (or is missing)
194
+ return AIMessage(content=response_message.content)
195
+
196
+ # print("parsed_json_results")
197
+ # print(parsed_json_results)
198
 
199
  tool_calls = []
200
  for parsed_json_result in parsed_json_results:
prompts.py CHANGED
@@ -46,8 +46,8 @@ def query_prompt(chat_model, think=False):
46
  return prompt
47
 
48
 
49
- def generate_prompt(chat_model, think=False, with_tools=False):
50
- """Return system prompt for generate step"""
51
  prompt = (
52
  f"Today Date: {date.today()}. "
53
  "You are a helpful chatbot designed to answer questions about R programming based on the R-help mailing list archives. "
 
46
  return prompt
47
 
48
 
49
+ def answer_prompt(chat_model, think=False, with_tools=False):
50
+ """Return system prompt for answer step"""
51
  prompt = (
52
  f"Today Date: {date.today()}. "
53
  "You are a helpful chatbot designed to answer questions about R programming based on the R-help mailing list archives. "