Spaces:
Building
Building
""" | |
State Orchestrator for Flare Realtime Chat | |
========================================== | |
Central state machine and flow control | |
""" | |
import asyncio | |
from typing import Dict, Optional, Set, Any | |
from enum import Enum | |
from datetime import datetime | |
import traceback | |
from dataclasses import dataclass, field | |
from .event_bus import EventBus, Event, EventType, publish_state_transition, publish_error | |
from .session import Session | |
from utils.logger import log_info, log_error, log_debug, log_warning | |
class ConversationState(Enum): | |
"""Conversation states""" | |
IDLE = "idle" | |
INITIALIZING = "initializing" | |
PREPARING_WELCOME = "preparing_welcome" | |
PLAYING_WELCOME = "playing_welcome" | |
LISTENING = "listening" | |
PROCESSING_SPEECH = "processing_speech" | |
PREPARING_RESPONSE = "preparing_response" | |
PLAYING_RESPONSE = "playing_response" | |
ERROR = "error" | |
ENDED = "ended" | |
class SessionContext: | |
"""Context for a conversation session""" | |
session_id: str | |
session: Session | |
state: ConversationState = ConversationState.IDLE | |
stt_instance: Optional[Any] = None | |
tts_instance: Optional[Any] = None | |
llm_context: Optional[Any] = None | |
audio_buffer: Optional[Any] = None | |
websocket_connection: Optional[Any] = None | |
created_at: datetime = field(default_factory=datetime.utcnow) | |
last_activity: datetime = field(default_factory=datetime.utcnow) | |
metadata: Dict[str, Any] = field(default_factory=dict) | |
def update_activity(self): | |
"""Update last activity timestamp""" | |
self.last_activity = datetime.utcnow() | |
async def cleanup(self): | |
"""Cleanup all session resources""" | |
# Cleanup will be implemented by resource managers | |
log_debug(f"🧹 Cleaning up session context", session_id=self.session_id) | |
class StateOrchestrator: | |
"""Central state machine for conversation flow""" | |
# Valid state transitions | |
VALID_TRANSITIONS = { | |
ConversationState.IDLE: {ConversationState.INITIALIZING}, | |
ConversationState.INITIALIZING: {ConversationState.PREPARING_WELCOME, ConversationState.LISTENING}, | |
ConversationState.PREPARING_WELCOME: {ConversationState.PLAYING_WELCOME, ConversationState.ERROR}, | |
ConversationState.PLAYING_WELCOME: {ConversationState.LISTENING, ConversationState.ERROR}, | |
ConversationState.LISTENING: {ConversationState.PROCESSING_SPEECH, ConversationState.ERROR, ConversationState.ENDED}, | |
ConversationState.PROCESSING_SPEECH: {ConversationState.PREPARING_RESPONSE, ConversationState.ERROR}, | |
ConversationState.PREPARING_RESPONSE: {ConversationState.PLAYING_RESPONSE, ConversationState.ERROR}, | |
ConversationState.PLAYING_RESPONSE: {ConversationState.LISTENING, ConversationState.ERROR}, | |
ConversationState.ERROR: {ConversationState.LISTENING, ConversationState.ENDED}, | |
ConversationState.ENDED: set() # No transitions from ENDED | |
} | |
def __init__(self, event_bus: EventBus): | |
self.event_bus = event_bus | |
self.sessions: Dict[str, SessionContext] = {} | |
self._setup_event_handlers() | |
def _setup_event_handlers(self): | |
"""Subscribe to relevant events""" | |
# Conversation events | |
self.event_bus.subscribe(EventType.CONVERSATION_STARTED, self._handle_conversation_started) | |
self.event_bus.subscribe(EventType.CONVERSATION_ENDED, self._handle_conversation_ended) | |
# Session lifecycle | |
self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started) | |
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) | |
# ✅ WebSocket events | |
self.event_bus.subscribe(EventType.WEBSOCKET_DISCONNECTED, self._handle_websocket_disconnected) | |
# STT events | |
self.event_bus.subscribe(EventType.STT_READY, self._handle_stt_ready) | |
self.event_bus.subscribe(EventType.STT_RESULT, self._handle_stt_result) | |
self.event_bus.subscribe(EventType.STT_ERROR, self._handle_stt_error) | |
# TTS events | |
self.event_bus.subscribe(EventType.TTS_COMPLETED, self._handle_tts_completed) | |
self.event_bus.subscribe(EventType.TTS_ERROR, self._handle_tts_error) | |
# Audio events | |
self.event_bus.subscribe(EventType.AUDIO_PLAYBACK_COMPLETED, self._handle_audio_playback_completed) | |
# LLM events | |
self.event_bus.subscribe(EventType.LLM_RESPONSE_READY, self._handle_llm_response_ready) | |
self.event_bus.subscribe(EventType.LLM_ERROR, self._handle_llm_error) | |
# Error events | |
self.event_bus.subscribe(EventType.CRITICAL_ERROR, self._handle_critical_error) | |
async def _handle_websocket_disconnected(self, event: Event): | |
"""Handle WebSocket disconnection""" | |
session_id = event.session_id | |
context = self.sessions.get(session_id) | |
if not context: | |
return | |
log_info(f"🔌 Handling WebSocket disconnect | session_id={session_id}, state={context.state.value}") | |
# Eğer conversation aktifse, önce conversation'ı sonlandır | |
if context.state not in [ConversationState.IDLE, ConversationState.ENDED]: | |
await self._handle_conversation_ended(Event( | |
type=EventType.CONVERSATION_ENDED, | |
session_id=session_id, | |
data={"reason": "websocket_disconnected"} | |
)) | |
async def _handle_conversation_started(self, event: Event) -> None: | |
"""Handle conversation start within existing session""" | |
session_id = event.session_id | |
context = self.sessions.get(session_id) | |
if not context: | |
log_error(f"❌ Session not found for conversation start | session_id={session_id}") | |
return | |
log_info(f"🎤 Conversation started | session_id={session_id}") | |
# İlk olarak IDLE'dan INITIALIZING'e geç | |
await self.transition_to(session_id, ConversationState.INITIALIZING) | |
# Welcome mesajı varsa | |
if context.metadata.get("has_welcome") and context.metadata.get("welcome_text"): | |
await self.transition_to(session_id, ConversationState.PREPARING_WELCOME) | |
# Request TTS for welcome message | |
await self.event_bus.publish(Event( | |
type=EventType.TTS_STARTED, | |
session_id=session_id, | |
data={ | |
"text": context.metadata.get("welcome_text", ""), | |
"is_welcome": True | |
} | |
)) | |
else: | |
# Welcome yoksa direkt LISTENING'e geç | |
await self.transition_to(session_id, ConversationState.LISTENING) | |
# Start STT | |
await self.event_bus.publish( | |
Event( | |
type=EventType.STT_STARTED, | |
data={}, | |
session_id=session_id | |
) | |
) | |
async def _handle_conversation_ended(self, event: Event) -> None: | |
"""Handle conversation end - but keep session alive""" | |
session_id = event.session_id | |
context = self.sessions.get(session_id) | |
if not context: | |
log_warning(f"⚠️ Session not found for conversation end | session_id={session_id}") | |
return | |
log_info(f"🔚 Conversation ended | session_id={session_id}") | |
# Stop STT if running | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STOPPED, | |
session_id=session_id, | |
data={"reason": "conversation_ended"} | |
)) | |
# Stop any ongoing TTS | |
await self.event_bus.publish(Event( | |
type=EventType.TTS_STOPPED, | |
session_id=session_id, | |
data={"reason": "conversation_ended"} | |
)) | |
# Transition back to IDLE - session still alive! | |
await self.transition_to(session_id, ConversationState.IDLE) | |
log_info(f"💤 Session back to IDLE, ready for new conversation | session_id={session_id}") | |
async def _handle_session_started(self, event: Event): | |
"""Handle session start""" | |
session_id = event.session_id | |
session_data = event.data | |
log_info(f"🎬 Session started", session_id=session_id) | |
# Create session context | |
context = SessionContext( | |
session_id=session_id, | |
session=session_data.get("session"), | |
metadata={ | |
"has_welcome": session_data.get("has_welcome", False), | |
"welcome_text": session_data.get("welcome_text", "") | |
} | |
) | |
self.sessions[session_id] = context | |
# Session başladığında IDLE state'te kalmalı | |
# Conversation başlayana kadar bekleyeceğiz | |
# Zaten SessionContext default state'i IDLE | |
log_info(f"📍 Session created in IDLE state | session_id={session_id}") | |
async def _handle_session_ended(self, event: Event): | |
"""Handle session end - complete cleanup""" | |
session_id = event.session_id | |
log_info(f"🏁 Session ended | session_id={session_id}") | |
# Get context for cleanup | |
context = self.sessions.get(session_id) | |
if context: | |
# Try to transition to ENDED if possible | |
try: | |
await self.transition_to(session_id, ConversationState.ENDED) | |
except Exception as e: | |
log_warning(f"Could not transition to ENDED state: {e}") | |
# Stop all components | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STOPPED, | |
session_id=session_id, | |
data={"reason": "session_ended"} | |
)) | |
await self.event_bus.publish(Event( | |
type=EventType.TTS_STOPPED, | |
session_id=session_id, | |
data={"reason": "session_ended"} | |
)) | |
# Cleanup session context | |
await context.cleanup() | |
# Remove session | |
self.sessions.pop(session_id, None) | |
# Clear event bus session data | |
self.event_bus.clear_session_data(session_id) | |
log_info(f"✅ Session fully cleaned up | session_id={session_id}") | |
async def _handle_stt_ready(self, event: Event): | |
"""Handle STT ready signal""" | |
session_id = event.session_id | |
current_state = self.get_state(session_id) | |
log_debug(f"🎤 STT ready", session_id=session_id, current_state=current_state) | |
# Only process if we're expecting STT to be ready | |
if current_state in [ConversationState.LISTENING, ConversationState.PLAYING_WELCOME]: | |
# STT is ready, we're already in the right state | |
pass | |
async def _handle_stt_result(self, event: Event): | |
"""Handle STT transcription result""" | |
session_id = event.session_id | |
context = self.sessions.get(session_id) | |
if not context: | |
return | |
current_state = context.state | |
result_data = event.data | |
# Batch mode'da her zaman final result gelir | |
text = result_data.get("text", "").strip() | |
if current_state != ConversationState.LISTENING: | |
log_warning( | |
f"⚠️ STT result in unexpected state", | |
session_id=session_id, | |
state=current_state.value | |
) | |
return | |
if text: | |
log_info(f"💬 Transcription: '{text}'", session_id=session_id) | |
else: | |
log_info(f"💬 No speech detected", session_id=session_id) | |
# Boş transcript'te STT'yi yeniden başlat | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STARTED, | |
session_id=session_id, | |
data={} | |
)) | |
return | |
# Transition to processing | |
await self.transition_to(session_id, ConversationState.PROCESSING_SPEECH) | |
# Send to LLM | |
await self.event_bus.publish(Event( | |
type=EventType.LLM_PROCESSING_STARTED, | |
session_id=session_id, | |
data={"text": text} | |
)) | |
async def _handle_llm_response_ready(self, event: Event): | |
"""Handle LLM response""" | |
session_id = event.session_id | |
current_state = self.get_state(session_id) | |
# Debug için context'i direkt alalım | |
context = self.sessions.get(session_id) | |
if not context: | |
log_error(f"❌ No context for LLM response", session_id=session_id) | |
return | |
current_state = context.state | |
log_debug(f"🔍 LLM response handler - current state: {current_state.value}") | |
if current_state != ConversationState.PROCESSING_SPEECH: | |
log_warning( | |
f"⚠️ LLM response in unexpected state", | |
session_id=session_id, | |
state=current_state | |
) | |
return | |
response_text = event.data.get("text", "") | |
log_info(f"🤖 LLM response ready", session_id=session_id, length=len(response_text)) | |
# Transition to preparing response | |
await self.transition_to(session_id, ConversationState.PREPARING_RESPONSE) | |
# Request TTS | |
await self.event_bus.publish(Event( | |
type=EventType.TTS_STARTED, | |
session_id=session_id, | |
data={"text": response_text} | |
)) | |
async def _handle_tts_completed(self, event: Event): | |
"""Handle TTS completion""" | |
session_id = event.session_id | |
context = self.sessions.get(session_id) | |
if not context: | |
return | |
current_state = context.state | |
log_info(f"🔊 TTS completed", session_id=session_id, state=current_state.value) | |
if current_state == ConversationState.PREPARING_WELCOME: | |
await self.transition_to(session_id, ConversationState.PLAYING_WELCOME) | |
# Welcome audio frontend'te çalınacak, biz sadece state'i güncelliyoruz | |
# Frontend audio bitince bize audio_playback_completed gönderecek | |
elif current_state == ConversationState.PREPARING_RESPONSE: | |
await self.transition_to(session_id, ConversationState.PLAYING_RESPONSE) | |
async def _handle_audio_playback_completed(self, event: Event): | |
"""Handle audio playback completion""" | |
session_id = event.session_id | |
context = self.sessions.get(session_id) | |
if not context: | |
return | |
current_state = context.state | |
log_info(f"🎵 Audio playback completed", session_id=session_id, state=current_state.value) | |
if current_state in [ConversationState.PLAYING_WELCOME, ConversationState.PLAYING_RESPONSE]: | |
# Transition to listening | |
await self.transition_to(session_id, ConversationState.LISTENING) | |
# ✅ STT'yi başlat - batch mode için yeni utterance | |
locale = context.metadata.get("locale", "tr") | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STARTED, | |
session_id=session_id, | |
data={ | |
"locale": locale, | |
"batch_mode": True, # ✅ Batch mode aktif | |
"silence_threshold_ms": 1000 # 1 saniye sessizlik | |
} | |
)) | |
# Send STT ready signal to frontend | |
await self.event_bus.publish(Event( | |
type=EventType.STT_READY, | |
session_id=session_id, | |
data={} | |
)) | |
async def _handle_stt_error(self, event: Event): | |
"""Handle STT errors""" | |
session_id = event.session_id | |
error_data = event.data | |
log_error( | |
f"❌ STT error", | |
session_id=session_id, | |
error=error_data.get("message") | |
) | |
# Try to recover by transitioning back to listening | |
current_state = self.get_state(session_id) | |
if current_state != ConversationState.ENDED: | |
await self.transition_to(session_id, ConversationState.ERROR) | |
# Try recovery after delay | |
await asyncio.sleep(2.0) | |
if self.get_state(session_id) == ConversationState.ERROR: | |
await self.transition_to(session_id, ConversationState.LISTENING) | |
# Restart STT | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STARTED, | |
session_id=session_id, | |
data={"retry": True} | |
)) | |
async def _handle_tts_error(self, event: Event): | |
"""Handle TTS errors""" | |
session_id = event.session_id | |
error_data = event.data | |
log_error( | |
f"❌ TTS error", | |
session_id=session_id, | |
error=error_data.get("message") | |
) | |
# Skip TTS and go to listening | |
current_state = self.get_state(session_id) | |
if current_state in [ConversationState.PREPARING_WELCOME, ConversationState.PREPARING_RESPONSE]: | |
await self.transition_to(session_id, ConversationState.LISTENING) | |
# Start STT | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STARTED, | |
session_id=session_id, | |
data={} | |
)) | |
async def _handle_llm_error(self, event: Event): | |
"""Handle LLM errors""" | |
session_id = event.session_id | |
error_data = event.data | |
log_error( | |
f"❌ LLM error", | |
session_id=session_id, | |
error=error_data.get("message") | |
) | |
# Go back to listening | |
await self.transition_to(session_id, ConversationState.LISTENING) | |
# Start STT | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STARTED, | |
session_id=session_id, | |
data={} | |
)) | |
async def _handle_critical_error(self, event: Event): | |
"""Handle critical errors""" | |
session_id = event.session_id | |
error_data = event.data | |
log_error( | |
f"💥 Critical error", | |
session_id=session_id, | |
error=error_data.get("message") | |
) | |
# End session | |
await self.transition_to(session_id, ConversationState.ENDED) | |
# Publish session end event | |
await self.event_bus.publish(Event( | |
type=EventType.SESSION_ENDED, | |
session_id=session_id, | |
data={"reason": "critical_error"} | |
)) | |
async def transition_to(self, session_id: str, new_state: ConversationState) -> bool: | |
""" | |
Transition to a new state with validation | |
""" | |
try: | |
# Get session context | |
context = self.sessions.get(session_id) | |
if not context: | |
log_info(f"❌ Session not found for state transition | session_id={session_id}") | |
return False | |
# Get current state from context | |
current_state = context.state | |
# Check if transition is valid | |
if new_state not in self.VALID_TRANSITIONS.get(current_state, set()): | |
log_info(f"❌ Invalid state transition | session_id={session_id}, current={current_state.value}, requested={new_state.value}") | |
return False | |
# Update state | |
old_state = current_state | |
context.state = new_state | |
context.last_activity = datetime.utcnow() | |
log_info(f"✅ State transition | session_id={session_id}, {old_state.value} → {new_state.value}") | |
# Emit state transition event with correct field names | |
await self.event_bus.publish( | |
Event( | |
type=EventType.STATE_TRANSITION, | |
data={ | |
"old_state": old_state.value, # Backend uses old_state/new_state | |
"new_state": new_state.value, | |
"timestamp": datetime.utcnow().isoformat() | |
}, | |
session_id=session_id | |
) | |
) | |
return True | |
except Exception as e: | |
log_error(f"❌ State transition error | session_id={session_id}", e) | |
return False | |
def get_state(self, session_id: str) -> Optional[ConversationState]: | |
"""Get current state for a session""" | |
return self.sessions.get(session_id) | |
def get_session_data(self, session_id: str) -> Optional[Dict[str, Any]]: | |
"""Get session data""" | |
return self.session_data.get(session_id) | |
async def handle_error_recovery(self, session_id: str, error_type: str): | |
"""Handle error recovery strategies""" | |
context = self.sessions.get(session_id) | |
if not context or context.state == ConversationState.ENDED: | |
return | |
log_info( | |
f"🔧 Attempting error recovery", | |
session_id=session_id, | |
error_type=error_type, | |
current_state=context.state.value | |
) | |
# Update activity | |
context.update_activity() | |
# Define recovery strategies | |
recovery_strategies = { | |
"stt_error": self._recover_from_stt_error, | |
"tts_error": self._recover_from_tts_error, | |
"llm_error": self._recover_from_llm_error, | |
"websocket_error": self._recover_from_websocket_error | |
} | |
strategy = recovery_strategies.get(error_type) | |
if strategy: | |
await strategy(session_id) | |
else: | |
# Default recovery: go to error state then back to listening | |
await self.transition_to(session_id, ConversationState.ERROR) | |
await asyncio.sleep(1.0) | |
await self.transition_to(session_id, ConversationState.LISTENING) | |
async def _recover_from_stt_error(self, session_id: str): | |
"""Recover from STT error""" | |
# Stop STT, wait, restart | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STOPPED, | |
session_id=session_id, | |
data={"reason": "error_recovery"} | |
)) | |
await asyncio.sleep(2.0) | |
await self.transition_to(session_id, ConversationState.LISTENING) | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STARTED, | |
session_id=session_id, | |
data={"retry": True} | |
)) | |
async def _recover_from_tts_error(self, session_id: str): | |
"""Recover from TTS error""" | |
# Skip TTS, go directly to listening | |
await self.transition_to(session_id, ConversationState.LISTENING) | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STARTED, | |
session_id=session_id, | |
data={} | |
)) | |
async def _recover_from_llm_error(self, session_id: str): | |
"""Recover from LLM error""" | |
# Go back to listening | |
await self.transition_to(session_id, ConversationState.LISTENING) | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STARTED, | |
session_id=session_id, | |
data={} | |
)) | |
async def _recover_from_websocket_error(self, session_id: str): | |
"""Recover from WebSocket error""" | |
# End session cleanly | |
await self.transition_to(session_id, ConversationState.ENDED) | |
await self.event_bus.publish(Event( | |
type=EventType.SESSION_ENDED, | |
session_id=session_id, | |
data={"reason": "websocket_error"} | |
)) |