import gradio as gr import os import spaces # Import the spaces library from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer import torch from threading import Thread import logging from typing import Tuple, List, Dict, Generator # --- Logging Configuration --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # --- Model & Quantization Settings --- MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit" # Dictionaries to store the loaded model and tokenizer models: Dict[str, AutoModelForCausalLM] = {} tokenizers: Dict[str, AutoTokenizer] = {} bnb_config_4bit = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed ) def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]: """ Lazy-load the model and tokenizer if not already loaded. Returns: Tuple[model, tokenizer]: The loaded model and tokenizer. """ if "7B" not in models: logging.info(f"Loading 7B model: {MODEL_ID} on demand") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, quantization_config=bnb_config_4bit, torch_dtype=torch.bfloat16, # Or torch.float16 if needed device_map='auto', trust_remote_code=True, ) model.eval() # Set the model to evaluation mode models["7B"] = model tokenizers["7B"] = tokenizer logging.info("Loaded 7B model on demand.") except Exception as e: logging.error(f"Failed to load model and tokenizer: {e}") raise e return models["7B"], tokenizers["7B"] # --- Default Prompt Templates --- default_prompt_brainstorm = """**Brainstorming Task (Round 1)** As a Senior Code Analyst, provide an initial analysis of the problem below. **User Request:** {user_prompt} **Guidelines:** 1. Identify key challenges and constraints. 2. Suggest multiple potential approaches. 3. Outline any potential edge cases or critical considerations. """ default_prompt_code_generation = """**Advanced Reasoning & Code Generation (Round 2)** Based on the initial analysis below: **Initial Analysis:** {brainstorm_response} **User Request:** {user_prompt} **Task:** 1. Develop a detailed solution that includes production-ready code. 2. Explain the reasoning behind the chosen approach. 3. Incorporate advanced reasoning to handle edge cases. 4. Provide commented code that is clear and maintainable. """ default_prompt_synthesis = """**Synthesis & Final Refinement (Round 3)** Review the detailed code generation and reasoning below, and produce a final, refined response that: 1. Synthesizes the brainstorming insights and advanced reasoning. 2. Provides a concise summary of the solution. 3. Highlights any potential improvements or considerations. **Detailed Response:** {code_response} """ # --- Memory Management --- class MemoryManager: """Encapsulate shared memory for storing and retrieving conversation items.""" def __init__(self) -> None: self.shared_memory: List[str] = [] def store(self, item: str) -> None: """ Store a memory item and log an excerpt. Args: item (str): The memory content to store. """ self.shared_memory.append(item) logging.info(f"[Memory Stored]: {item[:50]}...") def retrieve(self, query: str, top_k: int = 3) -> List[str]: """ Retrieve memory items that contain the query text (case-insensitive). Args: query (str): The text query to search for. top_k (int): Maximum number of memory items to return. Returns: List[str]: A list of up to top_k memory items. """ query_lower = query.lower() relevant = [item for item in self.shared_memory if query_lower in item.lower()] if not relevant: logging.info("[Memory Retrieval]: No relevant memories found.") else: logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.") return relevant[:top_k] # Create a global memory manager instance for RAG purposes. global_memory_manager = MemoryManager() # --- Multi-Round Swarm Agent Function --- @spaces.GPU(duration=180) # Adjust duration as needed def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int, prompt_brainstorm_text: str, prompt_code_generation_text: str, prompt_synthesis_text: str ) -> Generator[str, None, None]: """ A three-round iterative process that uses the provided prompt templates: - Round 1: Brainstorming. - Round 2: Advanced reasoning & code generation. - Round 3: Synthesis & refinement. This generator yields the response from the final round as it is produced. Yields: str: Progressive updates of the final response. """ model, tokenizer = get_model_and_tokenizer() # ----- Round 1: Brainstorming ----- logging.info("--- Round 1: Brainstorming ---") prompt_r1 = prompt_brainstorm_text.format(user_prompt=user_prompt) input_ids_r1 = tokenizer.encode(prompt_r1, return_tensors="pt").to(model.device) streamer_r1 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) kwargs_r1 = dict( input_ids=input_ids_r1, streamer=streamer_r1, max_new_tokens=max_new_tokens, do_sample=True, temperature=temp, top_p=top_p, ) try: thread_r1 = Thread(target=model.generate, kwargs=kwargs_r1) with torch.no_grad(): thread_r1.start() except Exception as e: logging.error(f"Error starting Round 1 thread: {e}") raise e brainstorm_response = "" try: for text in streamer_r1: logging.info(text) brainstorm_response += text except Exception as e: logging.error(f"Error during Round 1 generation: {e}") raise e thread_r1.join() global_memory_manager.store(f"Brainstorm Response: {brainstorm_response[:200]}...") # ----- Round 2: Code Generation ----- logging.info("--- Round 2: Code Generation ---") prompt_r2 = prompt_code_generation_text.format( brainstorm_response=brainstorm_response, user_prompt=user_prompt ) input_ids_r2 = tokenizer.encode(prompt_r2, return_tensors="pt").to(model.device) streamer_r2 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) kwargs_r2 = dict( input_ids=input_ids_r2, streamer=streamer_r2, max_new_tokens=max_new_tokens + 100, # extra tokens for detail temperature=temp, top_p=top_p, ) try: thread_r2 = Thread(target=model.generate, kwargs=kwargs_r2) with torch.no_grad(): thread_r2.start() except Exception as e: logging.error(f"Error starting Round 2 thread: {e}") raise e code_response = "" try: for text in streamer_r2: logging.info(text) code_response += text except Exception as e: logging.error(f"Error during Round 2 generation: {e}") raise e thread_r2.join() global_memory_manager.store(f"Code Generation Response: {code_response[:200]}...") # ----- Round 3: Synthesis & Refinement ----- logging.info("--- Round 3: Synthesis & Refinement ---") prompt_r3 = prompt_synthesis_text.format(code_response=code_response) input_ids_r3 = tokenizer.encode(prompt_r3, return_tensors="pt").to(model.device) streamer_r3 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) kwargs_r3 = dict( input_ids=input_ids_r3, streamer=streamer_r3, max_new_tokens=max_new_tokens // 2, temperature=temp, top_p=top_p, ) try: thread_r3 = Thread(target=model.generate, kwargs=kwargs_r3) with torch.no_grad(): thread_r3.start() except Exception as e: logging.error(f"Error starting Round 3 thread: {e}") raise e final_response = "" try: for text in streamer_r3: logging.info(text) final_response += text yield final_response # Yield progressive updates except Exception as e: logging.error(f"Error during Round 3 generation: {e}") raise e thread_r3.join() global_memory_manager.store(f"Final Synthesis Response: {final_response[:200]}...") # --- Explanation Function for Puns --- def handle_explanation_request(user_prompt: str) -> str: """ If the user asks for an explanation of the puns, this function retrieves relevant stored memory items (which are expected to include pun examples) and constructs a new prompt to generate a detailed explanation. Args: user_prompt (str): The user request (e.g. "explain the different puns you mentioned") Returns: str: The explanation generated by the model. """ # Retrieve memory items that contain "pun" (assuming previous outputs include puns) retrieved = global_memory_manager.retrieve("pun", top_k=3) if not retrieved: explanation_prompt = "No previous puns found to explain. Please provide the pun examples." else: explanation_prompt = "Please explain the following coding puns in detail:\n\n" for item in retrieved: explanation_prompt += f"- {item}\n" explanation_prompt += "\nProvide a detailed explanation for each pun." model, tokenizer = get_model_and_tokenizer() input_ids = tokenizer.encode(explanation_prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=300, temperature=0.7, top_p=0.9, ) try: thread = Thread(target=model.generate, kwargs=kwargs) with torch.no_grad(): thread.start() except Exception as e: logging.error(f"Error starting explanation thread: {e}") raise e explanation = "" try: for text in streamer: explanation += text except Exception as e: logging.error(f"Error during explanation generation: {e}") raise e thread.join() return explanation # --- Helper to Format History --- def format_history(history: List) -> List[Dict[str, str]]: """ Convert history (which might be a list of [user, assistant] pairs or already formatted dictionaries) into a list of OpenAI-style message dictionaries. Args: history (List): List of conversation items. Returns: List[Dict[str, str]]: A list of formatted message dictionaries. """ messages = [] for item in history: # If item is a list or tuple, try to unpack it if it has exactly 2 elements. if isinstance(item, (list, tuple)) and len(item) == 2: user_msg, assistant_msg = item messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) elif isinstance(item, dict): messages.append(item) return messages # --- Gradio Chat Interface Function --- def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict) -> Generator[List[Dict[str, str]], None, None]: """ This function is called by Gradio's ChatInterface. It uses the current saved generation parameters and prompt templates. If the user request appears to ask for an explanation of puns, it routes the request to the explanation function. Args: message (str): The user message. history (List): The conversation history. param_state (Dict): Generation parameters. prompt_state (Dict): Prompt templates. Yields: Generator[List[Dict[str, str]]]: Updated history in OpenAI-style message dictionaries. """ # Check if the user is asking to explain puns. if "explain" in message.lower() and "pun" in message.lower(): explanation = handle_explanation_request(message) history = history + [[message, explanation]] yield format_history(history) return try: temp = float(param_state.get("temperature", 0.5)) top_p = float(param_state.get("top_p", 0.9)) max_new_tokens = int(param_state.get("max_new_tokens", 300)) memory_top_k = int(param_state.get("memory_top_k", 2)) except Exception as e: logging.error(f"Parameter conversion error: {e}") temp, top_p, max_new_tokens, memory_top_k = 0.5, 0.9, 300, 2 prompt_brainstorm_text = prompt_state.get("prompt_brainstorm", default_prompt_brainstorm) prompt_code_generation_text = prompt_state.get("prompt_code_generation", default_prompt_code_generation) prompt_synthesis_text = prompt_state.get("prompt_synthesis", default_prompt_synthesis) # Append the new user message with an empty assistant reply (as a two-item list) history = history + [[message, ""]] # Call the multi-round agent as a generator (for streaming) for partial_response in swarm_agent_iterative( user_prompt=message, temp=temp, top_p=top_p, max_new_tokens=max_new_tokens, memory_top_k=memory_top_k, prompt_brainstorm_text=prompt_brainstorm_text, prompt_code_generation_text=prompt_code_generation_text, prompt_synthesis_text=prompt_synthesis_text ): # Update the last assistant message with the new partial response. history[-1][1] = partial_response yield format_history(history) # --- UI Settings & Styling --- ui_description = '''
Multi-round agent:
- Brainstorming
- Advanced reasoning & code generation
- Synthesis & refinement
Ask me anything...