Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
3 |
-
import torch
|
4 |
import spaces # Import the spaces library
|
|
|
|
|
|
|
5 |
|
6 |
# Model IDs from Hugging Face Hub (Fixed to Unsloth 7B and 32B Unsloth 4bit)
|
7 |
model_ids = {
|
@@ -93,8 +95,8 @@ def retrieve_from_memory(query, top_k=2):
|
|
93 |
return relevant_memories[:top_k]
|
94 |
|
95 |
|
96 |
-
# --- Swarm Agent Function - Fixed Models (Unsloth 7B and 32B Unsloth) ---
|
97 |
-
@spaces.GPU
|
98 |
def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_template, temperature=0.5, top_p=0.9, max_new_tokens=300): # Removed final_model_size
|
99 |
global shared_memory
|
100 |
shared_memory = [] # Clear memory for each new request
|
@@ -111,16 +113,29 @@ def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_temp
|
|
111 |
prompt_7b_brainstorm = prompt_1_5b_template.format(user_prompt=user_prompt, context_1_5b=context_7b) # Reusing 1.5B template - adjust if needed
|
112 |
|
113 |
input_ids_7b = tokenizer_7b.encode(prompt_7b_brainstorm, return_tensors="pt").to(model_7b.device)
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
116 |
max_new_tokens=max_new_tokens, # Use user-defined max_new_tokens
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
120 |
)
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
# 32B Unsloth Model - Final Code Generation (Lazy Load and get model)
|
126 |
final_model, final_tokenizer = get_model_and_tokenizer("32B-Unsloth") # Lazy load 32B Unsloth
|
@@ -128,54 +143,113 @@ def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_temp
|
|
128 |
model_stage_name = "32B Unsloth Model - Final Code"
|
129 |
final_max_new_tokens = max_new_tokens + 200 # More tokens for 32B model
|
130 |
|
131 |
-
retrieved_memory_final = retrieve_from_memory(
|
132 |
context_final = "\n".join([f"- {mem}" for mem in retrieved_memory_final]) if retrieved_memory_final else "No relevant context found in memory."
|
133 |
|
134 |
# Use user-provided prompt template for final model (using 7B template)
|
135 |
-
prompt_final = prompt_7b_template.format(response_1_5b=
|
136 |
-
|
137 |
|
138 |
input_ids_final = final_tokenizer.encode(prompt_final, return_tensors="pt").to(final_model.device)
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
141 |
max_new_tokens=final_max_new_tokens,
|
142 |
temperature=temperature,
|
143 |
top_p=top_p,
|
144 |
-
|
145 |
)
|
146 |
-
response_final = final_tokenizer.decode(output_final[0], skip_special_tokens=True)
|
147 |
-
print(f"{model_stage_name} Response:\n{response_final}")
|
148 |
-
store_in_memory(f"{model_stage_name} Response: {response_final[:200]}...")
|
149 |
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
|
153 |
# --- Gradio ChatInterface --- (No Model Selection Dropdown anymore)
|
154 |
def gradio_interface(message, history, temp, top_p, max_tokens, prompt_1_5b_text, prompt_7b_text): # Removed final_model_selector
|
155 |
# history is automatically managed by ChatInterface
|
156 |
-
|
|
|
157 |
message,
|
158 |
prompt_1_5b_template=prompt_1_5b_text, # Pass prompt templates
|
159 |
prompt_7b_template=prompt_7b_text,
|
160 |
temperature=temp,
|
161 |
top_p=top_p,
|
162 |
max_new_tokens=int(max_tokens) # Ensure max_tokens is an integer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
)
|
164 |
-
|
165 |
-
|
166 |
-
iface = gr.ChatInterface( # Using ChatInterface now
|
167 |
-
fn=gradio_interface,
|
168 |
-
# Define additional inputs for settings and prompts (NO model dropdown)
|
169 |
-
additional_inputs=[
|
170 |
-
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature"), # Lowered default temp to 0.5
|
171 |
-
gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P"),
|
172 |
-
gr.Number(value=300, label="Max Tokens", precision=0), # Use Number for integer tokens
|
173 |
-
gr.Textbox(value=default_prompt_1_5b, lines=10, label="Brainstorming Model Prompt Template (Unsloth 7B)"), # Updated label - Unsloth 7B now brainstormer
|
174 |
-
gr.Textbox(value=default_prompt_7b, lines=10, label="Code Generation Prompt Template (Unsloth 32B)"), # Updated label - Unsloth 32B is code generator
|
175 |
-
],
|
176 |
-
title="DeepSeek Agent Swarm Chat (ZeroGPU Demo - Fixed Models: Unsloth 7B + 32B)", # Updated title
|
177 |
-
description="Chat with a DeepSeek agent swarm (Unsloth 7B + 32B) with shared memory, adjustable settings, **and customizable prompts!** **GPU accelerated using ZeroGPU!** (Requires Pro Space)", # Updated description
|
178 |
-
)
|
179 |
|
180 |
if __name__ == "__main__":
|
181 |
-
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
|
|
3 |
import spaces # Import the spaces library
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
5 |
+
import torch
|
6 |
+
from threading import Thread
|
7 |
|
8 |
# Model IDs from Hugging Face Hub (Fixed to Unsloth 7B and 32B Unsloth 4bit)
|
9 |
model_ids = {
|
|
|
95 |
return relevant_memories[:top_k]
|
96 |
|
97 |
|
98 |
+
# --- Streaming Swarm Agent Function - Fixed Models (Unsloth 7B and 32B Unsloth) ---
|
99 |
+
@spaces.GPU(duration=120) # Added duration
|
100 |
def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_template, temperature=0.5, top_p=0.9, max_new_tokens=300): # Removed final_model_size
|
101 |
global shared_memory
|
102 |
shared_memory = [] # Clear memory for each new request
|
|
|
113 |
prompt_7b_brainstorm = prompt_1_5b_template.format(user_prompt=user_prompt, context_1_5b=context_7b) # Reusing 1.5B template - adjust if needed
|
114 |
|
115 |
input_ids_7b = tokenizer_7b.encode(prompt_7b_brainstorm, return_tensors="pt").to(model_7b.device)
|
116 |
+
streamer_7b = TextIteratorStreamer(tokenizer_7b, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 7B
|
117 |
+
|
118 |
+
generate_kwargs_7b = dict( # Generation kwargs for 7B
|
119 |
+
input_ids= input_ids_7b,
|
120 |
+
streamer=streamer_7b,
|
121 |
max_new_tokens=max_new_tokens, # Use user-defined max_new_tokens
|
122 |
+
do_sample=True,
|
123 |
+
temperature=temperature,
|
124 |
+
top_p=top_p,
|
125 |
+
# eos_token_id=tokenizer_7b.eos_token_id, # Not strictly needed as streamer handles EOS
|
126 |
)
|
127 |
+
|
128 |
+
thread_7b = Thread(target=model_7b.generate, kwargs=generate_kwargs_7b) # Thread for 7B generation
|
129 |
+
thread_7b.start()
|
130 |
+
|
131 |
+
response_7b_stream = "" # Accumulate streamed 7B response
|
132 |
+
print(f"7B Unsloth Response (Brainstorming):\n", end="")
|
133 |
+
for text in streamer_7b: # Stream and print 7B response
|
134 |
+
print(text, end="", flush=True) # Print in place
|
135 |
+
response_7b_stream += text
|
136 |
+
yield response_7b_stream # Yield intermediate 7B response
|
137 |
+
|
138 |
+
store_in_memory(f"7B Unsloth Model Initial Response: {response_7b_stream[:200]}...") # Store accumulated 7B response
|
139 |
|
140 |
# 32B Unsloth Model - Final Code Generation (Lazy Load and get model)
|
141 |
final_model, final_tokenizer = get_model_and_tokenizer("32B-Unsloth") # Lazy load 32B Unsloth
|
|
|
143 |
model_stage_name = "32B Unsloth Model - Final Code"
|
144 |
final_max_new_tokens = max_new_tokens + 200 # More tokens for 32B model
|
145 |
|
146 |
+
retrieved_memory_final = retrieve_from_memory(response_7b_stream) # Memory from streamed 7B response
|
147 |
context_final = "\n".join([f"- {mem}" for mem in retrieved_memory_final]) if retrieved_memory_final else "No relevant context found in memory."
|
148 |
|
149 |
# Use user-provided prompt template for final model (using 7B template)
|
150 |
+
prompt_final = prompt_7b_template.format(response_1_5b=response_7b_stream, context_7b=context_final) # Using prompt_7b_template for final stage
|
|
|
151 |
|
152 |
input_ids_final = final_tokenizer.encode(prompt_final, return_tensors="pt").to(final_model.device)
|
153 |
+
streamer_final = TextIteratorStreamer(final_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 32B
|
154 |
+
|
155 |
+
generate_kwargs_final = dict( # Generation kwargs for 32B
|
156 |
+
input_ids= input_ids_final,
|
157 |
+
streamer=streamer_final,
|
158 |
max_new_tokens=final_max_new_tokens,
|
159 |
temperature=temperature,
|
160 |
top_p=top_p,
|
161 |
+
# eos_token_id=final_tokenizer.eos_token_id, # Not strictly needed as streamer handles EOS
|
162 |
)
|
|
|
|
|
|
|
163 |
|
164 |
+
thread_final = Thread(target=final_model.generate, kwargs=generate_kwargs_final) # Thread for 32B generation
|
165 |
+
thread_final.start()
|
166 |
+
|
167 |
+
response_final_stream = "" # Accumulate streamed 32B response
|
168 |
+
print(f"\n{model_stage_name} Response:\n", end="")
|
169 |
+
for text in streamer_final: # Stream and print 32B response
|
170 |
+
print(text, end="", flush=True) # Print in place
|
171 |
+
response_final_stream += text
|
172 |
+
yield response_final_stream # Yield intermediate 32B response
|
173 |
+
|
174 |
+
store_in_memory(f"{model_stage_name} Response: {response_final_stream[:200]}...") # Store accumulated 32B response
|
175 |
+
|
176 |
+
return response_final_stream # Returns final streamed response
|
177 |
|
178 |
|
179 |
# --- Gradio ChatInterface --- (No Model Selection Dropdown anymore)
|
180 |
def gradio_interface(message, history, temp, top_p, max_tokens, prompt_1_5b_text, prompt_7b_text): # Removed final_model_selector
|
181 |
# history is automatically managed by ChatInterface
|
182 |
+
full_response = "" # Accumulate full response from generator
|
183 |
+
for partial_response in swarm_agent_sequential_rag( # Iterate through generator
|
184 |
message,
|
185 |
prompt_1_5b_template=prompt_1_5b_text, # Pass prompt templates
|
186 |
prompt_7b_template=prompt_7b_text,
|
187 |
temperature=temp,
|
188 |
top_p=top_p,
|
189 |
max_new_tokens=int(max_tokens) # Ensure max_tokens is an integer
|
190 |
+
):
|
191 |
+
full_response = partial_response # Update full response with partial response
|
192 |
+
yield full_response # Yield intermediate full response
|
193 |
+
|
194 |
+
|
195 |
+
DESCRIPTION = '''
|
196 |
+
<div>
|
197 |
+
<h1 style="text-align: center;">DeepSeek Agent Swarm Chat (Unsloth 7B + 32B) - Streaming Demo</h1>
|
198 |
+
<p style="text-align: center;">Agent swarm using Unsloth DeepSeek-R1-Distill models (7B + 32B) with shared memory, adjustable settings, and customizable prompts. GPU accelerated using ZeroGPU! (Requires Pro Space)</p>
|
199 |
+
</div>
|
200 |
+
'''
|
201 |
+
|
202 |
+
LICENSE = """
|
203 |
+
<p/>
|
204 |
+
---
|
205 |
+
"""
|
206 |
+
|
207 |
+
PLACEHOLDER = """
|
208 |
+
Ask me anything...
|
209 |
+
"""
|
210 |
+
|
211 |
+
|
212 |
+
css = """
|
213 |
+
h1 {
|
214 |
+
text-align: center;
|
215 |
+
display: block;
|
216 |
+
}
|
217 |
+
#duplicate-button {
|
218 |
+
margin: auto;
|
219 |
+
color: white;
|
220 |
+
background: #1565c0;
|
221 |
+
border-radius: 100vh;
|
222 |
+
}
|
223 |
+
"""
|
224 |
+
# Gradio ChatInterface with streaming
|
225 |
+
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Agent Swarm Output')
|
226 |
+
|
227 |
+
with gr.Blocks(fill_height=True, css=css) as demo:
|
228 |
+
|
229 |
+
gr.Markdown(DESCRIPTION)
|
230 |
+
gr.ChatInterface(
|
231 |
+
fn=gradio_interface,
|
232 |
+
chatbot=chatbot,
|
233 |
+
fill_height=True,
|
234 |
+
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), # Accordion for params
|
235 |
+
additional_inputs=[
|
236 |
+
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature"),
|
237 |
+
gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P"),
|
238 |
+
gr.Number(value=300, label="Max Tokens", precision=0),
|
239 |
+
gr.Textbox(value=default_prompt_1_5b, lines=7, label="Brainstorming Model Prompt Template (Unsloth 7B)"),
|
240 |
+
gr.Textbox(value=default_prompt_7b, lines=7, label="Code Generation Prompt Template (Unsloth 32B)"),
|
241 |
+
],
|
242 |
+
examples=[
|
243 |
+
['How to setup a human base on Mars? Give short answer.'],
|
244 |
+
['Explain theory of relativity to me like I’m 8 years old.'],
|
245 |
+
['Write a streamlit app to track my finances'],
|
246 |
+
['Write a pun-filled happy birthday message to my friend Alex.'],
|
247 |
+
['Justify why a penguin might make a good king of the jungle.']
|
248 |
+
],
|
249 |
+
cache_examples=False,
|
250 |
)
|
251 |
+
|
252 |
+
gr.Markdown(LICENSE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
if __name__ == "__main__":
|
255 |
+
demo.launch()
|