Spaces:
Building
Building
Upload 2 files
Browse files- routes/audio_routes.py +398 -263
- routes/chat_handler.py +142 -54
routes/audio_routes.py
CHANGED
@@ -1,263 +1,398 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
Provides text-to-speech (TTS) and speech-to-text (STT) endpoints.
|
5 |
-
"""
|
6 |
-
|
7 |
-
from fastapi import APIRouter, HTTPException, Response, Body
|
8 |
-
from pydantic import BaseModel
|
9 |
-
from typing import Optional
|
10 |
-
from datetime import datetime
|
11 |
-
import sys
|
12 |
-
|
13 |
-
|
14 |
-
from
|
15 |
-
from tts.
|
16 |
-
from
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
#
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
"""
|
42 |
-
try:
|
43 |
-
#
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
log_info("
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
"
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
"
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
"
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
"
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Audio API endpoints for Flare (Refactored with Event-Driven Architecture)
|
3 |
+
========================================================================
|
4 |
+
Provides text-to-speech (TTS) and speech-to-text (STT) endpoints.
|
5 |
+
"""
|
6 |
+
|
7 |
+
from fastapi import APIRouter, HTTPException, Response, Body, Request
|
8 |
+
from pydantic import BaseModel
|
9 |
+
from typing import Optional
|
10 |
+
from datetime import datetime
|
11 |
+
import sys
|
12 |
+
import base64
|
13 |
+
|
14 |
+
from utils.logger import log_info, log_error, log_warning, log_debug
|
15 |
+
from tts.tts_factory import TTSFactory
|
16 |
+
from tts.tts_preprocessor import TTSPreprocessor
|
17 |
+
from config.config_provider import ConfigProvider
|
18 |
+
|
19 |
+
router = APIRouter(tags=["audio"])
|
20 |
+
|
21 |
+
# ===================== Models =====================
|
22 |
+
class TTSRequest(BaseModel):
|
23 |
+
text: str
|
24 |
+
voice_id: Optional[str] = None
|
25 |
+
language: Optional[str] = "tr-TR"
|
26 |
+
session_id: Optional[str] = None # For event-driven mode
|
27 |
+
|
28 |
+
class STTRequest(BaseModel):
|
29 |
+
audio_data: str # Base64 encoded audio
|
30 |
+
language: Optional[str] = "tr-TR"
|
31 |
+
format: Optional[str] = "webm" # webm, wav, mp3
|
32 |
+
session_id: Optional[str] = None # For event-driven mode
|
33 |
+
|
34 |
+
# ===================== TTS Endpoints =====================
|
35 |
+
@router.post("/tts/generate")
|
36 |
+
async def generate_tts(request: TTSRequest, req: Request):
|
37 |
+
"""
|
38 |
+
Generate TTS audio from text
|
39 |
+
- If session_id is provided and event bus is available, uses event-driven mode
|
40 |
+
- Otherwise, uses direct TTS generation
|
41 |
+
"""
|
42 |
+
try:
|
43 |
+
# Check if we should use event-driven mode
|
44 |
+
if request.session_id and hasattr(req.app.state, 'event_bus'):
|
45 |
+
# Event-driven mode for realtime sessions
|
46 |
+
from event_bus import Event, EventType
|
47 |
+
|
48 |
+
log_info(f"π€ TTS request via event bus for session: {request.session_id}")
|
49 |
+
|
50 |
+
# Publish TTS event
|
51 |
+
await req.app.state.event_bus.publish(Event(
|
52 |
+
type=EventType.TTS_STARTED,
|
53 |
+
session_id=request.session_id,
|
54 |
+
data={
|
55 |
+
"text": request.text,
|
56 |
+
"voice_id": request.voice_id,
|
57 |
+
"language": request.language,
|
58 |
+
"is_api_call": True # Flag to indicate this is from REST API
|
59 |
+
}
|
60 |
+
))
|
61 |
+
|
62 |
+
# Return a response indicating audio will be streamed via WebSocket
|
63 |
+
return {
|
64 |
+
"status": "processing",
|
65 |
+
"message": "TTS audio will be streamed via WebSocket connection",
|
66 |
+
"session_id": request.session_id
|
67 |
+
}
|
68 |
+
|
69 |
+
else:
|
70 |
+
# Direct TTS generation (legacy mode)
|
71 |
+
tts_provider = TTSFactory.create_provider()
|
72 |
+
|
73 |
+
if not tts_provider:
|
74 |
+
log_info("π΅ TTS disabled - returning empty response")
|
75 |
+
return Response(
|
76 |
+
content=b"",
|
77 |
+
media_type="audio/mpeg",
|
78 |
+
headers={"X-TTS-Status": "disabled"}
|
79 |
+
)
|
80 |
+
|
81 |
+
log_info(f"π€ Direct TTS request: '{request.text[:50]}...' with provider: {tts_provider.get_provider_name()}")
|
82 |
+
|
83 |
+
# Preprocess text if needed
|
84 |
+
preprocessor = TTSPreprocessor(language=request.language)
|
85 |
+
processed_text = preprocessor.preprocess(
|
86 |
+
request.text,
|
87 |
+
tts_provider.get_preprocessing_flags()
|
88 |
+
)
|
89 |
+
|
90 |
+
log_debug(f"π Preprocessed text: {processed_text[:100]}...")
|
91 |
+
|
92 |
+
# Generate audio
|
93 |
+
audio_data = await tts_provider.synthesize(
|
94 |
+
text=processed_text,
|
95 |
+
voice_id=request.voice_id
|
96 |
+
)
|
97 |
+
|
98 |
+
log_info(f"β
TTS generated {len(audio_data)} bytes of audio")
|
99 |
+
|
100 |
+
# Return audio as binary response
|
101 |
+
return Response(
|
102 |
+
content=audio_data,
|
103 |
+
media_type="audio/mpeg",
|
104 |
+
headers={
|
105 |
+
"Content-Disposition": 'inline; filename="tts_output.mp3"',
|
106 |
+
"X-TTS-Provider": tts_provider.get_provider_name(),
|
107 |
+
"X-TTS-Language": request.language,
|
108 |
+
"Cache-Control": "no-cache"
|
109 |
+
}
|
110 |
+
)
|
111 |
+
|
112 |
+
except Exception as e:
|
113 |
+
log_error("β TTS generation error", e)
|
114 |
+
raise HTTPException(
|
115 |
+
status_code=500,
|
116 |
+
detail=f"TTS generation failed: {str(e)}"
|
117 |
+
)
|
118 |
+
|
119 |
+
@router.get("/tts/voices")
|
120 |
+
async def get_tts_voices():
|
121 |
+
"""Get available TTS voices"""
|
122 |
+
try:
|
123 |
+
tts_provider = TTSFactory.create_provider()
|
124 |
+
|
125 |
+
if not tts_provider:
|
126 |
+
return {
|
127 |
+
"voices": [],
|
128 |
+
"provider": "none",
|
129 |
+
"enabled": False
|
130 |
+
}
|
131 |
+
|
132 |
+
voices = tts_provider.get_supported_voices()
|
133 |
+
|
134 |
+
# Convert dict to list format
|
135 |
+
voice_list = [
|
136 |
+
{"id": voice_id, "name": voice_name}
|
137 |
+
for voice_id, voice_name in voices.items()
|
138 |
+
]
|
139 |
+
|
140 |
+
return {
|
141 |
+
"voices": voice_list,
|
142 |
+
"provider": tts_provider.get_provider_name(),
|
143 |
+
"enabled": True
|
144 |
+
}
|
145 |
+
|
146 |
+
except Exception as e:
|
147 |
+
log_error("β Error getting TTS voices", e)
|
148 |
+
return {
|
149 |
+
"voices": [],
|
150 |
+
"provider": "error",
|
151 |
+
"enabled": False,
|
152 |
+
"error": str(e)
|
153 |
+
}
|
154 |
+
|
155 |
+
@router.get("/tts/status")
|
156 |
+
async def get_tts_status():
|
157 |
+
"""Get TTS service status"""
|
158 |
+
cfg = ConfigProvider.get()
|
159 |
+
|
160 |
+
return {
|
161 |
+
"enabled": cfg.global_config.tts_provider.name != "no_tts",
|
162 |
+
"provider": cfg.global_config.tts_provider.name,
|
163 |
+
"provider_config": {
|
164 |
+
"name": cfg.global_config.tts_provider.name,
|
165 |
+
"has_api_key": bool(cfg.global_config.tts_provider.api_key),
|
166 |
+
"endpoint": cfg.global_config.tts_provider.endpoint
|
167 |
+
}
|
168 |
+
}
|
169 |
+
|
170 |
+
# ===================== STT Endpoints =====================
|
171 |
+
@router.post("/stt/transcribe")
|
172 |
+
async def transcribe_audio(request: STTRequest, req: Request):
|
173 |
+
"""
|
174 |
+
Transcribe audio to text
|
175 |
+
- If session_id is provided and event bus is available, uses event-driven mode
|
176 |
+
- Otherwise, uses direct STT transcription
|
177 |
+
"""
|
178 |
+
try:
|
179 |
+
# Check if we should use event-driven mode
|
180 |
+
if request.session_id and hasattr(req.app.state, 'event_bus'):
|
181 |
+
# Event-driven mode for realtime sessions
|
182 |
+
from event_bus import Event, EventType
|
183 |
+
|
184 |
+
log_info(f"π§ STT request via event bus for session: {request.session_id}")
|
185 |
+
|
186 |
+
# Publish audio chunk event
|
187 |
+
await req.app.state.event_bus.publish(Event(
|
188 |
+
type=EventType.AUDIO_CHUNK_RECEIVED,
|
189 |
+
session_id=request.session_id,
|
190 |
+
data={
|
191 |
+
"audio_data": request.audio_data, # Already base64
|
192 |
+
"format": request.format,
|
193 |
+
"language": request.language,
|
194 |
+
"is_api_call": True
|
195 |
+
}
|
196 |
+
))
|
197 |
+
|
198 |
+
# Return a response indicating transcription will be available via WebSocket
|
199 |
+
return {
|
200 |
+
"status": "processing",
|
201 |
+
"message": "Transcription will be available via WebSocket connection",
|
202 |
+
"session_id": request.session_id
|
203 |
+
}
|
204 |
+
|
205 |
+
else:
|
206 |
+
# Direct STT transcription (legacy mode)
|
207 |
+
from stt.stt_factory import STTFactory
|
208 |
+
from stt.stt_interface import STTConfig
|
209 |
+
|
210 |
+
# Create STT provider
|
211 |
+
stt_provider = STTFactory.create_provider()
|
212 |
+
|
213 |
+
if not stt_provider or not stt_provider.supports_realtime():
|
214 |
+
log_warning("π΅ STT disabled or doesn't support transcription")
|
215 |
+
raise HTTPException(
|
216 |
+
status_code=503,
|
217 |
+
detail="STT service not available"
|
218 |
+
)
|
219 |
+
|
220 |
+
# Get config
|
221 |
+
cfg = ConfigProvider.get()
|
222 |
+
stt_config = cfg.global_config.stt_provider.settings
|
223 |
+
|
224 |
+
# Decode audio data
|
225 |
+
audio_bytes = base64.b64decode(request.audio_data)
|
226 |
+
|
227 |
+
# Create STT config
|
228 |
+
config = STTConfig(
|
229 |
+
language=request.language or stt_config.get("language", "tr-TR"),
|
230 |
+
sample_rate=16000,
|
231 |
+
encoding=request.format.upper() if request.format else "WEBM_OPUS",
|
232 |
+
enable_punctuation=stt_config.get("enable_punctuation", True),
|
233 |
+
enable_word_timestamps=False,
|
234 |
+
model=stt_config.get("model", "latest_long"),
|
235 |
+
use_enhanced=stt_config.get("use_enhanced", True),
|
236 |
+
single_utterance=True,
|
237 |
+
interim_results=False
|
238 |
+
)
|
239 |
+
|
240 |
+
# Start streaming session
|
241 |
+
await stt_provider.start_streaming(config)
|
242 |
+
|
243 |
+
# Process audio
|
244 |
+
transcription = ""
|
245 |
+
confidence = 0.0
|
246 |
+
|
247 |
+
try:
|
248 |
+
async for result in stt_provider.stream_audio(audio_bytes):
|
249 |
+
if result.is_final:
|
250 |
+
transcription = result.text
|
251 |
+
confidence = result.confidence
|
252 |
+
break
|
253 |
+
finally:
|
254 |
+
# Stop streaming
|
255 |
+
await stt_provider.stop_streaming()
|
256 |
+
|
257 |
+
log_info(f"β
STT transcription completed: '{transcription[:50]}...'")
|
258 |
+
|
259 |
+
return {
|
260 |
+
"text": transcription,
|
261 |
+
"confidence": confidence,
|
262 |
+
"language": request.language,
|
263 |
+
"provider": stt_provider.get_provider_name()
|
264 |
+
}
|
265 |
+
|
266 |
+
except HTTPException:
|
267 |
+
raise
|
268 |
+
except Exception as e:
|
269 |
+
log_error("β STT transcription error", e)
|
270 |
+
raise HTTPException(
|
271 |
+
status_code=500,
|
272 |
+
detail=f"Transcription failed: {str(e)}"
|
273 |
+
)
|
274 |
+
|
275 |
+
@router.get("/stt/languages")
|
276 |
+
async def get_stt_languages():
|
277 |
+
"""Get supported STT languages"""
|
278 |
+
try:
|
279 |
+
from stt.stt_factory import STTFactory
|
280 |
+
|
281 |
+
stt_provider = STTFactory.create_provider()
|
282 |
+
|
283 |
+
if not stt_provider:
|
284 |
+
return {
|
285 |
+
"languages": [],
|
286 |
+
"provider": "none",
|
287 |
+
"enabled": False
|
288 |
+
}
|
289 |
+
|
290 |
+
languages = stt_provider.get_supported_languages()
|
291 |
+
|
292 |
+
return {
|
293 |
+
"languages": languages,
|
294 |
+
"provider": stt_provider.get_provider_name(),
|
295 |
+
"enabled": True
|
296 |
+
}
|
297 |
+
|
298 |
+
except Exception as e:
|
299 |
+
log_error("β Error getting STT languages", e)
|
300 |
+
return {
|
301 |
+
"languages": [],
|
302 |
+
"provider": "error",
|
303 |
+
"enabled": False,
|
304 |
+
"error": str(e)
|
305 |
+
}
|
306 |
+
|
307 |
+
@router.get("/stt/status")
|
308 |
+
async def get_stt_status():
|
309 |
+
"""Get STT service status"""
|
310 |
+
cfg = ConfigProvider.get()
|
311 |
+
|
312 |
+
return {
|
313 |
+
"enabled": cfg.global_config.stt_provider.name != "no_stt",
|
314 |
+
"provider": cfg.global_config.stt_provider.name,
|
315 |
+
"provider_config": {
|
316 |
+
"name": cfg.global_config.stt_provider.name,
|
317 |
+
"has_api_key": bool(cfg.global_config.stt_provider.api_key),
|
318 |
+
"endpoint": cfg.global_config.stt_provider.endpoint
|
319 |
+
}
|
320 |
+
}
|
321 |
+
|
322 |
+
# ===================== WebSocket Audio Stream Endpoint =====================
|
323 |
+
@router.websocket("/ws/audio/{session_id}")
|
324 |
+
async def audio_websocket(websocket: WebSocket, session_id: str, request: Request):
|
325 |
+
"""
|
326 |
+
WebSocket endpoint for streaming audio
|
327 |
+
This is a dedicated audio stream separate from the main conversation WebSocket
|
328 |
+
"""
|
329 |
+
from fastapi import WebSocketDisconnect
|
330 |
+
|
331 |
+
try:
|
332 |
+
await websocket.accept()
|
333 |
+
log_info(f"π΅ Audio WebSocket connected for session: {session_id}")
|
334 |
+
|
335 |
+
if not hasattr(request.app.state, 'event_bus'):
|
336 |
+
await websocket.send_json({
|
337 |
+
"type": "error",
|
338 |
+
"message": "Event bus not initialized"
|
339 |
+
})
|
340 |
+
await websocket.close()
|
341 |
+
return
|
342 |
+
|
343 |
+
while True:
|
344 |
+
try:
|
345 |
+
# Receive audio data
|
346 |
+
data = await websocket.receive_json()
|
347 |
+
|
348 |
+
if data.get("type") == "audio_chunk":
|
349 |
+
# Forward to event bus
|
350 |
+
from event_bus import Event, EventType
|
351 |
+
|
352 |
+
await request.app.state.event_bus.publish(Event(
|
353 |
+
type=EventType.AUDIO_CHUNK_RECEIVED,
|
354 |
+
session_id=session_id,
|
355 |
+
data={
|
356 |
+
"audio_data": data.get("data"),
|
357 |
+
"timestamp": data.get("timestamp"),
|
358 |
+
"chunk_index": data.get("chunk_index", 0)
|
359 |
+
}
|
360 |
+
))
|
361 |
+
|
362 |
+
elif data.get("type") == "control":
|
363 |
+
action = data.get("action")
|
364 |
+
|
365 |
+
if action == "start_recording":
|
366 |
+
from event_bus import Event, EventType
|
367 |
+
|
368 |
+
await request.app.state.event_bus.publish(Event(
|
369 |
+
type=EventType.STT_STARTED,
|
370 |
+
session_id=session_id,
|
371 |
+
data={
|
372 |
+
"language": data.get("language", "tr-TR"),
|
373 |
+
"format": data.get("format", "webm")
|
374 |
+
}
|
375 |
+
))
|
376 |
+
|
377 |
+
elif action == "stop_recording":
|
378 |
+
from event_bus import Event, EventType
|
379 |
+
|
380 |
+
await request.app.state.event_bus.publish(Event(
|
381 |
+
type=EventType.STT_STOPPED,
|
382 |
+
session_id=session_id,
|
383 |
+
data={"reason": "user_request"}
|
384 |
+
))
|
385 |
+
|
386 |
+
except WebSocketDisconnect:
|
387 |
+
break
|
388 |
+
except Exception as e:
|
389 |
+
log_error(f"Error in audio WebSocket", error=str(e))
|
390 |
+
await websocket.send_json({
|
391 |
+
"type": "error",
|
392 |
+
"message": str(e)
|
393 |
+
})
|
394 |
+
|
395 |
+
except Exception as e:
|
396 |
+
log_error(f"Audio WebSocket error", error=str(e))
|
397 |
+
finally:
|
398 |
+
log_info(f"π΅ Audio WebSocket disconnected for session: {session_id}")
|
routes/chat_handler.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
"""
|
2 |
-
Flare β Chat Handler (
|
3 |
-
|
4 |
"""
|
5 |
|
6 |
import re, json, sys, httpx, os
|
7 |
from datetime import datetime
|
8 |
from typing import Dict, List, Optional, Any
|
9 |
-
from fastapi import APIRouter, HTTPException, Header
|
10 |
from pydantic import BaseModel
|
11 |
import requests
|
12 |
|
@@ -66,6 +66,48 @@ def setup_llm_provider():
|
|
66 |
log_error("β Failed to initialize LLM provider", e)
|
67 |
raise
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
# βββββββββββββββββββββββββ PARAMETER EXTRACTION βββββββββββββββββββββββββ #
|
70 |
def _extract_parameters_from_response(raw: str, session: Session, intent_config) -> bool:
|
71 |
"""Extract parameters from the LLM response"""
|
@@ -163,7 +205,7 @@ class ChatRequest(BaseModel):
|
|
163 |
|
164 |
class StartRequest(BaseModel):
|
165 |
project_name: str
|
166 |
-
version_no: Optional[int] = None
|
167 |
is_realtime: bool = False
|
168 |
locale: Optional[str] = None
|
169 |
|
@@ -173,8 +215,8 @@ class ChatResponse(BaseModel):
|
|
173 |
|
174 |
# βββββββββββββββββββββββββ API ENDPOINTS βββββββββββββββββββββββββ #
|
175 |
@router.post("/start_session", response_model=ChatResponse)
|
176 |
-
async def start_session(req: StartRequest):
|
177 |
-
"""Create new session"""
|
178 |
global llm_provider
|
179 |
|
180 |
try:
|
@@ -186,7 +228,6 @@ async def start_session(req: StartRequest):
|
|
186 |
# Determine locale
|
187 |
session_locale = req.locale
|
188 |
if not session_locale:
|
189 |
-
# Use project's default locale
|
190 |
session_locale = project.default_locale
|
191 |
|
192 |
# Validate locale is supported by project
|
@@ -198,58 +239,75 @@ async def start_session(req: StartRequest):
|
|
198 |
|
199 |
# Find version
|
200 |
if req.version_no:
|
201 |
-
# Specific version requested
|
202 |
version = next((v for v in project.versions if v.no == req.version_no), None)
|
203 |
if not version:
|
204 |
raise HTTPException(404, f"Version {req.version_no} not found for project '{req.project_name}'")
|
205 |
else:
|
206 |
-
# Find published version with highest version number
|
207 |
published_versions = [v for v in project.versions if v.published]
|
208 |
if not published_versions:
|
209 |
raise HTTPException(404, f"No published version for project '{req.project_name}'")
|
210 |
-
|
211 |
-
# Sort by version number (no) and get the highest
|
212 |
version = max(published_versions, key=lambda v: v.no)
|
213 |
|
214 |
-
# Create
|
215 |
-
if not llm_provider:
|
216 |
-
from llm.llm_factory import LLMFactory
|
217 |
-
llm_provider = LLMFactory.create_provider()
|
218 |
-
log_info(f"π€ LLM Provider created: {type(llm_provider).__name__}")
|
219 |
-
|
220 |
-
# Create session with version config - PARAMETRE DΓZELTMESΔ°
|
221 |
session = session_store.create_session(
|
222 |
project_name=req.project_name,
|
223 |
version_no=version.no,
|
224 |
is_realtime=req.is_realtime,
|
225 |
locale=session_locale
|
226 |
)
|
227 |
-
|
228 |
-
# Version config'i session'a ekle
|
229 |
session.set_version_config(version)
|
230 |
|
231 |
-
#
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
except HTTPException:
|
255 |
raise
|
@@ -259,18 +317,24 @@ async def start_session(req: StartRequest):
|
|
259 |
|
260 |
@router.post("/chat")
|
261 |
async def chat(req: ChatRequest, x_session_id: str = Header(...)):
|
262 |
-
"""Process chat message"""
|
263 |
try:
|
264 |
# Get session
|
265 |
session = session_store.get_session(x_session_id)
|
266 |
if not session:
|
267 |
-
# Better error message
|
268 |
raise HTTPException(
|
269 |
status_code=404,
|
270 |
detail=get_user_friendly_error("session_not_found")
|
271 |
)
|
272 |
|
273 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
if session.is_expired():
|
275 |
session_store.delete_session(x_session_id)
|
276 |
raise HTTPException(
|
@@ -282,7 +346,6 @@ async def chat(req: ChatRequest, x_session_id: str = Header(...)):
|
|
282 |
session.last_activity = datetime.utcnow().isoformat()
|
283 |
session_store.update_session(session)
|
284 |
|
285 |
-
# Mevcut kod devam ediyor...
|
286 |
# Add user message to history
|
287 |
session.add_message("user", req.message)
|
288 |
log_info(f"π¬ User [{session.session_id[:8]}...]: {req.message}")
|
@@ -317,7 +380,7 @@ async def chat(req: ChatRequest, x_session_id: str = Header(...)):
|
|
317 |
|
318 |
if intent_config:
|
319 |
session.current_intent = intent_name
|
320 |
-
session.intent_config
|
321 |
session.state = "collect_params"
|
322 |
log_info(f"π― Intent detected: {intent_name}")
|
323 |
|
@@ -338,7 +401,6 @@ async def chat(req: ChatRequest, x_session_id: str = Header(...)):
|
|
338 |
return {"response": response, "intent": intent_name, "state": "completed"}
|
339 |
else:
|
340 |
# Need to collect more parameters
|
341 |
-
# Get parameter collection config
|
342 |
collection_config = cfg.global_config.llm_provider.settings.get("parameter_collection_config", {})
|
343 |
max_params = collection_config.get("max_params_per_question", 2)
|
344 |
|
@@ -372,7 +434,7 @@ async def chat(req: ChatRequest, x_session_id: str = Header(...)):
|
|
372 |
|
373 |
elif session.state == "collect_params":
|
374 |
# Continue parameter collection
|
375 |
-
intent_config = session.
|
376 |
|
377 |
# Try to extract parameters from user message
|
378 |
param_prompt = f"""
|
@@ -405,11 +467,9 @@ async def chat(req: ChatRequest, x_session_id: str = Header(...)):
|
|
405 |
return {"response": response, "intent": session.current_intent, "state": "completed"}
|
406 |
else:
|
407 |
# Still need more parameters
|
408 |
-
# Get parameter collection config
|
409 |
collection_config = cfg.global_config.llm_provider.settings.get("parameter_collection_config", {})
|
410 |
max_params = collection_config.get("max_params_per_question", 2)
|
411 |
|
412 |
-
# Decide which parameters to ask
|
413 |
params_to_ask = missing_params[:max_params]
|
414 |
|
415 |
param_prompt = build_parameter_prompt(
|
@@ -436,7 +496,6 @@ async def chat(req: ChatRequest, x_session_id: str = Header(...)):
|
|
436 |
except HTTPException:
|
437 |
raise
|
438 |
except requests.exceptions.Timeout:
|
439 |
-
# Better timeout error
|
440 |
log_error(f"Timeout in chat for session {x_session_id[:8]}")
|
441 |
return {
|
442 |
"response": get_user_friendly_error("llm_timeout"),
|
@@ -447,13 +506,42 @@ async def chat(req: ChatRequest, x_session_id: str = Header(...)):
|
|
447 |
log_error("β Chat error", e)
|
448 |
import traceback
|
449 |
traceback.print_exc()
|
450 |
-
# Better generic error
|
451 |
return {
|
452 |
"response": get_user_friendly_error("internal_error"),
|
453 |
"state": "error",
|
454 |
"error": True
|
455 |
}
|
456 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
def get_user_friendly_error(error_type: str, context: dict = None) -> str:
|
458 |
"""Get user-friendly error messages"""
|
459 |
error_messages = {
|
@@ -522,4 +610,4 @@ def validate_parameter_with_message(param_config, value, locale="tr") -> tuple[b
|
|
522 |
return False, "DeΔer kontrol edilirken bir hata oluΕtu."
|
523 |
|
524 |
# Initialize LLM on module load
|
525 |
-
setup_llm_provider()
|
|
|
1 |
"""
|
2 |
+
Flare β Chat Handler (REST API Only - Realtime moved to Event-Driven)
|
3 |
+
====================================================================
|
4 |
"""
|
5 |
|
6 |
import re, json, sys, httpx, os
|
7 |
from datetime import datetime
|
8 |
from typing import Dict, List, Optional, Any
|
9 |
+
from fastapi import APIRouter, HTTPException, Header, Request
|
10 |
from pydantic import BaseModel
|
11 |
import requests
|
12 |
|
|
|
66 |
log_error("β Failed to initialize LLM provider", e)
|
67 |
raise
|
68 |
|
69 |
+
# βββββββββββββββββββββββββ LLM GENERATION βββββββββββββββββββββββββ #
|
70 |
+
async def llm_generate(s: Session, prompt: str, user_msg: str) -> str:
|
71 |
+
"""Call LLM provider with proper error handling"""
|
72 |
+
global llm_provider
|
73 |
+
|
74 |
+
if llm_provider is None:
|
75 |
+
setup_llm_provider()
|
76 |
+
|
77 |
+
try:
|
78 |
+
# Get version config from session
|
79 |
+
version = s.get_version_config()
|
80 |
+
if not version:
|
81 |
+
# Fallback: get from project config
|
82 |
+
project = next((p for p in cfg.projects if p.name == s.project_name), None)
|
83 |
+
if not project:
|
84 |
+
raise ValueError(f"Project not found: {s.project_name}")
|
85 |
+
version = next((v for v in project.versions if v.published), None)
|
86 |
+
if not version:
|
87 |
+
raise ValueError("No published version found")
|
88 |
+
|
89 |
+
log_info(f"π Calling LLM for session {s.session_id[:8]}...")
|
90 |
+
log_info(f"π Prompt preview (first 200 chars): {prompt[:200]}...")
|
91 |
+
|
92 |
+
history = s.chat_history
|
93 |
+
|
94 |
+
# Call the configured LLM provider
|
95 |
+
raw = await llm_provider.generate(
|
96 |
+
user_input=user_msg,
|
97 |
+
system_prompt=prompt,
|
98 |
+
context=history[-10:] if history else []
|
99 |
+
)
|
100 |
+
|
101 |
+
log_info(f"πͺ LLM raw response: {raw[:100]}...")
|
102 |
+
return raw
|
103 |
+
|
104 |
+
except requests.exceptions.Timeout:
|
105 |
+
log_warning(f"β±οΈ LLM timeout for session {s.session_id[:8]}")
|
106 |
+
raise HTTPException(status_code=504, detail="LLM request timed out")
|
107 |
+
except Exception as e:
|
108 |
+
log_error("β LLM error", e)
|
109 |
+
raise HTTPException(status_code=500, detail=f"LLM error: {str(e)}")
|
110 |
+
|
111 |
# βββββββββββββββββββββββββ PARAMETER EXTRACTION βββββββββββββββββββββββββ #
|
112 |
def _extract_parameters_from_response(raw: str, session: Session, intent_config) -> bool:
|
113 |
"""Extract parameters from the LLM response"""
|
|
|
205 |
|
206 |
class StartRequest(BaseModel):
|
207 |
project_name: str
|
208 |
+
version_no: Optional[int] = None
|
209 |
is_realtime: bool = False
|
210 |
locale: Optional[str] = None
|
211 |
|
|
|
215 |
|
216 |
# βββββββββββββββββββββββββ API ENDPOINTS βββββββββββββββββββββββββ #
|
217 |
@router.post("/start_session", response_model=ChatResponse)
|
218 |
+
async def start_session(req: StartRequest, request: Request):
|
219 |
+
"""Create new session - supports both REST and realtime"""
|
220 |
global llm_provider
|
221 |
|
222 |
try:
|
|
|
228 |
# Determine locale
|
229 |
session_locale = req.locale
|
230 |
if not session_locale:
|
|
|
231 |
session_locale = project.default_locale
|
232 |
|
233 |
# Validate locale is supported by project
|
|
|
239 |
|
240 |
# Find version
|
241 |
if req.version_no:
|
|
|
242 |
version = next((v for v in project.versions if v.no == req.version_no), None)
|
243 |
if not version:
|
244 |
raise HTTPException(404, f"Version {req.version_no} not found for project '{req.project_name}'")
|
245 |
else:
|
|
|
246 |
published_versions = [v for v in project.versions if v.published]
|
247 |
if not published_versions:
|
248 |
raise HTTPException(404, f"No published version for project '{req.project_name}'")
|
|
|
|
|
249 |
version = max(published_versions, key=lambda v: v.no)
|
250 |
|
251 |
+
# Create session
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
session = session_store.create_session(
|
253 |
project_name=req.project_name,
|
254 |
version_no=version.no,
|
255 |
is_realtime=req.is_realtime,
|
256 |
locale=session_locale
|
257 |
)
|
|
|
|
|
258 |
session.set_version_config(version)
|
259 |
|
260 |
+
# For realtime sessions, publish event to start the flow
|
261 |
+
if req.is_realtime and hasattr(request.app.state, 'event_bus'):
|
262 |
+
from event_bus import Event, EventType
|
263 |
+
|
264 |
+
await request.app.state.event_bus.publish(Event(
|
265 |
+
type=EventType.SESSION_STARTED,
|
266 |
+
session_id=session.session_id,
|
267 |
+
data={
|
268 |
+
"session": session,
|
269 |
+
"has_welcome": bool(version.welcome_prompt),
|
270 |
+
"welcome_text": version.welcome_prompt or "HoΕ geldiniz! Size nasΔ±l yardΔ±mcΔ± olabilirim?",
|
271 |
+
"locale": session_locale,
|
272 |
+
"project_name": req.project_name,
|
273 |
+
"version_no": version.no
|
274 |
+
}
|
275 |
+
))
|
276 |
+
|
277 |
+
# For realtime, return minimal response
|
278 |
+
return ChatResponse(
|
279 |
+
session_id=session.session_id,
|
280 |
+
answer="[REALTIME_MODE] Connect via WebSocket to continue."
|
281 |
+
)
|
282 |
+
|
283 |
+
# For REST mode, process welcome prompt normally
|
284 |
+
else:
|
285 |
+
# Create LLM provider if not exists
|
286 |
+
if not llm_provider:
|
287 |
+
from llm.llm_factory import LLMFactory
|
288 |
+
llm_provider = LLMFactory.create_provider()
|
289 |
+
log_info(f"π€ LLM Provider created: {type(llm_provider).__name__}")
|
290 |
+
|
291 |
+
# Process welcome prompt
|
292 |
+
greeting = "HoΕ geldiniz! Size nasΔ±l yardΔ±mcΔ± olabilirim?"
|
293 |
+
if version.welcome_prompt:
|
294 |
+
log_info(f"π Processing welcome prompt for session {session.session_id[:8]}...")
|
295 |
+
try:
|
296 |
+
welcome_result = await llm_provider.generate(
|
297 |
+
user_input="",
|
298 |
+
system_prompt=version.welcome_prompt,
|
299 |
+
context=[]
|
300 |
+
)
|
301 |
+
if welcome_result and welcome_result.strip():
|
302 |
+
greeting = welcome_result.strip()
|
303 |
+
except Exception as e:
|
304 |
+
log_error("β οΈ Welcome prompt processing failed", e)
|
305 |
+
|
306 |
+
session.add_turn("assistant", greeting)
|
307 |
+
|
308 |
+
log_info(f"β
Session created for project '{req.project_name}' version {version.no}")
|
309 |
+
|
310 |
+
return ChatResponse(session_id=session.session_id, answer=greeting)
|
311 |
|
312 |
except HTTPException:
|
313 |
raise
|
|
|
317 |
|
318 |
@router.post("/chat")
|
319 |
async def chat(req: ChatRequest, x_session_id: str = Header(...)):
|
320 |
+
"""Process chat message - REST API only (realtime uses WebSocket)"""
|
321 |
try:
|
322 |
# Get session
|
323 |
session = session_store.get_session(x_session_id)
|
324 |
if not session:
|
|
|
325 |
raise HTTPException(
|
326 |
status_code=404,
|
327 |
detail=get_user_friendly_error("session_not_found")
|
328 |
)
|
329 |
|
330 |
+
# Check if this is a realtime session
|
331 |
+
if session.is_realtime:
|
332 |
+
raise HTTPException(
|
333 |
+
status_code=400,
|
334 |
+
detail="This is a realtime session. Please use WebSocket connection instead."
|
335 |
+
)
|
336 |
+
|
337 |
+
# Session expiry check
|
338 |
if session.is_expired():
|
339 |
session_store.delete_session(x_session_id)
|
340 |
raise HTTPException(
|
|
|
346 |
session.last_activity = datetime.utcnow().isoformat()
|
347 |
session_store.update_session(session)
|
348 |
|
|
|
349 |
# Add user message to history
|
350 |
session.add_message("user", req.message)
|
351 |
log_info(f"π¬ User [{session.session_id[:8]}...]: {req.message}")
|
|
|
380 |
|
381 |
if intent_config:
|
382 |
session.current_intent = intent_name
|
383 |
+
session.set_intent_config(intent_config)
|
384 |
session.state = "collect_params"
|
385 |
log_info(f"π― Intent detected: {intent_name}")
|
386 |
|
|
|
401 |
return {"response": response, "intent": intent_name, "state": "completed"}
|
402 |
else:
|
403 |
# Need to collect more parameters
|
|
|
404 |
collection_config = cfg.global_config.llm_provider.settings.get("parameter_collection_config", {})
|
405 |
max_params = collection_config.get("max_params_per_question", 2)
|
406 |
|
|
|
434 |
|
435 |
elif session.state == "collect_params":
|
436 |
# Continue parameter collection
|
437 |
+
intent_config = session.get_intent_config()
|
438 |
|
439 |
# Try to extract parameters from user message
|
440 |
param_prompt = f"""
|
|
|
467 |
return {"response": response, "intent": session.current_intent, "state": "completed"}
|
468 |
else:
|
469 |
# Still need more parameters
|
|
|
470 |
collection_config = cfg.global_config.llm_provider.settings.get("parameter_collection_config", {})
|
471 |
max_params = collection_config.get("max_params_per_question", 2)
|
472 |
|
|
|
473 |
params_to_ask = missing_params[:max_params]
|
474 |
|
475 |
param_prompt = build_parameter_prompt(
|
|
|
496 |
except HTTPException:
|
497 |
raise
|
498 |
except requests.exceptions.Timeout:
|
|
|
499 |
log_error(f"Timeout in chat for session {x_session_id[:8]}")
|
500 |
return {
|
501 |
"response": get_user_friendly_error("llm_timeout"),
|
|
|
506 |
log_error("β Chat error", e)
|
507 |
import traceback
|
508 |
traceback.print_exc()
|
|
|
509 |
return {
|
510 |
"response": get_user_friendly_error("internal_error"),
|
511 |
"state": "error",
|
512 |
"error": True
|
513 |
}
|
514 |
|
515 |
+
@router.post("/end_session")
|
516 |
+
async def end_session(x_session_id: str = Header(...), request: Request = None):
|
517 |
+
"""End a session - works for both REST and realtime"""
|
518 |
+
try:
|
519 |
+
session = session_store.get_session(x_session_id)
|
520 |
+
if not session:
|
521 |
+
raise HTTPException(404, "Session not found")
|
522 |
+
|
523 |
+
# For realtime sessions, publish end event
|
524 |
+
if session.is_realtime and request and hasattr(request.app.state, 'event_bus'):
|
525 |
+
from event_bus import Event, EventType
|
526 |
+
|
527 |
+
await request.app.state.event_bus.publish(Event(
|
528 |
+
type=EventType.SESSION_ENDED,
|
529 |
+
session_id=x_session_id,
|
530 |
+
data={"reason": "user_request"}
|
531 |
+
))
|
532 |
+
|
533 |
+
# Delete session
|
534 |
+
session_store.delete_session(x_session_id)
|
535 |
+
|
536 |
+
return {"message": "Session ended successfully"}
|
537 |
+
|
538 |
+
except HTTPException:
|
539 |
+
raise
|
540 |
+
except Exception as e:
|
541 |
+
log_error("β Error ending session", e)
|
542 |
+
raise HTTPException(500, f"Failed to end session: {str(e)}")
|
543 |
+
|
544 |
+
# βββββββββββββββββββββββββ HELPER FUNCTIONS βββββββββββββββββββββββββ #
|
545 |
def get_user_friendly_error(error_type: str, context: dict = None) -> str:
|
546 |
"""Get user-friendly error messages"""
|
547 |
error_messages = {
|
|
|
610 |
return False, "DeΔer kontrol edilirken bir hata oluΕtu."
|
611 |
|
612 |
# Initialize LLM on module load
|
613 |
+
setup_llm_provider()
|