File size: 10,700 Bytes
a707ccd
361c4d3
ea0faa1
361c4d3
 
 
a707ccd
85bfd55
ea0faa1
85bfd55
 
ea0faa1
a707ccd
6a12f54
 
 
85bfd55
6a12f54
 
 
 
 
 
 
 
 
 
 
 
85bfd55
 
 
 
 
 
 
6a12f54
 
 
 
 
 
 
7714f74
b27451f
7714f74
b27451f
7714f74
 
b27451f
7714f74
 
b27451f
 
 
6a12f54
 
7714f74
 
 
d858dc3
7714f74
b27451f
7714f74
 
b27451f
7714f74
 
d858dc3
 
 
 
 
7714f74
 
b27451f
ea0faa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361c4d3
 
6a12f54
ea0faa1
 
 
6a12f54
ea0faa1
85bfd55
 
 
6a12f54
 
7714f74
6a12f54
 
7714f74
6a12f54
361c4d3
 
 
 
 
5138a85
361c4d3
 
 
 
5138a85
361c4d3
 
 
 
 
 
 
 
 
 
 
 
6a12f54
 
 
 
 
 
 
361c4d3
d858dc3
 
6a12f54
361c4d3
d858dc3
 
361c4d3
 
 
 
 
d858dc3
 
 
361c4d3
5138a85
ea0faa1
361c4d3
 
 
 
 
 
 
 
 
 
 
 
 
ea0faa1
 
6a12f54
 
5138a85
361c4d3
 
5138a85
7714f74
 
 
5138a85
 
361c4d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5138a85
361c4d3
 
ea0faa1
 
361c4d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
import gradio as gr
import os
import spaces  # Import the spaces library
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
import torch
from threading import Thread

# Model IDs from Hugging Face Hub (Fixed to Unsloth 7B and 32B Unsloth 4bit)
model_ids = {
    "7B-Unsloth": "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit", # Unsloth 7B model
    "32B-Unsloth": "unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit",       # Unsloth 32B model
}

models = {}  # Keep models as a dictionary, but initially empty
tokenizers = {} # Keep tokenizers as a dictionary, initially empty

# BitsAndBytesConfig for 4-bit quantization (for BOTH models now)
bnb_config_4bit = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
)


def get_model_and_tokenizer(size): # Function to load model on demand
    if size not in models: # Load only if not already loaded
        model_id = model_ids[size]
        print(f"Loading {size} model: {model_id} on demand")
        tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            quantization_config=bnb_config_4bit, # Apply 4-bit config for BOTH models
            torch_dtype=torch.bfloat16, # Or torch.float16 if needed
            device_map='auto',
            trust_remote_code=True
        )
        models[size] = model
        tokenizers[size] = tokenizer
        print(f"Loaded {size} model on demand.")
    return models[size], tokenizers[size]


# Revised Default Prompts (as defined previously - these are still good)
default_prompt_1_5b = """**Code Analysis Task**
As a Senior Code Analyst, analyze this programming problem:

**User Request:**
{user_prompt}

**Relevant Context:**
{context_1_5b}

**Analysis Required:**
1. Briefly break down the problem, including key constraints and edge cases.
2. Suggest 2-3 potential approach options (algorithms/data structures).
3. Recommend ONE primary strategy and briefly justify your choice.
4. Provide a very brief initial pseudocode sketch of the core logic."""


default_prompt_7b = """**Code Implementation Task**
As a Principal Software Engineer, provide production-ready Streamlit/Python code based on this analysis:

**Initial Analysis:**
{response_1_5b}

**Relevant Context:**
{context_7b}

**Code Requirements:**
1.  Generate concise, production-grade Python code for a Streamlit app.
2.  Include necessary imports, UI elements, and basic functionality.
3.  Add comments for clarity.
    """


# --- Shared Memory Implementation --- (Same)
shared_memory = []

def store_in_memory(memory_item):
    shared_memory.append(memory_item)
    print(f"\n[Memory Stored]: {memory_item[:50]}...")

def retrieve_from_memory(query, top_k=2):
    relevant_memories = []
    query_lower = query.lower()
    for memory_item in shared_memory:
        if query_lower in memory_item.lower():
            relevant_memories.append(memory_item)

    if not relevant_memories:
        print("\n[Memory Retrieval]: No relevant memories found.")
        return []

    print(f"\n[Memory Retrieval]: Found {len(relevant_memories)} relevant memories.")
    return relevant_memories[:top_k]


