flare / stt_lifecycle_manager.py
ciyidogan's picture
Upload 8 files
e8a19b3 verified
raw
history blame
14.1 kB
"""
STT Lifecycle Manager for Flare
===============================
Manages STT instances lifecycle per session
"""
import asyncio
from typing import Dict, Optional, Any
from datetime import datetime
import traceback
import base64
from event_bus import EventBus, Event, EventType, publish_error
from resource_manager import ResourceManager, ResourceType
from stt_factory import STTFactory
from stt_interface import STTInterface, STTConfig, TranscriptionResult
from logger import log_info, log_error, log_debug, log_warning
class STTSession:
"""STT session wrapper"""
def __init__(self, session_id: str, stt_instance: STTInterface):
self.session_id = session_id
self.stt_instance = stt_instance
self.is_streaming = False
self.config: Optional[STTConfig] = None
self.created_at = datetime.utcnow()
self.last_activity = datetime.utcnow()
self.total_chunks = 0
self.total_bytes = 0
def update_activity(self):
"""Update last activity timestamp"""
self.last_activity = datetime.utcnow()
class STTLifecycleManager:
"""Manages STT instances lifecycle"""
def __init__(self, event_bus: EventBus, resource_manager: ResourceManager):
self.event_bus = event_bus
self.resource_manager = resource_manager
self.stt_sessions: Dict[str, STTSession] = {}
self._setup_event_handlers()
self._setup_resource_pool()
def _setup_event_handlers(self):
"""Subscribe to STT-related events"""
self.event_bus.subscribe(EventType.STT_STARTED, self._handle_stt_start)
self.event_bus.subscribe(EventType.STT_STOPPED, self._handle_stt_stop)
self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk)
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended)
def _setup_resource_pool(self):
"""Setup STT instance pool"""
self.resource_manager.register_pool(
resource_type=ResourceType.STT_INSTANCE,
factory=self._create_stt_instance,
max_idle=5,
max_age_seconds=300 # 5 minutes
)
async def _create_stt_instance(self) -> STTInterface:
"""Factory for creating STT instances"""
try:
stt_instance = STTFactory.create_provider()
if not stt_instance:
raise ValueError("Failed to create STT instance")
log_debug("🎀 Created new STT instance")
return stt_instance
except Exception as e:
log_error(f"❌ Failed to create STT instance", error=str(e))
raise
async def _handle_stt_start(self, event: Event):
"""Handle STT start request"""
session_id = event.session_id
config_data = event.data
try:
log_info(f"🎀 Starting STT", session_id=session_id)
# Check if already exists
if session_id in self.stt_sessions:
stt_session = self.stt_sessions[session_id]
if stt_session.is_streaming:
log_warning(f"⚠️ STT already streaming", session_id=session_id)
return
else:
# Acquire STT instance from pool
resource_id = f"stt_{session_id}"
stt_instance = await self.resource_manager.acquire(
resource_id=resource_id,
session_id=session_id,
resource_type=ResourceType.STT_INSTANCE,
cleanup_callback=self._cleanup_stt_instance
)
# Create session wrapper
stt_session = STTSession(session_id, stt_instance)
self.stt_sessions[session_id] = stt_session
# Get session locale from state orchestrator
locale = config_data.get("locale", "tr")
# Build STT config
stt_config = STTConfig(
language=self._get_language_code(locale),
sample_rate=config_data.get("sample_rate", 16000),
encoding=config_data.get("encoding", "WEBM_OPUS"),
enable_punctuation=config_data.get("enable_punctuation", True),
enable_word_timestamps=False,
model=config_data.get("model", "latest_long"),
use_enhanced=config_data.get("use_enhanced", True),
single_utterance=False, # Continuous listening
interim_results=config_data.get("interim_results", True),
vad_enabled=config_data.get("vad_enabled", True),
speech_timeout_ms=config_data.get("speech_timeout_ms", 2000),
noise_reduction_enabled=config_data.get("noise_reduction_enabled", True),
noise_reduction_level=config_data.get("noise_reduction_level", 2)
)
stt_session.config = stt_config
# Start streaming
await stt_session.stt_instance.start_streaming(stt_config)
stt_session.is_streaming = True
stt_session.update_activity()
log_info(f"βœ… STT started", session_id=session_id, language=stt_config.language)
# Notify STT is ready
await self.event_bus.publish(Event(
type=EventType.STT_READY,
session_id=session_id,
data={"language": stt_config.language}
))
except Exception as e:
log_error(
f"❌ Failed to start STT",
session_id=session_id,
error=str(e),
traceback=traceback.format_exc()
)
# Clean up on error
if session_id in self.stt_sessions:
await self._cleanup_session(session_id)
# Publish error event
await publish_error(
session_id=session_id,
error_type="stt_error",
error_message=f"Failed to start STT: {str(e)}"
)
async def _handle_stt_stop(self, event: Event):
"""Handle STT stop request"""
session_id = event.session_id
reason = event.data.get("reason", "unknown")
log_info(f"πŸ›‘ Stopping STT", session_id=session_id, reason=reason)
stt_session = self.stt_sessions.get(session_id)
if not stt_session:
log_warning(f"⚠️ No STT session found", session_id=session_id)
return
try:
if stt_session.is_streaming:
# Stop streaming
final_result = await stt_session.stt_instance.stop_streaming()
stt_session.is_streaming = False
# If we got a final result, publish it
if final_result and final_result.text:
await self.event_bus.publish(Event(
type=EventType.STT_RESULT,
session_id=session_id,
data={
"text": final_result.text,
"is_final": True,
"confidence": final_result.confidence
}
))
# Don't remove session immediately - might restart
stt_session.update_activity()
log_info(f"βœ… STT stopped", session_id=session_id)
except Exception as e:
log_error(
f"❌ Error stopping STT",
session_id=session_id,
error=str(e)
)
async def _handle_audio_chunk(self, event: Event):
"""Process audio chunk through STT"""
session_id = event.session_id
stt_session = self.stt_sessions.get(session_id)
if not stt_session or not stt_session.is_streaming:
# STT not ready, ignore chunk
return
try:
# Decode audio data
audio_data = base64.b64decode(event.data.get("audio_data", ""))
# Update stats
stt_session.total_chunks += 1
stt_session.total_bytes += len(audio_data)
stt_session.update_activity()
# Stream to STT
async for result in stt_session.stt_instance.stream_audio(audio_data):
# Publish transcription results
await self.event_bus.publish(Event(
type=EventType.STT_RESULT,
session_id=session_id,
data={
"text": result.text,
"is_final": result.is_final,
"confidence": result.confidence,
"timestamp": result.timestamp
}
))
# Log final results
if result.is_final:
log_info(
f"πŸ“ STT final result",
session_id=session_id,
text=result.text[:50] + "..." if len(result.text) > 50 else result.text,
confidence=result.confidence
)
# Log progress periodically
if stt_session.total_chunks % 100 == 0:
log_debug(
f"πŸ“Š STT progress",
session_id=session_id,
chunks=stt_session.total_chunks,
bytes=stt_session.total_bytes
)
except Exception as e:
log_error(
f"❌ Error processing audio chunk",
session_id=session_id,
error=str(e)
)
# Check if it's a recoverable error
if "stream duration" in str(e) or "timeout" in str(e).lower():
# STT timeout, restart needed
await publish_error(
session_id=session_id,
error_type="stt_timeout",
error_message="STT stream timeout, restart needed"
)
else:
# Other STT error
await publish_error(
session_id=session_id,
error_type="stt_error",
error_message=str(e)
)
async def _handle_session_ended(self, event: Event):
"""Clean up STT resources when session ends"""
session_id = event.session_id
await self._cleanup_session(session_id)
async def _cleanup_session(self, session_id: str):
"""Clean up STT session"""
stt_session = self.stt_sessions.pop(session_id, None)
if not stt_session:
return
try:
# Stop streaming if active
if stt_session.is_streaming:
await stt_session.stt_instance.stop_streaming()
# Release resource
resource_id = f"stt_{session_id}"
await self.resource_manager.release(resource_id, delay_seconds=60)
log_info(
f"🧹 STT session cleaned up",
session_id=session_id,
total_chunks=stt_session.total_chunks,
total_bytes=stt_session.total_bytes
)
except Exception as e:
log_error(
f"❌ Error cleaning up STT session",
session_id=session_id,
error=str(e)
)
async def _cleanup_stt_instance(self, stt_instance: STTInterface):
"""Cleanup callback for STT instance"""
try:
# Ensure streaming is stopped
if hasattr(stt_instance, 'is_streaming') and stt_instance.is_streaming:
await stt_instance.stop_streaming()
log_debug("🧹 STT instance cleaned up")
except Exception as e:
log_error(f"❌ Error cleaning up STT instance", error=str(e))
def _get_language_code(self, locale: str) -> str:
"""Convert locale to STT language code"""
# Map common locales to STT language codes
locale_map = {
"tr": "tr-TR",
"en": "en-US",
"de": "de-DE",
"fr": "fr-FR",
"es": "es-ES",
"it": "it-IT",
"pt": "pt-BR",
"ru": "ru-RU",
"ja": "ja-JP",
"ko": "ko-KR",
"zh": "zh-CN",
"ar": "ar-SA"
}
# Check direct match
if locale in locale_map:
return locale_map[locale]
# Check if it's already a full code
if "-" in locale and len(locale) == 5:
return locale
# Default to locale-LOCALE format
return f"{locale}-{locale.upper()}"
def get_stats(self) -> Dict[str, Any]:
"""Get STT manager statistics"""
session_stats = {}
for session_id, stt_session in self.stt_sessions.items():
session_stats[session_id] = {
"is_streaming": stt_session.is_streaming,
"total_chunks": stt_session.total_chunks,
"total_bytes": stt_session.total_bytes,
"uptime_seconds": (datetime.utcnow() - stt_session.created_at).total_seconds(),
"last_activity": stt_session.last_activity.isoformat()
}
return {
"active_sessions": len(self.stt_sessions),
"streaming_sessions": sum(1 for s in self.stt_sessions.values() if s.is_streaming),
"sessions": session_stats
}