import gradio as gr import os import spaces # Import the spaces library from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer import torch from threading import Thread import logging from typing import Tuple, List, Dict, Generator # --- Logging Configuration --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # --- Model & Quantization Settings --- MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit" models: Dict[str, AutoModelForCausalLM] = {} tokenizers: Dict[str, AutoTokenizer] = {} bnb_config_4bit = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed ) def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]: """ Lazy-load the model and tokenizer if not already loaded. Returns: Tuple[model, tokenizer]: The loaded model and tokenizer. """ if "7B" not in models: logging.info(f"Loading 7B model: {MODEL_ID} on demand") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, quantization_config=bnb_config_4bit, torch_dtype=torch.bfloat16, # Or torch.float16 if needed device_map='auto', trust_remote_code=True, ) model.eval() # Set the model to evaluation mode models["7B"] = model tokenizers["7B"] = tokenizer logging.info("Loaded 7B model on demand.") except Exception as e: logging.error(f"Failed to load model and tokenizer: {e}") raise e return models["7B"], tokenizers["7B"] # --- Default Prompt Templates for Multiple Presets --- default_prompts = { "coding": { "brainstorm": """**Coding Brainstorm (Round 1)** As a Senior Code Analyst, analyze the following problem and list key challenges and potential approaches. **User Request:** {user_prompt} **Guidelines:** 1. Identify coding challenges. 2. Suggest potential methods and approaches. 3. Highlight any critical edge cases. """, "round2": """**Advanced Reasoning & Code Generation (Round 2)** Based on your initial analysis: **Initial Analysis:** {brainstorm_response} **User Request:** {user_prompt} **Task:** 1. Generate production-ready code with advanced reasoning. 2. Include a pun-filled birthday message with a coding twist within your output. 3. Comment the code clearly. """, "synthesis": """**Synthesis & Final Refinement (Round 3)** Review the detailed code and reasoning below, and synthesize a final, refined response that: 1. Combines the brainstorming insights and advanced code generation. 2. Summarizes the solution succinctly. 3. Provides any additional improvements. **Detailed Code & Reasoning:** {round2_response} """, "rationale": """**Pun Generation and Rationale (Round 4)** Based on the final refined response below, generate a clear, stand-alone pun-filled birthday message with a coding twist, then explain in detail why that pun was chosen. Final Refined Response: {final_response} Your answer should: 1. Clearly output the pun as a separate line. 2. Explain the pun’s connection to birthdays and coding concepts (e.g., binary, syntax). 3. Describe any creative insights behind the choice. """ }, "math": { "brainstorm": """**Math Problem Brainstorm (Round 1)** As an expert mathematician, analyze the following problem and outline key concepts and strategies. **Problem:** {user_prompt} **Guidelines:** 1. Identify the mathematical concepts involved. 2. List potential strategies or methods. 3. Note any assumptions or conditions. """, "round2": """**Solution Strategy Development (Round 2)** Based on the initial analysis: **Initial Analysis:** {brainstorm_response} **Problem:** {user_prompt} **Task:** 1. Develop a detailed strategy to solve the problem. 2. Include potential methods and intermediate steps. """, "synthesis": """**Solution Synthesis (Round 3)** Review the strategy and previous analysis below, and produce a refined, step-by-step solution that: 1. Clearly explains the solution path. 2. Highlights key steps and justifications. 3. Summarizes the final answer. **Detailed Strategy:** {round2_response} """, "rationale": """**Solution Rationale (Round 4)** Based on the final refined solution below, provide a detailed explanation of the key steps and mathematical insights. Final Refined Solution: {final_response} Your response should: 1. Clearly explain why each step was taken. 2. Detail any assumptions and mathematical principles used. 3. Summarize the creative reasoning behind the solution. """ }, "writing": { "brainstorm": """**Creative Brainstorm (Round 1)** As a seasoned writer, brainstorm creative ideas for the following writing prompt. **Writing Prompt:** {user_prompt} **Guidelines:** 1. List key themes and creative directions. 2. Suggest multiple approaches to the narrative. 3. Highlight any unique stylistic ideas. """, "round2": """**Outline Generation (Round 2)** Based on the brainstorming below: **Brainstormed Ideas:** {brainstorm_response} **Writing Prompt:** {user_prompt} **Task:** 1. Generate a detailed outline for a creative piece. 2. Organize the ideas into a coherent structure. 3. Provide bullet points or sections for the narrative. """, "synthesis": """**Draft Writing (Round 3)** Review the outline below and produce a refined draft of the creative piece that: 1. Synthesizes the brainstorming insights and the outline. 2. Provides a coherent and engaging narrative. 3. Includes stylistic and thematic elements. **Outline:** {round2_response} """, "rationale": """**Final Editing and Rationale (Round 4)** Based on the final draft below, refine the piece further and provide a detailed explanation of your creative choices. Final Draft: {final_response} Your answer should: 1. Present the final refined text. 2. Explain the narrative choices, stylistic decisions, and thematic connections. 3. Detail any creative insights that influenced the final version. """ } } # --- Domain Detection --- def detect_domain(user_prompt: str) -> str: """ Detect the domain based on keywords. Args: user_prompt (str): The user query. Returns: str: One of 'math', 'writing', or 'coding' (defaulting to coding). """ prompt_lower = user_prompt.lower() math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"] writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"] coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"] if any(kw in prompt_lower for kw in math_keywords): logging.info("Domain detected as: math") return "math" elif any(kw in prompt_lower for kw in writing_keywords): logging.info("Domain detected as: writing") return "writing" elif any(kw in prompt_lower for kw in coding_keywords): logging.info("Domain detected as: coding") return "coding" else: logging.info("No specific domain detected; defaulting to coding") return "coding" # --- Memory Management --- class MemoryManager: """Encapsulate shared memory for storing and retrieving conversation items.""" def __init__(self) -> None: self.shared_memory: List[str] = [] def store(self, item: str) -> None: """Store a memory item and log an excerpt.""" self.shared_memory.append(item) logging.info(f"[Memory Stored]: {item[:50]}...") def retrieve(self, query: str, top_k: int = 3) -> List[str]: """Retrieve recent memory items containing the query text.""" query_lower = query.lower() relevant = [item for item in self.shared_memory if query_lower in item.lower()] if not relevant: logging.info("[Memory Retrieval]: No relevant memories found.") else: logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.") return relevant[-top_k:] global_memory_manager = MemoryManager() # --- Unified Generation Function --- def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float) -> str: """Generate a response for a given prompt.""" input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, ) thread = Thread(target=model.generate, kwargs=kwargs) with torch.no_grad(): thread.start() response = "" try: for text in streamer: response += text except Exception as e: logging.error(f"Error during generation: {e}") raise e thread.join() return response # --- Multi-Round Agent Class --- class MultiRoundAgent: """ Encapsulate the multi-round prompt chaining and response generation. This class runs a 4-round pipeline based on the given preset. """ def __init__(self, model, tokenizer, prompt_templates: Dict[str, str], memory_manager: MemoryManager): self.model = model self.tokenizer = tokenizer self.prompt_templates = prompt_templates self.memory_manager = memory_manager def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]: # Round 1: Brainstorming / Analysis logging.info("--- Round 1 ---") prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt) r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"), params.get("top_p")) self.memory_manager.store(f"Round 1 Response: {r1}") # Round 2: Secondary Generation (strategy/outline/code) logging.info("--- Round 2 ---") prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt) r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100, params.get("temp"), params.get("top_p")) self.memory_manager.store(f"Round 2 Response: {r2}") # Round 3: Synthesis & Refinement (streaming updates) logging.info("--- Round 3 ---") prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2) input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device) streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) kwargs_r3 = dict( input_ids=input_ids_r3, streamer=streamer_r3, max_new_tokens=params.get("max_new_tokens") // 2, temperature=params.get("temp"), top_p=params.get("top_p") ) thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3) with torch.no_grad(): thread_r3.start() r3 = "" try: for text in streamer_r3: r3 += text yield r3 # Yield progressive updates from Round 3 except Exception as e: logging.error(f"Error during Round 3 streaming: {e}") raise e thread_r3.join() self.memory_manager.store(f"Final Synthesis Response: {r3}") # Round 4: Rationale / Final Output logging.info("--- Round 4 ---") prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3) r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"), params.get("top_p")) self.memory_manager.store(f"Round 4 Response: {r4}") # Construct final output based on the show_raw flag. if show_raw: final_output = ( f"{r4}\n\n[Raw Outputs]\n" f"Round 1:\n{r1}\n\n" f"Round 2:\n{r2}\n\n" f"Round 3:\n{r3}\n\n" f"Round 4:\n{r4}\n" ) else: final_output = r4 yield final_output # --- Swarm Agent Iterative Function --- @spaces.GPU(duration=180) # Adjust duration as needed def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int, prompt_templates: Dict[str, str], domain: str, show_raw: bool) -> Generator[str, None, None]: """ Wraps the multi-round agent functionality. Depending on the detected domain, it runs the 4-round pipeline. """ model, tokenizer = get_model_and_tokenizer() agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager) params = {"temp": temp, "top_p": top_p, "max_new_tokens": max_new_tokens} return agent.run_pipeline(user_prompt, params, show_raw) # --- Explanation Function for Additional Requests --- def handle_explanation_request(user_prompt: str, history: List) -> str: """ Retrieve stored rationale and additional context from conversation history, then generate an explanation. """ retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3) explanation_prompt = "Below are previous final outputs and related context from our conversation:\n" if retrieved: for item in retrieved: explanation_prompt += f"- {item}\n" else: explanation_prompt += "No stored final output found.\n" explanation_prompt += "\nRecent related exchanges:\n" for chat in history: if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()): explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n" explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices." model, tokenizer = get_model_and_tokenizer() explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9) return explanation # --- Helper to Format History --- def format_history(history: List) -> List[Dict[str, str]]: """ Convert history (list of [user, assistant] pairs) into a list of message dictionaries. """ messages = [] for item in history: if isinstance(item, (list, tuple)) and len(item) == 2: user_msg, assistant_msg = item messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) elif isinstance(item, dict): messages.append(item) return messages # --- Gradio Chat Interface Function --- def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict) -> Generator[List[Dict[str, str]], None, None]: """ Called by Gradio's ChatInterface. Uses current generation parameters and preset prompt templates. If the user asks for an explanation, routes the request accordingly. """ if "explain" in message.lower(): explanation = handle_explanation_request(message, history) history = history + [[message, explanation]] yield format_history(history) return try: temp = float(param_state.get("temperature", 0.5)) top_p = float(param_state.get("top_p", 0.9)) max_new_tokens = int(param_state.get("max_new_tokens", 300)) memory_top_k = int(param_state.get("memory_top_k", 2)) show_raw = bool(param_state.get("show_raw_output", False)) except Exception as e: logging.error(f"Parameter conversion error: {e}") temp, top_p, max_new_tokens, memory_top_k, show_raw = 0.5, 0.9, 300, 2, False domain = detect_domain(message) # Get the prompt templates for the detected domain; default to coding if not set. prompt_templates = prompt_state.get(domain, default_prompts.get(domain, default_prompts["coding"])) history = history + [[message, ""]] for partial_response in swarm_agent_iterative( user_prompt=message, temp=temp, top_p=top_p, max_new_tokens=max_new_tokens, memory_top_k=memory_top_k, prompt_templates=prompt_templates, domain=domain, show_raw=show_raw ): history[-1][1] = partial_response yield format_history(history) # --- UI Settings & Styling --- ui_description = '''
Multi-round agent with 4-round prompt chaining for three presets:
- Coding
- Math
- Writing
Ask me anything...