Spaces:
Building
Building
""" | |
STT Lifecycle Manager for Flare - Batch Mode | |
=============================== | |
Manages STT instances and audio collection | |
""" | |
import asyncio | |
from typing import Dict, Optional, Any | |
from datetime import datetime | |
import traceback | |
import base64 | |
from chat_session.event_bus import EventBus, Event, EventType, publish_error | |
from chat_session.resource_manager import ResourceManager, ResourceType | |
from stt.stt_factory import STTFactory | |
from stt.stt_interface import STTInterface, STTConfig, TranscriptionResult | |
from stt.voice_activity_detector import VoiceActivityDetector | |
from utils.logger import log_info, log_error, log_debug, log_warning | |
class STTSession: | |
"""STT session with audio collection""" | |
def __init__(self, session_id: str, stt_instance: STTInterface): | |
self.session_id = session_id | |
self.stt_instance = stt_instance | |
self.is_active = False | |
self.config: Optional[STTConfig] = None | |
self.created_at = datetime.utcnow() | |
# Audio collection | |
self.audio_buffer = [] | |
self.vad = VoiceActivityDetector() | |
# Stats | |
self.total_chunks = 0 | |
self.total_bytes = 0 | |
def reset(self): | |
"""Reset session for new utterance""" | |
self.audio_buffer = [] | |
self.vad.reset() | |
self.total_chunks = 0 | |
self.total_bytes = 0 | |
class STTLifecycleManager: | |
"""Manages STT instances lifecycle""" | |
def __init__(self, event_bus: EventBus, resource_manager: ResourceManager): | |
self.event_bus = event_bus | |
self.resource_manager = resource_manager | |
self.stt_sessions: Dict[str, STTSession] = {} | |
self._setup_event_handlers() | |
self._setup_resource_pool() | |
def _setup_event_handlers(self): | |
"""Subscribe to STT-related events""" | |
self.event_bus.subscribe(EventType.STT_STARTED, self._handle_stt_start) | |
self.event_bus.subscribe(EventType.STT_STOPPED, self._handle_stt_stop) | |
self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk) | |
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) | |
async def _handle_stt_start(self, event: Event): | |
"""Handle STT start request""" | |
session_id = event.session_id | |
config_data = event.data | |
try: | |
log_info(f"🎤 Starting STT", session_id=session_id) | |
# Get or create session | |
if session_id not in self.stt_sessions: | |
# Acquire STT instance from pool | |
resource_id = f"stt_{session_id}" | |
stt_instance = await self.resource_manager.acquire( | |
resource_id=resource_id, | |
session_id=session_id, | |
resource_type=ResourceType.STT_INSTANCE, | |
cleanup_callback=self._cleanup_stt_instance | |
) | |
# Create session | |
stt_session = STTSession(session_id, stt_instance) | |
self.stt_sessions[session_id] = stt_session | |
else: | |
stt_session = self.stt_sessions[session_id] | |
stt_session.reset() | |
# Build STT config | |
locale = config_data.get("locale", "tr") | |
stt_config = STTConfig( | |
language=locale, | |
sample_rate=config_data.get("sample_rate", 16000), | |
encoding=config_data.get("encoding", "LINEAR16"), | |
enable_punctuation=config_data.get("enable_punctuation", True), | |
model=config_data.get("model", "latest_long"), | |
use_enhanced=config_data.get("use_enhanced", True), | |
) | |
stt_session.config = stt_config | |
stt_session.is_active = True | |
log_info(f"✅ STT started in batch mode", session_id=session_id, language=stt_config.language) | |
# Notify STT is ready | |
await self.event_bus.publish(Event( | |
type=EventType.STT_READY, | |
session_id=session_id, | |
data={"language": stt_config.language} | |
)) | |
except Exception as e: | |
log_error( | |
f"❌ Failed to start STT", | |
session_id=session_id, | |
error=str(e), | |
traceback=traceback.format_exc() | |
) | |
# Clean up on error | |
if session_id in self.stt_sessions: | |
await self._cleanup_session(session_id) | |
# Publish error event | |
await publish_error( | |
session_id=session_id, | |
error_type="stt_error", | |
error_message=f"Failed to start STT: {str(e)}" | |
) | |
async def _handle_audio_chunk(self, event: Event): | |
"""Process audio chunk through VAD and collect""" | |
session_id = event.session_id | |
stt_session = self.stt_sessions.get(session_id) | |
if not stt_session: | |
# STT session yoksa chunk'ları ignore et | |
return | |
# ✅ STT inaktifse chunk'ları tamamen ignore et | |
if not stt_session.is_active: | |
return | |
try: | |
# Decode audio data | |
audio_data = base64.b64decode(event.data.get("audio_data", "")) | |
# Add to buffer - sadece aktifken | |
stt_session.audio_buffer.append(audio_data) | |
stt_session.total_chunks += 1 | |
stt_session.total_bytes += len(audio_data) | |
# Process through VAD | |
is_speech, silence_duration_ms = stt_session.vad.process_chunk(audio_data) | |
# Check if utterance ended (silence threshold reached) | |
if not is_speech and silence_duration_ms >= 2000: # 2 seconds of silence | |
log_info(f"💬 Utterance ended after {silence_duration_ms}ms silence", session_id=session_id) | |
# ✅ Hemen STT'yi inaktif yap ki daha fazla chunk işlenmesin | |
stt_session.is_active = False | |
# ✅ Frontend'e derhal recording durdurmayı söyle | |
await self.event_bus.publish(Event( | |
type=EventType.STT_STOPPED, | |
session_id=session_id, | |
data={"reason": "silence_detected", "stop_recording": True} | |
)) | |
log_info(f"🛑 STT stopped and frontend notified to stop recording", session_id=session_id) | |
# Log progress periodically | |
if stt_session.total_chunks % 100 == 0: | |
log_debug( | |
f"📊 STT progress", | |
session_id=session_id, | |
chunks=stt_session.total_chunks, | |
bytes=stt_session.total_bytes, | |
vad_stats=stt_session.vad.get_stats() | |
) | |
except Exception as e: | |
log_error( | |
f"❌ Error processing audio chunk", | |
session_id=session_id, | |
error=str(e) | |
) | |
async def _handle_session_ended(self, event: Event): | |
"""Clean up STT resources when session ends""" | |
session_id = event.session_id | |
await self._cleanup_session(session_id) | |
async def _handle_stt_stop(self, event: Event): | |
"""Handle STT stop request and perform transcription""" | |
session_id = event.session_id | |
reason = event.data.get("reason", "unknown") | |
log_info(f"🛑 Stopping STT", session_id=session_id, reason=reason) | |
stt_session = self.stt_sessions.get(session_id) | |
if not stt_session: | |
log_warning(f"⚠️ No STT session found", session_id=session_id) | |
return | |
try: | |
# ✅ STT'yi inaktif yap ki daha fazla chunk işlenmesin | |
stt_session.is_active = False | |
# ✅ Transcription sadece audio buffer varsa ve reason silence_detected ise yap | |
if reason == "silence_detected" and stt_session.audio_buffer: | |
# Combine audio chunks | |
combined_audio = b''.join(stt_session.audio_buffer) | |
log_info(f"📝 Transcribing {len(combined_audio)} bytes of audio", session_id=session_id) | |
# Transcribe using batch mode | |
result = await stt_session.stt_instance.transcribe( | |
audio_data=combined_audio, | |
config=stt_session.config | |
) | |
# Publish result if we got transcription | |
if result and result.text: | |
await self.event_bus.publish(Event( | |
type=EventType.STT_RESULT, | |
session_id=session_id, | |
data={ | |
"text": result.text, | |
"is_final": True, | |
"confidence": result.confidence | |
} | |
)) | |
log_info(f"✅ Transcription completed: '{result.text}'", session_id=session_id) | |
else: | |
log_warning(f"⚠️ No transcription result", session_id=session_id) | |
elif reason != "silence_detected": | |
log_info(f"📝 STT stopped without transcription (reason: {reason})", session_id=session_id) | |
# Reset session for next utterance | |
stt_session.reset() | |
log_info(f"✅ STT session reset and ready for next utterance", session_id=session_id) | |
except Exception as e: | |
log_error( | |
f"❌ Error stopping STT", | |
session_id=session_id, | |
error=str(e) | |
) | |
async def _cleanup_session(self, session_id: str): | |
"""Clean up STT session""" | |
stt_session = self.stt_sessions.pop(session_id, None) | |
if not stt_session: | |
return | |
try: | |
# Mark as inactive | |
stt_session.is_active = False | |
# Release resource | |
resource_id = f"stt_{session_id}" | |
await self.resource_manager.release(resource_id, delay_seconds=60) | |
log_info( | |
f"🧹 STT session cleaned up", | |
session_id=session_id, | |
total_chunks=stt_session.total_chunks, | |
total_bytes=stt_session.total_bytes | |
) | |
except Exception as e: | |
log_error( | |
f"❌ Error cleaning up STT session", | |
session_id=session_id, | |
error=str(e) | |
) | |
async def _cleanup_stt_instance(self, stt_instance: STTInterface): | |
"""Cleanup callback for STT instance""" | |
try: | |
log_debug("🧹 STT instance cleaned up") | |
except Exception as e: | |
log_error(f"❌ Error cleaning up STT instance", error=str(e)) | |
def _setup_resource_pool(self): | |
"""Setup STT instance pool""" | |
self.resource_manager.register_pool( | |
resource_type=ResourceType.STT_INSTANCE, | |
factory=self._create_stt_instance, | |
max_idle=5, | |
max_age_seconds=300 # 5 minutes | |
) | |
async def _create_stt_instance(self) -> STTInterface: | |
"""Factory for creating STT instances""" | |
try: | |
stt_instance = STTFactory.create_provider() | |
if not stt_instance: | |
raise ValueError("Failed to create STT instance") | |
log_debug("🎤 Created new STT instance") | |
return stt_instance | |
except Exception as e: | |
log_error(f"❌ Failed to create STT instance", error=str(e)) | |
raise | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get STT manager statistics""" | |
session_stats = {} | |
for session_id, stt_session in self.stt_sessions.items(): | |
session_stats[session_id] = { | |
"is_active": stt_session.is_active, | |
"total_chunks": stt_session.total_chunks, | |
"total_bytes": stt_session.total_bytes, | |
"vad_stats": stt_session.vad.get_stats() if stt_session.vad else {} | |
} | |
return { | |
"active_sessions": len(self.stt_sessions), | |
"active_streaming": sum(1 for s in self.stt_sessions.values() if s.is_active), | |
"sessions": session_stats | |
} |