Update app.py
Browse files
app.py
CHANGED
@@ -3,11 +3,10 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
3 |
import torch
|
4 |
import spaces # Import the spaces library
|
5 |
|
6 |
-
# Model IDs from Hugging Face Hub (
|
7 |
model_ids = {
|
8 |
"1.5B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
9 |
"7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
10 |
-
"14B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B"
|
11 |
}
|
12 |
|
13 |
# Function to load model and tokenizer (slightly adjusted device_map)
|
@@ -21,7 +20,7 @@ def load_model_and_tokenizer(model_id):
|
|
21 |
)
|
22 |
return model, tokenizer
|
23 |
|
24 |
-
# Load
|
25 |
models = {}
|
26 |
tokenizers = {}
|
27 |
for size, model_id in model_ids.items():
|
@@ -81,29 +80,22 @@ def swarm_agent_sequential_rag(user_prompt):
|
|
81 |
print(f"7B Response:\n{response_7b}")
|
82 |
store_in_memory(f"7B Model Elaborated Response: {response_7b[:200]}...")
|
83 |
|
84 |
-
# 14B Model
|
85 |
-
print("\n[14B Model - Final Refinement] - GPU Accelerated") # Added GPU indication
|
86 |
-
retrieved_memory_14b = retrieve_from_memory(response_7b)
|
87 |
-
context_14b = "\n".join([f"- {mem}" for mem in retrieved_memory_14b]) if retrieved_memory_14b else "No relevant context found in memory."
|
88 |
-
prompt_14b = f"Context from Shared Memory:\n{context_14b}\n\nYou are a high-level reasoner and refiner. Take the following elaborated response and refine it to be a final, well-reasoned, and polished answer, considering the context above. \n\nElaborated Response:\n{response_7b}\n\nFinal Answer:"
|
89 |
-
input_ids_14b = tokenizers["14B"].encode(prompt_14b, return_tensors="pt").to(models["14B"].device)
|
90 |
-
output_14b = models["14B"].generate(input_ids_14b, max_new_tokens=400, temperature=0.6, do_sample=True) # Reverted to original max_new_tokens
|
91 |
-
response_14b = tokenizers["14B"].decode(output_14b[0], skip_special_tokens=True)
|
92 |
-
print(f"14B Response (Final):\n{response_14b}")
|
93 |
|
94 |
-
return
|
95 |
|
96 |
|
97 |
-
# --- Gradio Interface --- (
|
98 |
def gradio_interface(user_prompt):
|
99 |
-
|
|
|
100 |
|
101 |
iface = gr.Interface(
|
102 |
fn=gradio_interface,
|
103 |
inputs=gr.Textbox(lines=5, placeholder="Enter your task here..."),
|
104 |
outputs=gr.Textbox(lines=10, placeholder="Agent Swarm Output will appear here..."),
|
105 |
-
title="DeepSeek Agent Swarm (ZeroGPU Demo)",
|
106 |
-
description="Agent swarm using DeepSeek-R1-Distill models (1.5B, 7B
|
107 |
)
|
108 |
|
109 |
if __name__ == "__main__":
|
|
|
3 |
import torch
|
4 |
import spaces # Import the spaces library
|
5 |
|
6 |
+
# Model IDs from Hugging Face Hub (now only 1.5B and 7B)
|
7 |
model_ids = {
|
8 |
"1.5B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
9 |
"7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
|
|
|
10 |
}
|
11 |
|
12 |
# Function to load model and tokenizer (slightly adjusted device_map)
|
|
|
20 |
)
|
21 |
return model, tokenizer
|
22 |
|
23 |
+
# Load the selected models and tokenizers
|
24 |
models = {}
|
25 |
tokenizers = {}
|
26 |
for size, model_id in model_ids.items():
|
|
|
80 |
print(f"7B Response:\n{response_7b}")
|
81 |
store_in_memory(f"7B Model Elaborated Response: {response_7b[:200]}...")
|
82 |
|
83 |
+
# No 14B Model Stage anymore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
+
return response_7b # Now returns the 7B model's response as final
|
86 |
|
87 |
|
88 |
+
# --- Gradio Interface --- (Modified to reflect 2-model output)
|
89 |
def gradio_interface(user_prompt):
|
90 |
+
final_response = swarm_agent_sequential_rag(user_prompt) # Get the final response (from 7B now)
|
91 |
+
return final_response
|
92 |
|
93 |
iface = gr.Interface(
|
94 |
fn=gradio_interface,
|
95 |
inputs=gr.Textbox(lines=5, placeholder="Enter your task here..."),
|
96 |
outputs=gr.Textbox(lines=10, placeholder="Agent Swarm Output will appear here..."),
|
97 |
+
title="DeepSeek Agent Swarm (ZeroGPU Demo - 2 Models)", # Updated title
|
98 |
+
description="Agent swarm using DeepSeek-R1-Distill models (1.5B, 7B) with shared memory. **GPU accelerated using ZeroGPU!** (Requires Pro Space)", # Updated description
|
99 |
)
|
100 |
|
101 |
if __name__ == "__main__":
|