""" Audio Buffer Manager for Flare ============================== Manages audio buffering, silence detection, and chunk processing """ import asyncio from typing import Dict, Optional, List, Tuple, Any from collections import deque from datetime import datetime import base64 import numpy as np from dataclasses import dataclass import traceback from chat_session.event_bus import EventBus, Event, EventType from utils.logger import log_info, log_error, log_debug, log_warning @dataclass class AudioChunk: """Audio chunk with metadata""" data: bytes timestamp: datetime chunk_index: int is_speech: bool = True energy_level: float = 0.0 class SilenceDetector: """Detect silence in audio stream""" def __init__(self, threshold_ms: int = 2000, energy_threshold: float = 0.01, sample_rate: int = 16000): self.threshold_ms = threshold_ms self.energy_threshold = energy_threshold self.sample_rate = sample_rate self.silence_start: Optional[datetime] = None def detect_silence(self, audio_chunk: bytes) -> Tuple[bool, int]: """ Detect if chunk is silence and return duration Returns: (is_silence, silence_duration_ms) """ try: # Handle empty or invalid chunk if not audio_chunk or len(audio_chunk) < 2: return True, 0 # Ensure even number of bytes for 16-bit audio if len(audio_chunk) % 2 != 0: audio_chunk = audio_chunk[:-1] # Convert to numpy array audio_data = np.frombuffer(audio_chunk, dtype=np.int16) if len(audio_data) == 0: return True, 0 # Calculate RMS energy rms = np.sqrt(np.mean(audio_data.astype(float) ** 2)) normalized_rms = rms / 32768.0 # Normalize for 16-bit audio is_silence = normalized_rms < self.energy_threshold # Track silence duration now = datetime.utcnow() if is_silence: if self.silence_start is None: self.silence_start = now duration_ms = int((now - self.silence_start).total_seconds() * 1000) else: self.silence_start = None duration_ms = 0 return is_silence, duration_ms except Exception as e: log_warning(f"Silence detection error: {e}") return False, 0 def reset(self): """Reset silence detection state""" self.silence_start = None class AudioBuffer: """Manage audio chunks for a session""" def __init__(self, session_id: str, max_chunks: int = 1000, chunk_size_bytes: int = 4096): self.session_id = session_id self.max_chunks = max_chunks self.chunk_size_bytes = chunk_size_bytes self.chunks: deque[AudioChunk] = deque(maxlen=max_chunks) self.chunk_counter = 0 self.total_bytes = 0 self.lock = asyncio.Lock() async def add_chunk(self, audio_data: bytes, timestamp: Optional[datetime] = None) -> AudioChunk: """Add audio chunk to buffer""" async with self.lock: if timestamp is None: timestamp = datetime.utcnow() chunk = AudioChunk( data=audio_data, timestamp=timestamp, chunk_index=self.chunk_counter ) self.chunks.append(chunk) self.chunk_counter += 1 self.total_bytes += len(audio_data) return chunk async def get_recent_audio(self, duration_ms: int = 5000) -> bytes: """Get recent audio data""" async with self.lock: cutoff_time = datetime.utcnow() audio_parts = [] # Iterate backwards through chunks for chunk in reversed(self.chunks): time_diff = (cutoff_time - chunk.timestamp).total_seconds() * 1000 if time_diff > duration_ms: break audio_parts.append(chunk.data) # Reverse to maintain chronological order audio_parts.reverse() return b''.join(audio_parts) async def clear(self): """Clear buffer""" async with self.lock: self.chunks.clear() self.chunk_counter = 0 self.total_bytes = 0 def get_stats(self) -> Dict[str, Any]: """Get buffer statistics""" return { "chunks": len(self.chunks), "total_bytes": self.total_bytes, "chunk_counter": self.chunk_counter, "oldest_chunk": self.chunks[0].timestamp if self.chunks else None, "newest_chunk": self.chunks[-1].timestamp if self.chunks else None } class AudioBufferManager: """Manage audio buffers for all sessions""" def __init__(self, event_bus: EventBus): self.event_bus = event_bus self.session_buffers: Dict[str, AudioBuffer] = {} self.silence_detectors: Dict[str, SilenceDetector] = {} self._setup_event_handlers() def _setup_event_handlers(self): """Subscribe to audio events""" self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started) self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk) async def _handle_session_started(self, event: Event): """Initialize buffer for new session""" session_id = event.session_id config = event.data # Create audio buffer self.session_buffers[session_id] = AudioBuffer( session_id=session_id, max_chunks=config.get("max_chunks", 1000), chunk_size_bytes=config.get("chunk_size", 4096) ) log_info(f"📦 Audio buffer initialized", session_id=session_id) async def _handle_session_ended(self, event: Event): """Cleanup session buffers""" session_id = event.session_id # Clear and remove buffer if session_id in self.session_buffers: await self.session_buffers[session_id].clear() del self.session_buffers[session_id] # Remove silence detector if session_id in self.silence_detectors: del self.silence_detectors[session_id] log_info(f"📦 Audio buffer cleaned up", session_id=session_id) async def _handle_audio_chunk(self, event: Event): """Process incoming audio chunk""" session_id = event.session_id buffer = self.session_buffers.get(session_id) if not buffer: log_warning(f"⚠️ No buffer for session", session_id=session_id) return try: # Decode audio data audio_data = base64.b64decode(event.data.get("audio_data", "")) # Add to buffer chunk = await buffer.add_chunk(audio_data) # Log periodically if chunk.chunk_index % 100 == 0: stats = buffer.get_stats() log_debug( f"📊 Buffer stats", session_id=session_id, **stats ) except Exception as e: log_error( f"❌ Error processing audio chunk", session_id=session_id, error=str(e), traceback=traceback.format_exc() ) async def get_buffer(self, session_id: str) -> Optional[AudioBuffer]: """Get buffer for session""" return self.session_buffers.get(session_id) async def reset_buffer(self, session_id: str): """Reset buffer for new utterance""" buffer = self.session_buffers.get(session_id) detector = self.silence_detectors.get(session_id) if buffer: await buffer.clear() if detector: detector.reset() log_debug(f"🔄 Audio buffer reset", session_id=session_id) def get_all_stats(self) -> Dict[str, Dict[str, Any]]: """Get statistics for all buffers""" stats = {} for session_id, buffer in self.session_buffers.items(): stats[session_id] = buffer.get_stats() return stats