wuhp commited on
Commit
361c4d3
·
verified ·
1 Parent(s): 85bfd55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -39
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
- import torch
4
  import spaces # Import the spaces library
 
 
 
5
 
6
  # Model IDs from Hugging Face Hub (Fixed to Unsloth 7B and 32B Unsloth 4bit)
7
  model_ids = {
@@ -93,8 +95,8 @@ def retrieve_from_memory(query, top_k=2):
93
  return relevant_memories[:top_k]
94
 
95
 
96
- # --- Swarm Agent Function - Fixed Models (Unsloth 7B and 32B Unsloth) ---
97
- @spaces.GPU # <---- GPU DECORATOR ADDED HERE!
98
  def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_template, temperature=0.5, top_p=0.9, max_new_tokens=300): # Removed final_model_size
99
  global shared_memory
100
  shared_memory = [] # Clear memory for each new request
@@ -111,16 +113,29 @@ def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_temp
111
  prompt_7b_brainstorm = prompt_1_5b_template.format(user_prompt=user_prompt, context_1_5b=context_7b) # Reusing 1.5B template - adjust if needed
112
 
113
  input_ids_7b = tokenizer_7b.encode(prompt_7b_brainstorm, return_tensors="pt").to(model_7b.device)
114
- output_7b = model_7b.generate(
115
- input_ids_7b,
 
 
 
116
  max_new_tokens=max_new_tokens, # Use user-defined max_new_tokens
117
- temperature=temperature, # Use user-defined temperature
118
- top_p=top_p, # Use user-defined top_p
119
- do_sample=True
 
120
  )
121
- response_7b = tokenizer_7b.decode(output_7b[0], skip_special_tokens=True)
122
- print(f"7B Unsloth Response (Brainstorming):\n{response_7b}") # Updated message
123
- store_in_memory(f"7B Unsloth Model Initial Response: {response_7b[:200]}...")
 
 
 
 
 
 
 
 
 
124
 
125
  # 32B Unsloth Model - Final Code Generation (Lazy Load and get model)
126
  final_model, final_tokenizer = get_model_and_tokenizer("32B-Unsloth") # Lazy load 32B Unsloth
@@ -128,54 +143,113 @@ def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_temp
128
  model_stage_name = "32B Unsloth Model - Final Code"
129
  final_max_new_tokens = max_new_tokens + 200 # More tokens for 32B model
130
 
131
- retrieved_memory_final = retrieve_from_memory(response_7b) # Memory from 7B brainstorm
132
  context_final = "\n".join([f"- {mem}" for mem in retrieved_memory_final]) if retrieved_memory_final else "No relevant context found in memory."
133
 
134
  # Use user-provided prompt template for final model (using 7B template)
135
- prompt_final = prompt_7b_template.format(response_1_5b=response_7b, context_7b=context_final) # Using prompt_7b_template for final stage
136
-
137
 
138
  input_ids_final = final_tokenizer.encode(prompt_final, return_tensors="pt").to(final_model.device)
139
- output_final = final_model.generate(
140
- input_ids_final,
 
 
 
141
  max_new_tokens=final_max_new_tokens,
142
  temperature=temperature,
143
  top_p=top_p,
144
- do_sample=True
145
  )
146
- response_final = final_tokenizer.decode(output_final[0], skip_special_tokens=True)
147
- print(f"{model_stage_name} Response:\n{response_final}")
148
- store_in_memory(f"{model_stage_name} Response: {response_final[:200]}...")
149
 
150
- return response_final # Returns final model's response
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
  # --- Gradio ChatInterface --- (No Model Selection Dropdown anymore)
154
  def gradio_interface(message, history, temp, top_p, max_tokens, prompt_1_5b_text, prompt_7b_text): # Removed final_model_selector
155
  # history is automatically managed by ChatInterface
