""" TTS Lifecycle Manager for Flare =============================== Manages TTS instances lifecycle per session """ import asyncio from typing import Dict, Optional, Any, List from datetime import datetime import traceback import base64 from chat_session.event_bus import EventBus, Event, EventType, publish_error from chat_session.resource_manager import ResourceManager, ResourceType from tts.tts_factory import TTSFactory from tts.tts_interface import TTSInterface from tts.tts_preprocessor import TTSPreprocessor from utils.logger import log_info, log_error, log_debug, log_warning class TTSJob: """TTS synthesis job""" def __init__(self, job_id: str, session_id: str, text: str, is_welcome: bool = False): self.job_id = job_id self.session_id = session_id self.text = text self.is_welcome = is_welcome self.created_at = datetime.utcnow() self.completed_at: Optional[datetime] = None self.audio_data: Optional[bytes] = None self.error: Optional[str] = None self.chunks_sent = 0 def complete(self, audio_data: bytes): """Mark job as completed""" self.audio_data = audio_data self.completed_at = datetime.utcnow() def fail(self, error: str): """Mark job as failed""" self.error = error self.completed_at = datetime.utcnow() class TTSSession: """TTS session wrapper""" def __init__(self, session_id: str, tts_instance: TTSInterface): self.session_id = session_id self.tts_instance = tts_instance self.preprocessor: Optional[TTSPreprocessor] = None self.active_jobs: Dict[str, TTSJob] = {} self.completed_jobs: List[TTSJob] = [] self.created_at = datetime.utcnow() self.last_activity = datetime.utcnow() self.total_jobs = 0 self.total_chars = 0 def update_activity(self): """Update last activity timestamp""" self.last_activity = datetime.utcnow() class TTSLifecycleManager: """Manages TTS instances lifecycle""" def __init__(self, event_bus: EventBus, resource_manager: ResourceManager): self.event_bus = event_bus self.resource_manager = resource_manager self.tts_sessions: Dict[str, TTSSession] = {} self.chunk_size = 16384 # 16KB chunks for base64 self._setup_event_handlers() self._setup_resource_pool() def _setup_event_handlers(self): """Subscribe to TTS-related events""" self.event_bus.subscribe(EventType.TTS_STARTED, self._handle_tts_start) self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) def _setup_resource_pool(self): """Setup TTS instance pool""" self.resource_manager.register_pool( resource_type=ResourceType.TTS_INSTANCE, factory=self._create_tts_instance, max_idle=3, max_age_seconds=600 # 10 minutes ) async def _create_tts_instance(self) -> Optional[TTSInterface]: """Factory for creating TTS instances""" try: tts_instance = TTSFactory.create_provider() if not tts_instance: log_warning("โš ๏ธ No TTS provider configured") return None log_debug("๐Ÿ”Š Created new TTS instance") return tts_instance except Exception as e: log_error(f"โŒ Failed to create TTS instance", error=str(e)) return None async def _handle_tts_start(self, event: Event): """Handle TTS synthesis request""" session_id = event.session_id text = event.data.get("text", "") is_welcome = event.data.get("is_welcome", False) if not text: log_warning(f"โš ๏ธ Empty text for TTS", session_id=session_id) return try: log_info( f"๐Ÿ”Š Starting TTS", session_id=session_id, text_length=len(text), is_welcome=is_welcome ) # Get or create session if session_id not in self.tts_sessions: # Acquire TTS instance from pool resource_id = f"tts_{session_id}" tts_instance = await self.resource_manager.acquire( resource_id=resource_id, session_id=session_id, resource_type=ResourceType.TTS_INSTANCE, cleanup_callback=self._cleanup_tts_instance ) if not tts_instance: # No TTS available await self._handle_no_tts(session_id, text, is_welcome) return # Create session tts_session = TTSSession(session_id, tts_instance) # Get locale from event data or default locale = event.data.get("locale", "tr") tts_session.preprocessor = TTSPreprocessor(language=locale) self.tts_sessions[session_id] = tts_session else: tts_session = self.tts_sessions[session_id] # Create job job_id = f"{session_id}_{tts_session.total_jobs}" job = TTSJob(job_id, session_id, text, is_welcome) tts_session.active_jobs[job_id] = job tts_session.total_jobs += 1 tts_session.total_chars += len(text) tts_session.update_activity() # Process TTS await self._process_tts_job(tts_session, job) except Exception as e: log_error( f"โŒ Failed to start TTS", session_id=session_id, error=str(e), traceback=traceback.format_exc() ) # Publish error event await publish_error( session_id=session_id, error_type="tts_error", error_message=f"Failed to synthesize speech: {str(e)}" ) async def _process_tts_job(self, tts_session: TTSSession, job: TTSJob): """Process a TTS job""" try: # Preprocess text processed_text = tts_session.preprocessor.preprocess( job.text, tts_session.tts_instance.get_preprocessing_flags() ) log_debug( f"๐Ÿ“ TTS preprocessed", session_id=job.session_id, original_length=len(job.text), processed_length=len(processed_text) ) # Synthesize audio audio_data = await tts_session.tts_instance.synthesize(processed_text) if not audio_data: raise ValueError("TTS returned empty audio data") job.complete(audio_data) log_info( f"โœ… TTS synthesis complete", session_id=job.session_id, audio_size=len(audio_data), duration_ms=(datetime.utcnow() - job.created_at).total_seconds() * 1000 ) # Stream audio chunks await self._stream_audio_chunks(tts_session, job) # Move to completed tts_session.active_jobs.pop(job.job_id, None) tts_session.completed_jobs.append(job) # Keep only last 10 completed jobs if len(tts_session.completed_jobs) > 10: tts_session.completed_jobs.pop(0) except Exception as e: job.fail(str(e)) # Handle specific TTS errors error_message = str(e) if "quota" in error_message.lower() or "limit" in error_message.lower(): log_error(f"โŒ TTS quota exceeded", session_id=job.session_id) await publish_error( session_id=job.session_id, error_type="tts_quota_exceeded", error_message="TTS service quota exceeded" ) else: log_error( f"โŒ TTS synthesis failed", session_id=job.session_id, error=error_message ) await publish_error( session_id=job.session_id, error_type="tts_error", error_message=error_message ) async def _stream_audio_chunks(self, tts_session: TTSSession, job: TTSJob): """Stream audio data as chunks""" if not job.audio_data: return # Convert to base64 audio_base64 = base64.b64encode(job.audio_data).decode('utf-8') total_length = len(audio_base64) total_chunks = (total_length + self.chunk_size - 1) // self.chunk_size log_debug( f"๐Ÿ“ค Streaming TTS audio", session_id=job.session_id, total_size=len(job.audio_data), base64_size=total_length, chunks=total_chunks ) # Stream chunks for i in range(0, total_length, self.chunk_size): chunk = audio_base64[i:i + self.chunk_size] chunk_index = i // self.chunk_size is_last = chunk_index == total_chunks - 1 await self.event_bus.publish(Event( type=EventType.TTS_CHUNK_READY, session_id=job.session_id, data={ "audio_data": chunk, "chunk_index": chunk_index, "total_chunks": total_chunks, "is_last": is_last, "mime_type": "audio/mpeg", "is_welcome": job.is_welcome }, priority=8 # Higher priority for audio chunks )) job.chunks_sent += 1 # Small delay between chunks to prevent overwhelming await asyncio.sleep(0.01) # Notify completion await self.event_bus.publish(Event( type=EventType.TTS_COMPLETED, session_id=job.session_id, data={ "job_id": job.job_id, "total_chunks": total_chunks, "is_welcome": job.is_welcome } )) log_info( f"โœ… TTS streaming complete", session_id=job.session_id, chunks_sent=job.chunks_sent ) async def _handle_no_tts(self, session_id: str, text: str, is_welcome: bool): """Handle case when TTS is not available""" log_warning(f"โš ๏ธ No TTS available, skipping audio generation", session_id=session_id) # Just notify completion without audio await self.event_bus.publish(Event( type=EventType.TTS_COMPLETED, session_id=session_id, data={ "no_audio": True, "text": text, "is_welcome": is_welcome } )) async def _handle_session_ended(self, event: Event): """Clean up TTS resources when session ends""" session_id = event.session_id await self._cleanup_session(session_id) async def _cleanup_session(self, session_id: str): """Clean up TTS session""" tts_session = self.tts_sessions.pop(session_id, None) if not tts_session: return try: # Cancel any active jobs for job in tts_session.active_jobs.values(): if not job.completed_at: job.fail("Session ended") # Release resource resource_id = f"tts_{session_id}" await self.resource_manager.release(resource_id, delay_seconds=120) log_info( f"๐Ÿงน TTS session cleaned up", session_id=session_id, total_jobs=tts_session.total_jobs, total_chars=tts_session.total_chars ) except Exception as e: log_error( f"โŒ Error cleaning up TTS session", session_id=session_id, error=str(e) ) async def _cleanup_tts_instance(self, tts_instance: TTSInterface): """Cleanup callback for TTS instance""" try: # TTS instances typically don't need special cleanup log_debug("๐Ÿงน TTS instance cleaned up") except Exception as e: log_error(f"โŒ Error cleaning up TTS instance", error=str(e)) def get_stats(self) -> Dict[str, Any]: """Get TTS manager statistics""" session_stats = {} for session_id, tts_session in self.tts_sessions.items(): session_stats[session_id] = { "active_jobs": len(tts_session.active_jobs), "completed_jobs": len(tts_session.completed_jobs), "total_jobs": tts_session.total_jobs, "total_chars": tts_session.total_chars, "uptime_seconds": (datetime.utcnow() - tts_session.created_at).total_seconds(), "last_activity": tts_session.last_activity.isoformat() } return { "active_sessions": len(self.tts_sessions), "total_active_jobs": sum(len(s.active_jobs) for s in self.tts_sessions.values()), "sessions": session_stats }