Spaces:
Running
on
Zero
Running
on
Zero
jedick
commited on
Commit
Β·
3575a77
1
Parent(s):
ace4242
Enable thinking for answer
Browse files- app.py +9 -8
- graph.py +17 -13
- main.py +8 -8
- mods/tool_calling_llm.py +11 -3
- prompts.py +2 -2
app.py
CHANGED
@@ -88,7 +88,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
|
88 |
# Get the chat model and build the graph
|
89 |
chat_model = GetChatModel(compute_mode)
|
90 |
graph_builder = BuildGraph(
|
91 |
-
chat_model, compute_mode, search_type,
|
92 |
)
|
93 |
# Compile the graph with an in-memory checkpointer
|
94 |
memory = MemorySaver()
|
@@ -184,7 +184,7 @@ def run_workflow(input, history, compute_mode, thread_id, session_hash):
|
|
184 |
retrieved_emails = "\n\n".join(retrieved_emails)
|
185 |
yield history, retrieved_emails, []
|
186 |
|
187 |
-
if node == "
|
188 |
# Append messages (thinking and non-thinking) to history
|
189 |
chunk_messages = chunk["messages"]
|
190 |
history = append_content(chunk_messages, history, thinking_about="answer")
|
@@ -383,8 +383,9 @@ with gr.Blocks(
|
|
383 |
status_text = f"""
|
384 |
π Now in **local** mode, using ZeroGPU hardware<br>
|
385 |
β Response time is about one minute<br>
|
386 |
-
|
387 |
-
 
|
|
|
388 |
β¨ [nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) and [{model_id.split("/")[-1]}](https://huggingface.co/{model_id})<br>
|
389 |
π See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
|
390 |
"""
|
@@ -412,15 +413,15 @@ with gr.Blocks(
|
|
412 |
"""Get example questions based on compute mode"""
|
413 |
questions = [
|
414 |
# "What is today's date?",
|
415 |
-
"Summarize emails from the last two months",
|
416 |
-
"
|
417 |
"When was has.HLC mentioned?",
|
418 |
"Who reported installation problems in 2023-2024?",
|
419 |
]
|
420 |
|
421 |
if compute_mode == "remote":
|
422 |
-
# Remove "/
|
423 |
-
questions = [q.replace(" /
|
424 |
|
425 |
# cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
|
426 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
|
|
88 |
# Get the chat model and build the graph
|
89 |
chat_model = GetChatModel(compute_mode)
|
90 |
graph_builder = BuildGraph(
|
91 |
+
chat_model, compute_mode, search_type, think_answer=True
|
92 |
)
|
93 |
# Compile the graph with an in-memory checkpointer
|
94 |
memory = MemorySaver()
|
|
|
184 |
retrieved_emails = "\n\n".join(retrieved_emails)
|
185 |
yield history, retrieved_emails, []
|
186 |
|
187 |
+
if node == "answer":
|
188 |
# Append messages (thinking and non-thinking) to history
|
189 |
chunk_messages = chunk["messages"]
|
190 |
history = append_content(chunk_messages, history, thinking_about="answer")
|
|
|
383 |
status_text = f"""
|
384 |
π Now in **local** mode, using ZeroGPU hardware<br>
|
385 |
β Response time is about one minute<br>
|
386 |
+
π§ Thinking is enabled for the answer<br>
|
387 |
+
  π Add **/think** to enable thinking for the query</br>
|
388 |
+
  π« Add **/no_think** to disable all thinking</br>
|
389 |
β¨ [nomic-embed-text-v1.5](https://huggingface.co/nomic-ai/nomic-embed-text-v1.5) and [{model_id.split("/")[-1]}](https://huggingface.co/{model_id})<br>
|
390 |
π See the project's [GitHub repository](https://github.com/jedick/R-help-chat)
|
391 |
"""
|
|
|
413 |
"""Get example questions based on compute mode"""
|
414 |
questions = [
|
415 |
# "What is today's date?",
|
416 |
+
"Summarize emails from the last two months /no_think",
|
417 |
+
"Show me code examples using plotmath",
|
418 |
"When was has.HLC mentioned?",
|
419 |
"Who reported installation problems in 2023-2024?",
|
420 |
]
|
421 |
|
422 |
if compute_mode == "remote":
|
423 |
+
# Remove "/no_think" from questions in remote mode
|
424 |
+
questions = [q.replace(" /no_think", "") for q in questions]
|
425 |
|
426 |
# cf. https://github.com/gradio-app/gradio/pull/8745 for updating examples
|
427 |
return gr.Dataset(samples=[[q] for q in questions]) if as_dataset else questions
|
graph.py
CHANGED
@@ -9,7 +9,7 @@ import os
|
|
9 |
|
10 |
# Local modules
|
11 |
from retriever import BuildRetriever
|
12 |
-
from prompts import query_prompt,
|
13 |
from mods.tool_calling_llm import ToolCallingLLM
|
14 |
|
15 |
# For tracing (disabled)
|
@@ -94,6 +94,7 @@ def BuildGraph(
|
|
94 |
search_type,
|
95 |
top_k=6,
|
96 |
think_query=False,
|
|
|
97 |
):
|
98 |
"""
|
99 |
Build conversational RAG graph for email retrieval and answering with citations.
|
@@ -103,7 +104,8 @@ def BuildGraph(
|
|
103 |
compute_mode: remote or local (for retriever)
|
104 |
search_type: dense, sparse, or hybrid (for retriever)
|
105 |
top_k: number of documents to retrieve
|
106 |
-
think_query: Whether to use thinking mode for query
|
|
|
107 |
|
108 |
Based on:
|
109 |
https://python.langchain.com/docs/how_to/qa_sources
|
@@ -193,11 +195,11 @@ def BuildGraph(
|
|
193 |
chat_model, query_prompt(chat_model, think=think_query)
|
194 |
).bind_tools([retrieve_emails])
|
195 |
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
196 |
-
|
197 |
else:
|
198 |
# For remote model (OpenAI API)
|
199 |
query_model = chat_model.bind_tools([retrieve_emails])
|
200 |
-
|
201 |
|
202 |
# Initialize the graph object
|
203 |
graph = StateGraph(MessagesState)
|
@@ -216,27 +218,29 @@ def BuildGraph(
|
|
216 |
|
217 |
return {"messages": response}
|
218 |
|
219 |
-
def
|
220 |
"""Generates an answer with the chat model"""
|
221 |
if is_local:
|
222 |
messages = state["messages"]
|
223 |
-
# print_message_summaries(messages, "---
|
224 |
messages = normalize_messages(messages)
|
225 |
# Add the system message here because we're not using tools
|
226 |
-
messages = [
|
227 |
-
|
|
|
|
|
228 |
else:
|
229 |
messages = [
|
230 |
-
SystemMessage(
|
231 |
] + state["messages"]
|
232 |
-
response =
|
233 |
|
234 |
return {"messages": response}
|
235 |
|
236 |
# Define model and tool nodes
|
237 |
graph.add_node("query", query)
|
238 |
-
graph.add_node("generate", generate)
|
239 |
graph.add_node("retrieve_emails", ToolNode([retrieve_emails]))
|
|
|
240 |
graph.add_node("answer_with_citations", ToolNode([answer_with_citations]))
|
241 |
|
242 |
# Route the user's input to the query model
|
@@ -249,13 +253,13 @@ def BuildGraph(
|
|
249 |
{END: END, "tools": "retrieve_emails"},
|
250 |
)
|
251 |
graph.add_conditional_edges(
|
252 |
-
"
|
253 |
tools_condition,
|
254 |
{END: END, "tools": "answer_with_citations"},
|
255 |
)
|
256 |
|
257 |
# Add edge from the retrieval tool to the generating model
|
258 |
-
graph.add_edge("retrieve_emails", "
|
259 |
|
260 |
# Done!
|
261 |
return graph
|
|
|
9 |
|
10 |
# Local modules
|
11 |
from retriever import BuildRetriever
|
12 |
+
from prompts import query_prompt, answer_prompt, generic_tools_template
|
13 |
from mods.tool_calling_llm import ToolCallingLLM
|
14 |
|
15 |
# For tracing (disabled)
|
|
|
94 |
search_type,
|
95 |
top_k=6,
|
96 |
think_query=False,
|
97 |
+
think_answer=False,
|
98 |
):
|
99 |
"""
|
100 |
Build conversational RAG graph for email retrieval and answering with citations.
|
|
|
104 |
compute_mode: remote or local (for retriever)
|
105 |
search_type: dense, sparse, or hybrid (for retriever)
|
106 |
top_k: number of documents to retrieve
|
107 |
+
think_query: Whether to use thinking mode for the query
|
108 |
+
think_answer: Whether to use thinking mode for the answer
|
109 |
|
110 |
Based on:
|
111 |
https://python.langchain.com/docs/how_to/qa_sources
|
|
|
195 |
chat_model, query_prompt(chat_model, think=think_query)
|
196 |
).bind_tools([retrieve_emails])
|
197 |
# Don't use answer_with_citations tool because responses with are sometimes unparseable
|
198 |
+
answer_model = chat_model
|
199 |
else:
|
200 |
# For remote model (OpenAI API)
|
201 |
query_model = chat_model.bind_tools([retrieve_emails])
|
202 |
+
answer_model = chat_model.bind_tools([answer_with_citations])
|
203 |
|
204 |
# Initialize the graph object
|
205 |
graph = StateGraph(MessagesState)
|
|
|
218 |
|
219 |
return {"messages": response}
|
220 |
|
221 |
+
def answer(state: MessagesState):
|
222 |
"""Generates an answer with the chat model"""
|
223 |
if is_local:
|
224 |
messages = state["messages"]
|
225 |
+
# print_message_summaries(messages, "--- answer: before normalization ---")
|
226 |
messages = normalize_messages(messages)
|
227 |
# Add the system message here because we're not using tools
|
228 |
+
messages = [
|
229 |
+
SystemMessage(answer_prompt(chat_model, think=think_answer))
|
230 |
+
] + messages
|
231 |
+
# print_message_summaries(messages, "--- answer: after normalization ---")
|
232 |
else:
|
233 |
messages = [
|
234 |
+
SystemMessage(answer_prompt(chat_model, with_tools=True))
|
235 |
] + state["messages"]
|
236 |
+
response = answer_model.invoke(messages)
|
237 |
|
238 |
return {"messages": response}
|
239 |
|
240 |
# Define model and tool nodes
|
241 |
graph.add_node("query", query)
|
|
|
242 |
graph.add_node("retrieve_emails", ToolNode([retrieve_emails]))
|
243 |
+
graph.add_node("answer", answer)
|
244 |
graph.add_node("answer_with_citations", ToolNode([answer_with_citations]))
|
245 |
|
246 |
# Route the user's input to the query model
|
|
|
253 |
{END: END, "tools": "retrieve_emails"},
|
254 |
)
|
255 |
graph.add_conditional_edges(
|
256 |
+
"answer",
|
257 |
tools_condition,
|
258 |
{END: END, "tools": "answer_with_citations"},
|
259 |
)
|
260 |
|
261 |
# Add edge from the retrieval tool to the generating model
|
262 |
+
graph.add_edge("retrieve_emails", "answer")
|
263 |
|
264 |
# Done!
|
265 |
return graph
|
main.py
CHANGED
@@ -23,7 +23,7 @@ from langchain_huggingface import ChatHuggingFace, HuggingFacePipeline
|
|
23 |
from index import ProcessFile
|
24 |
from retriever import BuildRetriever, db_dir
|
25 |
from graph import BuildGraph
|
26 |
-
from prompts import
|
27 |
|
28 |
# -----------
|
29 |
# R-help-chat
|
@@ -201,7 +201,7 @@ def RunChain(
|
|
201 |
chat_model = GetChatModel(compute_mode)
|
202 |
|
203 |
# Get prompt with /no_think for SmolLM3/Qwen
|
204 |
-
system_prompt =
|
205 |
|
206 |
# Create a prompt template
|
207 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
@@ -236,8 +236,8 @@ def RunGraph(
|
|
236 |
compute_mode: str = "remote",
|
237 |
search_type: str = "hybrid",
|
238 |
top_k: int = 6,
|
239 |
-
|
240 |
-
|
241 |
thread_id=None,
|
242 |
):
|
243 |
"""Run graph for conversational RAG app
|
@@ -247,8 +247,8 @@ def RunGraph(
|
|
247 |
compute_mode: Compute mode for embedding and chat models (remote or local)
|
248 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
249 |
top_k: Number of documents to retrieve
|
250 |
-
|
251 |
-
|
252 |
thread_id: Thread ID for memory (optional)
|
253 |
|
254 |
Example:
|
@@ -263,8 +263,8 @@ def RunGraph(
|
|
263 |
compute_mode,
|
264 |
search_type,
|
265 |
top_k,
|
266 |
-
|
267 |
-
|
268 |
)
|
269 |
|
270 |
# Compile the graph with an in-memory checkpointer
|
|
|
23 |
from index import ProcessFile
|
24 |
from retriever import BuildRetriever, db_dir
|
25 |
from graph import BuildGraph
|
26 |
+
from prompts import answer_prompt
|
27 |
|
28 |
# -----------
|
29 |
# R-help-chat
|
|
|
201 |
chat_model = GetChatModel(compute_mode)
|
202 |
|
203 |
# Get prompt with /no_think for SmolLM3/Qwen
|
204 |
+
system_prompt = answer_prompt(chat_model)
|
205 |
|
206 |
# Create a prompt template
|
207 |
system_template = ChatPromptTemplate.from_messages([SystemMessage(system_prompt)])
|
|
|
236 |
compute_mode: str = "remote",
|
237 |
search_type: str = "hybrid",
|
238 |
top_k: int = 6,
|
239 |
+
think_query=False,
|
240 |
+
think_answer=False,
|
241 |
thread_id=None,
|
242 |
):
|
243 |
"""Run graph for conversational RAG app
|
|
|
247 |
compute_mode: Compute mode for embedding and chat models (remote or local)
|
248 |
search_type: Type of search to use. Options: "dense", "sparse", or "hybrid"
|
249 |
top_k: Number of documents to retrieve
|
250 |
+
think_query: Whether to use thinking mode for the query
|
251 |
+
think_answer: Whether to use thinking mode for the answer
|
252 |
thread_id: Thread ID for memory (optional)
|
253 |
|
254 |
Example:
|
|
|
263 |
compute_mode,
|
264 |
search_type,
|
265 |
top_k,
|
266 |
+
think_query,
|
267 |
+
think_answer,
|
268 |
)
|
269 |
|
270 |
# Compile the graph with an in-memory checkpointer
|
mods/tool_calling_llm.py
CHANGED
@@ -183,10 +183,18 @@ class ToolCallingLLM(BaseChatModel, ABC):
|
|
183 |
|
184 |
# Parse output for JSON (support multiple objects separated by commas)
|
185 |
try:
|
|
|
186 |
parsed_json_results = json.loads(f"[{post_think}]")
|
187 |
-
except
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
tool_calls = []
|
192 |
for parsed_json_result in parsed_json_results:
|
|
|
183 |
|
184 |
# Parse output for JSON (support multiple objects separated by commas)
|
185 |
try:
|
186 |
+
# Works for one or more JSON objects not enclosed in "[]"
|
187 |
parsed_json_results = json.loads(f"[{post_think}]")
|
188 |
+
except:
|
189 |
+
try:
|
190 |
+
# Works for one or more JSON objects already enclosed in "[]"
|
191 |
+
parsed_json_results = json.loads(f"{post_think}")
|
192 |
+
except json.JSONDecodeError:
|
193 |
+
# Return entire response if JSON wasn't parsed (or is missing)
|
194 |
+
return AIMessage(content=response_message.content)
|
195 |
+
|
196 |
+
# print("parsed_json_results")
|
197 |
+
# print(parsed_json_results)
|
198 |
|
199 |
tool_calls = []
|
200 |
for parsed_json_result in parsed_json_results:
|
prompts.py
CHANGED
@@ -46,8 +46,8 @@ def query_prompt(chat_model, think=False):
|
|
46 |
return prompt
|
47 |
|
48 |
|
49 |
-
def
|
50 |
-
"""Return system prompt for
|
51 |
prompt = (
|
52 |
f"Today Date: {date.today()}. "
|
53 |
"You are a helpful chatbot designed to answer questions about R programming based on the R-help mailing list archives. "
|
|
|
46 |
return prompt
|
47 |
|
48 |
|
49 |
+
def answer_prompt(chat_model, think=False, with_tools=False):
|
50 |
+
"""Return system prompt for answer step"""
|
51 |
prompt = (
|
52 |
f"Today Date: {date.today()}. "
|
53 |
"You are a helpful chatbot designed to answer questions about R programming based on the R-help mailing list archives. "
|