flare / tts /tts_lifecycle_manager.py
ciyidogan's picture
Upload tts_lifecycle_manager.py
114bc80 verified
"""
TTS Lifecycle Manager for Flare
===============================
Manages TTS instances lifecycle per session
"""
import asyncio
from typing import Dict, Optional, Any, List
from datetime import datetime
import traceback
import base64
from chat_session.event_bus import EventBus, Event, EventType, publish_error
from chat_session.resource_manager import ResourceManager, ResourceType
from tts.tts_factory import TTSFactory
from tts.tts_interface import TTSInterface
from tts.tts_preprocessor import TTSPreprocessor
from utils.logger import log_info, log_error, log_debug, log_warning
class TTSJob:
"""TTS synthesis job"""
def __init__(self, job_id: str, session_id: str, text: str, is_welcome: bool = False):
self.job_id = job_id
self.session_id = session_id
self.text = text
self.is_welcome = is_welcome
self.created_at = datetime.utcnow()
self.completed_at: Optional[datetime] = None
self.audio_data: Optional[bytes] = None
self.error: Optional[str] = None
self.chunks_sent = 0
def complete(self, audio_data: bytes):
"""Mark job as completed"""
self.audio_data = audio_data
self.completed_at = datetime.utcnow()
def fail(self, error: str):
"""Mark job as failed"""
self.error = error
self.completed_at = datetime.utcnow()
class TTSSession:
"""TTS session wrapper"""
def __init__(self, session_id: str, tts_instance: TTSInterface):
self.session_id = session_id
self.tts_instance = tts_instance
self.preprocessor: Optional[TTSPreprocessor] = None
self.active_jobs: Dict[str, TTSJob] = {}
self.completed_jobs: List[TTSJob] = []
self.created_at = datetime.utcnow()
self.last_activity = datetime.utcnow()
self.total_jobs = 0
self.total_chars = 0
def update_activity(self):
"""Update last activity timestamp"""
self.last_activity = datetime.utcnow()
class TTSLifecycleManager:
"""Manages TTS instances lifecycle"""
def __init__(self, event_bus: EventBus, resource_manager: ResourceManager):
self.event_bus = event_bus
self.resource_manager = resource_manager
self.tts_sessions: Dict[str, TTSSession] = {}
self.chunk_size = 16384 # 16KB chunks for base64
self._setup_event_handlers()
self._setup_resource_pool()
def _setup_event_handlers(self):
"""Subscribe to TTS-related events"""
self.event_bus.subscribe(EventType.TTS_STARTED, self._handle_tts_start)
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
def _setup_resource_pool(self):
"""Setup TTS instance pool"""
self.resource_manager.register_pool(
resource_type=ResourceType.TTS_INSTANCE,
factory=self._create_tts_instance,
max_idle=3,
max_age_seconds=600 # 10 minutes
)
async def _create_tts_instance(self) -> Optional[TTSInterface]:
"""Factory for creating TTS instances"""
try:
tts_instance = TTSFactory.create_provider()
if not tts_instance:
log_warning("⚠️ No TTS provider configured")
return None
log_debug("πŸ”Š Created new TTS instance")
return tts_instance
except Exception as e:
log_error(f"❌ Failed to create TTS instance", error=str(e))
return None
async def _handle_tts_start(self, event: Event):
"""Handle TTS synthesis request"""
session_id = event.session_id
text = event.data.get("text", "")
is_welcome = event.data.get("is_welcome", False)
if not text:
log_warning(f"⚠️ Empty text for TTS", session_id=session_id)
return
try:
log_info(
f"πŸ”Š Starting TTS",
session_id=session_id,
text_length=len(text),
is_welcome=is_welcome
)
# Get or create session
if session_id not in self.tts_sessions:
# Acquire TTS instance from pool
resource_id = f"tts_{session_id}"
tts_instance = await self.resource_manager.acquire(
resource_id=resource_id,
session_id=session_id,
resource_type=ResourceType.TTS_INSTANCE,
cleanup_callback=self._cleanup_tts_instance
)
if not tts_instance:
# No TTS available
await self._handle_no_tts(session_id, text, is_welcome)
return
# Create session
tts_session = TTSSession(session_id, tts_instance)
# Get locale from event data or default
locale = event.data.get("locale", "tr")
tts_session.preprocessor = TTSPreprocessor(language=locale)
self.tts_sessions[session_id] = tts_session
else:
tts_session = self.tts_sessions[session_id]
# Create job
job_id = f"{session_id}_{tts_session.total_jobs}"
job = TTSJob(job_id, session_id, text, is_welcome)
tts_session.active_jobs[job_id] = job
tts_session.total_jobs += 1
tts_session.total_chars += len(text)
tts_session.update_activity()
# Process TTS
await self._process_tts_job(tts_session, job)
except Exception as e:
log_error(
f"❌ Failed to start TTS",
session_id=session_id,
error=str(e),
traceback=traceback.format_exc()
)
# Publish error event
await publish_error(
session_id=session_id,
error_type="tts_error",
error_message=f"Failed to synthesize speech: {str(e)}"
)
async def _process_tts_job(self, tts_session: TTSSession, job: TTSJob):
"""Process a TTS job"""
try:
# Preprocess text
processed_text = tts_session.preprocessor.preprocess(
job.text,
tts_session.tts_instance.get_preprocessing_flags()
)
log_debug(
f"πŸ“ TTS preprocessed",
session_id=job.session_id,
original_length=len(job.text),
processed_length=len(processed_text)
)
# Synthesize audio
audio_data = await tts_session.tts_instance.synthesize(processed_text)
if not audio_data:
raise ValueError("TTS returned empty audio data")
job.complete(audio_data)
log_info(
f"βœ… TTS synthesis complete",
session_id=job.session_id,
audio_size=len(audio_data),
duration_ms=(datetime.utcnow() - job.created_at).total_seconds() * 1000
)
# Stream audio chunks
await self._stream_audio_chunks(tts_session, job)
# Move to completed
tts_session.active_jobs.pop(job.job_id, None)
tts_session.completed_jobs.append(job)
# Keep only last 10 completed jobs
if len(tts_session.completed_jobs) > 10:
tts_session.completed_jobs.pop(0)
except Exception as e:
job.fail(str(e))
# Handle specific TTS errors
error_message = str(e)
if "quota" in error_message.lower() or "limit" in error_message.lower():
log_error(f"❌ TTS quota exceeded", session_id=job.session_id)
await publish_error(
session_id=job.session_id,
error_type="tts_quota_exceeded",
error_message="TTS service quota exceeded"
)
else:
log_error(
f"❌ TTS synthesis failed",
session_id=job.session_id,
error=error_message
)
await publish_error(
session_id=job.session_id,
error_type="tts_error",
error_message=error_message
)
async def _stream_audio_chunks(self, tts_session: TTSSession, job: TTSJob):
"""Stream audio data as chunks"""
if not job.audio_data:
return
# Convert to base64
audio_base64 = base64.b64encode(job.audio_data).decode('utf-8')
total_length = len(audio_base64)
total_chunks = (total_length + self.chunk_size - 1) // self.chunk_size
log_debug(
f"πŸ“€ Streaming TTS audio",
session_id=job.session_id,
total_size=len(job.audio_data),
base64_size=total_length,
chunks=total_chunks
)
# Stream chunks
for i in range(0, total_length, self.chunk_size):
chunk = audio_base64[i:i + self.chunk_size]
chunk_index = i // self.chunk_size
is_last = chunk_index == total_chunks - 1
await self.event_bus.publish(Event(
type=EventType.TTS_CHUNK_READY,
session_id=job.session_id,
data={
"audio_data": chunk,
"chunk_index": chunk_index,
"total_chunks": total_chunks,
"is_last": is_last,
"mime_type": "audio/mpeg",
"is_welcome": job.is_welcome
},
priority=8 # Higher priority for audio chunks
))
job.chunks_sent += 1
# Small delay between chunks to prevent overwhelming
await asyncio.sleep(0.01)
# Notify completion
await self.event_bus.publish(Event(
type=EventType.TTS_COMPLETED,
session_id=job.session_id,
data={
"job_id": job.job_id,
"total_chunks": total_chunks,
"is_welcome": job.is_welcome
}
))
log_info(
f"βœ… TTS streaming complete",
session_id=job.session_id,
chunks_sent=job.chunks_sent
)
async def _handle_no_tts(self, session_id: str, text: str, is_welcome: bool):
"""Handle case when TTS is not available"""
log_warning(f"⚠️ No TTS available, skipping audio generation", session_id=session_id)
# Just notify completion without audio
await self.event_bus.publish(Event(
type=EventType.TTS_COMPLETED,
session_id=session_id,
data={
"no_audio": True,
"text": text,
"is_welcome": is_welcome
}
))
async def _handle_session_ended(self, event: Event):
"""Clean up TTS resources when session ends"""
session_id = event.session_id
await self._cleanup_session(session_id)
async def _cleanup_session(self, session_id: str):
"""Clean up TTS session"""
tts_session = self.tts_sessions.pop(session_id, None)
if not tts_session:
return
try:
# Cancel any active jobs
for job in tts_session.active_jobs.values():
if not job.completed_at:
job.fail("Session ended")
# Release resource
resource_id = f"tts_{session_id}"
await self.resource_manager.release(resource_id, delay_seconds=120)
log_info(
f"🧹 TTS session cleaned up",
session_id=session_id,
total_jobs=tts_session.total_jobs,
total_chars=tts_session.total_chars
)
except Exception as e:
log_error(
f"❌ Error cleaning up TTS session",
session_id=session_id,
error=str(e)
)
async def _cleanup_tts_instance(self, tts_instance: TTSInterface):
"""Cleanup callback for TTS instance"""
try:
# TTS instances typically don't need special cleanup
log_debug("🧹 TTS instance cleaned up")
except Exception as e:
log_error(f"❌ Error cleaning up TTS instance", error=str(e))
def get_stats(self) -> Dict[str, Any]:
"""Get TTS manager statistics"""
session_stats = {}
for session_id, tts_session in self.tts_sessions.items():
session_stats[session_id] = {
"active_jobs": len(tts_session.active_jobs),
"completed_jobs": len(tts_session.completed_jobs),
"total_jobs": tts_session.total_jobs,
"total_chars": tts_session.total_chars,
"uptime_seconds": (datetime.utcnow() - tts_session.created_at).total_seconds(),
"last_activity": tts_session.last_activity.isoformat()
}
return {
"active_sessions": len(self.tts_sessions),
"total_active_jobs": sum(len(s.active_jobs) for s in self.tts_sessions.values()),
"sessions": session_stats
}