Spaces:
Building
Building
""" | |
Audio Buffer Manager for Flare | |
============================== | |
Manages audio buffering, silence detection, and chunk processing | |
""" | |
import asyncio | |
from typing import Dict, Optional, List, Tuple, Any | |
from collections import deque | |
from datetime import datetime | |
import base64 | |
import numpy as np | |
from dataclasses import dataclass | |
import traceback | |
from chat_session.event_bus import EventBus, Event, EventType | |
from utils.logger import log_info, log_error, log_debug, log_warning | |
class AudioChunk: | |
"""Audio chunk with metadata""" | |
data: bytes | |
timestamp: datetime | |
chunk_index: int | |
is_speech: bool = True | |
energy_level: float = 0.0 | |
class SilenceDetector: | |
"""Detect silence in audio stream""" | |
def __init__(self, | |
threshold_ms: int = 2000, | |
energy_threshold: float = 0.01, | |
sample_rate: int = 16000): | |
self.threshold_ms = threshold_ms | |
self.energy_threshold = energy_threshold | |
self.sample_rate = sample_rate | |
self.silence_start: Optional[datetime] = None | |
def detect_silence(self, audio_chunk: bytes) -> Tuple[bool, int]: | |
""" | |
Detect if chunk is silence and return duration | |
Returns: (is_silence, silence_duration_ms) | |
""" | |
try: | |
# Handle empty or invalid chunk | |
if not audio_chunk or len(audio_chunk) < 2: | |
return True, 0 | |
# Ensure even number of bytes for 16-bit audio | |
if len(audio_chunk) % 2 != 0: | |
audio_chunk = audio_chunk[:-1] | |
# Convert to numpy array | |
audio_data = np.frombuffer(audio_chunk, dtype=np.int16) | |
if len(audio_data) == 0: | |
return True, 0 | |
# Calculate RMS energy | |
rms = np.sqrt(np.mean(audio_data.astype(float) ** 2)) | |
normalized_rms = rms / 32768.0 # Normalize for 16-bit audio | |
is_silence = normalized_rms < self.energy_threshold | |
# Track silence duration | |
now = datetime.utcnow() | |
if is_silence: | |
if self.silence_start is None: | |
self.silence_start = now | |
duration_ms = int((now - self.silence_start).total_seconds() * 1000) | |
else: | |
self.silence_start = None | |
duration_ms = 0 | |
return is_silence, duration_ms | |
except Exception as e: | |
log_warning(f"Silence detection error: {e}") | |
return False, 0 | |
def reset(self): | |
"""Reset silence detection state""" | |
self.silence_start = None | |
class AudioBuffer: | |
"""Manage audio chunks for a session""" | |
def __init__(self, | |
session_id: str, | |
max_chunks: int = 1000, | |
chunk_size_bytes: int = 4096): | |
self.session_id = session_id | |
self.max_chunks = max_chunks | |
self.chunk_size_bytes = chunk_size_bytes | |
self.chunks: deque[AudioChunk] = deque(maxlen=max_chunks) | |
self.chunk_counter = 0 | |
self.total_bytes = 0 | |
self.lock = asyncio.Lock() | |
async def add_chunk(self, audio_data: bytes, timestamp: Optional[datetime] = None) -> AudioChunk: | |
"""Add audio chunk to buffer""" | |
async with self.lock: | |
if timestamp is None: | |
timestamp = datetime.utcnow() | |
chunk = AudioChunk( | |
data=audio_data, | |
timestamp=timestamp, | |
chunk_index=self.chunk_counter | |
) | |
self.chunks.append(chunk) | |
self.chunk_counter += 1 | |
self.total_bytes += len(audio_data) | |
return chunk | |
async def get_recent_audio(self, duration_ms: int = 5000) -> bytes: | |
"""Get recent audio data""" | |
async with self.lock: | |
cutoff_time = datetime.utcnow() | |
audio_parts = [] | |
# Iterate backwards through chunks | |
for chunk in reversed(self.chunks): | |
time_diff = (cutoff_time - chunk.timestamp).total_seconds() * 1000 | |
if time_diff > duration_ms: | |
break | |
audio_parts.append(chunk.data) | |
# Reverse to maintain chronological order | |
audio_parts.reverse() | |
return b''.join(audio_parts) | |
async def clear(self): | |
"""Clear buffer""" | |
async with self.lock: | |
self.chunks.clear() | |
self.chunk_counter = 0 | |
self.total_bytes = 0 | |
def get_stats(self) -> Dict[str, Any]: | |
"""Get buffer statistics""" | |
return { | |
"chunks": len(self.chunks), | |
"total_bytes": self.total_bytes, | |
"chunk_counter": self.chunk_counter, | |
"oldest_chunk": self.chunks[0].timestamp if self.chunks else None, | |
"newest_chunk": self.chunks[-1].timestamp if self.chunks else None | |
} | |
class AudioBufferManager: | |
"""Manage audio buffers for all sessions""" | |
def __init__(self, event_bus: EventBus): | |
self.event_bus = event_bus | |
self.session_buffers: Dict[str, AudioBuffer] = {} | |
self.silence_detectors: Dict[str, SilenceDetector] = {} | |
self._setup_event_handlers() | |
def _setup_event_handlers(self): | |
"""Subscribe to audio events""" | |
self.event_bus.subscribe(EventType.SESSION_STARTED, self._handle_session_started) | |
self.event_bus.subscribe(EventType.SESSION_ENDED, self._handle_session_ended) | |
self.event_bus.subscribe(EventType.AUDIO_CHUNK_RECEIVED, self._handle_audio_chunk) | |
async def _handle_session_started(self, event: Event): | |
"""Initialize buffer for new session""" | |
session_id = event.session_id | |
config = event.data | |
# Create audio buffer | |
self.session_buffers[session_id] = AudioBuffer( | |
session_id=session_id, | |
max_chunks=config.get("max_chunks", 1000), | |
chunk_size_bytes=config.get("chunk_size", 4096) | |
) | |
log_info(f"π¦ Audio buffer initialized", session_id=session_id) | |
async def _handle_session_ended(self, event: Event): | |
"""Cleanup session buffers""" | |
session_id = event.session_id | |
# Clear and remove buffer | |
if session_id in self.session_buffers: | |
await self.session_buffers[session_id].clear() | |
del self.session_buffers[session_id] | |
# Remove silence detector | |
if session_id in self.silence_detectors: | |
del self.silence_detectors[session_id] | |
log_info(f"π¦ Audio buffer cleaned up", session_id=session_id) | |
async def _handle_audio_chunk(self, event: Event): | |
"""Process incoming audio chunk""" | |
session_id = event.session_id | |
buffer = self.session_buffers.get(session_id) | |
if not buffer: | |
log_warning(f"β οΈ No buffer for session", session_id=session_id) | |
return | |
try: | |
# Decode audio data | |
audio_data = base64.b64decode(event.data.get("audio_data", "")) | |
# Add to buffer | |
chunk = await buffer.add_chunk(audio_data) | |
# Log periodically | |
if chunk.chunk_index % 100 == 0: | |
stats = buffer.get_stats() | |
log_debug( | |
f"π Buffer stats", | |
session_id=session_id, | |
**stats | |
) | |
except Exception as e: | |
log_error( | |
f"β Error processing audio chunk", | |
session_id=session_id, | |
error=str(e), | |
traceback=traceback.format_exc() | |
) | |
async def get_buffer(self, session_id: str) -> Optional[AudioBuffer]: | |
"""Get buffer for session""" | |
return self.session_buffers.get(session_id) | |
async def reset_buffer(self, session_id: str): | |
"""Reset buffer for new utterance""" | |
buffer = self.session_buffers.get(session_id) | |
detector = self.silence_detectors.get(session_id) | |
if buffer: | |
await buffer.clear() | |
if detector: | |
detector.reset() | |
log_debug(f"π Audio buffer reset", session_id=session_id) | |
def get_all_stats(self) -> Dict[str, Dict[str, Any]]: | |
"""Get statistics for all buffers""" | |
stats = {} | |
for session_id, buffer in self.session_buffers.items(): | |
stats[session_id] = buffer.get_stats() | |
return stats |