ciyidogan commited on
Commit
84dee41
·
verified ·
1 Parent(s): 9e188ed

Delete state_orchestrator.py

Browse files
Files changed (1) hide show
  1. state_orchestrator.py +0 -622
state_orchestrator.py DELETED
@@ -1,622 +0,0 @@
1
- """
2
- State Orchestrator for Flare Realtime Chat
3
- ==========================================
4
- Central state machine and flow control
5
- """
6
- import asyncio
7
- from typing import Dict, Optional, Set, Any
8
- from enum import Enum
9
- from datetime import datetime
10
- import traceback
11
- from dataclasses import dataclass, field
12
-
13
- from event_bus import EventBus, Event, EventType, publish_state_transition, publish_error
14
- from session import Session
15
- from utils.logger import log_info, log_error, log_debug, log_warning
16
-
17
-
18
- class ConversationState(Enum):
19
- """Conversation states"""
20
- IDLE = "idle"
21
- INITIALIZING = "initializing"
22
- PREPARING_WELCOME = "preparing_welcome"
23
- PLAYING_WELCOME = "playing_welcome"
24
- LISTENING = "listening"
25
- PROCESSING_SPEECH = "processing_speech"
26
- PREPARING_RESPONSE = "preparing_response"
27
- PLAYING_RESPONSE = "playing_response"
28
- ERROR = "error"
29
- ENDED = "ended"
30
-
31
-
32
- @dataclass
33
- class SessionContext:
34
- """Context for a conversation session"""
35
- session_id: str
36
- session: Session
37
- state: ConversationState = ConversationState.IDLE
38
- stt_instance: Optional[Any] = None
39
- tts_instance: Optional[Any] = None
40
- llm_context: Optional[Any] = None
41
- audio_buffer: Optional[Any] = None
42
- websocket_connection: Optional[Any] = None
43
- created_at: datetime = field(default_factory=datetime.utcnow)
44
- last_activity: datetime = field(default_factory=datetime.utcnow)
45
- metadata: Dict[str, Any] = field(default_factory=dict)
46
-
47
- def update_activity(self):
48
- """Update last activity timestamp"""
49
- self.last_activity = datetime.utcnow()
50
-
51
- async def cleanup(self):
52
- """Cleanup all session resources"""
53
- # Cleanup will be implemented by resource managers
54
- log_debug(f"🧹 Cleaning up session context", session_id=self.session_id)
55
-
56
-
57
- class StateOrchestrator:
58
- """Central state machine for conversation flow"""
59
-
60
- # Valid state transitions
61
- VALID_TRANSITIONS = {
62
- ConversationState.IDLE: {ConversationState.INITIALIZING},
63
- ConversationState.INITIALIZING: {ConversationState.PREPARING_WELCOME, ConversationState.LISTENING},
64
- ConversationState.PREPARING_WELCOME: {ConversationState.PLAYING_WELCOME, ConversationState.ERROR},
65
- ConversationState.PLAYING_WELCOME: {ConversationState.LISTENING, ConversationState.ERROR},
66
- ConversationState.LISTENING: {ConversationState.PROCESSING_SPEECH, ConversationState.ERROR, ConversationState.ENDED},
67
- ConversationState.PROCESSING_SPEECH: {ConversationState.PREPARING_RESPONSE, ConversationState.ERROR},
68
- ConversationState.PREPARING_RESPONSE: {ConversationState.PLAYING_RESPONSE, ConversationState.ERROR},
69
- ConversationState.PLAYING_RESPONSE: {ConversationState.LISTENING, ConversationState.ERROR},
70
- ConversationState.ERROR: {ConversationState.LISTENING, ConversationState.ENDED},
71
- ConversationState.ENDED: set() # No transitions from ENDED
72
- }
73
-
74
- def __init__(self, event_bus: EventBus):
75
- self.event_bus = event_bus
76
- self.sessions: Dict[str, SessionContext] = {}
77
- self._setup_event_handlers()
78
-
79
- def _setup_event_handlers(self):
80
- """Subscribe to relevant events"""
81
-
82
- # Conversation events
83
- self.event_bus.subscribe(EventType.CONVERSATION_STARTED, self._handle_conversation_started)
84
- self.event_bus.subscribe(EventType.CONVERSATION_ENDED, self._handle_conversation_ended)
85
-
86
- # Session lifecycle
87
- self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started)
88
- self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
89
-
90
- # STT events
91
- self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready)
92
- self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result)
93
- self.event_bus.subscribe(EventType.STT_ERROR, self._handle_stt_error)
94
-
95
- # TTS events
96
- self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed)
97
- self.event_bus.subscribe(EventType.TTS_ERROR, self._handle_tts_error)
98
-
99
- # Audio events
100
- self.event_bus.subscribe(EventType.AUDIO_PLAYBACK_COMPLETED, self._handle_audio_playback_completed)
101
-
102
- # LLM events
103
- self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response_ready)
104
- self.event_bus.subscribe(EventType.LLM_ERROR, self._handle_llm_error)
105
-
106
- # Error events
107
- self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_critical_error)
108
-
109
- async def _handle_conversation_started(self, event: Event) -> None:
110
- """Handle conversation start within existing session"""
111
- session_id = event.session_id
112
- context = self.sessions.get(session_id)
113
-
114
- if not context:
115
- log_error(f"❌ Session not found for conversation start | session_id={session_id}")
116
- return
117
-
118
- log_info(f"🎤 Conversation started | session_id={session_id}")
119
-
120
- # İlk olarak IDLE'dan INITIALIZING'e geç
121
- await self.transition_to(session_id, ConversationState.INITIALIZING)
122
-
123
- # Welcome mesajı varsa
124
- if context.metadata.get("has_welcome") and context.metadata.get("welcome_text"):
125
- await self.transition_to(session_id, ConversationState.PREPARING_WELCOME)
126
-
127
- # Request TTS for welcome message
128
- await self.event_bus.publish(Event(
129
- type=EventType.TTS_STARTED,
130
- session_id=session_id,
131
- data={
132
- "text": context.metadata.get("welcome_text", ""),
133
- "is_welcome": True
134
- }
135
- ))
136
- else:
137
- # Welcome yoksa direkt LISTENING'e geç
138
- await self.transition_to(session_id, ConversationState.LISTENING)
139
-
140
- # Start STT
141
- await self.event_bus.publish(
142
- Event(
143
- type=EventType.STT_STARTED,
144
- data={},
145
- session_id=session_id
146
- )
147
- )
148
-
149
- async def _handle_conversation_ended(self, event: Event) -> None:
150
- """Handle conversation end - but keep session alive"""
151
- session_id = event.session_id
152
- context = self.sessions.get(session_id)
153
-
154
- if not context:
155
- log_warning(f"⚠️ Session not found for conversation end | session_id={session_id}")
156
- return
157
-
158
- log_info(f"🔚 Conversation ended | session_id={session_id}")
159
-
160
- # Stop STT if running
161
- await self.event_bus.publish(Event(
162
- type=EventType.STT_STOPPED,
163
- session_id=session_id,
164
- data={"reason": "conversation_ended"}
165
- ))
166
-
167
- # Stop any ongoing TTS
168
- await self.event_bus.publish(Event(
169
- type=EventType.TTS_STOPPED,
170
- session_id=session_id,
171
- data={"reason": "conversation_ended"}
172
- ))
173
-
174
- # Transition back to IDLE - session still alive!
175
- await self.transition_to(session_id, ConversationState.IDLE)
176
-
177
- log_info(f"💤 Session back to IDLE, ready for new conversation | session_id={session_id}")
178
-
179
- async def _handle_session_started(self, event: Event):
180
- """Handle session start"""
181
- session_id = event.session_id
182
- session_data = event.data
183
-
184
- log_info(f"🎬 Session started", session_id=session_id)
185
-
186
- # Create session context
187
- context = SessionContext(
188
- session_id=session_id,
189
- session=session_data.get("session"),
190
- metadata={
191
- "has_welcome": session_data.get("has_welcome", False),
192
- "welcome_text": session_data.get("welcome_text", "")
193
- }
194
- )
195
-
196
- self.sessions[session_id] = context
197
-
198
- # Session başladığında IDLE state'te kalmalı
199
- # Conversation başlayana kadar bekleyeceğiz
200
- # Zaten SessionContext default state'i IDLE
201
- log_info(f"📍 Session created in IDLE state | session_id={session_id}")
202
-
203
- async def _handle_session_ended(self, event: Event):
204
- """Handle session end - complete cleanup"""
205
- session_id = event.session_id
206
-
207
- log_info(f"🏁 Session ended | session_id={session_id}")
208
-
209
- # Get context for cleanup
210
- context = self.sessions.get(session_id)
211
-
212
- if context:
213
- # Try to transition to ENDED if possible
214
- try:
215
- await self.transition_to(session_id, ConversationState.ENDED)
216
- except Exception as e:
217
- log_warning(f"Could not transition to ENDED state: {e}")
218
-
219
- # Stop all components
220
- await self.event_bus.publish(Event(
221
- type=EventType.STT_STOPPED,
222
- session_id=session_id,
223
- data={"reason": "session_ended"}
224
- ))
225
-
226
- await self.event_bus.publish(Event(
227
- type=EventType.TTS_STOPPED,
228
- session_id=session_id,
229
- data={"reason": "session_ended"}
230
- ))
231
-
232
- # Cleanup session context
233
- await context.cleanup()
234
-
235
- # Remove session
236
- self.sessions.pop(session_id, None)
237
-
238
- # Clear event bus session data
239
- self.event_bus.clear_session_data(session_id)
240
-
241
- log_info(f"✅ Session fully cleaned up | session_id={session_id}")
242
-
243
- async def _handle_stt_ready(self, event: Event):
244
- """Handle STT ready signal"""
245
- session_id = event.session_id
246
- current_state = self.get_state(session_id)
247
-
248
- log_debug(f"🎤 STT ready", session_id=session_id, current_state=current_state)
249
-
250
- # Only process if we're expecting STT to be ready
251
- if current_state in [ConversationState.LISTENING, ConversationState.PLAYING_WELCOME]:
252
- # STT is ready, we're already in the right state
253
- pass
254
-
255
- async def _handle_stt_result(self, event: Event):
256
- """Handle STT transcription result"""
257
- session_id = event.session_id
258
- context = self.sessions.get(session_id)
259
-
260
- if not context:
261
- return
262
-
263
- current_state = context.state
264
- result_data = event.data
265
- is_final = result_data.get("is_final", False)
266
-
267
- # Interim result'ları websocket'e gönder ama state değiştirme
268
- if not is_final:
269
- # Sadece log, state değişikliği yok
270
- text = result_data.get("text", "").strip()
271
- if text:
272
- log_debug(f"📝 Interim transcription: '{text}'", session_id=session_id)
273
- return
274
-
275
- # Final result işleme
276
- text = result_data.get("text", "").strip()
277
- if not text:
278
- log_warning(f"⚠️ Empty final transcription", session_id=session_id)
279
- return
280
-
281
- if current_state != ConversationState.LISTENING:
282
- log_warning(
283
- f"⚠️ STT result in unexpected state",
284
- session_id=session_id,
285
- state=current_state.value
286
- )
287
- return
288
-
289
- log_info(f"💬 Final transcription: '{text}'", session_id=session_id)
290
-
291
- # ✅ STT'yi otomatik durdur
292
- await self.event_bus.publish(Event(
293
- type=EventType.STT_STOPPED,
294
- session_id=session_id,
295
- data={"reason": "utterance_completed"}
296
- ))
297
-
298
- # Transition to processing
299
- await self.transition_to(session_id, ConversationState.PROCESSING_SPEECH)
300
-
301
- # Send to LLM
302
- await self.event_bus.publish(Event(
303
- type=EventType.LLM_PROCESSING_STARTED,
304
- session_id=session_id,
305
- data={"text": text}
306
- ))
307
-
308
- async def _handle_llm_response_ready(self, event: Event):
309
- """Handle LLM response"""
310
- session_id = event.session_id
311
- current_state = self.get_state(session_id)
312
-
313
- if current_state != ConversationState.PROCESSING_SPEECH:
314
- log_warning(
315
- f"⚠️ LLM response in unexpected state",
316
- session_id=session_id,
317
- state=current_state
318
- )
319
- return
320
-
321
- response_text = event.data.get("text", "")
322
- log_info(f"🤖 LLM response ready", session_id=session_id, length=len(response_text))
323
-
324
- # Transition to preparing response
325
- await self.transition_to(session_id, ConversationState.PREPARING_RESPONSE)
326
-
327
- # Request TTS
328
- await self.event_bus.publish(Event(
329
- type=EventType.TTS_STARTED,
330
- session_id=session_id,
331
- data={"text": response_text}
332
- ))
333
-
334
- async def _handle_tts_completed(self, event: Event):
335
- """Handle TTS completion"""
336
- session_id = event.session_id
337
- context = self.sessions.get(session_id)
338
-
339
- if not context:
340
- return
341
-
342
- current_state = context.state
343
-
344
- log_info(f"🔊 TTS completed", session_id=session_id, state=current_state.value)
345
-
346
- if current_state == ConversationState.PREPARING_WELCOME:
347
- await self.transition_to(session_id, ConversationState.PLAYING_WELCOME)
348
-
349
- # Welcome audio frontend'te çalınacak, biz sadece state'i güncelliyoruz
350
- # Frontend audio bitince bize audio_playback_completed gönderecek
351
-
352
- elif current_state == ConversationState.PREPARING_RESPONSE:
353
- await self.transition_to(session_id, ConversationState.PLAYING_RESPONSE)
354
-
355
- async def _handle_audio_playback_completed(self, event: Event):
356
- """Handle audio playback completion"""
357
- session_id = event.session_id
358
- context = self.sessions.get(session_id)
359
-
360
- if not context:
361
- return
362
-
363
- current_state = context.state
364
-
365
- log_info(f"🎵 Audio playback completed", session_id=session_id, state=current_state.value)
366
-
367
- if current_state in [ConversationState.PLAYING_WELCOME, ConversationState.PLAYING_RESPONSE]:
368
- # Transition to listening
369
- await self.transition_to(session_id, ConversationState.LISTENING)
370
-
371
- # ✅ STT'yi başlat - tek konuşma modunda
372
- locale = context.metadata.get("locale", "tr")
373
- await self.event_bus.publish(Event(
374
- type=EventType.STT_STARTED,
375
- session_id=session_id,
376
- data={
377
- "locale": locale,
378
- "single_utterance": True, # ✅ Tek konuşma modu
379
- "interim_results": False, # ✅ Sadece final
380
- "speech_timeout_ms": 2000 # 2 saniye sessizlik
381
- }
382
- ))
383
-
384
- # Send STT ready signal to frontend
385
- await self.event_bus.publish(Event(
386
- type=EventType.STT_READY,
387
- session_id=session_id,
388
- data={}
389
- ))
390
-
391
- async def _handle_stt_error(self, event: Event):
392
- """Handle STT errors"""
393
- session_id = event.session_id
394
- error_data = event.data
395
-
396
- log_error(
397
- f"❌ STT error",
398
- session_id=session_id,
399
- error=error_data.get("message")
400
- )
401
-
402
- # Try to recover by transitioning back to listening
403
- current_state = self.get_state(session_id)
404
- if current_state != ConversationState.ENDED:
405
- await self.transition_to(session_id, ConversationState.ERROR)
406
-
407
- # Try recovery after delay
408
- await asyncio.sleep(2.0)
409
-
410
- if self.get_state(session_id) == ConversationState.ERROR:
411
- await self.transition_to(session_id, ConversationState.LISTENING)
412
-
413
- # Restart STT
414
- await self.event_bus.publish(Event(
415
- type=EventType.STT_STARTED,
416
- session_id=session_id,
417
- data={"retry": True}
418
- ))
419
-
420
- async def _handle_tts_error(self, event: Event):
421
- """Handle TTS errors"""
422
- session_id = event.session_id
423
- error_data = event.data
424
-
425
- log_error(
426
- f"❌ TTS error",
427
- session_id=session_id,
428
- error=error_data.get("message")
429
- )
430
-
431
- # Skip TTS and go to listening
432
- current_state = self.get_state(session_id)
433
- if current_state in [ConversationState.PREPARING_WELCOME, ConversationState.PREPARING_RESPONSE]:
434
- await self.transition_to(session_id, ConversationState.LISTENING)
435
-
436
- # Start STT
437
- await self.event_bus.publish(Event(
438
- type=EventType.STT_STARTED,
439
- session_id=session_id,
440
- data={}
441
- ))
442
-
443
- async def _handle_llm_error(self, event: Event):
444
- """Handle LLM errors"""
445
- session_id = event.session_id
446
- error_data = event.data
447
-
448
- log_error(
449
- f"❌ LLM error",
450
- session_id=session_id,
451
- error=error_data.get("message")
452
- )
453
-
454
- # Go back to listening
455
- await self.transition_to(session_id, ConversationState.LISTENING)
456
-
457
- # Start STT
458
- await self.event_bus.publish(Event(
459
- type=EventType.STT_STARTED,
460
- session_id=session_id,
461
- data={}
462
- ))
463
-
464
- async def _handle_critical_error(self, event: Event):
465
- """Handle critical errors"""
466
- session_id = event.session_id
467
- error_data = event.data
468
-
469
- log_error(
470
- f"💥 Critical error",
471
- session_id=session_id,
472
- error=error_data.get("message")
473
- )
474
-
475
- # End session
476
- await self.transition_to(session_id, ConversationState.ENDED)
477
-
478
- # Publish session end event
479
- await self.event_bus.publish(Event(
480
- type=EventType.SESSION_ENDED,
481
- session_id=session_id,
482
- data={"reason": "critical_error"}
483
- ))
484
-
485
- async def transition_to(self, session_id: str, new_state: ConversationState) -> bool:
486
- """
487
- Transition to a new state with validation
488
- """
489
- try:
490
- # Get session context
491
- context = self.sessions.get(session_id)
492
- if not context:
493
- log_info(f"❌ Session not found for state transition | session_id={session_id}")
494
- return False
495
-
496
- # Get current state from context
497
- current_state = context.state
498
-
499
- # Check if transition is valid
500
- if new_state not in self.VALID_TRANSITIONS.get(current_state, set()):
501
- log_info(f"❌ Invalid state transition | session_id={session_id}, current={current_state.value}, requested={new_state.value}")
502
- return False
503
-
504
- # Update state
505
- old_state = current_state
506
- context.state = new_state
507
- context.last_activity = datetime.utcnow()
508
-
509
- log_info(f"✅ State transition | session_id={session_id}, {old_state.value} → {new_state.value}")
510
-
511
- # Emit state transition event with correct field names
512
- await self.event_bus.publish(
513
- Event(
514
- type=EventType.STATE_TRANSITION,
515
- data={
516
- "old_state": old_state.value, # Backend uses old_state/new_state
517
- "new_state": new_state.value,
518
- "timestamp": datetime.utcnow().isoformat()
519
- },
520
- session_id=session_id
521
- )
522
- )
523
-
524
- return True
525
-
526
- except Exception as e:
527
- log_error(f"❌ State transition error | session_id={session_id}", e)
528
- return False
529
-
530
- def get_state(self, session_id: str) -> Optional[ConversationState]:
531
- """Get current state for a session"""
532
- return self.sessions.get(session_id)
533
-
534
- def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]:
535
- """Get session data"""
536
- return self.session_data.get(session_id)
537
-
538
- async def handle_error_recovery(self, session_id: str, error_type: str):
539
- """Handle error recovery strategies"""
540
- context = self.sessions.get(session_id)
541
-
542
- if not context or context.state == ConversationState.ENDED:
543
- return
544
-
545
- log_info(
546
- f"🔧 Attempting error recovery",
547
- session_id=session_id,
548
- error_type=error_type,
549
- current_state=context.state.value
550
- )
551
-
552
- # Update activity
553
- context.update_activity()
554
-
555
- # Define recovery strategies
556
- recovery_strategies = {
557
- "stt_error": self._recover_from_stt_error,
558
- "tts_error": self._recover_from_tts_error,
559
- "llm_error": self._recover_from_llm_error,
560
- "websocket_error": self._recover_from_websocket_error
561
- }
562
-
563
- strategy = recovery_strategies.get(error_type)
564
- if strategy:
565
- await strategy(session_id)
566
- else:
567
- # Default recovery: go to error state then back to listening
568
- await self.transition_to(session_id, ConversationState.ERROR)
569
- await asyncio.sleep(1.0)
570
- await self.transition_to(session_id, ConversationState.LISTENING)
571
-
572
- async def _recover_from_stt_error(self, session_id: str):
573
- """Recover from STT error"""
574
- # Stop STT, wait, restart
575
- await self.event_bus.publish(Event(
576
- type=EventType.STT_STOPPED,
577
- session_id=session_id,
578
- data={"reason": "error_recovery"}
579
- ))
580
-
581
- await asyncio.sleep(2.0)
582
-
583
- await self.transition_to(session_id, ConversationState.LISTENING)
584
-
585
- await self.event_bus.publish(Event(
586
- type=EventType.STT_STARTED,
587
- session_id=session_id,
588
- data={"retry": True}
589
- ))
590
-
591
- async def _recover_from_tts_error(self, session_id: str):
592
- """Recover from TTS error"""
593
- # Skip TTS, go directly to listening
594
- await self.transition_to(session_id, ConversationState.LISTENING)
595
-
596
- await self.event_bus.publish(Event(
597
- type=EventType.STT_STARTED,
598
- session_id=session_id,
599
- data={}
600
- ))
601
-
602
- async def _recover_from_llm_error(self, session_id: str):
603
- """Recover from LLM error"""
604
- # Go back to listening
605
- await self.transition_to(session_id, ConversationState.LISTENING)
606
-
607
- await self.event_bus.publish(Event(
608
- type=EventType.STT_STARTED,
609
- session_id=session_id,
610
- data={}
611
- ))
612
-
613
- async def _recover_from_websocket_error(self, session_id: str):
614
- """Recover from WebSocket error"""
615
- # End session cleanly
616
- await self.transition_to(session_id, ConversationState.ENDED)
617
-
618
- await self.event_bus.publish(Event(
619
- type=EventType.SESSION_ENDED,
620
- session_id=session_id,
621
- data={"reason": "websocket_error"}
622
- ))