File size: 10,700 Bytes
a707ccd 361c4d3 ea0faa1 361c4d3 a707ccd 85bfd55 ea0faa1 85bfd55 ea0faa1 a707ccd 6a12f54 85bfd55 6a12f54 85bfd55 6a12f54 7714f74 b27451f 7714f74 b27451f 7714f74 b27451f 7714f74 b27451f 6a12f54 7714f74 d858dc3 7714f74 b27451f 7714f74 b27451f 7714f74 d858dc3 7714f74 b27451f ea0faa1 361c4d3 6a12f54 ea0faa1 6a12f54 ea0faa1 85bfd55 6a12f54 7714f74 6a12f54 7714f74 6a12f54 361c4d3 5138a85 361c4d3 5138a85 361c4d3 6a12f54 361c4d3 d858dc3 6a12f54 361c4d3 d858dc3 361c4d3 d858dc3 361c4d3 5138a85 ea0faa1 361c4d3 ea0faa1 6a12f54 5138a85 361c4d3 5138a85 7714f74 5138a85 361c4d3 5138a85 361c4d3 ea0faa1 361c4d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
import gradio as gr
import os
import spaces # Import the spaces library
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
import torch
from threading import Thread
# Model IDs from Hugging Face Hub (Fixed to Unsloth 7B and 32B Unsloth 4bit)
model_ids = {
"7B-Unsloth": "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit", # Unsloth 7B model
"32B-Unsloth": "unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit", # Unsloth 32B model
}
models = {} # Keep models as a dictionary, but initially empty
tokenizers = {} # Keep tokenizers as a dictionary, initially empty
# BitsAndBytesConfig for 4-bit quantization (for BOTH models now)
bnb_config_4bit = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
)
def get_model_and_tokenizer(size): # Function to load model on demand
if size not in models: # Load only if not already loaded
model_id = model_ids[size]
print(f"Loading {size} model: {model_id} on demand")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=bnb_config_4bit, # Apply 4-bit config for BOTH models
torch_dtype=torch.bfloat16, # Or torch.float16 if needed
device_map='auto',
trust_remote_code=True
)
models[size] = model
tokenizers[size] = tokenizer
print(f"Loaded {size} model on demand.")
return models[size], tokenizers[size]
# Revised Default Prompts (as defined previously - these are still good)
default_prompt_1_5b = """**Code Analysis Task**
As a Senior Code Analyst, analyze this programming problem:
**User Request:**
{user_prompt}
**Relevant Context:**
{context_1_5b}
**Analysis Required:**
1. Briefly break down the problem, including key constraints and edge cases.
2. Suggest 2-3 potential approach options (algorithms/data structures).
3. Recommend ONE primary strategy and briefly justify your choice.
4. Provide a very brief initial pseudocode sketch of the core logic."""
default_prompt_7b = """**Code Implementation Task**
As a Principal Software Engineer, provide production-ready Streamlit/Python code based on this analysis:
**Initial Analysis:**
{response_1_5b}
**Relevant Context:**
{context_7b}
**Code Requirements:**
1. Generate concise, production-grade Python code for a Streamlit app.
2. Include necessary imports, UI elements, and basic functionality.
3. Add comments for clarity.
"""
# --- Shared Memory Implementation --- (Same)
shared_memory = []
def store_in_memory(memory_item):
shared_memory.append(memory_item)
print(f"\n[Memory Stored]: {memory_item[:50]}...")
def retrieve_from_memory(query, top_k=2):
relevant_memories = []
query_lower = query.lower()
for memory_item in shared_memory:
if query_lower in memory_item.lower():
relevant_memories.append(memory_item)
if not relevant_memories:
print("\n[Memory Retrieval]: No relevant memories found.")
return []
print(f"\n[Memory Retrieval]: Found {len(relevant_memories)} relevant memories.")
return relevant_memories[:top_k]
# --- Streaming Swarm Agent Function - Fixed Models (Unsloth 7B and 32B Unsloth) ---
@spaces.GPU(duration=120) # Added duration
def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_template, temperature=0.5, top_p=0.9, max_new_tokens=300): # Removed final_model_size
global shared_memory
shared_memory = [] # Clear memory for each new request
print(f"\n--- Swarm Agent Processing with Shared Memory (RAG) - GPU ACCELERATED - Final Model: 32B Unsloth ---") # Updated message
# 7B Unsloth Model - Brainstorming/Initial Draft (Lazy Load and get model)
print("\n[7B Unsloth Model - Brainstorming] - GPU Accelerated") # Now 7B Unsloth is brainstorming
model_7b, tokenizer_7b = get_model_and_tokenizer("7B-Unsloth") # Lazy load 7B Unsloth
retrieved_memory_7b = retrieve_from_memory(user_prompt)
context_7b = "\n".join([f"- {mem}" for mem in retrieved_memory_7b]) if retrieved_memory_7b else "No relevant context found in memory."
# Use user-provided prompt template for 7B model (as brainstorming model now)
prompt_7b_brainstorm = prompt_1_5b_template.format(user_prompt=user_prompt, context_1_5b=context_7b) # Reusing 1.5B template - adjust if needed
input_ids_7b = tokenizer_7b.encode(prompt_7b_brainstorm, return_tensors="pt").to(model_7b.device)
streamer_7b = TextIteratorStreamer(tokenizer_7b, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 7B
generate_kwargs_7b = dict( # Generation kwargs for 7B
input_ids= input_ids_7b,
streamer=streamer_7b,
max_new_tokens=max_new_tokens, # Use user-defined max_new_tokens
do_sample=True,
temperature=temperature,
top_p=top_p,
# eos_token_id=tokenizer_7b.eos_token_id, # Not strictly needed as streamer handles EOS
)
thread_7b = Thread(target=model_7b.generate, kwargs=generate_kwargs_7b) # Thread for 7B generation
thread_7b.start()
response_7b_stream = "" # Accumulate streamed 7B response
print(f"7B Unsloth Response (Brainstorming):\n", end="")
for text in streamer_7b: # Stream and print 7B response
print(text, end="", flush=True) # Print in place
response_7b_stream += text
yield response_7b_stream # Yield intermediate 7B response
store_in_memory(f"7B Unsloth Model Initial Response: {response_7b_stream[:200]}...") # Store accumulated 7B response
# 32B Unsloth Model - Final Code Generation (Lazy Load and get model)
final_model, final_tokenizer = get_model_and_tokenizer("32B-Unsloth") # Lazy load 32B Unsloth
print("\n[32B Unsloth Model - Final Code Generation] - GPU Accelerated") # Model-specific message
model_stage_name = "32B Unsloth Model - Final Code"
final_max_new_tokens = max_new_tokens + 200 # More tokens for 32B model
retrieved_memory_final = retrieve_from_memory(response_7b_stream) # Memory from streamed 7B response
context_final = "\n".join([f"- {mem}" for mem in retrieved_memory_final]) if retrieved_memory_final else "No relevant context found in memory."
# Use user-provided prompt template for final model (using 7B template)
prompt_final = prompt_7b_template.format(response_1_5b=response_7b_stream, context_7b=context_final) # Using prompt_7b_template for final stage
input_ids_final = final_tokenizer.encode(prompt_final, return_tensors="pt").to(final_model.device)
streamer_final = TextIteratorStreamer(final_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 32B
generate_kwargs_final = dict( # Generation kwargs for 32B
input_ids= input_ids_final,
streamer=streamer_final,
max_new_tokens=final_max_new_tokens,
temperature=temperature,
top_p=top_p,
# eos_token_id=final_tokenizer.eos_token_id, # Not strictly needed as streamer handles EOS
)
thread_final = Thread(target=final_model.generate, kwargs=generate_kwargs_final) # Thread for 32B generation
thread_final.start()
response_final_stream = "" # Accumulate streamed 32B response
print(f"\n{model_stage_name} Response:\n", end="")
for text in streamer_final: # Stream and print 32B response
print(text, end="", flush=True) # Print in place
response_final_stream += text
yield response_final_stream # Yield intermediate 32B response
store_in_memory(f"{model_stage_name} Response: {response_final_stream[:200]}...") # Store accumulated 32B response
return response_final_stream # Returns final streamed response
# --- Gradio ChatInterface --- (No Model Selection Dropdown anymore)
def gradio_interface(message, history, temp, top_p, max_tokens, prompt_1_5b_text, prompt_7b_text): # Removed final_model_selector
# history is automatically managed by ChatInterface
full_response = "" # Accumulate full response from generator
for partial_response in swarm_agent_sequential_rag( # Iterate through generator
message,
prompt_1_5b_template=prompt_1_5b_text, # Pass prompt templates
prompt_7b_template=prompt_7b_text,
temperature=temp,
top_p=top_p,
max_new_tokens=int(max_tokens) # Ensure max_tokens is an integer
):
full_response = partial_response # Update full response with partial response
yield full_response # Yield intermediate full response
DESCRIPTION = '''
<div>
<h1 style="text-align: center;">DeepSeek Agent Swarm Chat (Unsloth 7B + 32B) - Streaming Demo</h1>
<p style="text-align: center;">Agent swarm using Unsloth DeepSeek-R1-Distill models (7B + 32B) with shared memory, adjustable settings, and customizable prompts. GPU accelerated using ZeroGPU! (Requires Pro Space)</p>
</div>
'''
LICENSE = """
<p/>
---
"""
PLACEHOLDER = """
Ask me anything...
"""
css = """
h1 {
text-align: center;
display: block;
}
#duplicate-button {
margin: auto;
color: white;
background: #1565c0;
border-radius: 100vh;
}
"""
# Gradio ChatInterface with streaming
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Agent Swarm Output')
with gr.Blocks(fill_height=True, css=css) as demo:
gr.Markdown(DESCRIPTION)
gr.ChatInterface(
fn=gradio_interface,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), # Accordion for params
additional_inputs=[
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature"),
gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P"),
gr.Number(value=300, label="Max Tokens", precision=0),
gr.Textbox(value=default_prompt_1_5b, lines=7, label="Brainstorming Model Prompt Template (Unsloth 7B)"),
gr.Textbox(value=default_prompt_7b, lines=7, label="Code Generation Prompt Template (Unsloth 32B)"),
],
examples=[
['How to setup a human base on Mars? Give short answer.'],
['Explain theory of relativity to me like I’m 8 years old.'],
['Write a streamlit app to track my finances'],
['Write a pun-filled happy birthday message to my friend Alex.'],
['Justify why a penguin might make a good king of the jungle.']
],
cache_examples=False,
)
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.launch() |