Update app.py
Browse files
app.py
CHANGED
@@ -5,210 +5,273 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
5 |
import torch
|
6 |
from threading import Thread
|
7 |
|
8 |
-
# Model
|
9 |
-
|
10 |
-
"7B-Unsloth": "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit", # Unsloth 7B model
|
11 |
-
"32B-Unsloth": "unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit", # Unsloth 32B model
|
12 |
-
}
|
13 |
|
14 |
-
|
15 |
-
|
|
|
16 |
|
17 |
-
# BitsAndBytesConfig for 4-bit quantization (for BOTH models now)
|
18 |
bnb_config_4bit = BitsAndBytesConfig(
|
19 |
load_in_4bit=True,
|
20 |
bnb_4bit_quant_type="nf4",
|
21 |
-
bnb_4bit_compute_dtype=torch.bfloat16,
|
22 |
)
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
if
|
27 |
-
|
28 |
-
|
29 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
30 |
model = AutoModelForCausalLM.from_pretrained(
|
31 |
-
|
32 |
-
quantization_config=bnb_config_4bit,
|
33 |
-
torch_dtype=torch.bfloat16,
|
34 |
device_map='auto',
|
35 |
-
trust_remote_code=True
|
36 |
)
|
37 |
-
models[
|
38 |
-
tokenizers[
|
39 |
-
print(
|
40 |
-
return models[
|
41 |
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
As a Senior Code Analyst, analyze this programming problem:
|
46 |
|
47 |
**User Request:**
|
48 |
{user_prompt}
|
49 |
|
50 |
-
**
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
2. Suggest 2-3 potential approach options (algorithms/data structures).
|
56 |
-
3. Recommend ONE primary strategy and briefly justify your choice.
|
57 |
-
4. Provide a very brief initial pseudocode sketch of the core logic."""
|
58 |
-
|
59 |
|
60 |
-
|
61 |
-
|
62 |
|
63 |
**Initial Analysis:**
|
64 |
-
{
|
65 |
|
66 |
-
**
|
67 |
-
{
|
68 |
|
69 |
-
**
|
70 |
-
1.
|
71 |
-
2.
|
72 |
-
3.
|
73 |
-
|
|
|
74 |
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
-
|
|
|
|
|
|
|
|
|
77 |
shared_memory = []
|
78 |
|
79 |
def store_in_memory(memory_item):
|
|
|
80 |
shared_memory.append(memory_item)
|
81 |
print(f"\n[Memory Stored]: {memory_item[:50]}...")
|
82 |
|
83 |
def retrieve_from_memory(query, top_k=2):
|
|
|
|
|
|
|
|
|
84 |
relevant_memories = []
|
85 |
query_lower = query.lower()
|
86 |
for memory_item in shared_memory:
|
87 |
if query_lower in memory_item.lower():
|
88 |
relevant_memories.append(memory_item)
|
89 |
-
|
90 |
if not relevant_memories:
|
91 |
print("\n[Memory Retrieval]: No relevant memories found.")
|
92 |
return []
|
93 |
-
|
94 |
print(f"\n[Memory Retrieval]: Found {len(relevant_memories)} relevant memories.")
|
95 |
return relevant_memories[:top_k]
|
96 |
|
97 |
-
|
98 |
-
#
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
global shared_memory
|
102 |
-
shared_memory = []
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
#
|
107 |
-
print("\n
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
input_ids_7b = tokenizer_7b.encode(prompt_7b_brainstorm, return_tensors="pt").to(model_7b.device)
|
116 |
-
streamer_7b = TextIteratorStreamer(tokenizer_7b, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 7B
|
117 |
-
|
118 |
-
generate_kwargs_7b = dict( # Generation kwargs for 7B
|
119 |
-
input_ids= input_ids_7b,
|
120 |
-
streamer=streamer_7b,
|
121 |
-
max_new_tokens=max_new_tokens, # Use user-defined max_new_tokens
|
122 |
do_sample=True,
|
123 |
-
temperature=
|
124 |
top_p=top_p,
|
125 |
-
# eos_token_id=tokenizer_7b.eos_token_id, # Not strictly needed as streamer handles EOS
|
126 |
)
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
# Use user-provided prompt template for final model (using 7B template)
|
150 |
-
prompt_final = prompt_7b_template.format(response_1_5b=response_7b_stream, context_7b=context_final) # Using prompt_7b_template for final stage
|
151 |
-
|
152 |
-
input_ids_final = final_tokenizer.encode(prompt_final, return_tensors="pt").to(final_model.device)
|
153 |
-
streamer_final = TextIteratorStreamer(final_tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) # Streamer for 32B
|
154 |
-
|
155 |
-
generate_kwargs_final = dict( # Generation kwargs for 32B
|
156 |
-
input_ids= input_ids_final,
|
157 |
-
streamer=streamer_final,
|
158 |
-
max_new_tokens=final_max_new_tokens,
|
159 |
-
temperature=temperature,
|
160 |
top_p=top_p,
|
161 |
-
# eos_token_id=final_tokenizer.eos_token_id, # Not strictly needed as streamer handles EOS
|
162 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
for text in streamer_final: # Stream and print 32B response
|
170 |
-
print(text, end="", flush=True) # Print in place
|
171 |
-
response_final_stream += text
|
172 |
-
yield response_final_stream # Yield intermediate 32B response
|
173 |
-
|
174 |
-
store_in_memory(f"{model_stage_name} Response: {response_final_stream[:200]}...") # Store accumulated 32B response
|
175 |
-
|
176 |
-
return response_final_stream # Returns final streamed response
|
177 |
|
|
|
178 |
|
179 |
-
# ---
|
180 |
-
def
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
top_p=top_p,
|
189 |
-
max_new_tokens=
|
|
|
|
|
|
|
|
|
190 |
):
|
191 |
-
|
192 |
-
|
193 |
-
|
|
|
194 |
|
195 |
-
|
|
|
196 |
<div>
|
197 |
-
<h1 style="text-align: center;">DeepSeek Agent Swarm Chat
|
198 |
-
<p style="text-align: center;">
|
|
|
|
|
|
|
|
|
|
|
199 |
</div>
|
200 |
'''
|
201 |
|
202 |
-
|
203 |
<p/>
|
204 |
---
|
205 |
"""
|
206 |
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
209 |
"""
|
210 |
|
211 |
-
|
212 |
css = """
|
213 |
h1 {
|
214 |
text-align: center;
|
@@ -221,35 +284,92 @@ h1 {
|
|
221 |
border-radius: 100vh;
|
222 |
}
|
223 |
"""
|
224 |
-
# Gradio ChatInterface with streaming
|
225 |
-
chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Agent Swarm Output')
|
226 |
-
|
227 |
-
with gr.Blocks(fill_height=True, css=css) as demo:
|
228 |
-
|
229 |
-
gr.Markdown(DESCRIPTION)
|
230 |
-
gr.ChatInterface(
|
231 |
-
fn=gradio_interface,
|
232 |
-
chatbot=chatbot,
|
233 |
-
fill_height=True,
|
234 |
-
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False), # Accordion for params
|
235 |
-
additional_inputs=[
|
236 |
-
gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature"),
|
237 |
-
gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P"),
|
238 |
-
gr.Number(value=300, label="Max Tokens", precision=0),
|
239 |
-
gr.Textbox(value=default_prompt_1_5b, lines=7, label="Brainstorming Model Prompt Template (Unsloth 7B)"),
|
240 |
-
gr.Textbox(value=default_prompt_7b, lines=7, label="Code Generation Prompt Template (Unsloth 32B)"),
|
241 |
-
],
|
242 |
-
examples=[
|
243 |
-
['How to setup a human base on Mars? Give short answer.'],
|
244 |
-
['Explain theory of relativity to me like I’m 8 years old.'],
|
245 |
-
['Write a streamlit app to track my finances'],
|
246 |
-
['Write a pun-filled happy birthday message to my friend Alex.'],
|
247 |
-
['Justify why a penguin might make a good king of the jungle.']
|
248 |
-
],
|
249 |
-
cache_examples=False,
|
250 |
-
)
|
251 |
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
if __name__ == "__main__":
|
255 |
-
demo.launch()
|
|
|
5 |
import torch
|
6 |
from threading import Thread
|
7 |
|
8 |
+
# --- Model & Quantization Settings ---
|
9 |
+
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
|
|
|
|
|
|
|
10 |
|
11 |
+
# Dictionaries to store the loaded model and tokenizer
|
12 |
+
models = {}
|
13 |
+
tokenizers = {}
|
14 |
|
|
|
15 |
bnb_config_4bit = BitsAndBytesConfig(
|
16 |
load_in_4bit=True,
|
17 |
bnb_4bit_quant_type="nf4",
|
18 |
+
bnb_4bit_compute_dtype=torch.bfloat16, # Or torch.float16 if needed
|
19 |
)
|
20 |
|
21 |
+
def get_model_and_tokenizer():
|
22 |
+
"""Lazy-load the model and tokenizer if not already loaded."""
|
23 |
+
if "7B" not in models:
|
24 |
+
print(f"Loading 7B model: {MODEL_ID} on demand")
|
25 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
|
|
26 |
model = AutoModelForCausalLM.from_pretrained(
|
27 |
+
MODEL_ID,
|
28 |
+
quantization_config=bnb_config_4bit,
|
29 |
+
torch_dtype=torch.bfloat16, # Or torch.float16 if needed
|
30 |
device_map='auto',
|
31 |
+
trust_remote_code=True,
|
32 |
)
|
33 |
+
models["7B"] = model
|
34 |
+
tokenizers["7B"] = tokenizer
|
35 |
+
print("Loaded 7B model on demand.")
|
36 |
+
return models["7B"], tokenizers["7B"]
|
37 |
|
38 |
+
# --- Default Prompt Templates ---
|
39 |
+
default_prompt_brainstorm = """**Brainstorming Task (Round 1)**
|
40 |
+
As a Senior Code Analyst, provide an initial analysis of the problem below.
|
|
|
41 |
|
42 |
**User Request:**
|
43 |
{user_prompt}
|
44 |
|
45 |
+
**Guidelines:**
|
46 |
+
1. Identify key challenges and constraints.
|
47 |
+
2. Suggest multiple potential approaches.
|
48 |
+
3. Outline any potential edge cases or critical considerations.
|
49 |
+
"""
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
default_prompt_code_generation = """**Advanced Reasoning & Code Generation (Round 2)**
|
52 |
+
Based on the initial analysis below:
|
53 |
|
54 |
**Initial Analysis:**
|
55 |
+
{brainstorm_response}
|
56 |
|
57 |
+
**User Request:**
|
58 |
+
{user_prompt}
|
59 |
|
60 |
+
**Task:**
|
61 |
+
1. Develop a detailed solution that includes production-ready code.
|
62 |
+
2. Explain the reasoning behind the chosen approach.
|
63 |
+
3. Incorporate advanced reasoning to handle edge cases.
|
64 |
+
4. Provide commented code that is clear and maintainable.
|
65 |
+
"""
|
66 |
|
67 |
+
default_prompt_synthesis = """**Synthesis & Final Refinement (Round 3)**
|
68 |
+
Review the detailed code generation and reasoning below, and produce a final, refined response that:
|
69 |
+
1. Synthesizes the brainstorming insights and advanced reasoning.
|
70 |
+
2. Provides a concise summary of the solution.
|
71 |
+
3. Highlights any potential improvements or considerations.
|
72 |
|
73 |
+
**Detailed Response:**
|
74 |
+
{code_response}
|
75 |
+
"""
|
76 |
+
|
77 |
+
# --- Shared Memory for Rounds ---
|
78 |
shared_memory = []
|
79 |
|
80 |
def store_in_memory(memory_item):
|
81 |
+
"""Store a memory item and log an excerpt."""
|
82 |
shared_memory.append(memory_item)
|
83 |
print(f"\n[Memory Stored]: {memory_item[:50]}...")
|
84 |
|
85 |
def retrieve_from_memory(query, top_k=2):
|
86 |
+
"""
|
87 |
+
Retrieve memory items that contain the query text (case-insensitive).
|
88 |
+
Returns up to top_k items.
|
89 |
+
"""
|
90 |
relevant_memories = []
|
91 |
query_lower = query.lower()
|
92 |
for memory_item in shared_memory:
|
93 |
if query_lower in memory_item.lower():
|
94 |
relevant_memories.append(memory_item)
|
|
|
95 |
if not relevant_memories:
|
96 |
print("\n[Memory Retrieval]: No relevant memories found.")
|
97 |
return []
|
|
|
98 |
print(f"\n[Memory Retrieval]: Found {len(relevant_memories)} relevant memories.")
|
99 |
return relevant_memories[:top_k]
|
100 |
|
101 |
+
# --- Multi-Round Swarm Agent Function ---
|
102 |
+
@spaces.GPU(duration=180) # Adjust duration as needed
|
103 |
+
def swarm_agent_iterative(user_prompt, temp, top_p, max_new_tokens, memory_top_k,
|
104 |
+
prompt_brainstorm_text, prompt_code_generation_text, prompt_synthesis_text):
|
105 |
+
"""
|
106 |
+
A three-round iterative process that uses the provided prompt templates:
|
107 |
+
- Round 1: Brainstorming.
|
108 |
+
- Round 2: Advanced reasoning & code generation.
|
109 |
+
- Round 3: Synthesis & refinement.
|
110 |
+
This generator yields the response from the final round as it is produced.
|
111 |
+
"""
|
112 |
global shared_memory
|
113 |
+
shared_memory = [] # Clear shared memory for each new request
|
114 |
+
|
115 |
+
model, tokenizer = get_model_and_tokenizer()
|
116 |
+
|
117 |
+
# ----- Round 1: Brainstorming -----
|
118 |
+
print("\n--- Round 1: Brainstorming ---")
|
119 |
+
prompt_round1 = prompt_brainstorm_text.format(user_prompt=user_prompt)
|
120 |
+
input_ids_r1 = tokenizer.encode(prompt_round1, return_tensors="pt").to(model.device)
|
121 |
+
streamer_r1 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
122 |
+
kwargs_r1 = dict(
|
123 |
+
input_ids=input_ids_r1,
|
124 |
+
streamer=streamer_r1,
|
125 |
+
max_new_tokens=max_new_tokens,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
do_sample=True,
|
127 |
+
temperature=temp,
|
128 |
top_p=top_p,
|
|
|
129 |
)
|
130 |
+
thread_r1 = Thread(target=model.generate, kwargs=kwargs_r1)
|
131 |
+
thread_r1.start()
|
132 |
+
|
133 |
+
brainstorm_response = ""
|
134 |
+
for text in streamer_r1:
|
135 |
+
print(text, end="", flush=True)
|
136 |
+
brainstorm_response += text
|
137 |
+
store_in_memory(f"Brainstorm Response: {brainstorm_response[:200]}...")
|
138 |
+
|
139 |
+
# ----- Round 2: Code Generation -----
|
140 |
+
print("\n\n--- Round 2: Code Generation ---")
|
141 |
+
prompt_round2 = prompt_code_generation_text.format(
|
142 |
+
brainstorm_response=brainstorm_response,
|
143 |
+
user_prompt=user_prompt
|
144 |
+
)
|
145 |
+
input_ids_r2 = tokenizer.encode(prompt_round2, return_tensors="pt").to(model.device)
|
146 |
+
streamer_r2 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
147 |
+
kwargs_r2 = dict(
|
148 |
+
input_ids=input_ids_r2,
|
149 |
+
streamer=streamer_r2,
|
150 |
+
max_new_tokens=max_new_tokens + 100, # extra tokens for detail
|
151 |
+
temperature=temp,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
top_p=top_p,
|
|
|
153 |
)
|
154 |
+
thread_r2 = Thread(target=model.generate, kwargs=kwargs_r2)
|
155 |
+
thread_r2.start()
|
156 |
+
|
157 |
+
code_response = ""
|
158 |
+
for text in streamer_r2:
|
159 |
+
print(text, end="", flush=True)
|
160 |
+
code_response += text
|
161 |
+
store_in_memory(f"Code Generation Response: {code_response[:200]}...")
|
162 |
+
|
163 |
+
# ----- Round 3: Synthesis & Refinement -----
|
164 |
+
print("\n\n--- Round 3: Synthesis & Refinement ---")
|
165 |
+
prompt_round3 = prompt_synthesis_text.format(code_response=code_response)
|
166 |
+
input_ids_r3 = tokenizer.encode(prompt_round3, return_tensors="pt").to(model.device)
|
167 |
+
streamer_r3 = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
168 |
+
kwargs_r3 = dict(
|
169 |
+
input_ids=input_ids_r3,
|
170 |
+
streamer=streamer_r3,
|
171 |
+
max_new_tokens=max_new_tokens // 2,
|
172 |
+
temperature=temp,
|
173 |
+
top_p=top_p,
|
174 |
+
)
|
175 |
+
thread_r3 = Thread(target=model.generate, kwargs=kwargs_r3)
|
176 |
+
thread_r3.start()
|
177 |
|
178 |
+
final_response = ""
|
179 |
+
for text in streamer_r3:
|
180 |
+
print(text, end="", flush=True)
|
181 |
+
final_response += text
|
182 |
+
yield final_response # yield progressive updates
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
183 |
|
184 |
+
store_in_memory(f"Final Synthesis Response: {final_response[:200]}...")
|
185 |
|
186 |
+
# --- Helper to Format History ---
|
187 |
+
def format_history(history):
|
188 |
+
"""
|
189 |
+
Convert history (which might be a list of [user, assistant] pairs or already formatted dictionaries)
|
190 |
+
into a list of OpenAI-style message dictionaries.
|
191 |
+
"""
|
192 |
+
messages = []
|
193 |
+
for item in history:
|
194 |
+
# If item is a list or tuple, try to unpack it if it has exactly 2 elements.
|
195 |
+
if isinstance(item, (list, tuple)):
|
196 |
+
if len(item) == 2:
|
197 |
+
user_msg, assistant_msg = item
|
198 |
+
messages.append({"role": "user", "content": user_msg})
|
199 |
+
if assistant_msg:
|
200 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
201 |
+
else:
|
202 |
+
# If it doesn't have exactly two items, skip it.
|
203 |
+
continue
|
204 |
+
elif isinstance(item, dict):
|
205 |
+
# Already formatted message dictionary.
|
206 |
+
messages.append(item)
|
207 |
+
else:
|
208 |
+
continue
|
209 |
+
return messages
|
210 |
+
|
211 |
+
# --- Gradio Chat Interface Function ---
|
212 |
+
def gradio_interface(message, history, param_state, prompt_state):
|
213 |
+
"""
|
214 |
+
This function is called by Gradio's ChatInterface.
|
215 |
+
It uses the current saved generation parameters and prompt templates.
|
216 |
+
"""
|
217 |
+
# Unpack parameter state (with fallback defaults)
|
218 |
+
try:
|
219 |
+
temp = float(param_state.get("temperature", 0.5))
|
220 |
+
top_p = float(param_state.get("top_p", 0.9))
|
221 |
+
max_new_tokens = int(param_state.get("max_new_tokens", 300))
|
222 |
+
memory_top_k = int(param_state.get("memory_top_k", 2))
|
223 |
+
except Exception:
|
224 |
+
temp, top_p, max_new_tokens, memory_top_k = 0.5, 0.9, 300, 2
|
225 |
+
|
226 |
+
# Unpack prompt state (with fallback defaults)
|
227 |
+
prompt_brainstorm_text = prompt_state.get("prompt_brainstorm", default_prompt_brainstorm)
|
228 |
+
prompt_code_generation_text = prompt_state.get("prompt_code_generation", default_prompt_code_generation)
|
229 |
+
prompt_synthesis_text = prompt_state.get("prompt_synthesis", default_prompt_synthesis)
|
230 |
+
|
231 |
+
# Append the new user message with an empty assistant reply (as a two-item list)
|
232 |
+
history = history + [[message, ""]]
|
233 |
+
|
234 |
+
# Call the multi-round agent as a generator (for streaming)
|
235 |
+
for partial_response in swarm_agent_iterative(
|
236 |
+
user_prompt=message,
|
237 |
+
temp=temp,
|
238 |
top_p=top_p,
|
239 |
+
max_new_tokens=max_new_tokens,
|
240 |
+
memory_top_k=memory_top_k,
|
241 |
+
prompt_brainstorm_text=prompt_brainstorm_text,
|
242 |
+
prompt_code_generation_text=prompt_code_generation_text,
|
243 |
+
prompt_synthesis_text=prompt_synthesis_text
|
244 |
):
|
245 |
+
# Update the last assistant message with the new partial response.
|
246 |
+
history[-1][1] = partial_response
|
247 |
+
# Yield the history formatted as OpenAI-style messages.
|
248 |
+
yield format_history(history)
|
249 |
|
250 |
+
# --- UI Settings & Styling ---
|
251 |
+
ui_description = '''
|
252 |
<div>
|
253 |
+
<h1 style="text-align: center;">DeepSeek Agent Swarm Chat</h1>
|
254 |
+
<p style="text-align: center;">
|
255 |
+
Multi-round agent:
|
256 |
+
<br>- Brainstorming
|
257 |
+
<br>- Advanced reasoning & code generation
|
258 |
+
<br>- Synthesis & refinement
|
259 |
+
</p>
|
260 |
</div>
|
261 |
'''
|
262 |
|
263 |
+
ui_license = """
|
264 |
<p/>
|
265 |
---
|
266 |
"""
|
267 |
|
268 |
+
ui_placeholder = """
|
269 |
+
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
270 |
+
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">DeepSeek Agent Swarm</h1>
|
271 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
|
272 |
+
</div>
|
273 |
"""
|
274 |
|
|
|
275 |
css = """
|
276 |
h1 {
|
277 |
text-align: center;
|
|
|
284 |
border-radius: 100vh;
|
285 |
}
|
286 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
|
288 |
+
# --- Gradio UI ---
|
289 |
+
with gr.Blocks(css=css, title="DeepSeek Agent Swarm Chat") as demo:
|
290 |
+
gr.Markdown(ui_description)
|
291 |
+
|
292 |
+
# Hidden States to hold parameters and prompt configuration
|
293 |
+
param_state = gr.State({
|
294 |
+
"temperature": 0.5,
|
295 |
+
"top_p": 0.9,
|
296 |
+
"max_new_tokens": 300,
|
297 |
+
"memory_top_k": 2,
|
298 |
+
})
|
299 |
+
prompt_state = gr.State({
|
300 |
+
"prompt_brainstorm": default_prompt_brainstorm,
|
301 |
+
"prompt_code_generation": default_prompt_code_generation,
|
302 |
+
"prompt_synthesis": default_prompt_synthesis,
|
303 |
+
})
|
304 |
+
|
305 |
+
# Create top-level Tabs
|
306 |
+
with gr.Tabs():
|
307 |
+
# --- Chat Tab ---
|
308 |
+
with gr.Tab("Chat"):
|
309 |
+
# Set type="messages" for OpenAI-style message dictionaries
|
310 |
+
chatbot = gr.Chatbot(height=450, placeholder=ui_placeholder, label="Agent Swarm Output", type="messages")
|
311 |
+
# Use ChatInterface and pass the hidden states as additional inputs.
|
312 |
+
gr.ChatInterface(
|
313 |
+
fn=gradio_interface,
|
314 |
+
chatbot=chatbot,
|
315 |
+
additional_inputs=[param_state, prompt_state],
|
316 |
+
examples=[
|
317 |
+
['How can we build a robust web service that scales efficiently under load?'],
|
318 |
+
['Explain how to design a fault-tolerant distributed system.'],
|
319 |
+
['Develop a streamlit app that visualizes real-time financial data.'],
|
320 |
+
['Create a pun-filled birthday message with a coding twist.'],
|
321 |
+
['Design a system that uses machine learning to optimize resource allocation.']
|
322 |
+
],
|
323 |
+
cache_examples=False,
|
324 |
+
type="messages",
|
325 |
+
)
|
326 |
+
|
327 |
+
# --- Parameters Tab ---
|
328 |
+
with gr.Tab("Parameters"):
|
329 |
+
gr.Markdown("### Generation Parameters")
|
330 |
+
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature")
|
331 |
+
top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P")
|
332 |
+
max_tokens_num = gr.Number(value=300, label="Max new tokens", precision=0)
|
333 |
+
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
|
334 |
+
save_params_btn = gr.Button("Save Parameters")
|
335 |
+
# When the user clicks Save, update the param_state
|
336 |
+
save_params_btn.click(
|
337 |
+
lambda t, p, m, k: {"temperature": t, "top_p": p, "max_new_tokens": m, "memory_top_k": k},
|
338 |
+
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider],
|
339 |
+
outputs=param_state,
|
340 |
+
)
|
341 |
+
|
342 |
+
# --- Prompt Config Tab ---
|
343 |
+
with gr.Tab("Prompt Config"):
|
344 |
+
gr.Markdown("### Configure Prompt Templates")
|
345 |
+
prompt_brainstorm_box = gr.Textbox(
|
346 |
+
value=default_prompt_brainstorm,
|
347 |
+
label="Brainstorm Prompt",
|
348 |
+
lines=8,
|
349 |
+
)
|
350 |
+
prompt_code_generation_box = gr.Textbox(
|
351 |
+
value=default_prompt_code_generation,
|
352 |
+
label="Code Generation Prompt",
|
353 |
+
lines=8,
|
354 |
+
)
|
355 |
+
prompt_synthesis_box = gr.Textbox(
|
356 |
+
value=default_prompt_synthesis,
|
357 |
+
label="Synthesis Prompt",
|
358 |
+
lines=8,
|
359 |
+
)
|
360 |
+
save_prompts_btn = gr.Button("Save Prompts")
|
361 |
+
# When clicked, update the prompt_state with new values
|
362 |
+
save_prompts_btn.click(
|
363 |
+
lambda b, c, s: {
|
364 |
+
"prompt_brainstorm": b,
|
365 |
+
"prompt_code_generation": c,
|
366 |
+
"prompt_synthesis": s,
|
367 |
+
},
|
368 |
+
inputs=[prompt_brainstorm_box, prompt_code_generation_box, prompt_synthesis_box],
|
369 |
+
outputs=prompt_state,
|
370 |
+
)
|
371 |
+
|
372 |
+
gr.Markdown(ui_license)
|
373 |
|
374 |
if __name__ == "__main__":
|
375 |
+
demo.launch()
|