Spaces:
Build error
Build error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,887 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
import torch
|
4 |
+
from TTS.api import TTS
|
5 |
+
import spaces # assumed custom module providing GPU decorators
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
|
7 |
+
from threading import Thread
|
8 |
+
import logging
|
9 |
+
from typing import Tuple, List, Dict, Generator
|
10 |
+
import time
|
11 |
+
|
12 |
+
# NEW: Import whisper for speech-to-text.
|
13 |
+
import whisper
|
14 |
+
|
15 |
+
# ===========================
|
16 |
+
# Global Environment Settings
|
17 |
+
# ===========================
|
18 |
+
os.environ["COQUI_TOS_AGREED"] = "1"
|
19 |
+
# Global device override (will be updated from UI later)
|
20 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
|
22 |
+
# Load the Whisper model (this may take a moment at startup)
|
23 |
+
whisper_model = whisper.load_model("base")
|
24 |
+
|
25 |
+
# Global dictionary for storing saved voice clones.
|
26 |
+
voice_bank: Dict[str, str] = {}
|
27 |
+
|
28 |
+
# ---------------------------
|
29 |
+
# Simple Response Cache
|
30 |
+
# ---------------------------
|
31 |
+
response_cache: Dict[str, str] = {}
|
32 |
+
|
33 |
+
# ===========================
|
34 |
+
# Voice Cloning Setup
|
35 |
+
# ===========================
|
36 |
+
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)
|
37 |
+
|
38 |
+
@spaces.GPU(enable_queue=True)
|
39 |
+
def clone(text, audio):
|
40 |
+
"""
|
41 |
+
Generate a voice-cloned audio file given text and a reference audio file.
|
42 |
+
Returns the path to the output audio file.
|
43 |
+
"""
|
44 |
+
try:
|
45 |
+
tts.tts_to_file(text=text, speaker_wav=audio, language="en", file_path="./output.wav")
|
46 |
+
return "./output.wav"
|
47 |
+
except Exception as e:
|
48 |
+
logging.error(f"TTS cloning failed: {e}")
|
49 |
+
return None
|
50 |
+
|
51 |
+
def save_voice(voice_name: str, voice_audio: str) -> None:
|
52 |
+
"""
|
53 |
+
Save a cloned voice under the given name.
|
54 |
+
"""
|
55 |
+
global voice_bank
|
56 |
+
if voice_name and voice_audio:
|
57 |
+
voice_bank[voice_name] = voice_audio
|
58 |
+
|
59 |
+
def get_voice_options() -> List[str]:
|
60 |
+
"""
|
61 |
+
Returns a list of saved voice names.
|
62 |
+
"""
|
63 |
+
return list(voice_bank.keys())
|
64 |
+
|
65 |
+
def refresh_voice_list() -> gr.update:
|
66 |
+
"""
|
67 |
+
Returns an update with the latest voice list.
|
68 |
+
"""
|
69 |
+
options = get_voice_options()
|
70 |
+
new_val = options[0] if options else ""
|
71 |
+
return gr.update(choices=options, value=new_val)
|
72 |
+
|
73 |
+
# ===========================
|
74 |
+
# Deep Agent Chat Setup
|
75 |
+
# ===========================
|
76 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
77 |
+
|
78 |
+
MODEL_ID = "unsloth/DeepSeek-R1-Distill-Qwen-7B-unsloth-bnb-4bit"
|
79 |
+
models: Dict[str, AutoModelForCausalLM] = {}
|
80 |
+
tokenizers: Dict[str, AutoTokenizer] = {}
|
81 |
+
|
82 |
+
bnb_config_4bit = BitsAndBytesConfig(
|
83 |
+
load_in_4bit=True,
|
84 |
+
bnb_4bit_quant_type="nf4",
|
85 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
86 |
+
)
|
87 |
+
|
88 |
+
def get_model_and_tokenizer() -> Tuple[AutoModelForCausalLM, AutoTokenizer]:
|
89 |
+
# Warm-up: if the model isn’t loaded, load it now.
|
90 |
+
if "7B" not in models:
|
91 |
+
logging.info(f"Loading 7B model: {MODEL_ID} on demand")
|
92 |
+
try:
|
93 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
94 |
+
model = AutoModelForCausalLM.from_pretrained(
|
95 |
+
MODEL_ID,
|
96 |
+
quantization_config=bnb_config_4bit,
|
97 |
+
torch_dtype=torch.bfloat16,
|
98 |
+
device_map='auto',
|
99 |
+
trust_remote_code=True,
|
100 |
+
)
|
101 |
+
model.eval()
|
102 |
+
models["7B"] = model
|
103 |
+
tokenizers["7B"] = tokenizer
|
104 |
+
logging.info("Loaded 7B model on demand.")
|
105 |
+
except Exception as e:
|
106 |
+
logging.error(f"Failed to load model and tokenizer: {e}")
|
107 |
+
raise e
|
108 |
+
return models["7B"], tokenizers["7B"]
|
109 |
+
|
110 |
+
# ---------------------------
|
111 |
+
# Prompt Templates
|
112 |
+
# ---------------------------
|
113 |
+
default_prompts = {
|
114 |
+
"coding": {
|
115 |
+
"brainstorm": (
|
116 |
+
"**Round 1: Brainstorm & Analysis**\n"
|
117 |
+
"Please analyze the following coding challenge or question. Consider the overall problem, "
|
118 |
+
"potential edge cases, and any assumptions you might need to make. Explain your reasoning as you think aloud.\n\n"
|
119 |
+
"**User Request:**\n{user_prompt}\n"
|
120 |
+
),
|
121 |
+
"round2": (
|
122 |
+
"**Round 2: Detailed Reasoning & Strategy**\n"
|
123 |
+
"Based on your initial analysis, please break down the problem into logical steps. "
|
124 |
+
"Outline a plan or strategy that could be used to solve the challenge, highlighting key algorithms, structures, or design considerations.\n\n"
|
125 |
+
"**Initial Analysis:**\n{brainstorm_response}\n\n"
|
126 |
+
"**User Request:**\n{user_prompt}\n"
|
127 |
+
),
|
128 |
+
"synthesis": (
|
129 |
+
"**Round 3: Synthesis & Implementation**\n"
|
130 |
+
"Taking into account the steps outlined previously, synthesize a coherent solution. "
|
131 |
+
"Provide a detailed explanation of how the code addresses the problem while encouraging best practices and clear logic.\n\n"
|
132 |
+
"**Detailed Strategy:**\n{round2_response}\n"
|
133 |
+
),
|
134 |
+
"rationale": (
|
135 |
+
"**Round 4: Reflection & Final Output**\n"
|
136 |
+
"Review your solution and provide a final, well-rounded response that summarizes your reasoning and the implementation strategy. "
|
137 |
+
"Explain any key decisions made during the process and how they contribute to an effective solution.\n\n"
|
138 |
+
"**Final Draft:**\n{final_response}\n"
|
139 |
+
)
|
140 |
+
},
|
141 |
+
"math": {
|
142 |
+
"brainstorm": (
|
143 |
+
"**Round 1: Problem Analysis & Exploration**\n"
|
144 |
+
"Carefully analyze the mathematical problem provided. Describe the underlying concepts and any assumptions you are making. "
|
145 |
+
"Detail your initial reasoning and potential methods to tackle the problem.\n\n"
|
146 |
+
"**Problem:**\n{user_prompt}\n"
|
147 |
+
),
|
148 |
+
"round2": (
|
149 |
+
"**Round 2: Detailed Reasoning & Methodology**\n"
|
150 |
+
"Based on your initial exploration, break down the problem into sequential steps or methodologies. "
|
151 |
+
"Explain the reasoning behind each step and how they connect to solve the problem.\n\n"
|
152 |
+
"**Initial Analysis:**\n{brainstorm_response}\n\n"
|
153 |
+
"**Problem:**\n{user_prompt}\n"
|
154 |
+
),
|
155 |
+
"synthesis": (
|
156 |
+
"**Round 3: Synthesis & Step-by-Step Solution**\n"
|
157 |
+
"Integrate your previous reasoning into a structured solution. Clearly explain each step of your calculation or proof, "
|
158 |
+
"ensuring that your logical progression is easy to follow.\n\n"
|
159 |
+
"**Detailed Methodology:**\n{round2_response}\n"
|
160 |
+
),
|
161 |
+
"rationale": (
|
162 |
+
"**Round 4: Reflection & Final Explanation**\n"
|
163 |
+
"Present your final solution along with a detailed explanation of the reasoning behind each step. "
|
164 |
+
"Discuss any assumptions and insights that helped you arrive at the final answer.\n\n"
|
165 |
+
"**Final Solution:**\n{final_response}\n"
|
166 |
+
)
|
167 |
+
},
|
168 |
+
"writing": {
|
169 |
+
"brainstorm": (
|
170 |
+
"**Round 1: Creative Exploration & Conceptualization**\n"
|
171 |
+
"Read the following writing prompt and explore its themes, tone, and potential narrative directions. "
|
172 |
+
"Outline your initial thoughts and reasoning behind various creative choices.\n\n"
|
173 |
+
"**Writing Prompt:**\n{user_prompt}\n"
|
174 |
+
),
|
175 |
+
"round2": (
|
176 |
+
"**Round 2: Detailed Outline & Narrative Structure**\n"
|
177 |
+
"Based on your brainstorming, create a detailed outline that organizes the narrative or essay. "
|
178 |
+
"Explain the reasoning behind your structure, the flow of ideas, and how you plan to incorporate creative elements.\n\n"
|
179 |
+
"**Initial Brainstorming:**\n{brainstorm_response}\n\n"
|
180 |
+
"**Writing Prompt:**\n{user_prompt}\n"
|
181 |
+
),
|
182 |
+
"synthesis": (
|
183 |
+
"**Round 3: Draft Synthesis & Refinement**\n"
|
184 |
+
"Integrate your outline and creative ideas into a coherent draft. Provide a well-rounded narrative that is both engaging and logically structured. "
|
185 |
+
"Explain your thought process as you refine the narrative.\n\n"
|
186 |
+
"**Outline & Strategy:**\n{round2_response}\n"
|
187 |
+
),
|
188 |
+
"rationale": (
|
189 |
+
"**Round 4: Reflection & Final Editing**\n"
|
190 |
+
"Review your draft and provide a final version that reflects thoughtful editing and creative reasoning. "
|
191 |
+
"Explain the choices made in refining the text, from structure to stylistic decisions.\n\n"
|
192 |
+
"**Final Draft:**\n{final_response}\n"
|
193 |
+
)
|
194 |
+
}
|
195 |
+
}
|
196 |
+
|
197 |
+
# The prompt state now contains both default and custom modes.
|
198 |
+
initial_prompt_state = {
|
199 |
+
"default": default_prompts,
|
200 |
+
"custom": {} # custom modes will be added here as {mode_name: [round_prompt1, round_prompt2, ...]}
|
201 |
+
}
|
202 |
+
|
203 |
+
def detect_domain(user_prompt: str) -> str:
|
204 |
+
prompt_lower = user_prompt.lower()
|
205 |
+
math_keywords = ["solve", "integral", "derivative", "equation", "proof", "calculate", "sum", "product"]
|
206 |
+
writing_keywords = ["write", "story", "essay", "novel", "poem", "article", "narrative", "creative"]
|
207 |
+
coding_keywords = ["code", "program", "debug", "compile", "algorithm", "function"]
|
208 |
+
|
209 |
+
if any(kw in prompt_lower for kw in math_keywords):
|
210 |
+
logging.info("Domain detected as: math")
|
211 |
+
return "math"
|
212 |
+
elif any(kw in prompt_lower for kw in writing_keywords):
|
213 |
+
logging.info("Domain detected as: writing")
|
214 |
+
return "writing"
|
215 |
+
elif any(kw in prompt_lower for kw in coding_keywords):
|
216 |
+
logging.info("Domain detected as: coding")
|
217 |
+
return "coding"
|
218 |
+
else:
|
219 |
+
logging.info("No specific domain detected; defaulting to coding")
|
220 |
+
return "coding"
|
221 |
+
|
222 |
+
class MemoryManager:
|
223 |
+
def __init__(self) -> None:
|
224 |
+
self.shared_memory: List[str] = []
|
225 |
+
|
226 |
+
def store(self, item: str) -> None:
|
227 |
+
self.shared_memory.append(item)
|
228 |
+
logging.info(f"[Memory Stored]: {item[:50]}...")
|
229 |
+
|
230 |
+
def retrieve(self, query: str, top_k: int = 3) -> List[str]:
|
231 |
+
query_lower = query.lower()
|
232 |
+
relevant = [item for item in self.shared_memory if query_lower in item.lower()]
|
233 |
+
if not relevant:
|
234 |
+
logging.info("[Memory Retrieval]: No relevant memories found.")
|
235 |
+
else:
|
236 |
+
logging.info(f"[Memory Retrieval]: Found {len(relevant)} relevant memories.")
|
237 |
+
return relevant[-top_k:]
|
238 |
+
|
239 |
+
global_memory_manager = MemoryManager()
|
240 |
+
|
241 |
+
def generate_response(model, tokenizer, prompt: str, max_tokens: int, temperature: float, top_p: float,
|
242 |
+
repetition_penalty: float = 1.0, num_beams: int = 1) -> str:
|
243 |
+
# Check cache first
|
244 |
+
cache_key = f"{prompt}-{max_tokens}-{temperature}-{top_p}-{repetition_penalty}-{num_beams}"
|
245 |
+
if cache_key in response_cache:
|
246 |
+
logging.info("Returning cached response.")
|
247 |
+
return response_cache[cache_key]
|
248 |
+
|
249 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
|
250 |
+
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
251 |
+
kwargs = dict(
|
252 |
+
input_ids=input_ids,
|
253 |
+
streamer=streamer,
|
254 |
+
max_new_tokens=max_tokens,
|
255 |
+
temperature=temperature,
|
256 |
+
top_p=top_p,
|
257 |
+
do_sample=True,
|
258 |
+
repetition_penalty=repetition_penalty,
|
259 |
+
num_beams=num_beams,
|
260 |
+
)
|
261 |
+
thread = Thread(target=model.generate, kwargs=kwargs)
|
262 |
+
with torch.no_grad():
|
263 |
+
thread.start()
|
264 |
+
response = ""
|
265 |
+
try:
|
266 |
+
for text in streamer:
|
267 |
+
response += text
|
268 |
+
except Exception as e:
|
269 |
+
logging.error(f"Error during generation: {e}")
|
270 |
+
raise e
|
271 |
+
thread.join()
|
272 |
+
# Cache the response
|
273 |
+
response_cache[cache_key] = response
|
274 |
+
return response
|
275 |
+
|
276 |
+
class MultiRoundAgent:
|
277 |
+
def __init__(self, model, tokenizer, prompt_templates, memory_manager: MemoryManager):
|
278 |
+
"""
|
279 |
+
prompt_templates can be a dict (for default modes) or a list (for custom modes)
|
280 |
+
"""
|
281 |
+
self.model = model
|
282 |
+
self.tokenizer = tokenizer
|
283 |
+
self.prompt_templates = prompt_templates
|
284 |
+
self.memory_manager = memory_manager
|
285 |
+
|
286 |
+
def run_pipeline(self, user_prompt: str, params: Dict, show_raw: bool = False) -> Generator[str, None, None]:
|
287 |
+
if isinstance(self.prompt_templates, dict):
|
288 |
+
# Default fixed 4-round pipeline
|
289 |
+
logging.info("--- Round 1 ---")
|
290 |
+
prompt_r1 = self.prompt_templates["brainstorm"].format(user_prompt=user_prompt)
|
291 |
+
r1 = generate_response(self.model, self.tokenizer, prompt_r1, params.get("max_new_tokens"), params.get("temp"),
|
292 |
+
params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
|
293 |
+
self.memory_manager.store(f"Round 1 Response: {r1}")
|
294 |
+
|
295 |
+
logging.info("--- Round 2 ---")
|
296 |
+
prompt_r2 = self.prompt_templates["round2"].format(brainstorm_response=r1, user_prompt=user_prompt)
|
297 |
+
r2 = generate_response(self.model, self.tokenizer, prompt_r2, params.get("max_new_tokens") + 100,
|
298 |
+
params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
|
299 |
+
self.memory_manager.store(f"Round 2 Response: {r2}")
|
300 |
+
|
301 |
+
logging.info("--- Round 3 ---")
|
302 |
+
prompt_r3 = self.prompt_templates["synthesis"].format(round2_response=r2)
|
303 |
+
input_ids_r3 = self.tokenizer.encode(prompt_r3, return_tensors="pt").to(self.model.device)
|
304 |
+
streamer_r3 = TextIteratorStreamer(self.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
305 |
+
kwargs_r3 = dict(
|
306 |
+
input_ids=input_ids_r3,
|
307 |
+
streamer=streamer_r3,
|
308 |
+
max_new_tokens=params.get("max_new_tokens") // 2,
|
309 |
+
temperature=params.get("temp"),
|
310 |
+
top_p=params.get("top_p"),
|
311 |
+
repetition_penalty=params.get("repetition_penalty"),
|
312 |
+
num_beams=params.get("num_beams")
|
313 |
+
)
|
314 |
+
thread_r3 = Thread(target=self.model.generate, kwargs=kwargs_r3)
|
315 |
+
with torch.no_grad():
|
316 |
+
thread_r3.start()
|
317 |
+
r3 = ""
|
318 |
+
try:
|
319 |
+
for text in streamer_r3:
|
320 |
+
r3 += text
|
321 |
+
yield r3 # Progressive updates
|
322 |
+
except Exception as e:
|
323 |
+
logging.error(f"Error during Round 3 streaming: {e}")
|
324 |
+
raise e
|
325 |
+
thread_r3.join()
|
326 |
+
self.memory_manager.store(f"Final Synthesis Response: {r3}")
|
327 |
+
|
328 |
+
logging.info("--- Round 4 ---")
|
329 |
+
prompt_r4 = self.prompt_templates["rationale"].format(final_response=r3)
|
330 |
+
r4 = generate_response(self.model, self.tokenizer, prompt_r4, 300, params.get("temp"),
|
331 |
+
params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
|
332 |
+
self.memory_manager.store(f"Round 4 Response: {r4}")
|
333 |
+
|
334 |
+
final_output = (f"{r4}\n\n[Raw Outputs]\nRound 1:\n{r1}\n\nRound 2:\n{r2}\n\nRound 3:\n{r3}\n\nRound 4:\n{r4}\n") if show_raw else r4
|
335 |
+
yield final_output
|
336 |
+
|
337 |
+
elif isinstance(self.prompt_templates, list):
|
338 |
+
# Custom mode: iterate over rounds.
|
339 |
+
prev_response = ""
|
340 |
+
full_output = ""
|
341 |
+
total_rounds = len(self.prompt_templates)
|
342 |
+
for idx, round_template in enumerate(self.prompt_templates):
|
343 |
+
round_num = idx + 1
|
344 |
+
logging.info(f"--- Custom Mode: Round {round_num} of {total_rounds} ---")
|
345 |
+
if idx == 0:
|
346 |
+
prompt = round_template.format(user_prompt=user_prompt)
|
347 |
+
else:
|
348 |
+
prompt = round_template.format(user_prompt=user_prompt, prev_response=prev_response)
|
349 |
+
response = generate_response(self.model, self.tokenizer, prompt, params.get("max_new_tokens"),
|
350 |
+
params.get("temp"), params.get("top_p"), params.get("repetition_penalty"), params.get("num_beams"))
|
351 |
+
self.memory_manager.store(f"Custom Mode Round {round_num} Response: {response}")
|
352 |
+
full_output += f"\n--- Round {round_num} ---\n{response}"
|
353 |
+
prev_response = response
|
354 |
+
yield full_output
|
355 |
+
else:
|
356 |
+
yield "Invalid prompt template format."
|
357 |
+
|
358 |
+
@spaces.GPU(duration=180)
|
359 |
+
def swarm_agent_iterative(user_prompt: str, temp: float, top_p: float, max_new_tokens: int, memory_top_k: int,
|
360 |
+
prompt_templates, domain: str, show_raw: bool, repetition_penalty: float, num_beams: int) -> Generator[str, None, None]:
|
361 |
+
model, tokenizer = get_model_and_tokenizer()
|
362 |
+
agent = MultiRoundAgent(model, tokenizer, prompt_templates, global_memory_manager)
|
363 |
+
params = {
|
364 |
+
"temp": temp,
|
365 |
+
"top_p": top_p,
|
366 |
+
"max_new_tokens": max_new_tokens,
|
367 |
+
"repetition_penalty": repetition_penalty,
|
368 |
+
"num_beams": num_beams
|
369 |
+
}
|
370 |
+
return agent.run_pipeline(user_prompt, params, show_raw)
|
371 |
+
|
372 |
+
def handle_explanation_request(user_prompt: str, history: List) -> str:
|
373 |
+
retrieved = global_memory_manager.retrieve("Round 4 Response:", top_k=3)
|
374 |
+
explanation_prompt = "Below are previous final outputs and related context from our conversation:\n"
|
375 |
+
if retrieved:
|
376 |
+
for item in retrieved:
|
377 |
+
explanation_prompt += f"- {item}\n"
|
378 |
+
else:
|
379 |
+
explanation_prompt += "No stored final output found.\n"
|
380 |
+
explanation_prompt += "\nRecent related exchanges:\n"
|
381 |
+
for chat in history:
|
382 |
+
if ("explain" in chat[0].lower()) or (chat[1] and "explain" in chat[1].lower()):
|
383 |
+
explanation_prompt += f"User: {chat[0]}\nAssistant: {chat[1]}\n"
|
384 |
+
explanation_prompt += "\nBased on the above context, please provide a detailed explanation of the creative choices."
|
385 |
+
model, tokenizer = get_model_and_tokenizer()
|
386 |
+
explanation = generate_response(model, tokenizer, explanation_prompt, 300, 0.7, 0.9)
|
387 |
+
return explanation
|
388 |
+
|
389 |
+
def format_history(history: List) -> List[Dict[str, str]]:
|
390 |
+
messages = []
|
391 |
+
for item in history:
|
392 |
+
if isinstance(item, (list, tuple)) and len(item) == 2:
|
393 |
+
user_msg, assistant_msg = item
|
394 |
+
if user_msg == "__final_agent_response__":
|
395 |
+
continue
|
396 |
+
messages.append({"role": "user", "content": user_msg})
|
397 |
+
if assistant_msg:
|
398 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
399 |
+
elif isinstance(item, dict):
|
400 |
+
messages.append(item)
|
401 |
+
return messages
|
402 |
+
|
403 |
+
def gradio_interface(message: str, history: List, param_state: Dict, prompt_state: Dict, mode: str) -> Generator[List[Dict[str, str]], None, None]:
|
404 |
+
if "explain" in message.lower():
|
405 |
+
explanation = handle_explanation_request(message, history)
|
406 |
+
history = history + [[message, explanation]]
|
407 |
+
yield format_history(history)
|
408 |
+
return
|
409 |
+
|
410 |
+
try:
|
411 |
+
temp = float(param_state.get("temperature", 0.5))
|
412 |
+
top_p = float(param_state.get("top_p", 0.9))
|
413 |
+
max_new_tokens = int(param_state.get("max_new_tokens", 300))
|
414 |
+
repetition_penalty = float(param_state.get("repetition_penalty", 1.0))
|
415 |
+
num_beams = int(param_state.get("num_beams", 1))
|
416 |
+
memory_top_k = int(param_state.get("memory_top_k", 2))
|
417 |
+
show_raw = bool(param_state.get("show_raw_output", False))
|
418 |
+
except Exception as e:
|
419 |
+
logging.error(f"Parameter conversion error: {e}")
|
420 |
+
temp, top_p, max_new_tokens, repetition_penalty, num_beams, memory_top_k, show_raw = 0.5, 0.9, 300, 1.0, 1, 2, False
|
421 |
+
|
422 |
+
if mode in prompt_state.get("default", {}):
|
423 |
+
prompt_templates = prompt_state["default"][mode]
|
424 |
+
elif mode in prompt_state.get("custom", {}):
|
425 |
+
prompt_templates = prompt_state["custom"][mode]
|
426 |
+
else:
|
427 |
+
detected = detect_domain(message)
|
428 |
+
prompt_templates = prompt_state["default"].get(detected, prompt_state["default"]["coding"])
|
429 |
+
mode = detected
|
430 |
+
|
431 |
+
history = history + [[message, ""]]
|
432 |
+
# Show a loading status
|
433 |
+
yield format_history(history)
|
434 |
+
for partial_response in swarm_agent_iterative(
|
435 |
+
user_prompt=message,
|
436 |
+
temp=temp,
|
437 |
+
top_p=top_p,
|
438 |
+
max_new_tokens=max_new_tokens,
|
439 |
+
memory_top_k=memory_top_k,
|
440 |
+
prompt_templates=prompt_templates,
|
441 |
+
domain=mode,
|
442 |
+
show_raw=show_raw,
|
443 |
+
repetition_penalty=repetition_penalty,
|
444 |
+
num_beams=num_beams
|
445 |
+
):
|
446 |
+
history[-1][1] = partial_response
|
447 |
+
yield format_history(history)
|
448 |
+
yield format_history(history)
|
449 |
+
|
450 |
+
def generate_agent_audio(latest_text: str, voice_reference: str) -> str:
|
451 |
+
"""
|
452 |
+
Generate an audio response using the cloned voice.
|
453 |
+
If the provided voice_reference is a key in the voice bank, its stored file path is used.
|
454 |
+
"""
|
455 |
+
if latest_text:
|
456 |
+
if voice_reference in voice_bank:
|
457 |
+
audio_path = clone(latest_text, voice_bank[voice_reference])
|
458 |
+
else:
|
459 |
+
audio_path = clone(latest_text, voice_reference)
|
460 |
+
return audio_path
|
461 |
+
return None
|
462 |
+
|
463 |
+
# NEW: Speech-to-Text Function using Whisper.
|
464 |
+
def transcribe_audio(audio_file: str) -> str:
|
465 |
+
"""
|
466 |
+
Transcribe the provided audio file to text using the Whisper model.
|
467 |
+
"""
|
468 |
+
try:
|
469 |
+
result = whisper_model.transcribe(audio_file)
|
470 |
+
transcription = result.get("text", "").strip()
|
471 |
+
logging.info(f"Transcription result: {transcription}")
|
472 |
+
return transcription
|
473 |
+
except Exception as e:
|
474 |
+
logging.error(f"Transcription error: {e}")
|
475 |
+
return "Transcription failed."
|
476 |
+
|
477 |
+
# ---------------------------
|
478 |
+
# Warm-Up Model Function
|
479 |
+
# ---------------------------
|
480 |
+
def warmup_model():
|
481 |
+
try:
|
482 |
+
get_model_and_tokenizer()
|
483 |
+
logging.info("Model warm-up complete.")
|
484 |
+
except Exception as e:
|
485 |
+
logging.error(f"Model warm-up failed: {e}")
|
486 |
+
|
487 |
+
warmup_model()
|
488 |
+
|
489 |
+
# ===========================
|
490 |
+
# Custom Gradio Theme
|
491 |
+
# ===========================
|
492 |
+
theme = gr.themes.Soft(
|
493 |
+
primary_hue="pink",
|
494 |
+
secondary_hue="pink",
|
495 |
+
neutral_hue="purple",
|
496 |
+
font=['IBM Plex Sans', 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
497 |
+
).set(
|
498 |
+
background_fill_primary='white',
|
499 |
+
shadow_drop='rgba(0,0,0,0.05) 0px 1px 2px 0px',
|
500 |
+
shadow_drop_lg='0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1)',
|
501 |
+
shadow_spread='3px',
|
502 |
+
block_background_fill='*background_fill_primary',
|
503 |
+
block_border_width='1px',
|
504 |
+
block_border_width_dark='1px',
|
505 |
+
block_label_background_fill='*background_fill_primary',
|
506 |
+
block_label_background_fill_dark='*background_fill_secondary',
|
507 |
+
block_label_text_color='*neutral_500',
|
508 |
+
block_label_text_color_dark='*neutral_200',
|
509 |
+
block_label_margin='0',
|
510 |
+
block_label_padding='*spacing_sm *spacing_lg',
|
511 |
+
block_label_radius='calc(*radius_sm - 1px) 0 calc(*radius_sm - 1px) 0',
|
512 |
+
block_label_text_size='*text_sm',
|
513 |
+
block_label_text_weight='400',
|
514 |
+
block_title_background_fill='none',
|
515 |
+
block_title_background_fill_dark='none',
|
516 |
+
block_title_text_color='*neutral_500',
|
517 |
+
block_title_text_color_dark='*neutral_200',
|
518 |
+
block_title_padding='0',
|
519 |
+
block_title_radius='none',
|
520 |
+
block_title_text_weight='400',
|
521 |
+
panel_border_width='0',
|
522 |
+
panel_border_width_dark='0',
|
523 |
+
checkbox_background_color_selected='*color_accent',
|
524 |
+
checkbox_background_color_selected_dark='*color_accent',
|
525 |
+
checkbox_border_color='*neutral_300',
|
526 |
+
checkbox_border_color_dark='*neutral_700',
|
527 |
+
checkbox_border_color_focus='*color_accent',
|
528 |
+
checkbox_border_color_focus_dark='*color_accent',
|
529 |
+
checkbox_border_color_selected='*color_accent',
|
530 |
+
checkbox_border_color_selected_dark='*color_accent',
|
531 |
+
checkbox_border_width='*input_border_width',
|
532 |
+
checkbox_shadow='*input_shadow',
|
533 |
+
checkbox_label_background_fill_selected='*checkbox_label_background_fill',
|
534 |
+
checkbox_label_background_fill_selected_dark='*checkbox_label_background_fill',
|
535 |
+
checkbox_label_shadow='none',
|
536 |
+
checkbox_label_text_color_selected='*checkbox_label_text_color',
|
537 |
+
input_background_fill='*neutral_100',
|
538 |
+
input_border_color='*border_color_primary',
|
539 |
+
input_shadow='none',
|
540 |
+
input_shadow_dark='none',
|
541 |
+
input_shadow_focus='*input_shadow',
|
542 |
+
input_shadow_focus_dark='*input_shadow',
|
543 |
+
slider_color='*color_accent',
|
544 |
+
slider_color_dark='*color_accent',
|
545 |
+
button_primary_background_fill_hover='*primary_600',
|
546 |
+
button_primary_background_fill_hover_dark='*primary_700',
|
547 |
+
button_primary_shadow='none',
|
548 |
+
button_primary_shadow_hover='*button_primary_shadow',
|
549 |
+
button_primary_shadow_active='*button_primary_shadow',
|
550 |
+
button_primary_shadow_dark='none',
|
551 |
+
button_secondary_background_fill='*neutral_200',
|
552 |
+
button_secondary_background_fill_hover='*neutral_300',
|
553 |
+
button_secondary_background_fill_hover_dark='*neutral_700',
|
554 |
+
button_secondary_text_color='black',
|
555 |
+
button_secondary_shadow='*button_primary_shadow',
|
556 |
+
button_secondary_shadow_hover='*button_secondary_shadow',
|
557 |
+
button_secondary_shadow_active='*button_secondary_shadow',
|
558 |
+
button_secondary_shadow_dark='*button_primary_shadow'
|
559 |
+
)
|
560 |
+
|
561 |
+
# ===========================
|
562 |
+
# Combined Gradio Interface
|
563 |
+
# ===========================
|
564 |
+
with gr.Blocks(theme=theme, title="Combined Voice Clone & Agent Chat") as demo:
|
565 |
+
# Shared states for project settings, prompt configuration, and voice selection.
|
566 |
+
param_state = gr.State({
|
567 |
+
"temperature": 0.5,
|
568 |
+
"top_p": 0.9,
|
569 |
+
"max_new_tokens": 300,
|
570 |
+
"memory_top_k": 2,
|
571 |
+
"show_raw_output": False,
|
572 |
+
"repetition_penalty": 1.0,
|
573 |
+
"num_beams": 1,
|
574 |
+
"use_cpu": False # Toggle for device override
|
575 |
+
})
|
576 |
+
prompt_state = gr.State(initial_prompt_state)
|
577 |
+
selected_voice = gr.State(value="") # holds the currently selected voice
|
578 |
+
|
579 |
+
# A status display to show device info.
|
580 |
+
device_status = gr.Markdown(f"**Running on:** {device.upper()}")
|
581 |
+
|
582 |
+
with gr.Tabs():
|
583 |
+
# ----- Tab 1: Voice Setup -----
|
584 |
+
with gr.Tab("Voice Setup"):
|
585 |
+
gr.Markdown("<h2 style='text-align: center; padding-top: 10px;'>Voice Setup</h2>")
|
586 |
+
with gr.Column(variant="panel"):
|
587 |
+
gr.Markdown("<p style='text-align: center;'>Clone a voice and save it with a custom name. Test TTS using your cloned voices.</p>")
|
588 |
+
with gr.Row():
|
589 |
+
text_input = gr.Textbox(label='Text to Clone', placeholder="Enter the text to speak...", elem_classes="full-width")
|
590 |
+
with gr.Row():
|
591 |
+
audio_input = gr.Audio(label='Voice Reference Audio', type='filepath')
|
592 |
+
with gr.Row():
|
593 |
+
clone_btn = gr.Button("Clone Voice")
|
594 |
+
with gr.Row():
|
595 |
+
output_audio = gr.Audio(label='Cloned Voice Output', type='filepath')
|
596 |
+
clone_btn.click(fn=clone, inputs=[text_input, audio_input], outputs=output_audio)
|
597 |
+
with gr.Row():
|
598 |
+
voice_name_input = gr.Textbox(label="Voice Name", placeholder="Enter a name for this voice clone")
|
599 |
+
with gr.Row():
|
600 |
+
save_voice_btn = gr.Button("Save Voice")
|
601 |
+
save_voice_btn.click(fn=save_voice, inputs=[voice_name_input, output_audio], outputs=[])
|
602 |
+
with gr.Row():
|
603 |
+
refresh_voice_btn_setup = gr.Button("Refresh Voice List")
|
604 |
+
voice_dropdown_setup = gr.Dropdown(choices=get_voice_options(), label="Select Saved Voice", interactive=True)
|
605 |
+
set_voice_btn = gr.Button("Set Selected Voice")
|
606 |
+
refresh_voice_btn_setup.click(fn=refresh_voice_list, outputs=voice_dropdown_setup)
|
607 |
+
set_voice_btn.click(fn=lambda x: x, inputs=[voice_dropdown_setup], outputs=selected_voice)
|
608 |
+
gr.Markdown("<p style='text-align: center;'>(The selected voice will be used for TTS responses in Chat.)</p>")
|
609 |
+
gr.Markdown("<hr>")
|
610 |
+
gr.Markdown("<h3 style='text-align: center;'>TTS Test</h3>")
|
611 |
+
with gr.Row():
|
612 |
+
tts_test_input = gr.Textbox(label="Test Text", placeholder="Enter text to test TTS...", elem_classes="full-width")
|
613 |
+
with gr.Row():
|
614 |
+
tts_test_btn = gr.Button("Test TTS")
|
615 |
+
tts_test_output = gr.Audio(label="TTS Output", type="filepath")
|
616 |
+
tts_test_btn.click(fn=lambda txt, override, sel: generate_agent_audio(txt, override if override else sel),
|
617 |
+
inputs=[tts_test_input, audio_input, selected_voice],
|
618 |
+
outputs=tts_test_output)
|
619 |
+
|
620 |
+
# ----- Tab 2: Chat -----
|
621 |
+
with gr.Tab("Chat"):
|
622 |
+
gr.Markdown("""
|
623 |
+
<div style="text-align: center; padding: 10px;">
|
624 |
+
<h1>DeepSeek Agent Swarm Chat</h1>
|
625 |
+
<p>Multi-round agent with prompt chaining. Ask me anything!</p>
|
626 |
+
</div>
|
627 |
+
""")
|
628 |
+
with gr.Column():
|
629 |
+
with gr.Row():
|
630 |
+
mode_selector = gr.Radio(choices=["coding", "math", "writing"], value="coding", label="Select Mode")
|
631 |
+
with gr.Row():
|
632 |
+
chat_input = gr.Textbox(label="Your Message", placeholder="Type your message here...", lines=2, elem_id="msg_input")
|
633 |
+
with gr.Row():
|
634 |
+
chat_audio_input = gr.Audio(label="Or record/upload your message", type="filepath")
|
635 |
+
transcribe_btn = gr.Button("Transcribe Audio")
|
636 |
+
transcribe_btn.click(fn=transcribe_audio, inputs=chat_audio_input, outputs=chat_input)
|
637 |
+
with gr.Row():
|
638 |
+
send_btn = gr.Button("Send", variant="primary")
|
639 |
+
export_btn = gr.Button("Generate Chat Transcript")
|
640 |
+
chatbot = gr.Chatbot(height=450, label="Agent Swarm Output", type="messages")
|
641 |
+
with gr.Row():
|
642 |
+
use_tts_checkbox = gr.Checkbox(label="Generate Audio Response using TTS", value=False)
|
643 |
+
chat_voice_dropdown = gr.Dropdown(choices=get_voice_options(), label="Select Voice for TTS", interactive=True)
|
644 |
+
refresh_voice_btn_chat = gr.Button("Refresh Voice List")
|
645 |
+
refresh_voice_btn_chat.click(fn=refresh_voice_list, outputs=chat_voice_dropdown)
|
646 |
+
agent_audio = gr.Audio(label="Agent Audio Response", type="filepath")
|
647 |
+
|
648 |
+
def chat_wrapper(message, history, param_state, prompt_state, mode):
|
649 |
+
final_history = []
|
650 |
+
history.append(["", "**Generating response...**"])
|
651 |
+
for h in gradio_interface(message, history, param_state, prompt_state, mode):
|
652 |
+
final_history = h
|
653 |
+
return final_history
|
654 |
+
|
655 |
+
send_btn.click(fn=chat_wrapper,
|
656 |
+
inputs=[chat_input, chatbot, param_state, prompt_state, mode_selector],
|
657 |
+
outputs=[chatbot])
|
658 |
+
|
659 |
+
def conditional_tts(latest_text, use_tts, selected_voice_val):
|
660 |
+
if use_tts:
|
661 |
+
return generate_agent_audio(latest_text, selected_voice_val)
|
662 |
+
return None
|
663 |
+
|
664 |
+
def get_latest_text(chat_history):
|
665 |
+
for msg in reversed(chat_history):
|
666 |
+
if msg.get("role") == "assistant" and msg.get("content"):
|
667 |
+
return msg["content"]
|
668 |
+
return ""
|
669 |
+
latest_text_state = gr.State(value="")
|
670 |
+
gen_audio_btn = gr.Button("Generate Audio from Agent Response")
|
671 |
+
gen_audio_btn.click(fn=lambda chat: get_latest_text(chat),
|
672 |
+
inputs=[chatbot],
|
673 |
+
outputs=latest_text_state)
|
674 |
+
gen_audio_btn.click(fn=conditional_tts,
|
675 |
+
inputs=[latest_text_state, use_tts_checkbox, chat_voice_dropdown],
|
676 |
+
outputs=agent_audio)
|
677 |
+
def export_transcript(history):
|
678 |
+
transcript = ""
|
679 |
+
for item in history:
|
680 |
+
if isinstance(item, list) and len(item) == 2:
|
681 |
+
transcript += f"User: {item[0]}\nAssistant: {item[1]}\n\n"
|
682 |
+
return transcript
|
683 |
+
export_btn.click(fn=export_transcript, inputs=[chatbot], outputs=chatbot)
|
684 |
+
|
685 |
+
# ----- Tab 3: Project Settings -----
|
686 |
+
with gr.Tab("Project Settings"):
|
687 |
+
gr.Markdown("<h2 style='text-align: center;'>Project Settings</h2>")
|
688 |
+
with gr.Tabs():
|
689 |
+
with gr.Tab("Generation Parameters"):
|
690 |
+
gr.Markdown("<h3>Generation Parameters</h3>")
|
691 |
+
with gr.Row():
|
692 |
+
temp_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, value=0.5, label="Temperature")
|
693 |
+
top_p_slider = gr.Slider(minimum=0.01, maximum=1.0, step=0.05, value=0.9, label="Top P")
|
694 |
+
with gr.Row():
|
695 |
+
max_tokens_num = gr.Number(value=300, label="Max New Tokens", precision=0)
|
696 |
+
memory_topk_slider = gr.Slider(minimum=1, maximum=5, step=1, value=2, label="Memory Retrieval Top K")
|
697 |
+
with gr.Row():
|
698 |
+
rep_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Repetition Penalty")
|
699 |
+
num_beams_slider = gr.Slider(minimum=1, maximum=5, step=1, value=1, label="Number of Beams")
|
700 |
+
with gr.Row():
|
701 |
+
show_raw_checkbox = gr.Checkbox(value=False, label="Show Raw Output")
|
702 |
+
use_cpu_checkbox = gr.Checkbox(value=False, label="Force Use CPU")
|
703 |
+
save_params_btn = gr.Button("Save Generation Parameters")
|
704 |
+
def save_params(t, p, m, k, rp, nb, s, use_cpu):
|
705 |
+
global device
|
706 |
+
if use_cpu:
|
707 |
+
device = "cpu"
|
708 |
+
else:
|
709 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
710 |
+
return {
|
711 |
+
"temperature": t,
|
712 |
+
"top_p": p,
|
713 |
+
"max_new_tokens": m,
|
714 |
+
"memory_top_k": k,
|
715 |
+
"repetition_penalty": rp,
|
716 |
+
"num_beams": nb,
|
717 |
+
"show_raw_output": s,
|
718 |
+
"use_cpu": use_cpu
|
719 |
+
}
|
720 |
+
save_params_btn.click(
|
721 |
+
save_params,
|
722 |
+
inputs=[temp_slider, top_p_slider, max_tokens_num, memory_topk_slider, rep_penalty_slider, num_beams_slider, show_raw_checkbox, use_cpu_checkbox],
|
723 |
+
outputs=param_state,
|
724 |
+
)
|
725 |
+
save_params_btn.click(fn=lambda params: f"**Running on:** {device.upper()}", inputs=param_state, outputs=device_status)
|
726 |
+
gr.Markdown("Note: Repetition penalty and number of beams affect generation diversity and quality.")
|
727 |
+
with gr.Tab("Prompt Config (Default Modes)"):
|
728 |
+
gr.Markdown("<h3>Prompt Configurations for Default Modes</h3>")
|
729 |
+
with gr.Tabs():
|
730 |
+
with gr.Tab("Coding"):
|
731 |
+
prompt_brainstorm_box_code = gr.Textbox(
|
732 |
+
value=default_prompts["coding"]["brainstorm"],
|
733 |
+
label="Brainstorm Prompt (Coding)",
|
734 |
+
lines=8,
|
735 |
+
)
|
736 |
+
prompt_round2_box_code = gr.Textbox(
|
737 |
+
value=default_prompts["coding"]["round2"],
|
738 |
+
label="Round 2 Prompt (Coding)",
|
739 |
+
lines=8,
|
740 |
+
)
|
741 |
+
prompt_synthesis_box_code = gr.Textbox(
|
742 |
+
value=default_prompts["coding"]["synthesis"],
|
743 |
+
label="Synthesis Prompt (Coding)",
|
744 |
+
lines=8,
|
745 |
+
)
|
746 |
+
prompt_rationale_box_code = gr.Textbox(
|
747 |
+
value=default_prompts["coding"]["rationale"],
|
748 |
+
label="Rationale Prompt (Coding)",
|
749 |
+
lines=8,
|
750 |
+
)
|
751 |
+
with gr.Tab("Math"):
|
752 |
+
prompt_brainstorm_box_math = gr.Textbox(
|
753 |
+
value=default_prompts["math"]["brainstorm"],
|
754 |
+
label="Brainstorm Prompt (Math)",
|
755 |
+
lines=8,
|
756 |
+
)
|
757 |
+
prompt_round2_box_math = gr.Textbox(
|
758 |
+
value=default_prompts["math"]["round2"],
|
759 |
+
label="Round 2 Prompt (Math)",
|
760 |
+
lines=8,
|
761 |
+
)
|
762 |
+
prompt_synthesis_box_math = gr.Textbox(
|
763 |
+
value=default_prompts["math"]["synthesis"],
|
764 |
+
label="Synthesis Prompt (Math)",
|
765 |
+
lines=8,
|
766 |
+
)
|
767 |
+
prompt_rationale_box_math = gr.Textbox(
|
768 |
+
value=default_prompts["math"]["rationale"],
|
769 |
+
label="Rationale Prompt (Math)",
|
770 |
+
lines=8,
|
771 |
+
)
|
772 |
+
with gr.Tab("Writing"):
|
773 |
+
prompt_brainstorm_box_writing = gr.Textbox(
|
774 |
+
value=default_prompts["writing"]["brainstorm"],
|
775 |
+
label="Brainstorm Prompt (Writing)",
|
776 |
+
lines=8,
|
777 |
+
)
|
778 |
+
prompt_round2_box_writing = gr.Textbox(
|
779 |
+
value=default_prompts["writing"]["round2"],
|
780 |
+
label="Round 2 Prompt (Writing)",
|
781 |
+
lines=8,
|
782 |
+
)
|
783 |
+
prompt_synthesis_box_writing = gr.Textbox(
|
784 |
+
value=default_prompts["writing"]["synthesis"],
|
785 |
+
label="Synthesis Prompt (Writing)",
|
786 |
+
lines=8,
|
787 |
+
)
|
788 |
+
prompt_rationale_box_writing = gr.Textbox(
|
789 |
+
value=default_prompts["writing"]["rationale"],
|
790 |
+
label="Rationale Prompt (Writing)",
|
791 |
+
lines=8,
|
792 |
+
)
|
793 |
+
save_prompts_btn = gr.Button("Save Default Prompt Configurations")
|
794 |
+
def save_default_prompts(code_brain, code_r2, code_syn, code_rat, math_brain, math_r2, math_syn, math_rat, writing_brain, writing_r2, writing_syn, writing_rat):
|
795 |
+
return {
|
796 |
+
"default": {
|
797 |
+
"coding": {
|
798 |
+
"brainstorm": code_brain,
|
799 |
+
"round2": code_r2,
|
800 |
+
"synthesis": code_syn,
|
801 |
+
"rationale": code_rat,
|
802 |
+
},
|
803 |
+
"math": {
|
804 |
+
"brainstorm": math_brain,
|
805 |
+
"round2": math_r2,
|
806 |
+
"synthesis": math_syn,
|
807 |
+
"rationale": math_rat,
|
808 |
+
},
|
809 |
+
"writing": {
|
810 |
+
"brainstorm": writing_brain,
|
811 |
+
"round2": writing_r2,
|
812 |
+
"synthesis": writing_syn,
|
813 |
+
"rationale": writing_rat,
|
814 |
+
}
|
815 |
+
},
|
816 |
+
"custom": prompt_state.value.get("custom", {})
|
817 |
+
}
|
818 |
+
save_prompts_btn.click(
|
819 |
+
save_default_prompts,
|
820 |
+
inputs=[prompt_brainstorm_box_code, prompt_round2_box_code, prompt_synthesis_box_code, prompt_rationale_box_code,
|
821 |
+
prompt_brainstorm_box_math, prompt_round2_box_math, prompt_synthesis_box_math, prompt_rationale_box_math,
|
822 |
+
prompt_brainstorm_box_writing, prompt_round2_box_writing, prompt_synthesis_box_writing, prompt_rationale_box_writing],
|
823 |
+
outputs=prompt_state,
|
824 |
+
)
|
825 |
+
with gr.Tab("Custom Modes"):
|
826 |
+
gr.Markdown("<h3>Create / Edit Custom Modes</h3>")
|
827 |
+
gr.Markdown(
|
828 |
+
"Define a custom mode by providing a unique mode name, selecting the number of rounds (up to 10), "
|
829 |
+
"and editing the prompt for each round. In custom mode prompts, you can use the placeholders `{user_prompt}` "
|
830 |
+
"(for the first round) and `{prev_response}` (for subsequent rounds)."
|
831 |
+
)
|
832 |
+
with gr.Row():
|
833 |
+
custom_mode_name = gr.Textbox(label="Custom Mode Name", placeholder="Enter a unique mode name")
|
834 |
+
custom_round_count = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Rounds")
|
835 |
+
custom_round1 = gr.Textbox(label="Round 1 Prompt", lines=4, placeholder="e.g., Use {user_prompt} here")
|
836 |
+
custom_round2 = gr.Textbox(label="Round 2 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}")
|
837 |
+
custom_round3 = gr.Textbox(label="Round 3 Prompt", lines=4, placeholder="e.g., Use {user_prompt} and {prev_response}")
|
838 |
+
custom_round4 = gr.Textbox(label="Round 4 Prompt", lines=4, placeholder="Optional")
|
839 |
+
custom_round5 = gr.Textbox(label="Round 5 Prompt", lines=4, placeholder="Optional")
|
840 |
+
custom_round6 = gr.Textbox(label="Round 6 Prompt", lines=4, placeholder="Optional")
|
841 |
+
custom_round7 = gr.Textbox(label="Round 7 Prompt", lines=4, placeholder="Optional")
|
842 |
+
custom_round8 = gr.Textbox(label="Round 8 Prompt", lines=4, placeholder="Optional")
|
843 |
+
custom_round9 = gr.Textbox(label="Round 9 Prompt", lines=4, placeholder="Optional")
|
844 |
+
custom_round10 = gr.Textbox(label="Round 10 Prompt", lines=4, placeholder="Optional")
|
845 |
+
|
846 |
+
def save_custom_mode(name, round_count, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, current_prompt_state):
|
847 |
+
if not name:
|
848 |
+
return gr.update(), current_prompt_state
|
849 |
+
rounds = []
|
850 |
+
round_prompts = [r1, r2, r3, r4, r5, r6, r7, r8, r9, r10]
|
851 |
+
for i in range(round_count):
|
852 |
+
if round_prompts[i].strip():
|
853 |
+
rounds.append(round_prompts[i])
|
854 |
+
custom_modes = current_prompt_state.get("custom", {})
|
855 |
+
custom_modes[name] = rounds
|
856 |
+
new_prompt_state = {
|
857 |
+
"default": current_prompt_state.get("default", {}),
|
858 |
+
"custom": custom_modes
|
859 |
+
}
|
860 |
+
return gr.update(value=""), new_prompt_state
|
861 |
+
|
862 |
+
save_custom_mode_btn = gr.Button("Save Custom Mode")
|
863 |
+
save_custom_mode_btn.click(
|
864 |
+
save_custom_mode,
|
865 |
+
inputs=[custom_mode_name, custom_round_count, custom_round1, custom_round2, custom_round3, custom_round4,
|
866 |
+
custom_round5, custom_round6, custom_round7, custom_round8, custom_round9, custom_round10,
|
867 |
+
prompt_state],
|
868 |
+
outputs=[custom_mode_name, prompt_state]
|
869 |
+
)
|
870 |
+
|
871 |
+
def update_mode_choices(current_prompt_state):
|
872 |
+
default_modes = list(current_prompt_state.get("default", {}).keys())
|
873 |
+
custom_modes = list(current_prompt_state.get("custom", {}).keys())
|
874 |
+
all_modes = default_modes + custom_modes
|
875 |
+
default_choice = default_modes[0] if default_modes else (custom_modes[0] if custom_modes else "")
|
876 |
+
return gr.update(choices=all_modes, value=default_choice)
|
877 |
+
|
878 |
+
refresh_mode_selector_btn = gr.Button("Refresh Mode List")
|
879 |
+
refresh_mode_selector_btn.click(fn=update_mode_choices, inputs=prompt_state, outputs=mode_selector)
|
880 |
+
|
881 |
+
gr.Markdown("<hr>")
|
882 |
+
gr.Markdown("<p style='text-align: center;'>These settings affect the entire project.</p>")
|
883 |
+
|
884 |
+
gr.Markdown("<hr><p style='text-align: center;'>Agent Chat using DeepSeek Agent Swarm</p>")
|
885 |
+
|
886 |
+
if __name__ == "__main__":
|
887 |
+
demo.launch(share=True)
|