156
- response = swarm_agent_sequential_rag(
 
157
  message,
158
  prompt_1_5b_template=prompt_1_5b_text, # Pass prompt templates
159
  prompt_7b_template=prompt_7b_text,
160
  temperature=temp,
161
  top_p=top_p,
162
  max_new_tokens=int(max_tokens) # Ensure max_tokens is an integer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  )
164
- return response
165
-
166
- iface = gr.ChatInterface( # Using ChatInterface now
167
- fn=gradio_interface,
168
- # Define additional inputs for settings and prompts (NO model dropdown)
169
- additional_inputs=[
170
- gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature"), # Lowered default temp to 0.5
171
- gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P"),
172
- gr.Number(value=300, label="Max Tokens", precision=0), # Use Number for integer tokens
173
- gr.Textbox(value=default_prompt_1_5b, lines=10, label="Brainstorming Model Prompt Template (Unsloth 7B)"), # Updated label - Unsloth 7B now brainstormer
174
- gr.Textbox(value=default_prompt_7b, lines=10, label="Code Generation Prompt Template (Unsloth 32B)"), # Updated label - Unsloth 32B is code generator
175
- ],
176
- title="DeepSeek Agent Swarm Chat (ZeroGPU Demo - Fixed Models: Unsloth 7B + 32B)", # Updated title
177
- description="Chat with a DeepSeek agent swarm (Unsloth 7B + 32B) with shared memory, adjustable settings, **and customizable prompts!** **GPU accelerated using ZeroGPU!** (Requires Pro Space)", # Updated description
178
- )
179
 
180
  if __name__ == "__main__":
181
- iface.launch() # Only launch locally if running this script directly
 
1
  import gradio as gr
2
+ import os
 
3
  import spaces # Import the spaces library
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
5
+ import torch
6
+ from threading import Thread
7
 
8
  # Model IDs from Hugging Face Hub (Fixed to Unsloth 7B and 32B Unsloth 4bit)
9
  model_ids = {
 
95
  return relevant_memories[:top_k]
96
 
97
 
98
+ # --- Streaming Swarm Agent Function - Fixed Models (Unsloth 7B and 32B Unsloth) ---
99
+ @spaces.GPU(duration=120) # Added duration
100
  def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_template, temperature=0.5, top_p=0.9, max_new_tokens=300): # Removed final_model_size
101
  global shared_memory
102
  shared_memory = [] # Clear memory for each new request
 
113
  prompt_7b_brainstorm = prompt_1_5b_template.format(user_prompt=user_prompt, context_1_5b=context_7b) # Reusing 1.5B template - adjust if needed
114
 
115
  input_ids_7b = tokenizer_7b.encode(prompt_7b_brainstorm, return_tensors="pt").to(model_7b.device)
116
+ streamer_7b = TextIteratorStreamer(tokenizer_7b, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 7B
117
+
118
+ generate_kwargs_7b = dict( # Generation kwargs for 7B
119
+ input_ids= input_ids_7b,
120
+ streamer=streamer_7b,
121
  max_new_tokens=max_new_tokens, # Use user-defined max_new_tokens
122
+ do_sample=True,
123
+ temperature=temperature,
124
+ top_p=top_p,
125
+ # eos_token_id=tokenizer_7b.eos_token_id, # Not strictly needed as streamer handles EOS
126
  )
127
+
128
+ thread_7b = Thread(target=model_7b.generate, kwargs=generate_kwargs_7b) # Thread for 7B generation
129
+ thread_7b.start()
130
+
131
+ response_7b_stream = "" # Accumulate streamed 7B response
132
+ print(f"7B Unsloth Response (Brainstorming):\n", end="")
133
+ for text in streamer_7b: # Stream and print 7B response
134
+ print(text, end="", flush=True) # Print in place
135
+ response_7b_stream += text
136
+ yield response_7b_stream # Yield intermediate 7B response
137
+
138
+ store_in_memory(f"7B Unsloth Model Initial Response: {response_7b_stream[:200]}...") # Store accumulated 7B response
139
 
140
  # 32B Unsloth Model - Final Code Generation (Lazy Load and get model)
141
  final_model, final_tokenizer = get_model_and_tokenizer("32B-Unsloth") # Lazy load 32B Unsloth
 
143
  model_stage_name = "32B Unsloth Model - Final Code"
144
  final_max_new_tokens = max_new_tokens + 200 # More tokens for 32B model
145
 
146
+ retrieved_memory_final = retrieve_from_memory(response_7b_stream) # Memory from streamed 7B response
147
  context_final = "\n".join([f"- {mem}" for mem in retrieved_memory_final]) if retrieved_memory_final else "No relevant context found in memory."
148
 
149
  # Use user-provided prompt template for final model (using 7B template)
150
+ prompt_final = prompt_7b_template.format(response_1_5b=response_7b_stream, context_7b=context_final) # Using prompt_7b_template for final stage
 
151
 
152
  input_ids_final = final_tokenizer.encode(prompt_final, return_tensors="pt").to(final_model.device)
153
+ streamer_final = TextIteratorStreamer(final_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 32B
154
+
155
+ generate_kwargs_final = dict( # Generation kwargs for 32B
156
+ input_ids= input_ids_final,
157
+ streamer=streamer_final,
158
  max_new_tokens=final_max_new_tokens,
159
  temperature=temperature,
160
  top_p=top_p,
161
+ # eos_token_id=final_tokenizer.eos_token_id, # Not strictly needed as streamer handles EOS
162
  )
 
 
 
163
 
164
+ thread_final = Thread(target=final_model.generate, kwargs=generate_kwargs_final) # Thread for 32B generation
165
+ thread_final.start()
166
+
167
+ response_final_stream = "" # Accumulate streamed 32B response
168
+ print(f"\n{model_stage_name} Response:\n", end="")
169
+ for text in streamer_final: # Stream and print 32B response
170
+ print(text, end="", flush=True) # Print in place
171
+ response_final_stream += text
172
+ yield response_final_stream # Yield intermediate 32B response
173
+
174
+ store_in_memory(f"{model_stage_name} Response: {response_final_stream[:200]}...") # Store accumulated 32B response
175
+
176
+ return response_final_stream # Returns final streamed response
177
 
178
 
179
  # --- Gradio ChatInterface --- (No Model Selection Dropdown anymore)
180
  def gradio_interface(message, history, temp, top_p, max_tokens, prompt_1_5b_text, prompt_7b_text): # Removed final_model_selector
181
  # history is automatically managed by ChatInterface
182
+ full_response = "" # Accumulate full response from generator
183
+ for partial_response in swarm_agent_sequential_rag( # Iterate through generator
184
  message,
185
  prompt_1_5b_template=prompt_1_5b_text, # Pass prompt templates
186
  prompt_7b_template=prompt_7b_text,
187
  temperature=temp,
188
  top_p=top_p,
189
  max_new_tokens=int(max_tokens) # Ensure max_tokens is an integer
190
+ ):
191
+ full_response = partial_response # Update full response with partial response
192
+ yield full_response # Yield intermediate full response
193
+
194
+
195
+ DESCRIPTION = '''
196
+ <div>
197
+ <h1 style="text-align: center;">DeepSeek Agent Swarm Chat (Unsloth 7B + 32B) - Streaming Demo</h1>
198
+ <p style="text-align: center;">Agent swarm using Unsloth DeepSeek-R1-Distill models (7B + 32B) with shared memory, adjustable settings, and customizable prompts. GPU accelerated using ZeroGPU! (Requires Pro Space)</p>
199
+ </div>
200
+ '''
201
+
202
+ LICENSE = """
203
+ <p/>
204
+ ---
205
+ """
206
+
207
+ PLACEHOLDER = """
208
+ Ask me anything...
209
+ """
210
+
211
+
212
+ css = """
213
+ h1 {
214
+ text-align: center;
215
+ display: block;
216
+ }
217
+ #duplicate-button {
218
+ margin: auto;
219
+ color: white;
220
+ background: #1565c0;
221
+ border-radius: 100vh;
222
+ }
223
+ """
224
+ # Gradio ChatInterface with streaming
225
+ chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Agent Swarm Output')
226
+
227
+ with gr.Blocks(fill_height=True, css=css) as demo:
228
+
229
+ gr.Markdown(DESCRIPTION)
230
+ gr.ChatInterface(
231
+ fn=gradio_interface,
232
+ chatbot=chatbot,
233
+ fill_height=True,
234
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), # Accordion for params
235
+ additional_inputs=[
236
+ gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature"),
237
+ gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P"),
238
+ gr.Number(value=300, label="Max Tokens", precision=0),
239
+ gr.Textbox(value=default_prompt_1_5b, lines=7, label="Brainstorming Model Prompt Template (Unsloth 7B)"),
240
+ gr.Textbox(value=default_prompt_7b, lines=7, label="Code Generation Prompt Template (Unsloth 32B)"),
241
+ ],
242
+ examples=[
243
+ ['How to setup a human base on Mars? Give short answer.'],
244
+ ['Explain theory of relativity to me like I’m 8 years old.'],
245
+ ['Write a streamlit app to track my finances'],
246
+ ['Write a pun-filled happy birthday message to my friend Alex.'],
247
+ ['Justify why a penguin might make a good king of the jungle.']
248
+ ],
249
+ cache_examples=False,
250
  )
251
+
252
+ gr.Markdown(LICENSE)
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  if __name__ == "__main__":
255
+ demo.launch()