wuhp commited on
Commit
ccc6355
·
verified ·
1 Parent(s): 361c4d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +291 -171
app.py CHANGED
@@ -5,210 +5,273 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
  import torch
6
  from threading import Thread
7
 
8
- # Model IDs from Hugging Face Hub (Fixed to Unsloth 7B and 32B Unsloth 4bit)
9
- model_ids = {
10
- "7B-Unsloth": "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit", # Unsloth 7B model
11
- "32B-Unsloth": "unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit", # Unsloth 32B model
12
- }
13
 
14
- models = {} # Keep models as a dictionary, but initially empty
15
- tokenizers = {} # Keep tokenizers as a dictionary, initially empty
 
16
 
17
- # BitsAndBytesConfig for 4-bit quantization (for BOTH models now)
18
  bnb_config_4bit = BitsAndBytesConfig(
19
  load_in_4bit=True,
20
  bnb_4bit_quant_type="nf4",
21
- bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
22
  )
23
 
24
-
25
- def get_model_and_tokenizer(size): # Function to load model on demand
26
- if size not in models: # Load only if not already loaded
27
- model_id = model_ids[size]
28
- print(f"Loading {size} model: {model_id} on demand")
29
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
30
  model = AutoModelForCausalLM.from_pretrained(
31
- model_id,
32
- quantization_config=bnb_config_4bit, # Apply 4-bit config for BOTH models
33
- torch_dtype=torch.bfloat16, # Or torch.float16 if needed
34
  device_map='auto',
35
- trust_remote_code=True
36
  )
37
- models[size] = model
38
- tokenizers[size] = tokenizer
39
- print(f"Loaded {size} model on demand.")
40
- return models[size], tokenizers[size]
41
 
42
-
43
- # Revised Default Prompts (as defined previously - these are still good)
44
- default_prompt_1_5b = """**Code Analysis Task**
45
- As a Senior Code Analyst, analyze this programming problem:
46
 
47
  **User Request:**
48
  {user_prompt}
49
 
50
- **Relevant Context:**
51
- {context_1_5b}
52
-
53
- **Analysis Required:**
54
- 1. Briefly break down the problem, including key constraints and edge cases.
55
- 2. Suggest 2-3 potential approach options (algorithms/data structures).
56
- 3. Recommend ONE primary strategy and briefly justify your choice.
57
- 4. Provide a very brief initial pseudocode sketch of the core logic."""
58
-
59
 
60
- default_prompt_7b = """**Code Implementation Task**
61
- As a Principal Software Engineer, provide production-ready Streamlit/Python code based on this analysis:
62
 
63
  **Initial Analysis:**
64
- {response_1_5b}
65
 
66
- **Relevant Context:**
67
- {context_7b}
68
 
69
- **Code Requirements:**
70
- 1. Generate concise, production-grade Python code for a Streamlit app.
71
- 2. Include necessary imports, UI elements, and basic functionality.
72
- 3. Add comments for clarity.
73
- """
 
74
 
 
 
 
 
 
75
 
76
- # --- Shared Memory Implementation --- (Same)
 
 
 
 
77
  shared_memory = []
78
 
79
  def store_in_memory(memory_item):
 
80
  shared_memory.append(memory_item)
81
  print(f"\n[Memory Stored]: {memory_item[:50]}...")
82
 
83
  def retrieve_from_memory(query, top_k=2):
 
 
 
 
84
  relevant_memories = []
85
  query_lower = query.lower()
86
  for memory_item in shared_memory:
87
  if query_lower in memory_item.lower():
88
  relevant_memories.append(memory_item)
89
-
90
  if not relevant_memories:
91
  print("\n[Memory Retrieval]: No relevant memories found.")
92
  return []
93
-
94
  print(f"\n[Memory Retrieval]: Found {len(relevant_memories)} relevant memories.")
95
  return relevant_memories[:top_k]
96
 
