""" State Orchestrator for Flare Realtime Chat ========================================== Central state machine and flow control """ import asyncio from typing import Dict, Optional, Set, Any from enum import Enum from datetime import datetime import traceback from dataclasses import dataclass, field from .event_bus import EventBus, Event, EventType, publish_state_transition, publish_error from .session import Session from utils.logger import log_info, log_error, log_debug, log_warning class ConversationState(Enum): """Conversation states""" IDLE = "idle" INITIALIZING = "initializing" PREPARING_WELCOME = "preparing_welcome" PLAYING_WELCOME = "playing_welcome" LISTENING = "listening" PROCESSING_SPEECH = "processing_speech" PREPARING_RESPONSE = "preparing_response" PLAYING_RESPONSE = "playing_response" ERROR = "error" ENDED = "ended" @dataclass class SessionContext: """Context for a conversation session""" session_id: str session: Session state: ConversationState = ConversationState.IDLE stt_instance: Optional[Any] = None tts_instance: Optional[Any] = None llm_context: Optional[Any] = None audio_buffer: Optional[Any] = None websocket_connection: Optional[Any] = None created_at: datetime = field(default_factory=datetime.utcnow) last_activity: datetime = field(default_factory=datetime.utcnow) metadata: Dict[str, Any] = field(default_factory=dict) def update_activity(self): """Update last activity timestamp""" self.last_activity = datetime.utcnow() async def cleanup(self): """Cleanup all session resources""" # Cleanup will be implemented by resource managers log_debug(f"🧹 Cleaning up session context", session_id=self.session_id) class StateOrchestrator: """Central state machine for conversation flow""" # Valid state transitions VALID_TRANSITIONS = { ConversationState.IDLE: {ConversationState.INITIALIZING}, ConversationState.INITIALIZING: {ConversationState.PREPARING_WELCOME, ConversationState.LISTENING}, ConversationState.PREPARING_WELCOME: {ConversationState.PLAYING_WELCOME, ConversationState.ERROR}, ConversationState.PLAYING_WELCOME: {ConversationState.LISTENING, ConversationState.ERROR}, ConversationState.LISTENING: {ConversationState.PROCESSING_SPEECH, ConversationState.ERROR, ConversationState.ENDED}, ConversationState.PROCESSING_SPEECH: {ConversationState.PREPARING_RESPONSE, ConversationState.ERROR}, ConversationState.PREPARING_RESPONSE: {ConversationState.PLAYING_RESPONSE, ConversationState.ERROR}, ConversationState.PLAYING_RESPONSE: {ConversationState.LISTENING, ConversationState.ERROR}, ConversationState.ERROR: {ConversationState.LISTENING, ConversationState.ENDED}, ConversationState.ENDED: set() # No transitions from ENDED } def __init__(self, event_bus: EventBus): self.event_bus = event_bus self.sessions: Dict[str, SessionContext] = {} self._setup_event_handlers() def _setup_event_handlers(self): """Subscribe to relevant events""" # Conversation events self.event_bus.subscribe(EventType.CONVERSATION_STARTED, self._handle_conversation_started) self.event_bus.subscribe(EventType.CONVERSATION_ENDED, self._handle_conversation_ended) # Session lifecycle self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started) self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) # ✅ WebSocket events self.event_bus.subscribe(EventType.WEBSOCKET_DISCONNECTED, self._handle_websocket_disconnected) # STT events self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready) self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result) self.event_bus.subscribe(EventType.STT_ERROR, self._handle_stt_error) # TTS events self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed) self.event_bus.subscribe(EventType.TTS_ERROR, self._handle_tts_error) # Audio events self.event_bus.subscribe(EventType.AUDIO_PLAYBACK_COMPLETED, self._handle_audio_playback_completed) # LLM events self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response_ready) self.event_bus.subscribe(EventType.LLM_ERROR, self._handle_llm_error) # Error events self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_critical_error) async def _handle_websocket_disconnected(self, event: Event): """Handle WebSocket disconnection""" session_id = event.session_id context = self.sessions.get(session_id) if not context: return log_info(f"🔌 Handling WebSocket disconnect | session_id={session_id}, state={context.state.value}") # Eğer conversation aktifse, önce conversation'ı sonlandır if context.state not in [ConversationState.IDLE, ConversationState.ENDED]: await self._handle_conversation_ended(Event( type=EventType.CONVERSATION_ENDED, session_id=session_id, data={"reason": "websocket_disconnected"} )) async def _handle_conversation_started(self, event: Event) -> None: """Handle conversation start within existing session""" session_id = event.session_id context = self.sessions.get(session_id) if not context: log_error(f"❌ Session not found for conversation start | session_id={session_id}") return log_info(f"🎤 Conversation started | session_id={session_id}") # İlk olarak IDLE'dan INITIALIZING'e geç await self.transition_to(session_id, ConversationState.INITIALIZING) # Welcome mesajı varsa if context.metadata.get("has_welcome") and context.metadata.get("welcome_text"): await self.transition_to(session_id, ConversationState.PREPARING_WELCOME) # Request TTS for welcome message await self.event_bus.publish(Event( type=EventType.TTS_STARTED, session_id=session_id, data={ "text": context.metadata.get("welcome_text", ""), "is_welcome": True } )) else: # Welcome yoksa direkt LISTENING'e geç await self.transition_to(session_id, ConversationState.LISTENING) # Start STT await self.event_bus.publish( Event( type=EventType.STT_STARTED, data={}, session_id=session_id ) ) async def _handle_conversation_ended(self, event: Event) -> None: """Handle conversation end - but keep session alive""" session_id = event.session_id context = self.sessions.get(session_id) if not context: log_warning(f"⚠️ Session not found for conversation end | session_id={session_id}") return log_info(f"🔚 Conversation ended | session_id={session_id}") # Stop STT if running await self.event_bus.publish(Event( type=EventType.STT_STOPPED, session_id=session_id, data={"reason": "conversation_ended"} )) # Stop any ongoing TTS await self.event_bus.publish(Event( type=EventType.TTS_STOPPED, session_id=session_id, data={"reason": "conversation_ended"} )) # Transition back to IDLE - session still alive! await self.transition_to(session_id, ConversationState.IDLE) log_info(f"💤 Session back to IDLE, ready for new conversation | session_id={session_id}") async def _handle_session_started(self, event: Event): """Handle session start""" session_id = event.session_id session_data = event.data log_info(f"🎬 Session started", session_id=session_id) # Create session context context = SessionContext( session_id=session_id, session=session_data.get("session"), metadata={ "has_welcome": session_data.get("has_welcome", False), "welcome_text": session_data.get("welcome_text", "") } ) self.sessions[session_id] = context # Session başladığında IDLE state'te kalmalı # Conversation başlayana kadar bekleyeceğiz # Zaten SessionContext default state'i IDLE log_info(f"📍 Session created in IDLE state | session_id={session_id}") async def _handle_session_ended(self, event: Event): """Handle session end - complete cleanup""" session_id = event.session_id log_info(f"🏁 Session ended | session_id={session_id}") # Get context for cleanup context = self.sessions.get(session_id) if context: # Try to transition to ENDED if possible try: await self.transition_to(session_id, ConversationState.ENDED) except Exception as e: log_warning(f"Could not transition to ENDED state: {e}") # Stop all components await self.event_bus.publish(Event( type=EventType.STT_STOPPED, session_id=session_id, data={"reason": "session_ended"} )) await self.event_bus.publish(Event( type=EventType.TTS_STOPPED, session_id=session_id, data={"reason": "session_ended"} )) # Cleanup session context await context.cleanup() # Remove session self.sessions.pop(session_id, None) # Clear event bus session data self.event_bus.clear_session_data(session_id) log_info(f"✅ Session fully cleaned up | session_id={session_id}") async def _handle_stt_ready(self, event: Event): """Handle STT ready signal""" session_id = event.session_id current_state = self.get_state(session_id) log_debug(f"🎤 STT ready", session_id=session_id, current_state=current_state) # Only process if we're expecting STT to be ready if current_state in [ConversationState.LISTENING, ConversationState.PLAYING_WELCOME]: # STT is ready, we're already in the right state pass async def _handle_stt_result(self, event: Event): """Handle STT transcription result""" session_id = event.session_id context = self.sessions.get(session_id) if not context: return current_state = context.state result_data = event.data # Batch mode'da her zaman final result gelir text = result_data.get("text", "").strip() if current_state != ConversationState.LISTENING: log_warning( f"⚠️ STT result in unexpected state", session_id=session_id, state=current_state.value ) return if text: log_info(f"💬 Transcription: '{text}'", session_id=session_id) else: log_info(f"💬 No speech detected", session_id=session_id) # Boş transcript'te STT'yi yeniden başlat await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) return # Transition to processing await self.transition_to(session_id, ConversationState.PROCESSING_SPEECH) # Send to LLM await self.event_bus.publish(Event( type=EventType.LLM_PROCESSING_STARTED, session_id=session_id, data={"text": text} )) async def _handle_llm_response_ready(self, event: Event): """Handle LLM response""" session_id = event.session_id current_state = self.get_state(session_id) # Debug için context'i direkt alalım context = self.sessions.get(session_id) if not context: log_error(f"❌ No context for LLM response", session_id=session_id) return current_state = context.state log_debug(f"🔍 LLM response handler - current state: {current_state.value}") if current_state != ConversationState.PROCESSING_SPEECH: log_warning( f"⚠️ LLM response in unexpected state", session_id=session_id, state=current_state ) return response_text = event.data.get("text", "") log_info(f"🤖 LLM response ready", session_id=session_id, length=len(response_text)) # Transition to preparing response await self.transition_to(session_id, ConversationState.PREPARING_RESPONSE) # Request TTS await self.event_bus.publish(Event( type=EventType.TTS_STARTED, session_id=session_id, data={"text": response_text} )) async def _handle_tts_completed(self, event: Event): """Handle TTS completion""" session_id = event.session_id context = self.sessions.get(session_id) if not context: return current_state = context.state log_info(f"🔊 TTS completed", session_id=session_id, state=current_state.value) if current_state == ConversationState.PREPARING_WELCOME: await self.transition_to(session_id, ConversationState.PLAYING_WELCOME) # Welcome audio frontend'te çalınacak, biz sadece state'i güncelliyoruz # Frontend audio bitince bize audio_playback_completed gönderecek elif current_state == ConversationState.PREPARING_RESPONSE: await self.transition_to(session_id, ConversationState.PLAYING_RESPONSE) async def _handle_audio_playback_completed(self, event: Event): """Handle audio playback completion""" session_id = event.session_id context = self.sessions.get(session_id) if not context: return current_state = context.state log_info(f"🎵 Audio playback completed", session_id=session_id, state=current_state.value) if current_state in [ConversationState.PLAYING_WELCOME, ConversationState.PLAYING_RESPONSE]: # Transition to listening await self.transition_to(session_id, ConversationState.LISTENING) # ✅ STT'yi başlat - batch mode için yeni utterance locale = context.metadata.get("locale", "tr") await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={ "locale": locale, "batch_mode": True, # ✅ Batch mode aktif "silence_threshold_ms": 1000 # 1 saniye sessizlik } )) # Send STT ready signal to frontend await self.event_bus.publish(Event( type=EventType.STT_READY, session_id=session_id, data={} )) async def _handle_stt_error(self, event: Event): """Handle STT errors""" session_id = event.session_id error_data = event.data log_error( f"❌ STT error", session_id=session_id, error=error_data.get("message") ) # Try to recover by transitioning back to listening current_state = self.get_state(session_id) if current_state != ConversationState.ENDED: await self.transition_to(session_id, ConversationState.ERROR) # Try recovery after delay await asyncio.sleep(2.0) if self.get_state(session_id) == ConversationState.ERROR: await self.transition_to(session_id, ConversationState.LISTENING) # Restart STT await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={"retry": True} )) async def _handle_tts_error(self, event: Event): """Handle TTS errors""" session_id = event.session_id error_data = event.data log_error( f"❌ TTS error", session_id=session_id, error=error_data.get("message") ) # Skip TTS and go to listening current_state = self.get_state(session_id) if current_state in [ConversationState.PREPARING_WELCOME, ConversationState.PREPARING_RESPONSE]: await self.transition_to(session_id, ConversationState.LISTENING) # Start STT await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) async def _handle_llm_error(self, event: Event): """Handle LLM errors""" session_id = event.session_id error_data = event.data log_error( f"❌ LLM error", session_id=session_id, error=error_data.get("message") ) # Go back to listening await self.transition_to(session_id, ConversationState.LISTENING) # Start STT await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) async def _handle_critical_error(self, event: Event): """Handle critical errors""" session_id = event.session_id error_data = event.data log_error( f"💥 Critical error", session_id=session_id, error=error_data.get("message") ) # End session await self.transition_to(session_id, ConversationState.ENDED) # Publish session end event await self.event_bus.publish(Event( type=EventType.SESSION_ENDED, session_id=session_id, data={"reason": "critical_error"} )) async def transition_to(self, session_id: str, new_state: ConversationState) -> bool: """ Transition to a new state with validation """ try: # Get session context context = self.sessions.get(session_id) if not context: log_info(f"❌ Session not found for state transition | session_id={session_id}") return False # Get current state from context current_state = context.state # Check if transition is valid if new_state not in self.VALID_TRANSITIONS.get(current_state, set()): log_info(f"❌ Invalid state transition | session_id={session_id}, current={current_state.value}, requested={new_state.value}") return False # Update state old_state = current_state context.state = new_state context.last_activity = datetime.utcnow() log_info(f"✅ State transition | session_id={session_id}, {old_state.value} → {new_state.value}") # Emit state transition event with correct field names await self.event_bus.publish( Event( type=EventType.STATE_TRANSITION, data={ "old_state": old_state.value, # Backend uses old_state/new_state "new_state": new_state.value, "timestamp": datetime.utcnow().isoformat() }, session_id=session_id ) ) return True except Exception as e: log_error(f"❌ State transition error | session_id={session_id}", e) return False def get_state(self, session_id: str) -> Optional[ConversationState]: """Get current state for a session""" return self.sessions.get(session_id) def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: """Get session data""" return self.session_data.get(session_id) async def handle_error_recovery(self, session_id: str, error_type: str): """Handle error recovery strategies""" context = self.sessions.get(session_id) if not context or context.state == ConversationState.ENDED: return log_info( f"🔧 Attempting error recovery", session_id=session_id, error_type=error_type, current_state=context.state.value ) # Update activity context.update_activity() # Define recovery strategies recovery_strategies = { "stt_error": self._recover_from_stt_error, "tts_error": self._recover_from_tts_error, "llm_error": self._recover_from_llm_error, "websocket_error": self._recover_from_websocket_error } strategy = recovery_strategies.get(error_type) if strategy: await strategy(session_id) else: # Default recovery: go to error state then back to listening await self.transition_to(session_id, ConversationState.ERROR) await asyncio.sleep(1.0) await self.transition_to(session_id, ConversationState.LISTENING) async def _recover_from_stt_error(self, session_id: str): """Recover from STT error""" # Stop STT, wait, restart await self.event_bus.publish(Event( type=EventType.STT_STOPPED, session_id=session_id, data={"reason": "error_recovery"} )) await asyncio.sleep(2.0) await self.transition_to(session_id, ConversationState.LISTENING) await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={"retry": True} )) async def _recover_from_tts_error(self, session_id: str): """Recover from TTS error""" # Skip TTS, go directly to listening await self.transition_to(session_id, ConversationState.LISTENING) await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) async def _recover_from_llm_error(self, session_id: str): """Recover from LLM error""" # Go back to listening await self.transition_to(session_id, ConversationState.LISTENING) await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) async def _recover_from_websocket_error(self, session_id: str): """Recover from WebSocket error""" # End session cleanly await self.transition_to(session_id, ConversationState.ENDED) await self.event_bus.publish(Event( type=EventType.SESSION_ENDED, session_id=session_id, data={"reason": "websocket_error"} ))