r1-agents / app.py
wuhp's picture
Update app.py
07a46f8 verified
raw
history blame
24.8 kB
import gradio as gr
import os
import spaces # Import the spaces library
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
import torch
from threading import Thread
import logging
from typing import Tuple, List, Dict, Generator
# --- Logging Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Model & Quantization Settings ---
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
models: Dict[str, AutoModelForCausalLM] = {}
tokenizers: Dict[str, AutoTokenizer] = {}
bnb_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
)
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
"""
Lazy-load the model and tokenizer if not already loaded.
Returns:
Tuple[model, tokenizer]: The loaded model and tokenizer.
"""
if "7B" not in models:
logging.info(f"Loading 7B model: {MODEL_ID} on demand")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
quantization_config=bnb_config_4bit,
torch_dtype=torch.bfloat16, # Or torch.float16 if needed
device_map='auto',
trust_remote_code=True,
)
model.eval() # Set the model to evaluation mode
models["7B"] = model
tokenizers["7B"] = tokenizer
logging.info("Loaded 7B model on demand.")
except Exception as e:
logging.error(f"Failed to load model and tokenizer: {e}")
raise e
return models["7B"], tokenizers["7B"]
# --- Default Prompt Templates for Multiple Presets ---
default_prompts = {
"coding": {
"brainstorm": """**Coding Brainstorm (Round 1)**
As a Senior Code Analyst, analyze the following problem and list key challenges and potential approaches.
**User Request:**
{user_prompt}
**Guidelines:**
1. Identify coding challenges.
2. Suggest potential methods and approaches.
3. Highlight any critical edge cases.
""",
"round2": """**Advanced Reasoning & Code Generation (Round 2)**
Based on your initial analysis:
**Initial Analysis:**
{brainstorm_response}
**User Request:**
{user_prompt}
**Task:**
1. Generate production-ready code with advanced reasoning.
2. Include a pun-filled birthday message with a coding twist within your output.
3. Comment the code clearly.
""",
"synthesis": """**Synthesis & Final Refinement (Round 3)**
Review the detailed code and reasoning below, and synthesize a final, refined response that:
1. Combines the brainstorming insights and advanced code generation.
2. Summarizes the solution succinctly.
3. Provides any additional improvements.
**Detailed Code & Reasoning:**
{round2_response}
""",
"rationale": """**Pun Generation and Rationale (Round 4)**
Based on the final refined response below, generate a clear, stand-alone pun-filled birthday message with a coding twist, then explain in detail why that pun was chosen.
Final Refined Response:
{final_response}
Your answer should:
1. Clearly output the pun as a separate line.
2. Explain the pun’s connection to birthdays and coding concepts (e.g., binary, syntax).
3. Describe any creative insights behind the choice.
"""
},
"math": {
"brainstorm": """**Math Problem Brainstorm (Round 1)**
As an expert mathematician, analyze the following problem and outline key concepts and strategies.
**Problem:**
{user_prompt}
**Guidelines:**
1. Identify the mathematical concepts involved.
2. List potential strategies or methods.
3. Note any assumptions or conditions.
""",
"round2": """**Solution Strategy Development (Round 2)**
Based on the initial analysis:
**Initial Analysis:**
{brainstorm_response}
**Problem:**
{user_prompt}
**Task:**
1. Develop a detailed strategy to solve the problem.
2. Include potential methods and intermediate steps.
""",
"synthesis": """**Solution Synthesis (Round 3)**
Review the strategy and previous analysis below, and produce a refined, step-by-step solution that:
1. Clearly explains the solution path.
2. Highlights key steps and justifications.
3. Summarizes the final answer.
**Detailed Strategy:**
{round2_response}
""",
"rationale": """**Solution Rationale (Round 4)**
Based on the final refined solution below, provide a detailed explanation of the key steps and mathematical insights.
Final Refined Solution:
{final_response}
Your response should:
1. Clearly explain why each step was taken.
2. Detail any assumptions and mathematical principles used.
3. Summarize the creative reasoning behind the solution.
"""
},
"writing": {
"brainstorm": """**Creative Brainstorm (Round 1)**
As a seasoned writer, brainstorm creative ideas for the following writing prompt.
**Writing Prompt:**
{user_prompt}
**Guidelines:**
1. List key themes and creative directions.
2. Suggest multiple approaches to the narrative.
3. Highlight any unique stylistic ideas.
""",
"round2": """**Outline Generation (Round 2)**
Based on the brainstorming below:
**Brainstormed Ideas:**
{brainstorm_response}
**Writing Prompt:**
{user_prompt}
**Task:**
1. Generate a detailed outline for a creative piece.
2. Organize the ideas into a coherent structure.
3. Provide bullet points or sections for the narrative.
""",
"synthesis": """**Draft Writing (Round 3)**
Review the outline below and produce a refined draft of the creative piece that:
1. Synthesizes the brainstorming insights and the outline.
2. Provides a coherent and engaging narrative.
3. Includes stylistic and thematic elements.
**Outline:**
{round2_response}
""",
"rationale": """**Final Editing and Rationale (Round 4)**
Based on the final draft below, refine the piece further and provide a detailed explanation of your creative choices.
Final Draft:
{final_response}
Your answer should:
1. Present the final refined text.
2. Explain the narrative choices, stylistic decisions, and thematic connections.
3. Detail any creative insights that influenced the final version.
"""
}
}
# --- Domain Detection ---
def detect_domain(user_prompt: str) -> str:
"""
Detect the domain based on keywords.
Args:
user_prompt (str): The user query.
Returns:
str: One of 'math', 'writing', or 'coding' (defaulting to coding).
"""
prompt_lower = user_prompt.lower()
math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"]
writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"]
coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"]
if any(kw in prompt_lower for kw in math_keywords):
logging.info("Domain detected as: math")
return "math"
elif any(kw in prompt_lower for kw in writing_keywords):
logging.info("Domain detected as: writing")
return "writing"
elif any(kw in prompt_lower for kw in coding_keywords):
logging.info("Domain detected as: coding")
return "coding"
else:
logging.info("No specific domain detected; defaulting to coding")
return "coding"
# --- Memory Management ---
class MemoryManager:
"""Encapsulate shared memory for storing and retrieving conversation items."""
def __init__(self) -> None:
self.shared_memory: List[str] = []
def store(self, item: str) -> None:
"""Store a memory item and log an excerpt."""
self.shared_memory.append(item)
logging.info(f"[Memory Stored]: {item[:50]}...")
def retrieve(self, query: str, top_k: int = 3) -> List[str]:
"""Retrieve recent memory items containing the query text."""
query_lower = query.lower()
relevant = [item for item in self.shared_memory if query_lower in item.lower()]
if not relevant:
logging.info("[Memory Retrieval]: No relevant memories found.")
else:
logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.")
return relevant[-top_k:]
global_memory_manager = MemoryManager()
# --- Unified Generation Function ---
def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float) -> str:
"""Generate a response for a given prompt."""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
thread = Thread(target=model.generate, kwargs=kwargs)
with torch.no_grad():
thread.start()
response = ""
try:
for text in streamer:
response += text
except Exception as e:
logging.error(f"Error during generation: {e}")
raise e
thread.join()
return response
# --- Multi-Round Agent Class ---
class MultiRoundAgent:
"""
Encapsulate the multi-round prompt chaining and response generation.
This class runs a 4-round pipeline based on the given preset.
"""
def __init__(self, model, tokenizer, prompt_templates: Dict[str, str], memory_manager: MemoryManager):
self.model = model
self.tokenizer = tokenizer
self.prompt_templates = prompt_templates
self.memory_manager = memory_manager
def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]:
# Round 1: Brainstorming / Analysis
logging.info("--- Round 1 ---")
prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt)
r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"), params.get("top_p"))
self.memory_manager.store(f"Round 1 Response: {r1}")
# Round 2: Secondary Generation (strategy/outline/code)
logging.info("--- Round 2 ---")
prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt)
r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100, params.get("temp"), params.get("top_p"))
self.memory_manager.store(f"Round 2 Response: {r2}")
# Round 3: Synthesis & Refinement (streaming updates)
logging.info("--- Round 3 ---")
prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2)
input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device)
streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
kwargs_r3 = dict(
input_ids=input_ids_r3,
streamer=streamer_r3,
max_new_tokens=params.get("max_new_tokens") // 2,
temperature=params.get("temp"),
top_p=params.get("top_p")
)
thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3)
with torch.no_grad():
thread_r3.start()
r3 = ""
try:
for text in streamer_r3:
r3 += text
yield r3 # Yield progressive updates from Round 3
except Exception as e:
logging.error(f"Error during Round 3 streaming: {e}")
raise e
thread_r3.join()
self.memory_manager.store(f"Final Synthesis Response: {r3}")
# Round 4: Rationale / Final Output
logging.info("--- Round 4 ---")
prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3)
r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"), params.get("top_p"))
self.memory_manager.store(f"Round 4 Response: {r4}")
# Construct final output based on the show_raw flag.
if show_raw:
final_output = (
f"{r4}\n\n[Raw Outputs]\n"
f"Round 1:\n{r1}\n\n"
f"Round 2:\n{r2}\n\n"
f"Round 3:\n{r3}\n\n"
f"Round 4:\n{r4}\n"
)
else:
final_output = r4
yield final_output
# --- Swarm Agent Iterative Function ---
@spaces.GPU(duration=180) # Adjust duration as needed
def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int,
prompt_templates: Dict[str, str], domain: str, show_raw: bool) -> Generator[str, None, None]:
"""
Wraps the multi-round agent functionality. Depending on the detected domain,
it runs the 4-round pipeline.
"""
model, tokenizer = get_model_and_tokenizer()
agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager)
params = {"temp": temp, "top_p": top_p, "max_new_tokens": max_new_tokens}
return agent.run_pipeline(user_prompt, params, show_raw)
# --- Explanation Function for Additional Requests ---
def handle_explanation_request(user_prompt: str, history: List) -> str:
"""
Retrieve stored rationale and additional context from conversation history,
then generate an explanation.
"""
retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3)
explanation_prompt = "Below are previous final outputs and related context from our conversation:\n"
if retrieved:
for item in retrieved:
explanation_prompt += f"- {item}\n"
else:
explanation_prompt += "No stored final output found.\n"
explanation_prompt += "\nRecent related exchanges:\n"
for chat in history:
if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()):
explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n"
explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices."
model, tokenizer = get_model_and_tokenizer()
explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9)
return explanation
# --- Helper to Format History ---
def format_history(history: List) -> List[Dict[str, str]]:
"""
Convert history (list of [user, assistant] pairs) into a list of message dictionaries.
"""
messages = []
for item in history:
if isinstance(item, (list, tuple)) and len(item) == 2:
user_msg, assistant_msg = item
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
elif isinstance(item, dict):
messages.append(item)
return messages
# --- Gradio Chat Interface Function ---
def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict) -> Generator[List[Dict[str, str]], None, None]:
"""
Called by Gradio's ChatInterface. Uses current generation parameters and preset prompt templates.
If the user asks for an explanation, routes the request accordingly.
"""
if "explain" in message.lower():
explanation = handle_explanation_request(message, history)
history = history + [[message, explanation]]
yield format_history(history)
return
try:
temp = float(param_state.get("temperature", 0.5))
top_p = float(param_state.get("top_p", 0.9))
max_new_tokens = int(param_state.get("max_new_tokens", 300))
memory_top_k = int(param_state.get("memory_top_k", 2))
show_raw = bool(param_state.get("show_raw_output", False))
except Exception as e:
logging.error(f"Parameter conversion error: {e}")
temp, top_p, max_new_tokens, memory_top_k, show_raw = 0.5, 0.9, 300, 2, False
domain = detect_domain(message)
# Get the prompt templates for the detected domain; default to coding if not set.
prompt_templates = prompt_state.get(domain, default_prompts.get(domain, default_prompts["coding"]))
history = history + [[message, ""]]
for partial_response in swarm_agent_iterative(
user_prompt=message,
temp=temp,
top_p=top_p,
max_new_tokens=max_new_tokens,
memory_top_k=memory_top_k,
prompt_templates=prompt_templates,
domain=domain,
show_raw=show_raw
):
history[-1][1] = partial_response
yield format_history(history)
# --- UI Settings & Styling ---
ui_description = '''
<div>
<h1 style="text-align: center;">DeepSeek Agent Swarm Chat</h1>
<p style="text-align: center;">
Multi-round agent with 4-round prompt chaining for three presets:
<br>- Coding
<br>- Math
<br>- Writing
</p>
</div>
'''
ui_license = """
<p/>
---
"""
ui_placeholder = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">DeepSeek Agent Swarm</h1>
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# --- Gradio UI ---
with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
gr.Markdown(ui_description)
# Hidden states for parameters and prompt configurations.
param_state = gr.State({
"temperature": 0.5,
"top_p": 0.9,
"max_new_tokens": 300,
"memory_top_k": 2,
"show_raw_output": False, # New parameter for raw output
})
prompt_state = gr.State({
"coding": default_prompts["coding"],
"math": default_prompts["math"],
"writing": default_prompts["writing"],
})
with gr.Tabs():
with gr.Tab("Chat"):
chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages")
gr.ChatInterface(
fn=gradio_interface,
chatbot=chatbot,
additional_inputs=[param_state, prompt_state],
examples=[
['How can we build a robust web service that scales efficiently under load?'],
['Solve the integral of x^2 from 0 to 1.'],
['Write a short story about a mysterious writer in a busy city.'],
['Create a pun-filled birthday message with a coding twist.']
],
cache_examples=False,
type="messages",
)
with gr.Tab("Parameters"):
gr.Markdown("### Generation Parameters")
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature")
top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P")
max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0)
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
show_raw_checkbox = gr.Checkbox(value=False, label="Show Raw Output") # New checkbox for raw output
save_params_btn = gr.Button("Save Parameters")
save_params_btn.click(
lambda t, p, m, k, s: {
"temperature": t,
"top_p": p,
"max_new_tokens": m,
"memory_top_k": k,
"show_raw_output": s
},
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider, show_raw_checkbox],
outputs=param_state,
)
with gr.Tab("Prompt Config"):
gr.Markdown("### Configure Prompt Templates for Each Preset")
with gr.Tabs():
with gr.Tab("Coding"):
prompt_brainstorm_box_code = gr.Textbox(
value=default_prompts["coding"]["brainstorm"],
label="Brainstorm Prompt (Coding)",
lines=8,
)
prompt_round2_box_code = gr.Textbox(
value=default_prompts["coding"]["round2"],
label="Round 2 Prompt (Coding)",
lines=8,
)
prompt_synthesis_box_code = gr.Textbox(
value=default_prompts["coding"]["synthesis"],
label="Synthesis Prompt (Coding)",
lines=8,
)
prompt_rationale_box_code = gr.Textbox(
value=default_prompts["coding"]["rationale"],
label="Rationale Prompt (Coding)",
lines=8,
)
with gr.Tab("Math"):
prompt_brainstorm_box_math = gr.Textbox(
value=default_prompts["math"]["brainstorm"],
label="Brainstorm Prompt (Math)",
lines=8,
)
prompt_round2_box_math = gr.Textbox(
value=default_prompts["math"]["round2"],
label="Round 2 Prompt (Math)",
lines=8,
)
prompt_synthesis_box_math = gr.Textbox(
value=default_prompts["math"]["synthesis"],
label="Synthesis Prompt (Math)",
lines=8,
)
prompt_rationale_box_math = gr.Textbox(
value=default_prompts["math"]["rationale"],
label="Rationale Prompt (Math)",
lines=8,
)
with gr.Tab("Writing"):
prompt_brainstorm_box_writing = gr.Textbox(
value=default_prompts["writing"]["brainstorm"],
label="Brainstorm Prompt (Writing)",
lines=8,
)
prompt_round2_box_writing = gr.Textbox(
value=default_prompts["writing"]["round2"],
label="Round 2 Prompt (Writing)",
lines=8,
)
prompt_synthesis_box_writing = gr.Textbox(
value=default_prompts["writing"]["synthesis"],
label="Synthesis Prompt (Writing)",
lines=8,
)
prompt_rationale_box_writing = gr.Textbox(
value=default_prompts["writing"]["rationale"],
label="Rationale Prompt (Writing)",
lines=8,
)
save_prompts_btn = gr.Button("Save Prompts")
def save_prompts(code_brain, code_r2, code_syn, code_rat, math_brain, math_r2, math_syn, math_rat, writing_brain, writing_r2, writing_syn, writing_rat):
return {
"coding": {
"brainstorm": code_brain,
"round2": code_r2,
"synthesis": code_syn,
"rationale": code_rat,
},
"math": {
"brainstorm": math_brain,
"round2": math_r2,
"synthesis": math_syn,
"rationale": math_rat,
},
"writing": {
"brainstorm": writing_brain,
"round2": writing_r2,
"synthesis": writing_syn,
"rationale": writing_rat,
}
}
save_prompts_btn.click(
save_prompts,
inputs=[prompt_brainstorm_box_code, prompt_round2_box_code, prompt_synthesis_box_code, prompt_rationale_box_code,
prompt_brainstorm_box_math, prompt_round2_box_math, prompt_synthesis_box_math, prompt_rationale_box_math,
prompt_brainstorm_box_writing, prompt_round2_box_writing, prompt_synthesis_box_writing, prompt_rationale_box_writing],
outputs=prompt_state,
)
gr.Markdown(ui_license)
if __name__ == "__main__":
demo.launch()