97
-
98
- # --- Streaming Swarm Agent Function - Fixed Models (Unsloth 7B and 32B Unsloth) ---
99
- @spaces.GPU(duration=120) # Added duration
100
- def swarm_agent_sequential_rag(user_prompt, prompt_1_5b_template, prompt_7b_template, temperature=0.5, top_p=0.9, max_new_tokens=300): # Removed final_model_size
 
 
 
 
 
 
 
101
  global shared_memory
102
- shared_memory = [] # Clear memory for each new request
103
-
104
- print(f"\n--- Swarm Agent Processing with Shared Memory (RAG) - GPU ACCELERATED - Final Model: 32B Unsloth ---") # Updated message
105
-
106
- # 7B Unsloth Model - Brainstorming/Initial Draft (Lazy Load and get model)
107
- print("\n[7B Unsloth Model - Brainstorming] - GPU Accelerated") # Now 7B Unsloth is brainstorming
108
- model_7b, tokenizer_7b = get_model_and_tokenizer("7B-Unsloth") # Lazy load 7B Unsloth
109
- retrieved_memory_7b = retrieve_from_memory(user_prompt)
110
- context_7b = "\n".join([f"- {mem}" for mem in retrieved_memory_7b]) if retrieved_memory_7b else "No relevant context found in memory."
111
-
112
- # Use user-provided prompt template for 7B model (as brainstorming model now)
113
- prompt_7b_brainstorm = prompt_1_5b_template.format(user_prompt=user_prompt, context_1_5b=context_7b) # Reusing 1.5B template - adjust if needed
114
-
115
- input_ids_7b = tokenizer_7b.encode(prompt_7b_brainstorm, return_tensors="pt").to(model_7b.device)
116
- streamer_7b = TextIteratorStreamer(tokenizer_7b, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 7B
117
-
118
- generate_kwargs_7b = dict( # Generation kwargs for 7B
119
- input_ids= input_ids_7b,
120
- streamer=streamer_7b,
121
- max_new_tokens=max_new_tokens, # Use user-defined max_new_tokens
122
  do_sample=True,
123
- temperature=temperature,
124
  top_p=top_p,
125
- # eos_token_id=tokenizer_7b.eos_token_id, # Not strictly needed as streamer handles EOS
126
  )
127
-
128
- thread_7b = Thread(target=model_7b.generate, kwargs=generate_kwargs_7b) # Thread for 7B generation
129
- thread_7b.start()
130
-
131
- response_7b_stream = "" # Accumulate streamed 7B response
132
- print(f"7B Unsloth Response (Brainstorming):\n", end="")
133
- for text in streamer_7b: # Stream and print 7B response
134
- print(text, end="", flush=True) # Print in place
135
- response_7b_stream += text
136
- yield response_7b_stream # Yield intermediate 7B response
137
-
138
- store_in_memory(f"7B Unsloth Model Initial Response: {response_7b_stream[:200]}...") # Store accumulated 7B response
139
-
140
- # 32B Unsloth Model - Final Code Generation (Lazy Load and get model)
141
- final_model, final_tokenizer = get_model_and_tokenizer("32B-Unsloth") # Lazy load 32B Unsloth
142
- print("\n[32B Unsloth Model - Final Code Generation] - GPU Accelerated") # Model-specific message
143
- model_stage_name = "32B Unsloth Model - Final Code"
144
- final_max_new_tokens = max_new_tokens + 200 # More tokens for 32B model
145
-
146
- retrieved_memory_final = retrieve_from_memory(response_7b_stream) # Memory from streamed 7B response
147
- context_final = "\n".join([f"- {mem}" for mem in retrieved_memory_final]) if retrieved_memory_final else "No relevant context found in memory."
148
-
149
- # Use user-provided prompt template for final model (using 7B template)
150
- prompt_final = prompt_7b_template.format(response_1_5b=response_7b_stream, context_7b=context_final) # Using prompt_7b_template for final stage
151
-
152
- input_ids_final = final_tokenizer.encode(prompt_final, return_tensors="pt").to(final_model.device)
153
- streamer_final = TextIteratorStreamer(final_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 32B
154
-
155
- generate_kwargs_final = dict( # Generation kwargs for 32B
156
- input_ids= input_ids_final,
157
- streamer=streamer_final,
158
- max_new_tokens=final_max_new_tokens,
159
- temperature=temperature,
160
  top_p=top_p,
161
- # eos_token_id=final_tokenizer.eos_token_id, # Not strictly needed as streamer handles EOS
162
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- thread_final = Thread(target=final_model.generate, kwargs=generate_kwargs_final) # Thread for 32B generation
165
- thread_final.start()
166
-
167
- response_final_stream = "" # Accumulate streamed 32B response
168
- print(f"\n{model_stage_name} Response:\n", end="")
169
- for text in streamer_final: # Stream and print 32B response
170
- print(text, end="", flush=True) # Print in place
171
- response_final_stream += text
172
- yield response_final_stream # Yield intermediate 32B response
173
-
174
- store_in_memory(f"{model_stage_name} Response: {response_final_stream[:200]}...") # Store accumulated 32B response
175
-
176
- return response_final_stream # Returns final streamed response
177
 
 
178
 
179
- # --- Gradio ChatInterface --- (No Model Selection Dropdown anymore)
180
- def gradio_interface(message, history, temp, top_p, max_tokens, prompt_1_5b_text, prompt_7b_text): # Removed final_model_selector
181
- # history is automatically managed by ChatInterface
182
- full_response = "" # Accumulate full response from generator
183
- for partial_response in swarm_agent_sequential_rag( # Iterate through generator
184
- message,
185
- prompt_1_5b_template=prompt_1_5b_text, # Pass prompt templates
186
- prompt_7b_template=prompt_7b_text,
187
- temperature=temp,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  top_p=top_p,
189
- max_new_tokens=int(max_tokens) # Ensure max_tokens is an integer
 
 
 
 
190
  ):
191
- full_response = partial_response # Update full response with partial response
192
- yield full_response # Yield intermediate full response
193
-
 
194
 
195
- DESCRIPTION = '''
 
196
  <div>
197
- <h1 style="text-align: center;">DeepSeek Agent Swarm Chat (Unsloth 7B + 32B) - Streaming Demo</h1>
198
- <p style="text-align: center;">Agent swarm using Unsloth DeepSeek-R1-Distill models (7B + 32B) with shared memory, adjustable settings, and customizable prompts. GPU accelerated using ZeroGPU! (Requires Pro Space)</p>
 
 
 
 
 
199
  </div>
200
  '''
201
 
202
- LICENSE = """
203
  <p/>
204
  ---
205
  """
206
 
207
- PLACEHOLDER = """
208
- Ask me anything...
 
 
 
209
  """
210
 
211
-
212
  css = """
213
  h1 {
214
  text-align: center;
@@ -221,35 +284,92 @@ h1 {
221
  border-radius: 100vh;
222
  }
223
  """
224
- # Gradio ChatInterface with streaming
225
- chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Agent Swarm Output')
226
-
227
- with gr.Blocks(fill_height=True, css=css) as demo:
228
-
229
- gr.Markdown(DESCRIPTION)
230
- gr.ChatInterface(
231
- fn=gradio_interface,
232
- chatbot=chatbot,
233
- fill_height=True,
234
- additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), # Accordion for params
235
- additional_inputs=[
236
- gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature"),
237
- gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P"),
238
- gr.Number(value=300, label="Max Tokens", precision=0),
239
- gr.Textbox(value=default_prompt_1_5b, lines=7, label="Brainstorming Model Prompt Template (Unsloth 7B)"),
240
- gr.Textbox(value=default_prompt_7b, lines=7, label="Code Generation Prompt Template (Unsloth 32B)"),
241
- ],
242
- examples=[
243
- ['How to setup a human base on Mars? Give short answer.'],
244
- ['Explain theory of relativity to me like I’m 8 years old.'],
245
- ['Write a streamlit app to track my finances'],
246
- ['Write a pun-filled happy birthday message to my friend Alex.'],
247
- ['Justify why a penguin might make a good king of the jungle.']
248
- ],
249
- cache_examples=False,
250
- )
251
 
252
- gr.Markdown(LICENSE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  if __name__ == "__main__":
255
- demo.launch()
 
5
  import torch
6
  from threading import Thread
7
 
8
+ # --- Model & Quantization Settings ---
9
+ MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
 
 
 
10
 
11
+ # Dictionaries to store the loaded model and tokenizer
12
+ models = {}
13
+ tokenizers = {}
14
 
 
15
  bnb_config_4bit = BitsAndBytesConfig(
16
  load_in_4bit=True,
17
  bnb_4bit_quant_type="nf4",
18
+ bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
19
  )
20
 
21
+ def get_model_and_tokenizer():
22
+ """Lazy-load the model and tokenizer if not already loaded."""
23
+ if "7B" not in models:
24
+ print(f"Loading 7B model: {MODEL_ID} on demand")
25
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
 
26
  model = AutoModelForCausalLM.from_pretrained(
27
+ MODEL_ID,
28
+ quantization_config=bnb_config_4bit,
29
+ torch_dtype=torch.bfloat16, # Or torch.float16 if needed
30
  device_map='auto',
31
+ trust_remote_code=True,
32
  )
33
+ models["7B"] = model
34
+ tokenizers["7B"] = tokenizer
35
+ print("Loaded 7B model on demand.")
36
+ return models["7B"], tokenizers["7B"]
37
 
38
+ # --- Default Prompt Templates ---
39
+ default_prompt_brainstorm = """**Brainstorming Task (Round 1)**
40
+ As a Senior Code Analyst, provide an initial analysis of the problem below.
 
41
 
42
  **User Request:**
43
  {user_prompt}
44
 
45
+ **Guidelines:**
46
+ 1. Identify key challenges and constraints.
47
+ 2. Suggest multiple potential approaches.
48
+ 3. Outline any potential edge cases or critical considerations.
49
+ """
 
 
 
 
50
 
51
+ default_prompt_code_generation = """**Advanced Reasoning & Code Generation (Round 2)**
52
+ Based on the initial analysis below:
53
 
54
  **Initial Analysis:**
55
+ {brainstorm_response}
56
 
57
+ **User Request:**
58
+ {user_prompt}
59
 
60
+ **Task:**
61
+ 1. Develop a detailed solution that includes production-ready code.
62
+ 2. Explain the reasoning behind the chosen approach.
63
+ 3. Incorporate advanced reasoning to handle edge cases.
64
+ 4. Provide commented code that is clear and maintainable.
65
+ """
66
 
67
+ default_prompt_synthesis = """**Synthesis & Final Refinement (Round 3)**
68
+ Review the detailed code generation and reasoning below, and produce a final, refined response that:
69
+ 1. Synthesizes the brainstorming insights and advanced reasoning.
70
+ 2. Provides a concise summary of the solution.
71
+ 3. Highlights any potential improvements or considerations.
72
 
73
+ **Detailed Response:**
74
+ {code_response}
75
+ """
76
+
77
+ # --- Shared Memory for Rounds ---
78
  shared_memory = []
79
 
80
  def store_in_memory(memory_item):
81
+ """Store a memory item and log an excerpt."""
82
  shared_memory.append(memory_item)
83
  print(f"\n[Memory Stored]: {memory_item[:50]}...")
84
 
85
  def retrieve_from_memory(query, top_k=2):
86
+ """
87
+ Retrieve memory items that contain the query text (case-insensitive).
88
+ Returns up to top_k items.
89
+ """
90
  relevant_memories = []
91
  query_lower = query.lower()
92
  for memory_item in shared_memory:
93
  if query_lower in memory_item.lower():
94
  relevant_memories.append(memory_item)
 
95
  if not relevant_memories:
96
  print("\n[Memory Retrieval]: No relevant memories found.")
97
  return []
 
98
  print(f"\n[Memory Retrieval]: Found {len(relevant_memories)} relevant memories.")
99
  return relevant_memories[:top_k]
100
 
101
+ # --- Multi-Round Swarm Agent Function ---
102
+ @spaces.GPU(duration=180) # Adjust duration as needed
103
+ def swarm_agent_iterative(user_prompt, temp, top_p, max_new_tokens, memory_top_k,
104
+ prompt_brainstorm_text, prompt_code_generation_text, prompt_synthesis_text):
105
+ """
106
+ A three-round iterative process that uses the provided prompt templates:
107
+ - Round 1: Brainstorming.
108
+ - Round 2: Advanced reasoning & code generation.
109
+ - Round 3: Synthesis & refinement.
110
+ This generator yields the response from the final round as it is produced.
111
+ """
112
  global shared_memory
113
+ shared_memory = [] # Clear shared memory for each new request
114
+
115
+ model, tokenizer = get_model_and_tokenizer()
116
+
117
+ # ----- Round 1: Brainstorming -----
118
+ print("\n--- Round 1: Brainstorming ---")
119
+ prompt_round1 = prompt_brainstorm_text.format(user_prompt=user_prompt)
120
+ input_ids_r1 = tokenizer.encode(prompt_round1, return_tensors="pt").to(model.device)
121
+ streamer_r1 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
122
+ kwargs_r1 = dict(
123
+ input_ids=input_ids_r1,
124
+ streamer=streamer_r1,
125
+ max_new_tokens=max_new_tokens,
 
 
 
 
 
 
 
126
  do_sample=True,
127
+ temperature=temp,
128
  top_p=top_p,
 
129
  )
130
+ thread_r1 = Thread(target=model.generate, kwargs=kwargs_r1)
131
+ thread_r1.start()
132
+
133
+ brainstorm_response = ""
134
+ for text in streamer_r1:
135
+ print(text, end="", flush=True)
136
+ brainstorm_response += text
137
+ store_in_memory(f"Brainstorm Response: {brainstorm_response[:200]}...")
138
+
139
+ # ----- Round 2: Code Generation -----
140
+ print("\n\n--- Round 2: Code Generation ---")
141
+ prompt_round2 = prompt_code_generation_text.format(
142
+ brainstorm_response=brainstorm_response,
143
+ user_prompt=user_prompt
144
+ )
145
+ input_ids_r2 = tokenizer.encode(prompt_round2, return_tensors="pt").to(model.device)
146
+ streamer_r2 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
147
+ kwargs_r2 = dict(
148
+ input_ids=input_ids_r2,
149
+ streamer=streamer_r2,
150
+ max_new_tokens=max_new_tokens + 100, # extra tokens for detail
151
+ temperature=temp,
 
 
 
 
 
 
 
 
 
 
 
152
  top_p=top_p,
 
153
  )
154
+ thread_r2 = Thread(target=model.generate, kwargs=kwargs_r2)
155
+ thread_r2.start()
156
+
157
+ code_response = ""
158
+ for text in streamer_r2:
159
+ print(text, end="", flush=True)
160
+ code_response += text
161
+ store_in_memory(f"Code Generation Response: {code_response[:200]}...")
162
+
163
+ # ----- Round 3: Synthesis & Refinement -----
164
+ print("\n\n--- Round 3: Synthesis & Refinement ---")
165
+ prompt_round3 = prompt_synthesis_text.format(code_response=code_response)
166
+ input_ids_r3 = tokenizer.encode(prompt_round3, return_tensors="pt").to(model.device)
167
+ streamer_r3 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
168
+ kwargs_r3 = dict(
169
+ input_ids=input_ids_r3,
170
+ streamer=streamer_r3,
171
+ max_new_tokens=max_new_tokens // 2,
172
+ temperature=temp,
173
+ top_p=top_p,
174
+ )
175
+ thread_r3 = Thread(target=model.generate, kwargs=kwargs_r3)
176
+ thread_r3.start()
177
 
178
+ final_response = ""
179
+ for text in streamer_r3:
180
+ print(text, end="", flush=True)
181
+ final_response += text
182
+ yield final_response # yield progressive updates
 
 
 
 
 
 
 
 
183
 
184
+ store_in_memory(f"Final Synthesis Response: {final_response[:200]}...")
185
 
186
+ # --- Helper to Format History ---
187
+ def format_history(history):
188
+ """
189
+ Convert history (which might be a list of [user, assistant] pairs or already formatted dictionaries)
190
+ into a list of OpenAI-style message dictionaries.
191
+ """
192
+ messages = []
193
+ for item in history:
194
+ # If item is a list or tuple, try to unpack it if it has exactly 2 elements.
195
+ if isinstance(item, (list, tuple)):
196
+ if len(item) == 2:
197
+ user_msg, assistant_msg = item
198
+ messages.append({"role": "user", "content": user_msg})
199
+ if assistant_msg:
200
+ messages.append({"role": "assistant", "content": assistant_msg})
201
+ else:
202
+ # If it doesn't have exactly two items, skip it.
203
+ continue
204
+ elif isinstance(item, dict):
205
+ # Already formatted message dictionary.
206
+ messages.append(item)
207
+ else:
208
+ continue
209
+ return messages
210
+
211
+ # --- Gradio Chat Interface Function ---
212
+ def gradio_interface(message, history, param_state, prompt_state):
213
+ """
214
+ This function is called by Gradio's ChatInterface.
215
+ It uses the current saved generation parameters and prompt templates.
216
+ """
217
+ # Unpack parameter state (with fallback defaults)
218
+ try:
219
+ temp = float(param_state.get("temperature", 0.5))
220
+ top_p = float(param_state.get("top_p", 0.9))
221
+ max_new_tokens = int(param_state.get("max_new_tokens", 300))
222
+ memory_top_k = int(param_state.get("memory_top_k", 2))
223
+ except Exception:
224
+ temp, top_p, max_new_tokens, memory_top_k = 0.5, 0.9, 300, 2
225
+
226
+ # Unpack prompt state (with fallback defaults)
227
+ prompt_brainstorm_text = prompt_state.get("prompt_brainstorm", default_prompt_brainstorm)
228
+ prompt_code_generation_text = prompt_state.get("prompt_code_generation", default_prompt_code_generation)
229
+ prompt_synthesis_text = prompt_state.get("prompt_synthesis", default_prompt_synthesis)
230
+
231
+ # Append the new user message with an empty assistant reply (as a two-item list)
232
+ history = history + [[message, ""]]
233
+
234
+ # Call the multi-round agent as a generator (for streaming)
235
+ for partial_response in swarm_agent_iterative(
236
+ user_prompt=message,
237
+ temp=temp,
238
  top_p=top_p,
239
+ max_new_tokens=max_new_tokens,
240
+ memory_top_k=memory_top_k,
241
+ prompt_brainstorm_text=prompt_brainstorm_text,
242
+ prompt_code_generation_text=prompt_code_generation_text,
243
+ prompt_synthesis_text=prompt_synthesis_text
244
  ):
245
+ # Update the last assistant message with the new partial response.
246
+ history[-1][1] = partial_response
247
+ # Yield the history formatted as OpenAI-style messages.
248
+ yield format_history(history)
249
 
250
+ # --- UI Settings & Styling ---
251
+ ui_description = '''
252
  <div>
253
+ <h1 style="text-align: center;">DeepSeek Agent Swarm Chat</h1>
254
+ <p style="text-align: center;">
255
+ Multi-round agent:
256
+ <br>- Brainstorming
257
+ <br>- Advanced reasoning & code generation
258
+ <br>- Synthesis & refinement
259
+ </p>
260
  </div>
261
  '''
262
 
263
+ ui_license = """
264
  <p/>
265
  ---
266
  """
267
 
268
+ ui_placeholder = """
269
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
270
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">DeepSeek Agent Swarm</h1>
271
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
272
+ </div>
273
  """
274
 
 
275
  css = """
276
  h1 {
277
  text-align: center;
 
284
  border-radius: 100vh;
285
  }
286
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
 
288
+ # --- Gradio UI ---
289
+ with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
290
+ gr.Markdown(ui_description)
291
+
292
+ # Hidden States to hold parameters and prompt configuration
293
+ param_state = gr.State({
294
+ "temperature": 0.5,
295
+ "top_p": 0.9,
296
+ "max_new_tokens": 300,
297
+ "memory_top_k": 2,
298
+ })
299
+ prompt_state = gr.State({
300
+ "prompt_brainstorm": default_prompt_brainstorm,
301
+ "prompt_code_generation": default_prompt_code_generation,
302
+ "prompt_synthesis": default_prompt_synthesis,
303
+ })
304
+
305
+ # Create top-level Tabs
306
+ with gr.Tabs():
307
+ # --- Chat Tab ---
308
+ with gr.Tab("Chat"):
309
+ # Set type="messages" for OpenAI-style message dictionaries
310
+ chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages")
311
+ # Use ChatInterface and pass the hidden states as additional inputs.
312
+ gr.ChatInterface(
313
+ fn=gradio_interface,
314
+ chatbot=chatbot,
315
+ additional_inputs=[param_state, prompt_state],
316
+ examples=[
317
+ ['How can we build a robust web service that scales efficiently under load?'],
318
+ ['Explain how to design a fault-tolerant distributed system.'],
319
+ ['Develop a streamlit app that visualizes real-time financial data.'],
320
+ ['Create a pun-filled birthday message with a coding twist.'],
321
+ ['Design a system that uses machine learning to optimize resource allocation.']
322
+ ],
323
+ cache_examples=False,
324
+ type="messages",
325
+ )
326
+
327
+ # --- Parameters Tab ---
328
+ with gr.Tab("Parameters"):
329
+ gr.Markdown("### Generation Parameters")
330
+ temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature")
331
+ top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P")
332
+ max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0)
333
+ memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
334
+ save_params_btn = gr.Button("Save Parameters")
335
+ # When the user clicks Save, update the param_state
336
+ save_params_btn.click(
337
+ lambda t, p, m, k: {"temperature": t, "top_p": p, "max_new_tokens": m, "memory_top_k": k},
338
+ inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider],
339
+ outputs=param_state,
340
+ )
341
+
342
+ # --- Prompt Config Tab ---
343
+ with gr.Tab("Prompt Config"):
344
+ gr.Markdown("### Configure Prompt Templates")
345
+ prompt_brainstorm_box = gr.Textbox(
346
+ value=default_prompt_brainstorm,
347
+ label="Brainstorm Prompt",
348
+ lines=8,
349
+ )
350
+ prompt_code_generation_box = gr.Textbox(
351
+ value=default_prompt_code_generation,
352
+ label="Code Generation Prompt",
353
+ lines=8,
354
+ )
355
+ prompt_synthesis_box = gr.Textbox(
356
+ value=default_prompt_synthesis,
357
+ label="Synthesis Prompt",
358
+ lines=8,
359
+ )
360
+ save_prompts_btn = gr.Button("Save Prompts")
361
+ # When clicked, update the prompt_state with new values
362
+ save_prompts_btn.click(
363
+ lambda b, c, s: {
364
+ "prompt_brainstorm": b,
365
+ "prompt_code_generation": c,
366
+ "prompt_synthesis": s,
367
+ },
368
+ inputs=[prompt_brainstorm_box, prompt_code_generation_box, prompt_synthesis_box],
369
+ outputs=prompt_state,
370
+ )
371
+
372
+ gr.Markdown(ui_license)
373
 
374
  if __name__ == "__main__":
375
+ demo.launch()