Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -18,293 +18,186 @@ except Exception as e:
|
|
18 |
|
19 |
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
54 |
}
|
55 |
}
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
0.7,
|
62 |
-
ge=0.0,
|
63 |
-
le=2.0,
|
64 |
-
description="Controls randomness in the output. Higher values (e.g., 0.8) make the output more random, lower values (e.g., 0.2) make it more focused and deterministic."
|
65 |
-
)
|
66 |
-
max_new_tokens: int = Field(
|
67 |
-
100,
|
68 |
-
ge=1,
|
69 |
-
le=2048,
|
70 |
-
description="Maximum number of tokens to generate"
|
71 |
-
)
|
72 |
-
top_p: float = Field(
|
73 |
-
0.9,
|
74 |
-
ge=0.0,
|
75 |
-
le=1.0,
|
76 |
-
description="Nucleus sampling parameter. Only tokens with cumulative probability < top_p are considered."
|
77 |
-
)
|
78 |
-
top_k: int = Field(
|
79 |
-
50,
|
80 |
-
ge=0,
|
81 |
-
description="Only consider the top k tokens for text generation"
|
82 |
-
)
|
83 |
-
strategy: str = Field(
|
84 |
-
"default",
|
85 |
-
description="Generation strategy to use",
|
86 |
-
examples=["default", "majority_voting", "best_of_n", "beam_search", "dvts"]
|
87 |
-
)
|
88 |
-
num_samples: int = Field(
|
89 |
-
5,
|
90 |
-
ge=1,
|
91 |
-
le=10,
|
92 |
-
description="Number of samples to generate (used in majority_voting and best_of_n strategies)"
|
93 |
-
)
|
94 |
-
|
95 |
-
class GenerationRequest(BaseModel):
|
96 |
-
"""Request model for text generation."""
|
97 |
-
context: Optional[str] = Field(
|
98 |
-
None,
|
99 |
-
description="Additional context to guide the generation",
|
100 |
-
examples=["You are a helpful assistant skilled in Python programming"]
|
101 |
-
)
|
102 |
-
messages: List[ChatMessage] = Field(
|
103 |
-
...,
|
104 |
-
description="Chat history including the current message",
|
105 |
-
min_items=1
|
106 |
-
)
|
107 |
-
config: Optional[GenerationConfig] = Field(
|
108 |
-
None,
|
109 |
-
description="Generation configuration parameters"
|
110 |
-
)
|
111 |
-
stream: bool = Field(
|
112 |
-
False,
|
113 |
-
description="Whether to stream the response token by token"
|
114 |
-
)
|
115 |
-
|
116 |
-
model_config = ConfigDict(
|
117 |
-
json_schema_extra={
|
118 |
-
"example": {
|
119 |
-
"context": "You are a helpful assistant",
|
120 |
-
"messages": [
|
121 |
-
{"role": "user", "content": "What is the capital of France?"}
|
122 |
-
],
|
123 |
-
"config": {
|
124 |
-
"temperature": 0.7,
|
125 |
-
"max_new_tokens": 100
|
126 |
-
},
|
127 |
-
"stream": False
|
128 |
-
}
|
129 |
-
}
|
130 |
-
)
|
131 |
-
|
132 |
-
class GenerationResponse(BaseModel):
|
133 |
-
"""Response model for text generation."""
|
134 |
-
id: str = Field(..., description="Unique generation ID")
|
135 |
-
content: str = Field(..., description="Generated text content")
|
136 |
-
created_at: datetime = Field(
|
137 |
-
default_factory=datetime.now,
|
138 |
-
description="Timestamp of generation"
|
139 |
-
)
|
140 |
-
|
141 |
-
|
142 |
-
# Model and cache management
|
143 |
-
async def get_prm_model_path():
|
144 |
-
"""Download and cache the PRM model."""
|
145 |
-
return await asyncio.to_thread(
|
146 |
-
hf_hub_download,
|
147 |
-
repo_id="tensorblock/Llama3.1-8B-PRM-Mistral-Data-GGUF",
|
148 |
-
filename="Llama3.1-8B-PRM-Mistral-Data-Q4_K_M.gguf"
|
149 |
-
)
|
150 |
-
|
151 |
-
# Initialize generator globally
|
152 |
-
generator = None
|
153 |
-
|
154 |
-
@asynccontextmanager
|
155 |
-
async def lifespan(app: FastAPI):
|
156 |
-
"""Lifecycle management for the FastAPI application."""
|
157 |
-
# Startup: Initialize generator
|
158 |
-
global generator
|
159 |
-
try:
|
160 |
-
prm_model_path = await get_prm_model_path()
|
161 |
-
generator = LlamaGenerator(
|
162 |
-
llama_model_name="meta-llama/Llama-3.2-1B-Instruct",
|
163 |
-
prm_model_path=prm_model_path,
|
164 |
-
default_generation_config=GenerationConfig(
|
165 |
-
max_new_tokens=100,
|
166 |
-
temperature=0.7
|
167 |
-
)
|
168 |
-
)
|
169 |
-
yield
|
170 |
-
finally:
|
171 |
-
# Shutdown: Clean up resources
|
172 |
-
if generator:
|
173 |
-
await asyncio.to_thread(generator.cleanup)
|
174 |
-
|
175 |
-
# FastAPI application
|
176 |
-
app = FastAPI(
|
177 |
-
title="Inference Deluxe Service",
|
178 |
-
description="""
|
179 |
-
A service for generating text using LLaMA models with various generation strategies.
|
180 |
-
|
181 |
-
Generation Strategies:
|
182 |
-
- default: Standard autoregressive generation
|
183 |
-
- majority_voting: Generates multiple responses and selects the most common one
|
184 |
-
- best_of_n: Generates multiple responses and selects the best based on a scoring metric
|
185 |
-
- beam_search: Uses beam search for more coherent text generation
|
186 |
-
- dvts: Dynamic vocabulary tree search for efficient generation
|
187 |
-
""",
|
188 |
-
version="1.0.0",
|
189 |
-
lifespan=lifespan
|
190 |
-
)
|
191 |
-
|
192 |
-
# CORS middleware
|
193 |
-
app.add_middleware(
|
194 |
-
CORSMiddleware,
|
195 |
-
allow_origins=["*"],
|
196 |
-
allow_credentials=True,
|
197 |
-
allow_methods=["*"],
|
198 |
-
allow_headers=["*"],
|
199 |
-
)
|
200 |
-
|
201 |
-
async def get_generator():
|
202 |
-
"""Dependency to get the generator instance."""
|
203 |
-
if not generator:
|
204 |
-
raise HTTPException(
|
205 |
-
status_code=503,
|
206 |
-
detail="Generator not initialized"
|
207 |
-
)
|
208 |
-
return generator
|
209 |
-
|
210 |
-
@app.post(
|
211 |
-
"/generate",
|
212 |
-
response_model=GenerationResponse,
|
213 |
-
tags=["generation"],
|
214 |
-
summary="Generate text response",
|
215 |
-
response_description="Generated text with unique identifier"
|
216 |
-
)
|
217 |
-
async def generate(
|
218 |
-
request: GenerationRequest,
|
219 |
-
generator: Any = Depends(get_generator)
|
220 |
-
):
|
221 |
-
"""
|
222 |
-
Generate a text response based on the provided context and chat history.
|
223 |
-
"""
|
224 |
-
try:
|
225 |
-
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
|
226 |
-
user_input = request.messages[-1].content
|
227 |
-
|
228 |
-
# Extract or set defaults for additional arguments
|
229 |
-
config = request.config or GenerationConfig()
|
230 |
-
model_kwargs = {
|
231 |
-
"temperature": config.temperature if hasattr(config, "temperature") else 0.7,
|
232 |
-
"max_new_tokens": config.max_new_tokens if hasattr(config, "max_new_tokens") else 100,
|
233 |
-
# Add other model kwargs as needed
|
234 |
-
}
|
235 |
-
|
236 |
-
# Explicitly pass additional required arguments
|
237 |
-
response = await asyncio.to_thread(
|
238 |
-
generator.generate_with_context,
|
239 |
-
context=request.context or "",
|
240 |
-
user_input=user_input,
|
241 |
-
chat_history=chat_history,
|
242 |
-
model_kwargs=model_kwargs,
|
243 |
-
max_history_turns=config.max_history_turns if hasattr(config, "max_history_turns") else 3,
|
244 |
-
strategy=config.strategy if hasattr(config, "strategy") else "default",
|
245 |
-
num_samples=config.num_samples if hasattr(config, "num_samples") else 5,
|
246 |
-
depth=config.depth if hasattr(config, "depth") else 3,
|
247 |
-
breadth=config.breadth if hasattr(config, "breadth") else 2,
|
248 |
-
)
|
249 |
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
256 |
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
270 |
await websocket.accept()
|
271 |
-
|
272 |
try:
|
273 |
while True:
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
user_input = request.messages[-1].content
|
279 |
-
|
280 |
-
config = request.config or GenerationConfig()
|
281 |
-
|
282 |
-
async for token in generator.generate_stream(
|
283 |
-
prompt=generator.prompt_builder.format(
|
284 |
-
context=request.context or "",
|
285 |
-
user_input=user_input,
|
286 |
-
chat_history=chat_history
|
287 |
-
),
|
288 |
-
config=config
|
289 |
-
):
|
290 |
-
await websocket.send_text(json.dumps({
|
291 |
-
"token": token,
|
292 |
-
"finished": False
|
293 |
-
}))
|
294 |
-
|
295 |
-
await websocket.send_text(json.dumps({
|
296 |
-
"token": "",
|
297 |
-
"finished": True
|
298 |
-
}))
|
299 |
-
|
300 |
-
except Exception as e:
|
301 |
-
await websocket.send_text(json.dumps({
|
302 |
-
"error": str(e)
|
303 |
-
}))
|
304 |
-
finally:
|
305 |
-
await websocket.close()
|
306 |
|
307 |
-
|
308 |
-
|
309 |
-
|
|
|
310 |
|
|
|
|
|
|
18 |
|
19 |
|
20 |
|
21 |
+
########
|
22 |
+
|
23 |
+
html = """
|
24 |
+
<!DOCTYPE html>
|
25 |
+
<html lang="en">
|
26 |
+
<head>
|
27 |
+
<meta charset="UTF-8">
|
28 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
29 |
+
<title>AI State Machine</title>
|
30 |
+
<style>
|
31 |
+
body { font-family: Arial, sans-serif; text-align: center; }
|
32 |
+
#chatbox { width: 80%; height: 300px; border: 1px solid #ccc; overflow-y: auto; margin: 20px auto; padding: 10px; }
|
33 |
+
#inputbox { width: 70%; padding: 5px; }
|
34 |
+
button { padding: 5px 10px; }
|
35 |
+
</style>
|
36 |
+
</head>
|
37 |
+
<body>
|
38 |
+
<h2>AI State Machine</h2>
|
39 |
+
<div id="chatbox"></div>
|
40 |
+
<input type="text" id="inputbox" placeholder="Type your message...">
|
41 |
+
<button onclick="sendMessage()">Send</button>
|
42 |
+
|
43 |
+
<script>
|
44 |
+
let ws = new WebSocket("ws://localhost:8000/ws");
|
45 |
+
|
46 |
+
ws.onmessage = (event) => {
|
47 |
+
let chatbox = document.getElementById("chatbox");
|
48 |
+
chatbox.innerHTML += `<p>${event.data}</p>`;
|
49 |
+
chatbox.scrollTop = chatbox.scrollHeight;
|
50 |
+
};
|
51 |
+
|
52 |
+
function sendMessage() {
|
53 |
+
let input = document.getElementById("inputbox");
|
54 |
+
let message = input.value.trim();
|
55 |
+
if (message) {
|
56 |
+
ws.send(message);
|
57 |
+
input.value = "";
|
58 |
}
|
59 |
}
|
60 |
+
</script>
|
61 |
+
</body>
|
62 |
+
</html>
|
63 |
+
"""
|
64 |
+
######
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
+
import asyncio
|
67 |
+
import queue
|
68 |
+
import threading
|
69 |
+
import random
|
70 |
+
import time
|
71 |
+
from fastapi import FastAPI, WebSocket
|
72 |
+
from fastapi.responses import HTMLResponse
|
73 |
+
import uvicorn
|
74 |
+
|
75 |
+
# FastAPI App
|
76 |
+
app = FastAPI()
|
77 |
+
|
78 |
+
class AIStateManager:
|
79 |
+
def __init__(self):
|
80 |
+
self.state = "awake"
|
81 |
+
self.msg_queue = queue.Queue()
|
82 |
+
self.heartbeat_count = 0
|
83 |
+
self.lock = threading.Lock()
|
84 |
+
self.clients = set()
|
85 |
+
|
86 |
+
# Research Task List
|
87 |
+
self.research_tasks = ["Explore AI Ethics", "Find latest AI models", "Investigate quantum computing"]
|
88 |
+
self.current_task = None
|
89 |
+
|
90 |
+
def set_state(self, new_state):
|
91 |
+
with self.lock:
|
92 |
+
print(f"[STATE CHANGE] {self.state} → {new_state}")
|
93 |
+
self.state = new_state
|
94 |
+
self.add_message("system", f"State changed to {new_state}")
|
95 |
+
|
96 |
+
def receive_message(self, sender, message):
|
97 |
+
"""Adds messages to queue and resets heartbeat if human sends input."""
|
98 |
+
with self.lock:
|
99 |
+
self.msg_queue.put((sender, message))
|
100 |
+
if sender == "human":
|
101 |
+
self.heartbeat_count = 0
|
102 |
+
if self.state != "awake":
|
103 |
+
self.set_state("awake")
|
104 |
+
|
105 |
+
def add_message(self, sender, message):
|
106 |
+
"""AI or System can add messages to queue."""
|
107 |
+
with self.lock:
|
108 |
+
self.msg_queue.put((sender, message))
|
109 |
+
|
110 |
+
def process_messages(self):
|
111 |
+
"""Processes messages in queue."""
|
112 |
+
while not self.msg_queue.empty():
|
113 |
+
sender, message = self.msg_queue.get()
|
114 |
+
print(f"[{sender.upper()}] {message}")
|
115 |
+
asyncio.create_task(self.broadcast(f"[{sender.upper()}] {message}"))
|
116 |
+
|
117 |
+
async def broadcast(self, message):
|
118 |
+
"""Sends message to all connected WebSocket clients."""
|
119 |
+
for ws in self.clients:
|
120 |
+
await ws.send_text(message)
|
121 |
+
|
122 |
+
async def heartbeat(self):
|
123 |
+
"""Basic heartbeat loop."""
|
124 |
+
while True:
|
125 |
+
await asyncio.sleep(1) # One 'beat'
|
126 |
+
with self.lock:
|
127 |
+
self.heartbeat_count += 1
|
128 |
+
self.process_messages()
|
129 |
|
130 |
+
async def consciousness(self):
|
131 |
+
"""Controls research/sleep cycle."""
|
132 |
+
while True:
|
133 |
+
await asyncio.sleep(2) # Consciousness checks every 2 sec
|
134 |
+
with self.lock:
|
135 |
+
if self.state == "awake":
|
136 |
+
if self.heartbeat_count >= 5:
|
137 |
+
self.set_state("research")
|
138 |
+
asyncio.create_task(self.run_research())
|
139 |
+
|
140 |
+
elif self.state == "research":
|
141 |
+
if not self.research_tasks: # If no tasks left, move to sleep
|
142 |
+
self.set_state("sleeping")
|
143 |
+
asyncio.create_task(self.run_sleeping())
|
144 |
+
|
145 |
+
elif self.state == "sleeping":
|
146 |
+
if self.heartbeat_count >= 20:
|
147 |
+
self.set_state("research") # Restart research after sleep
|
148 |
+
asyncio.create_task(self.run_research())
|
149 |
+
|
150 |
+
async def run_research(self):
|
151 |
+
"""Runs research tasks in order."""
|
152 |
+
while self.state == "research" and self.research_tasks:
|
153 |
+
await asyncio.sleep(3)
|
154 |
+
self.current_task = self.research_tasks.pop(0)
|
155 |
+
self.add_message("ai", f"Researching: {self.current_task}")
|
156 |
+
if random.random() < 0.3: # AI might ask a follow-up question
|
157 |
+
self.add_message("ai", f"Question: What do you think about {self.current_task}?")
|
158 |
+
self.process_messages()
|
159 |
+
|
160 |
+
async def run_sleeping(self):
|
161 |
+
"""Runs self-training until state changes."""
|
162 |
+
while self.state == "sleeping":
|
163 |
+
await asyncio.sleep(5)
|
164 |
+
self.add_message("system", "Self-training in progress...")
|
165 |
+
self.process_messages()
|
166 |
+
|
167 |
+
def modify_research_tasks(self):
|
168 |
+
"""Background process to edit research tasks dynamically (subconscious)."""
|
169 |
+
while True:
|
170 |
+
time.sleep(10) # Runs every 10 seconds
|
171 |
+
with self.lock:
|
172 |
+
if self.state == "research" and random.random() < 0.5:
|
173 |
+
new_task = f"Investigate {random.choice(['AI Bias', 'Neural Networks', 'Data Privacy'])}"
|
174 |
+
self.research_tasks.append(new_task)
|
175 |
+
self.add_message("system", f"New research task added: {new_task}")
|
176 |
+
|
177 |
+
# Initialize AI Manager
|
178 |
+
ai_manager = AIStateManager()
|
179 |
+
|
180 |
+
# Start Heartbeat and Consciousness in separate threads
|
181 |
+
threading.Thread(target=lambda: asyncio.run(ai_manager.heartbeat()), daemon=True).start()
|
182 |
+
threading.Thread(target=lambda: asyncio.run(ai_manager.consciousness()), daemon=True).start()
|
183 |
+
threading.Thread(target=ai_manager.modify_research_tasks, daemon=True).start()
|
184 |
+
|
185 |
+
@app.websocket("/ws")
|
186 |
+
async def websocket_endpoint(websocket: WebSocket):
|
187 |
+
"""WebSocket connection handler."""
|
188 |
await websocket.accept()
|
189 |
+
ai_manager.clients.add(websocket)
|
190 |
try:
|
191 |
while True:
|
192 |
+
data = await websocket.receive_text()
|
193 |
+
ai_manager.receive_message("human", data)
|
194 |
+
except Exception:
|
195 |
+
ai_manager.clients.remove(websocket)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
|
197 |
+
@app.get("/")
|
198 |
+
async def get():
|
199 |
+
"""Serve frontend HTML."""
|
200 |
+
return HTMLResponse(html)
|
201 |
|
202 |
+
if __name__ == "__main__":
|
203 |
+
uvicorn.run(app, host="localhost", port=8000)
|