# --- Streaming Swarm Agent Function - Fixed Models (Unsloth 7B and 32B Unsloth) ---
@spaces.GPU(duration=120) # Added duration
def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_template, temperature=0.5, top_p=0.9, max_new_tokens=300): # Removed final_model_size
    global shared_memory
    shared_memory = [] # Clear memory for each new request

    print(f"\n--- Swarm Agent Processing with Shared Memory (RAG) - GPU ACCELERATED - Final Model: 32B Unsloth ---") # Updated message

    # 7B Unsloth Model - Brainstorming/Initial Draft (Lazy Load and get model)
    print("\n[7B Unsloth Model - Brainstorming] - GPU Accelerated") # Now 7B Unsloth is brainstorming
    model_7b, tokenizer_7b = get_model_and_tokenizer("7B-Unsloth") # Lazy load 7B Unsloth
    retrieved_memory_7b = retrieve_from_memory(user_prompt)
    context_7b = "\n".join([f"- {mem}" for mem in retrieved_memory_7b]) if retrieved_memory_7b else "No relevant context found in memory."

    # Use user-provided prompt template for 7B model (as brainstorming model now)
    prompt_7b_brainstorm = prompt_1_5b_template.format(user_prompt=user_prompt, context_1_5b=context_7b) # Reusing 1.5B template - adjust if needed

    input_ids_7b = tokenizer_7b.encode(prompt_7b_brainstorm, return_tensors="pt").to(model_7b.device)
    streamer_7b = TextIteratorStreamer(tokenizer_7b, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 7B

    generate_kwargs_7b = dict( # Generation kwargs for 7B
        input_ids= input_ids_7b,
        streamer=streamer_7b,
        max_new_tokens=max_new_tokens, # Use user-defined max_new_tokens
        do_sample=True,
        temperature=temperature,
        top_p=top_p,
        # eos_token_id=tokenizer_7b.eos_token_id, # Not strictly needed as streamer handles EOS
    )

    thread_7b = Thread(target=model_7b.generate, kwargs=generate_kwargs_7b) # Thread for 7B generation
    thread_7b.start()

    response_7b_stream = "" # Accumulate streamed 7B response
    print(f"7B Unsloth Response (Brainstorming):\n", end="")
    for text in streamer_7b: # Stream and print 7B response
        print(text, end="", flush=True) # Print in place
        response_7b_stream += text
        yield response_7b_stream # Yield intermediate 7B response

    store_in_memory(f"7B Unsloth Model Initial Response: {response_7b_stream[:200]}...") # Store accumulated 7B response

    # 32B Unsloth Model - Final Code Generation (Lazy Load and get model)
    final_model, final_tokenizer = get_model_and_tokenizer("32B-Unsloth") # Lazy load 32B Unsloth
    print("\n[32B Unsloth Model - Final Code Generation] - GPU Accelerated") # Model-specific message
    model_stage_name = "32B Unsloth Model - Final Code"
    final_max_new_tokens = max_new_tokens + 200 # More tokens for 32B model

    retrieved_memory_final = retrieve_from_memory(response_7b_stream) # Memory from streamed 7B response
    context_final = "\n".join([f"- {mem}" for mem in retrieved_memory_final]) if retrieved_memory_final else "No relevant context found in memory."

    # Use user-provided prompt template for final model (using 7B template)
    prompt_final = prompt_7b_template.format(response_1_5b=response_7b_stream, context_7b=context_final) # Using prompt_7b_template for final stage

    input_ids_final = final_tokenizer.encode(prompt_final, return_tensors="pt").to(final_model.device)
    streamer_final = TextIteratorStreamer(final_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 32B

    generate_kwargs_final = dict( # Generation kwargs for 32B
        input_ids= input_ids_final,
        streamer=streamer_final,
        max_new_tokens=final_max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        # eos_token_id=final_tokenizer.eos_token_id, # Not strictly needed as streamer handles EOS
    )

    thread_final = Thread(target=final_model.generate, kwargs=generate_kwargs_final) # Thread for 32B generation
    thread_final.start()

    response_final_stream = "" # Accumulate streamed 32B response
    print(f"\n{model_stage_name} Response:\n", end="")
    for text in streamer_final: # Stream and print 32B response
        print(text, end="", flush=True) # Print in place
        response_final_stream += text
        yield response_final_stream # Yield intermediate 32B response

    store_in_memory(f"{model_stage_name} Response: {response_final_stream[:200]}...") # Store accumulated 32B response

    return response_final_stream # Returns final streamed response


# --- Gradio ChatInterface --- (No Model Selection Dropdown anymore)
def gradio_interface(message, history, temp, top_p, max_tokens, prompt_1_5b_text, prompt_7b_text): # Removed final_model_selector
    # history is automatically managed by ChatInterface
    full_response = "" # Accumulate full response from generator
    for partial_response in swarm_agent_sequential_rag( # Iterate through generator
        message,
        prompt_1_5b_template=prompt_1_5b_text, # Pass prompt templates
        prompt_7b_template=prompt_7b_text,
        temperature=temp,
        top_p=top_p,
        max_new_tokens=int(max_tokens) # Ensure max_tokens is an integer
    ):
        full_response = partial_response # Update full response with partial response
        yield full_response # Yield intermediate full response


DESCRIPTION = '''
<div>
<h1 style="text-align: center;">DeepSeek Agent Swarm Chat (Unsloth 7B + 32B) - Streaming Demo</h1>
<p style="text-align: center;">Agent swarm using Unsloth DeepSeek-R1-Distill models (7B + 32B) with shared memory, adjustable settings, and customizable prompts.  GPU accelerated using ZeroGPU! (Requires Pro Space)</p>
</div>
'''

LICENSE = """
<p/>
---
"""

PLACEHOLDER = """
Ask me anything...
"""


css = """
h1 {
  text-align: center;
  display: block;
}
#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""
# Gradio ChatInterface with streaming
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Agent Swarm Output')

with gr.Blocks(fill_height=True, css=css) as demo:

    gr.Markdown(DESCRIPTION)
    gr.ChatInterface(
        fn=gradio_interface,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), # Accordion for params
        additional_inputs=[
            gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature"),
            gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P"),
            gr.Number(value=300, label="Max Tokens", precision=0),
            gr.Textbox(value=default_prompt_1_5b, lines=7, label="Brainstorming Model Prompt Template (Unsloth 7B)"),
            gr.Textbox(value=default_prompt_7b, lines=7, label="Code Generation Prompt Template (Unsloth 32B)"),
        ],
        examples=[
            ['How to setup a human base on Mars? Give short answer.'],
            ['Explain theory of relativity to me like I’m 8 years old.'],
            ['Write a streamlit app to track my finances'],
            ['Write a pun-filled happy birthday message to my friend Alex.'],
            ['Justify why a penguin might make a good king of the jungle.']
        ],
        cache_examples=False,
    )

    gr.Markdown(LICENSE)

if __name__ == "__main__":
    demo.launch()