Spaces:
Building
Building
""" | |
TTS Lifecycle Manager for Flare | |
=============================== | |
Manages TTS instances lifecycle per session | |
""" | |
import asyncio | |
from typing import Dict, Optional, Any, List | |
from datetime import datetime | |
import traceback | |
import base64 | |
from chat_session.event_bus import EventBus, Event, EventType, publish_error | |
from chat_session.resource_manager import ResourceManager, ResourceType | |
from tts.tts_factory import TTSFactory | |
from tts.tts_interface import TTSInterface | |
from tts.tts_preprocessor import TTSPreprocessor | |
from utils.logger import log_info, log_error, log_debug, log_warning | |
class TTSJob: | |
"""TTS synthesis job""" | |
def __init__(self, job_id: str, session_id: str, text: str, is_welcome: bool = False): | |
self.job_id = job_id | |
self.session_id = session_id | |
self.text = text | |
self.is_welcome = is_welcome | |
self.created_at = datetime.utcnow() | |
self.completed_at: Optional[datetime] = None | |
self.audio_data: Optional[bytes] = None | |
self.error: Optional[str] = None | |
self.chunks_sent = 0 | |
def complete(self, audio_data: bytes): | |
"""Mark job as completed""" | |
self.audio_data = audio_data | |
self.completed_at = datetime.utcnow() | |
def fail(self, error: str): | |
"""Mark job as failed""" | |
self.error = error | |
self.completed_at = datetime.utcnow() | |
class TTSSession: | |
"""TTS session wrapper""" | |
def __init__(self, session_id: str, tts_instance: TTSInterface): | |
self.session_id = session_id | |
self.tts_instance = tts_instance | |
self.preprocessor: Optional[TTSPreprocessor] = None | |
self.active_jobs: Dict[str, TTSJob] = {} | |
self.completed_jobs: List[TTSJob] = [] | |
self.created_at = datetime.utcnow() | |
self.last_activity = datetime.utcnow() | |
self.total_jobs = 0 | |
self.total_chars = 0 | |
def update_activity(self): | |
"""Update last activity timestamp""" | |
self.last_activity = datetime.utcnow() | |
class TTSLifecycleManager: | |
"""Manages TTS instances lifecycle""" | |
def __init__(self, event_bus: EventBus, resource_manager: ResourceManager): | |
self.event_bus = event_bus | |
self.resource_manager = resource_manager | |
self.tts_sessions: Dict[str, TTSSession] = {} | |
self.chunk_size = 16384 # 16KB chunks for base64 | |
self._setup_event_handlers() | |
self._setup_resource_pool() | |
def _setup_event_handlers(self): | |
"""Subscribe to TTS-related events""" | |
self.event_bus.subscribe(EventType.TTS_STARTED, self._handle_tts_start) | |
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) | |
def _setup_resource_pool(self): | |
"""Setup TTS instance pool""" | |
self.resource_manager.register_pool( | |
resource_type=ResourceType.TTS_INSTANCE, | |
factory=self._create_tts_instance, | |
max_idle=3, | |
max_age_seconds=600 # 10 minutes | |
) | |
async def _create_tts_instance(self) -> Optional[TTSInterface]: | |
"""Factory for creating TTS instances""" | |
try: | |
tts_instance = TTSFactory.create_provider() | |
if not tts_instance: | |
log_warning("β οΈ No TTS provider configured") | |
return None | |
log_debug("π Created new TTS instance") | |
return tts_instance | |
except Exception as e: | |
log_error(f"β Failed to create TTS instance", error=str(e)) | |
return None | |
async def _handle_tts_start(self, event: Event): | |
"""Handle TTS synthesis request""" | |
session_id = event.session_id | |
text = event.data.get("text", "") | |
is_welcome = event.data.get("is_welcome", False) | |
if not text: | |
log_warning(f"β οΈ Empty text for TTS", session_id=session_id) | |
return | |
try: | |
log_info( | |
f"π Starting TTS", | |
session_id=session_id, | |
text_length=len(text), | |
is_welcome=is_welcome | |
) | |
# Get or create session | |
if session_id not in self.tts_sessions: | |
# Acquire TTS instance from pool | |
resource_id = f"tts_{session_id}" | |
tts_instance = await self.resource_manager.acquire( | |
resource_id=resource_id, | |
session_id=session_id, | |
resource_type=ResourceType.TTS_INSTANCE, | |
cleanup_callback=self._cleanup_tts_instance | |
) | |
if not tts_instance: | |
# No TTS available | |
await self._handle_no_tts(session_id, text, is_welcome) | |
return | |
# Create session | |
tts_session = TTSSession(session_id, tts_instance) | |
# Get locale from event data or default | |
locale = event.data.get("locale", "tr") | |
tts_session.preprocessor = TTSPreprocessor(language=locale) | |
self.tts_sessions[session_id] = tts_session | |
else: | |
tts_session = self.tts_sessions[session_id] | |
# Create job | |
job_id = f"{session_id}_{tts_session.total_jobs}" | |
job = TTSJob(job_id, session_id, text, is_welcome) | |
tts_session.active_jobs[job_id] = job | |
tts_session.total_jobs += 1 | |
tts_session.total_chars += len(text) | |
tts_session.update_activity() | |
# Process TTS | |
await self._process_tts_job(tts_session, job) | |
except Exception as e: | |
log_error( | |
f"β Failed to start TTS", | |
session_id=session_id, | |
error=str(e), | |
traceback=traceback.format_exc() | |
) | |
# Publish error event | |
await publish_error( | |
session_id=session_id, | |
error_type="tts_error", | |
error_message=f"Failed to synthesize speech: {str(e)}" | |
) | |
async def _process_tts_job(self, tts_session: TTSSession, job: TTSJob): | |
"""Process a TTS job""" | |
try: | |
# Preprocess text | |
processed_text = tts_session.preprocessor.preprocess( | |
job.text, | |
tts_session.tts_instance.get_preprocessing_flags() | |
) | |
log_debug( | |
f"π TTS preprocessed", | |
session_id=job.session_id, | |
original_length=len(job.text), | |
processed_length=len(processed_text) | |
) | |
# Synthesize audio | |
audio_data = await tts_session.tts_instance.synthesize(processed_text) | |
if not audio_data: | |
raise ValueError("TTS returned empty audio data") | |
job.complete(audio_data) | |
log_info( | |
f"β TTS synthesis complete", | |
session_id=job.session_id, | |
audio_size=len(audio_data), | |
duration_ms=(datetime.utcnow() - job.created_at).total_seconds() * 1000 | |
) | |
# Stream audio chunks | |
await self._stream_audio_chunks(tts_session, job) | |
# Move to completed | |
tts_session.active_jobs.pop(job.job_id, None) | |
tts_session.completed_jobs.append(job) | |
# Keep only last 10 completed jobs | |
if len(tts_session.completed_jobs) > 10: | |
tts_session.completed_jobs.pop(0) | |
except Exception as e: | |
job.fail(str(e)) | |
# Handle specific TTS errors | |
error_message = str(e) | |
if "quota" in error_message.lower() or "limit" in error_message.lower(): | |
log_error(f"β TTS quota exceeded", session_id=job.session_id) | |
await publish_error( | |
session_id=job.session_id, | |
error_type="tts_quota_exceeded", | |
error_message="TTS service quota exceeded" | |
) | |
else: | |
log_error( | |
f"β TTS synthesis failed", | |
session_id=job.session_id, | |
error=error_message | |
) | |
await publish_error( | |
session_id=job.session_id, | |
error_type="tts_error", | |
error_message=error_message | |
) | |
async def _stream_audio_chunks(self, tts_session: TTSSession, job: TTSJob): | |
"""Stream audio data as chunks""" | |
if not job.audio_data: | |
return | |
# Convert to base64 | |
audio_base64 = base64.b64encode(job.audio_data).decode('utf-8') | |
total_length = len(audio_base64) | |
total_chunks = (total_length + self.chunk_size - 1) // self.chunk_size | |
log_debug( | |
f"π€ Streaming TTS audio", | |
session_id=job.session_id, | |
total_size=len(job.audio_data), | |
base64_size=total_length, | |
chunks=total_chunks | |
) | |
# Stream chunks | |
for i in range(0, total_length, self.chunk_size): | |
chunk = audio_base64[i:i + self.chunk_size] | |
chunk_index = i // self.chunk_size | |
is_last = chunk_index == total_chunks - 1 | |
await self.event_bus.publish(Event( | |
type=EventType.TTS_CHUNK_READY, | |
session_id=job.session_id, | |
data={ | |
"audio_data": chunk, | |
"chunk_index": chunk_index, | |
"total_chunks": total_chunks, | |
"is_last": is_last, | |
"mime_type": "audio/mpeg", | |
"is_welcome": job.is_welcome | |
}, | |
priority=8 # Higher priority for audio chunks | |
)) | |
job.chunks_sent += 1 | |
# Small delay between chunks to prevent overwhelming | |
await asyncio.sleep(0.01) | |
# Notify completion | |
await self.event_bus.publish(Event( | |
type=EventType.TTS_COMPLETED, | |
session_id=job.session_id, | |
data={ | |
"job_id": job.job_id, | |
"total_chunks": total_chunks, | |
"is_welcome": job.is_welcome | |
} | |
)) | |
log_info( | |
f"β TTS streaming complete", | |
session_id=job.session_id, | |
chunks_sent=job.chunks_sent | |
) | |
async def _handle_no_tts(self, session_id: str, text: str, is_welcome: bool): | |
"""Handle case when TTS is not available""" | |
log_warning(f"β οΈ No TTS available, skipping audio generation", session_id=session_id) | |
# Just notify completion without audio | |
await self.event_bus.publish(Event( | |
type=EventType.TTS_COMPLETED, | |
session_id=session_id, | |
data={ | |
"no_audio": True, | |
"text": text, | |
"is_welcome": is_welcome | |
} | |
)) | |
async def _handle_session_ended(self, event: Event): | |
"""Clean up TTS resources when session ends""" | |
session_id = event.session_id | |
await self._cleanup_session(session_id) | |
async def _cleanup_session(self, session_id: str): | |
"""Clean up TTS session""" | |
tts_session = self.tts_sessions.pop(session_id, None) | |
if not tts_session: | |
return | |
try: | |
# Cancel any active jobs | |
for job in tts_session.active_jobs.values(): | |
if not job.completed_at: | |
job.fail("Session ended") | |
# Release resource | |
resource_id = f"tts_{session_id}" | |
await self.resource_manager.release(resource_id, delay_seconds=120) | |
log_info( | |
f"π§Ή TTS session cleaned up", | |
session_id=session_id, | |
total_jobs=tts_session.total_jobs, | |
total_chars=tts_session.total_chars | |
) | |
except Exception as e: | |
log_error( | |
f"β Error cleaning up TTS session", | |
session_id=session_id, | |
error=str(e) | |
) | |
async def _cleanup_tts_instance(self, tts_instance: TTSInterface): | |
"""Cleanup callback for TTS instance""" | |
try: | |
# TTS instances typically don't need special cleanup | |
log_debug("π§Ή TTS instance cleaned up") | |
except Exception as e: | |
log_error(f"β Error cleaning up TTS instance", error=str(e)) | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get TTS manager statistics""" | |
session_stats = {} | |
for session_id, tts_session in self.tts_sessions.items(): | |
session_stats[session_id] = { | |
"active_jobs": len(tts_session.active_jobs), | |
"completed_jobs": len(tts_session.completed_jobs), | |
"total_jobs": tts_session.total_jobs, | |
"total_chars": tts_session.total_chars, | |
"uptime_seconds": (datetime.utcnow() - tts_session.created_at).total_seconds(), | |
"last_activity": tts_session.last_activity.isoformat() | |
} | |
return { | |
"active_sessions": len(self.tts_sessions), | |
"total_active_jobs": sum(len(s.active_jobs) for s in self.tts_sessions.values()), | |
"sessions": session_stats | |
} |