""" State Orchestrator for Flare Realtime Chat ========================================== Central state machine and flow control """ import asyncio from typing import Dict, Optional, Set, Any from enum import Enum from datetime import datetime import traceback from dataclasses import dataclass, field from event_bus import EventBus, Event, EventType, publish_state_transition, publish_error from session import Session from utils.logger import log_info, log_error, log_debug, log_warning class ConversationState(Enum): """Conversation states""" IDLE = "idle" INITIALIZING = "initializing" PREPARING_WELCOME = "preparing_welcome" PLAYING_WELCOME = "playing_welcome" LISTENING = "listening" PROCESSING_SPEECH = "processing_speech" PREPARING_RESPONSE = "preparing_response" PLAYING_RESPONSE = "playing_response" ERROR = "error" ENDED = "ended" @dataclass class SessionContext: """Context for a conversation session""" session_id: str session: Session state: ConversationState = ConversationState.IDLE stt_instance: Optional[Any] = None tts_instance: Optional[Any] = None llm_context: Optional[Any] = None audio_buffer: Optional[Any] = None websocket_connection: Optional[Any] = None created_at: datetime = field(default_factory=datetime.utcnow) last_activity: datetime = field(default_factory=datetime.utcnow) metadata: Dict[str, Any] = field(default_factory=dict) def update_activity(self): """Update last activity timestamp""" self.last_activity = datetime.utcnow() async def cleanup(self): """Cleanup all session resources""" # Cleanup will be implemented by resource managers log_debug(f"๐Ÿงน Cleaning up session context", session_id=self.session_id) class StateOrchestrator: """Central state machine for conversation flow""" # Valid state transitions VALID_TRANSITIONS = { ConversationState.IDLE: {ConversationState.INITIALIZING}, ConversationState.INITIALIZING: {ConversationState.PREPARING_WELCOME, ConversationState.LISTENING}, ConversationState.PREPARING_WELCOME: {ConversationState.PLAYING_WELCOME, ConversationState.ERROR}, ConversationState.PLAYING_WELCOME: {ConversationState.LISTENING, ConversationState.ERROR}, ConversationState.LISTENING: {ConversationState.PROCESSING_SPEECH, ConversationState.ERROR, ConversationState.ENDED}, ConversationState.PROCESSING_SPEECH: {ConversationState.PREPARING_RESPONSE, ConversationState.ERROR}, ConversationState.PREPARING_RESPONSE: {ConversationState.PLAYING_RESPONSE, ConversationState.ERROR}, ConversationState.PLAYING_RESPONSE: {ConversationState.LISTENING, ConversationState.ERROR}, ConversationState.ERROR: {ConversationState.LISTENING, ConversationState.ENDED}, ConversationState.ENDED: set() # No transitions from ENDED } def __init__(self, event_bus: EventBus): self.event_bus = event_bus self.sessions: Dict[str, SessionContext] = {} self._setup_event_handlers() def _setup_event_handlers(self): """Subscribe to relevant events""" # Session lifecycle self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started) self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) # STT events self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready) self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result) self.event_bus.subscribe(EventType.STT_ERROR, self._handle_stt_error) # TTS events self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed) self.event_bus.subscribe(EventType.TTS_ERROR, self._handle_tts_error) # Audio events self.event_bus.subscribe(EventType.AUDIO_PLAYBACK_COMPLETED, self._handle_audio_playback_completed) # LLM events self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response_ready) self.event_bus.subscribe(EventType.LLM_ERROR, self._handle_llm_error) # Error events self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_critical_error) async def _handle_session_started(self, event: Event): """Handle session start""" session_id = event.session_id session_data = event.data log_info(f"๐ŸŽฌ Session started", session_id=session_id) # Create session context context = SessionContext( session_id=session_id, session=session_data.get("session"), metadata={ "has_welcome": session_data.get("has_welcome", False), "welcome_text": session_data.get("welcome_text", "") } ) self.sessions[session_id] = context # Transition to INITIALIZING await self.transition_to(session_id, ConversationState.INITIALIZING) # Check if welcome prompt exists if session_data.get("has_welcome"): await self.transition_to(session_id, ConversationState.PREPARING_WELCOME) # Request TTS for welcome message await self.event_bus.publish(Event( type=EventType.TTS_STARTED, session_id=session_id, data={ "text": session_data.get("welcome_text", ""), "is_welcome": True } )) else: # No welcome, go straight to listening await self.transition_to(session_id, ConversationState.LISTENING) # Request STT start await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) async def _handle_session_ended(self, event: Event): """Handle session end""" session_id = event.session_id log_info(f"๐Ÿ Session ended", session_id=session_id) # Get context for cleanup context = self.sessions.get(session_id) # Transition to ended await self.transition_to(session_id, ConversationState.ENDED) # Stop all components await self.event_bus.publish(Event( type=EventType.STT_STOPPED, session_id=session_id, data={"reason": "session_ended"} )) # Cleanup session context if context: await context.cleanup() # Remove session self.sessions.pop(session_id, None) # Clear event bus session data self.event_bus.clear_session_data(session_id) async def _handle_stt_ready(self, event: Event): """Handle STT ready signal""" session_id = event.session_id current_state = self.get_state(session_id) log_debug(f"๐ŸŽค STT ready", session_id=session_id, current_state=current_state) # Only process if we're expecting STT to be ready if current_state in [ConversationState.LISTENING, ConversationState.PLAYING_WELCOME]: # STT is ready, we're already in the right state pass async def _handle_stt_result(self, event: Event): """Handle STT transcription result""" session_id = event.session_id current_state = self.get_state(session_id) if current_state != ConversationState.LISTENING: log_warning( f"โš ๏ธ STT result in unexpected state", session_id=session_id, state=current_state ) return result_data = event.data is_final = result_data.get("is_final", False) if is_final: text = result_data.get("text", "") log_info(f"๐Ÿ’ฌ Final transcription: '{text}'", session_id=session_id) # Stop STT await self.event_bus.publish(Event( type=EventType.STT_STOPPED, session_id=session_id, data={"reason": "final_result"} )) # Transition to processing await self.transition_to(session_id, ConversationState.PROCESSING_SPEECH) # Send to LLM await self.event_bus.publish(Event( type=EventType.LLM_PROCESSING_STARTED, session_id=session_id, data={"text": text} )) async def _handle_llm_response_ready(self, event: Event): """Handle LLM response""" session_id = event.session_id current_state = self.get_state(session_id) if current_state != ConversationState.PROCESSING_SPEECH: log_warning( f"โš ๏ธ LLM response in unexpected state", session_id=session_id, state=current_state ) return response_text = event.data.get("text", "") log_info(f"๐Ÿค– LLM response ready", session_id=session_id, length=len(response_text)) # Transition to preparing response await self.transition_to(session_id, ConversationState.PREPARING_RESPONSE) # Request TTS await self.event_bus.publish(Event( type=EventType.TTS_STARTED, session_id=session_id, data={"text": response_text} )) async def _handle_tts_completed(self, event: Event): """Handle TTS completion""" session_id = event.session_id current_state = self.get_state(session_id) log_info(f"๐Ÿ”Š TTS completed", session_id=session_id, state=current_state) if current_state == ConversationState.PREPARING_WELCOME: await self.transition_to(session_id, ConversationState.PLAYING_WELCOME) elif current_state == ConversationState.PREPARING_RESPONSE: await self.transition_to(session_id, ConversationState.PLAYING_RESPONSE) async def _handle_audio_playback_completed(self, event: Event): """Handle audio playback completion""" session_id = event.session_id current_state = self.get_state(session_id) log_info(f"๐ŸŽต Audio playback completed", session_id=session_id, state=current_state) if current_state in [ConversationState.PLAYING_WELCOME, ConversationState.PLAYING_RESPONSE]: # Transition back to listening await self.transition_to(session_id, ConversationState.LISTENING) # Start STT await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) async def _handle_stt_error(self, event: Event): """Handle STT errors""" session_id = event.session_id error_data = event.data log_error( f"โŒ STT error", session_id=session_id, error=error_data.get("message") ) # Try to recover by transitioning back to listening current_state = self.get_state(session_id) if current_state != ConversationState.ENDED: await self.transition_to(session_id, ConversationState.ERROR) # Try recovery after delay await asyncio.sleep(2.0) if self.get_state(session_id) == ConversationState.ERROR: await self.transition_to(session_id, ConversationState.LISTENING) # Restart STT await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={"retry": True} )) async def _handle_tts_error(self, event: Event): """Handle TTS errors""" session_id = event.session_id error_data = event.data log_error( f"โŒ TTS error", session_id=session_id, error=error_data.get("message") ) # Skip TTS and go to listening current_state = self.get_state(session_id) if current_state in [ConversationState.PREPARING_WELCOME, ConversationState.PREPARING_RESPONSE]: await self.transition_to(session_id, ConversationState.LISTENING) # Start STT await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) async def _handle_llm_error(self, event: Event): """Handle LLM errors""" session_id = event.session_id error_data = event.data log_error( f"โŒ LLM error", session_id=session_id, error=error_data.get("message") ) # Go back to listening await self.transition_to(session_id, ConversationState.LISTENING) # Start STT await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) async def _handle_critical_error(self, event: Event): """Handle critical errors""" session_id = event.session_id error_data = event.data log_error( f"๐Ÿ’ฅ Critical error", session_id=session_id, error=error_data.get("message") ) # End session await self.transition_to(session_id, ConversationState.ENDED) # Publish session end event await self.event_bus.publish(Event( type=EventType.SESSION_ENDED, session_id=session_id, data={"reason": "critical_error"} )) async def transition_to(self, session_id: str, new_state: ConversationState): """Transition to a new state""" current_state = self.get_state(session_id) if current_state is None: log_warning(f"โš ๏ธ Session not found for transition", session_id=session_id) return # Check if transition is valid if new_state not in self.VALID_TRANSITIONS.get(current_state, set()): log_error( f"โŒ Invalid state transition", session_id=session_id, from_state=current_state.value, to_state=new_state.value ) await publish_error( session_id=session_id, error_type="invalid_transition", error_message=f"Cannot transition from {current_state.value} to {new_state.value}" ) return # Update state self.sessions[session_id] = new_state log_info( f"๐Ÿ”„ State transition", session_id=session_id, from_state=current_state.value, to_state=new_state.value ) # Publish state transition event await publish_state_transition( session_id=session_id, from_state=current_state.value, to_state=new_state.value ) def get_state(self, session_id: str) -> Optional[ConversationState]: """Get current state for a session""" return self.sessions.get(session_id) def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: """Get session data""" return self.session_data.get(session_id) async def handle_error_recovery(self, session_id: str, error_type: str): """Handle error recovery strategies""" context = self.sessions.get(session_id) if not context or context.state == ConversationState.ENDED: return log_info( f"๐Ÿ”ง Attempting error recovery", session_id=session_id, error_type=error_type, current_state=context.state.value ) # Update activity context.update_activity() # Define recovery strategies recovery_strategies = { "stt_error": self._recover_from_stt_error, "tts_error": self._recover_from_tts_error, "llm_error": self._recover_from_llm_error, "websocket_error": self._recover_from_websocket_error } strategy = recovery_strategies.get(error_type) if strategy: await strategy(session_id) else: # Default recovery: go to error state then back to listening await self.transition_to(session_id, ConversationState.ERROR) await asyncio.sleep(1.0) await self.transition_to(session_id, ConversationState.LISTENING) async def _recover_from_stt_error(self, session_id: str): """Recover from STT error""" # Stop STT, wait, restart await self.event_bus.publish(Event( type=EventType.STT_STOPPED, session_id=session_id, data={"reason": "error_recovery"} )) await asyncio.sleep(2.0) await self.transition_to(session_id, ConversationState.LISTENING) await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={"retry": True} )) async def _recover_from_tts_error(self, session_id: str): """Recover from TTS error""" # Skip TTS, go directly to listening await self.transition_to(session_id, ConversationState.LISTENING) await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) async def _recover_from_llm_error(self, session_id: str): """Recover from LLM error""" # Go back to listening await self.transition_to(session_id, ConversationState.LISTENING) await self.event_bus.publish(Event( type=EventType.STT_STARTED, session_id=session_id, data={} )) async def _recover_from_websocket_error(self, session_id: str): """Recover from WebSocket error""" # End session cleanly await self.transition_to(session_id, ConversationState.ENDED) await self.event_bus.publish(Event( type=EventType.SESSION_ENDED, session_id=session_id, data={"reason": "websocket_error"} ))