Spaces:
Building
Building
""" | |
WebSocket Handler for Real-time STT/TTS with Barge-in Support | |
""" | |
from fastapi import WebSocket, WebSocketDisconnect | |
from typing import Dict, Any, Optional | |
import json | |
import asyncio | |
import base64 | |
from datetime import datetime | |
from collections import deque | |
from enum import Enum | |
import numpy as np | |
import traceback | |
from session import Session, session_store | |
from config_provider import ConfigProvider | |
from chat_handler import handle_new_message, handle_parameter_followup | |
from stt_factory import STTFactory | |
from tts_factory import TTSFactory | |
from logger import log_info, log_error, log_debug, log_warning | |
# ========================= CONSTANTS ========================= | |
# Default values - will be overridden by config | |
DEFAULT_SILENCE_THRESHOLD_MS = 2000 | |
DEFAULT_AUDIO_CHUNK_SIZE = 4096 | |
DEFAULT_ENERGY_THRESHOLD = 0.01 | |
DEFAULT_AUDIO_BUFFER_MAX_SIZE = 1000 | |
# ========================= ENUMS ========================= | |
class ConversationState(Enum): | |
IDLE = "idle" | |
LISTENING = "listening" | |
PROCESSING_STT = "processing_stt" | |
PROCESSING_LLM = "processing_llm" | |
PROCESSING_TTS = "processing_tts" | |
PLAYING_AUDIO = "playing_audio" | |
# ========================= CLASSES ========================= | |
class AudioBuffer: | |
"""Thread-safe circular buffer for audio chunks""" | |
def __init__(self, max_size: int = DEFAULT_AUDIO_BUFFER_MAX_SIZE): | |
self.buffer = deque(maxlen=max_size) | |
self.lock = asyncio.Lock() | |
async def add_chunk(self, chunk_data: str): | |
"""Add base64 encoded audio chunk""" | |
async with self.lock: | |
decoded = base64.b64decode(chunk_data) | |
self.buffer.append(decoded) | |
async def get_all_audio(self) -> bytes: | |
"""Get all audio data concatenated""" | |
async with self.lock: | |
return b''.join(self.buffer) | |
async def clear(self): | |
"""Clear buffer""" | |
async with self.lock: | |
self.buffer.clear() | |
def size(self) -> int: | |
"""Get current buffer size""" | |
return len(self.buffer) | |
class SilenceDetector: | |
"""Detect silence in audio stream""" | |
def __init__(self, threshold_ms: int = DEFAULT_SILENCE_THRESHOLD_MS, energy_threshold: float = DEFAULT_ENERGY_THRESHOLD): | |
self.threshold_ms = threshold_ms | |
self.energy_threshold = energy_threshold | |
self.silence_start = None | |
self.sample_rate = 16000 | |
def update(self, audio_chunk: bytes) -> int: | |
"""Update with new audio chunk and return silence duration in ms""" | |
if self.is_silence(audio_chunk): | |
if self.silence_start is None: | |
self.silence_start = datetime.now() | |
silence_duration = (datetime.now() - self.silence_start).total_seconds() * 1000 | |
return int(silence_duration) | |
else: | |
self.silence_start = None | |
return 0 | |
def is_silence(self, audio_chunk: bytes) -> bool: | |
"""Check if audio chunk is silence""" | |
try: | |
# Convert bytes to numpy array (assuming 16-bit PCM) | |
audio_data = np.frombuffer(audio_chunk, dtype=np.int16) | |
# Calculate RMS energy | |
if len(audio_data) == 0: | |
return True | |
rms = np.sqrt(np.mean(audio_data.astype(float) ** 2)) | |
normalized_rms = rms / 32768.0 # Normalize for 16-bit audio | |
return normalized_rms < self.energy_threshold | |
except Exception as e: | |
log_warning(f"Silence detection error: {e}") | |
return False | |
def reset(self): | |
"""Reset silence detection""" | |
self.silence_start = None | |
class BargeInHandler: | |
"""Handle user interruptions during TTS playback""" | |
def __init__(self): | |
self.active_tts_task: Optional[asyncio.Task] = None | |
self.is_interrupting = False | |
self.lock = asyncio.Lock() | |
async def start_tts_task(self, coro): | |
"""Start a cancellable TTS task""" | |
async with self.lock: | |
# Cancel any existing task | |
if self.active_tts_task and not self.active_tts_task.done(): | |
self.active_tts_task.cancel() | |
try: | |
await self.active_tts_task | |
except asyncio.CancelledError: | |
pass | |
# Start new task | |
self.active_tts_task = asyncio.create_task(coro) | |
return self.active_tts_task | |
async def handle_interruption(self, current_state: ConversationState): | |
"""Handle barge-in interruption""" | |
async with self.lock: | |
self.is_interrupting = True | |
# Cancel TTS if active | |
if self.active_tts_task and not self.active_tts_task.done(): | |
log_info("Barge-in: Cancelling active TTS") | |
self.active_tts_task.cancel() | |
try: | |
await self.active_tts_task | |
except asyncio.CancelledError: | |
pass | |
# Reset flag after short delay | |
await asyncio.sleep(0.5) | |
self.is_interrupting = False | |
class RealtimeSession: | |
"""Manage a real-time conversation session""" | |
def __init__(self, session: Session): | |
self.session = session | |
self.state = ConversationState.IDLE | |
# Get settings from config | |
config = ConfigProvider.get().global_config.stt_provider.settings | |
# Initialize with config values or defaults | |
silence_threshold = config.get("speech_timeout_ms", DEFAULT_SILENCE_THRESHOLD_MS) | |
energy_threshold = config.get("energy_threshold", DEFAULT_ENERGY_THRESHOLD) | |
buffer_max_size = config.get("audio_buffer_max_size", DEFAULT_AUDIO_BUFFER_MAX_SIZE) | |
self.audio_buffer = AudioBuffer(max_size=buffer_max_size) | |
self.silence_detector = SilenceDetector( | |
threshold_ms=silence_threshold, | |
energy_threshold=energy_threshold | |
) | |
self.barge_in_handler = BargeInHandler() | |
self.stt_manager = None | |
self.current_transcription = "" | |
self.is_streaming = False | |
self.lock = asyncio.Lock() | |
# Store config for later use | |
self.audio_chunk_size = config.get("audio_chunk_size", DEFAULT_AUDIO_CHUNK_SIZE) | |
self.silence_threshold_ms = silence_threshold | |
async def initialize_stt(self): | |
"""Initialize STT provider""" | |
try: | |
self.stt_manager = STTFactory.create_provider() | |
if self.stt_manager: | |
config = ConfigProvider.get().global_config.stt_provider.settings | |
await self.stt_manager.start_streaming({ | |
"language": config.get("language", "tr-TR"), | |
"interim_results": config.get("interim_results", True), | |
"single_utterance": False, | |
"enable_punctuation": config.get("enable_punctuation", True), | |
"sample_rate": 16000, | |
"encoding": "WEBM_OPUS" | |
}) | |
log_info("STT manager initialized", session_id=self.session.session_id) | |
return True | |
except Exception as e: | |
log_error(f"Failed to initialize STT", error=str(e), session_id=self.session.session_id) | |
return False | |
async def change_state(self, new_state: ConversationState): | |
"""Change conversation state""" | |
async with self.lock: | |
old_state = self.state | |
self.state = new_state | |
log_debug( | |
f"State change: {old_state.value} → {new_state.value}", | |
session_id=self.session.session_id | |
) | |
async def handle_barge_in(self): | |
"""Handle user interruption""" | |
await self.barge_in_handler.handle_interruption(self.state) | |
await self.change_state(ConversationState.LISTENING) | |
async def reset_for_new_utterance(self): | |
"""Reset for new user utterance""" | |
await self.audio_buffer.clear() | |
self.silence_detector.reset() | |
self.current_transcription = "" | |
async def cleanup(self): | |
"""Clean up resources""" | |
try: | |
if self.stt_manager: | |
await self.stt_manager.stop_streaming() | |
log_info(f"Cleaned up realtime session", session_id=self.session.session_id) | |
except Exception as e: | |
log_warning(f"Cleanup error", error=str(e), session_id=self.session.session_id) | |
# ========================= MAIN HANDLER ========================= | |
async def websocket_endpoint(websocket: WebSocket, session_id: str): | |
"""Main WebSocket endpoint for real-time conversation""" | |
await websocket.accept() | |
log_info(f"WebSocket connected", session_id=session_id) | |
# Get session | |
session = session_store.get_session(session_id) | |
if not session: | |
await websocket.send_json({ | |
"type": "error", | |
"message": "Session not found" | |
}) | |
await websocket.close() | |
return | |
# Mark as realtime session | |
session.is_realtime_session = True | |
session_store.update_session(session) | |
# Initialize conversation | |
realtime_session = RealtimeSession(session) | |
# Initialize STT | |
stt_initialized = await realtime_session.initialize_stt() | |
if not stt_initialized: | |
await websocket.send_json({ | |
"type": "error", | |
"message": "STT initialization failed" | |
}) | |
try: | |
while True: | |
# Receive message | |
message = await websocket.receive_json() | |
message_type = message.get("type") | |
if message_type == "audio_chunk": | |
await handle_audio_chunk(websocket, realtime_session, message) | |
elif message_type == "control": | |
await handle_control_message(websocket, realtime_session, message) | |
elif message_type == "ping": | |
# Keep-alive ping | |
await websocket.send_json({"type": "pong"}) | |
except WebSocketDisconnect: | |
log_info(f"WebSocket disconnected", session_id=session_id) | |
except Exception as e: | |
log_error( | |
f"WebSocket error", | |
error=str(e), | |
traceback=traceback.format_exc(), | |
session_id=session_id | |
) | |
await websocket.send_json({ | |
"type": "error", | |
"message": str(e) | |
}) | |
finally: | |
await realtime_session.cleanup() | |
# ========================= MESSAGE HANDLERS ========================= | |
async def handle_audio_chunk(websocket: WebSocket, session: RealtimeSession, message: Dict[str, Any]): | |
"""Handle incoming audio chunk with barge-in support""" | |
try: | |
audio_data = message.get("data") | |
if not audio_data: | |
return | |
# Check for barge-in during TTS/audio playback | |
if session.state in [ConversationState.PLAYING_AUDIO, ConversationState.PROCESSING_TTS]: | |
await session.handle_barge_in() | |
await websocket.send_json({ | |
"type": "control", | |
"action": "stop_playback" | |
}) | |
log_info(f"Barge-in detected", session_id=session.session.session_id, state=session.state.value) | |
# Change state to listening if idle | |
if session.state == ConversationState.IDLE: | |
await session.change_state(ConversationState.LISTENING) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "idle", | |
"to": "listening" | |
}) | |
# Add to buffer - don't lose any audio | |
await session.audio_buffer.add_chunk(audio_data) | |
# Decode for processing | |
decoded_audio = base64.b64decode(audio_data) | |
# Check silence | |
silence_duration = session.silence_detector.update(decoded_audio) | |
# Stream to STT if available | |
if session.stt_manager and session.state == ConversationState.LISTENING: | |
async for result in session.stt_manager.stream_audio(decoded_audio): | |
# Send transcription updates | |
await websocket.send_json({ | |
"type": "transcription", | |
"text": result.text, | |
"is_final": result.is_final, | |
"confidence": result.confidence | |
}) | |
if result.is_final: | |
session.current_transcription = result.text | |
# Process if silence detected and we have transcription | |
if silence_duration > session.silence_threshold_ms and session.current_transcription: | |
log_info( | |
f"User stopped speaking", | |
session_id=session.session.session_id, | |
silence_ms=silence_duration, | |
text=session.current_transcription | |
) | |
await process_user_input(websocket, session) | |
except Exception as e: | |
log_error( | |
f"Audio chunk handling error", | |
error=str(e), | |
traceback=traceback.format_exc(), | |
session_id=session.session.session_id | |
) | |
await websocket.send_json({ | |
"type": "error", | |
"message": f"Audio processing error: {str(e)}" | |
}) | |
async def handle_control_message(websocket: WebSocket, session: RealtimeSession, message: Dict[str, Any]): | |
"""Handle control messages""" | |
action = message.get("action") | |
config = message.get("config", {}) | |
log_debug(f"Control message", action=action, session_id=session.session.session_id) | |
if action == "start_session": | |
# Session configuration | |
await websocket.send_json({ | |
"type": "session_started", | |
"session_id": session.session.session_id, | |
"config": { | |
"silence_threshold_ms": session.silence_threshold_ms, | |
"audio_chunk_size": session.audio_chunk_size, | |
"supports_barge_in": True | |
} | |
}) | |
elif action == "end_session": | |
# Clean up and close | |
await session.cleanup() | |
await websocket.close() | |
elif action == "interrupt": | |
# Handle explicit interrupt | |
await session.handle_barge_in() | |
await websocket.send_json({ | |
"type": "control", | |
"action": "interrupt_acknowledged" | |
}) | |
elif action == "reset": | |
# Reset conversation state | |
await session.reset_for_new_utterance() | |
await session.change_state(ConversationState.IDLE) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": session.state.value, | |
"to": "idle" | |
}) | |
elif action == "audio_ended": | |
# Audio playback ended on client | |
if session.state == ConversationState.PLAYING_AUDIO: | |
await session.change_state(ConversationState.IDLE) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "playing_audio", | |
"to": "idle" | |
}) | |
# ========================= PROCESSING FUNCTIONS ========================= | |
async def process_user_input(websocket: WebSocket, session: RealtimeSession): | |
"""Process complete user input""" | |
try: | |
user_text = session.current_transcription | |
if not user_text: | |
await session.reset_for_new_utterance() | |
await session.change_state(ConversationState.IDLE) | |
return | |
log_info(f"Processing user input", text=user_text, session_id=session.session.session_id) | |
# State: STT Processing | |
await session.change_state(ConversationState.PROCESSING_STT) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "listening", | |
"to": "processing_stt" | |
}) | |
# Send final transcription | |
await websocket.send_json({ | |
"type": "transcription", | |
"text": user_text, | |
"is_final": True, | |
"confidence": 0.95 | |
}) | |
# State: LLM Processing | |
await session.change_state(ConversationState.PROCESSING_LLM) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "processing_stt", | |
"to": "processing_llm" | |
}) | |
# Add to chat history | |
session.session.add_message("user", user_text) | |
# Get LLM response based on session state | |
if session.session.state == "collect_params": | |
response_text = await handle_parameter_followup(session.session, user_text) | |
else: | |
response_text = await handle_new_message(session.session, user_text) | |
# Add response to history | |
session.session.add_message("assistant", response_text) | |
# Send text response | |
await websocket.send_json({ | |
"type": "assistant_response", | |
"text": response_text | |
}) | |
# Generate TTS if enabled | |
tts_provider = TTSFactory.create_provider() | |
if tts_provider: | |
await session.change_state(ConversationState.PROCESSING_TTS) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "processing_llm", | |
"to": "processing_tts" | |
}) | |
# Generate TTS with barge-in support | |
tts_task = session.barge_in_handler.start_tts_task( | |
generate_and_stream_tts(websocket, session, tts_provider, response_text) | |
) | |
try: | |
await tts_task | |
except asyncio.CancelledError: | |
log_info("TTS cancelled due to barge-in", session_id=session.session.session_id) | |
else: | |
# No TTS, go back to idle | |
await session.change_state(ConversationState.IDLE) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "processing_llm", | |
"to": "idle" | |
}) | |
# Reset for next input | |
await session.reset_for_new_utterance() | |
except Exception as e: | |
log_error( | |
f"Error processing user input", | |
error=str(e), | |
traceback=traceback.format_exc(), | |
session_id=session.session.session_id | |
) | |
await websocket.send_json({ | |
"type": "error", | |
"message": f"Processing error: {str(e)}" | |
}) | |
await session.reset_for_new_utterance() | |
await session.change_state(ConversationState.IDLE) | |
async def generate_and_stream_tts( | |
websocket: WebSocket, | |
session: RealtimeSession, | |
tts_provider, | |
text: str | |
): | |
"""Generate and stream TTS audio with cancellation support""" | |
try: | |
# Generate audio | |
audio_data = await tts_provider.synthesize(text) | |
# Change state to playing | |
await session.change_state(ConversationState.PLAYING_AUDIO) | |
await websocket.send_json({ | |
"type": "state_change", | |
"from": "processing_tts", | |
"to": "playing_audio" | |
}) | |
# Stream audio in chunks | |
chunk_size = session.audio_chunk_size | |
total_chunks = (len(audio_data) + chunk_size - 1) // chunk_size | |
for i in range(0, len(audio_data), chunk_size): | |
# Check for cancellation | |
if asyncio.current_task().cancelled(): | |
break | |
chunk = audio_data[i:i + chunk_size] | |
chunk_index = i // chunk_size | |
await websocket.send_json({ | |
"type": "tts_audio", | |
"data": base64.b64encode(chunk).decode('utf-8'), | |
"chunk_index": chunk_index, | |
"total_chunks": total_chunks, | |
"is_last": chunk_index == total_chunks - 1 | |
}) | |
# Small delay to prevent overwhelming the client | |
await asyncio.sleep(0.01) | |
log_info( | |
f"TTS streaming completed", | |
session_id=session.session.session_id, | |
text_length=len(text), | |
audio_size=len(audio_data) | |
) | |
except asyncio.CancelledError: | |
log_info("TTS streaming cancelled", session_id=session.session.session_id) | |
raise | |
except Exception as e: | |
log_error( | |
f"TTS generation error", | |
error=str(e), | |
session_id=session.session.session_id | |
) | |
await websocket.send_json({ | |
"type": "error", | |
"message": f"TTS error: {str(e)}" | |
}) |