""" WebSocket Handler for Real-time STT/TTS """ from fastapi import WebSocket, WebSocketDisconnect, HTTPException from typing import Dict, Any, Optional import json import asyncio import base64 from datetime import datetime import sys import numpy as np from enum import Enum from session import Session, session_store from config_provider import ConfigProvider from chat_handler import handle_new_message, handle_parameter_followup from stt_factory import STTFactory from tts_factory import TTSFactory from utils import log # ========================= CONSTANTS ========================= SILENCE_THRESHOLD_MS = 2000 AUDIO_CHUNK_SIZE = 4096 ENERGY_THRESHOLD = 0.01 # ========================= ENUMS ========================= class ConversationState(Enum): IDLE = "idle" LISTENING = "listening" PROCESSING_STT = "processing_stt" PROCESSING_LLM = "processing_llm" PROCESSING_TTS = "processing_tts" PLAYING_AUDIO = "playing_audio" # ========================= CLASSES ========================= class AudioBuffer: """Buffer for accumulating audio chunks""" def __init__(self): self.chunks = [] self.total_size = 0 def add_chunk(self, chunk_data: str): """Add base64 encoded audio chunk""" decoded = base64.b64decode(chunk_data) self.chunks.append(decoded) self.total_size += len(decoded) def get_audio(self) -> bytes: """Get concatenated audio data""" return b''.join(self.chunks) def clear(self): """Clear buffer""" self.chunks.clear() self.total_size = 0 class SilenceDetector: """Detect silence in audio stream""" def __init__(self, threshold_ms: int = SILENCE_THRESHOLD_MS, energy_threshold: float = ENERGY_THRESHOLD): self.threshold_ms = threshold_ms self.energy_threshold = energy_threshold self.silence_start = None self.sample_rate = 16000 # Default sample rate def is_silence(self, audio_chunk: bytes) -> bool: """Check if audio chunk is silence""" try: # Convert bytes to numpy array (assuming 16-bit PCM) audio_data = np.frombuffer(audio_chunk, dtype=np.int16) # Calculate RMS energy rms = np.sqrt(np.mean(audio_data**2)) normalized_rms = rms / 32768.0 # Normalize for 16-bit audio return normalized_rms < self.energy_threshold except Exception as e: log(f"โš ๏ธ Silence detection error: {e}") return False def update(self, audio_chunk: bytes) -> Optional[int]: """Update silence detection and return silence duration in ms""" is_silent = self.is_silence(audio_chunk) if is_silent: if self.silence_start is None: self.silence_start = datetime.now() log("๐Ÿ”‡ Silence started") else: silence_duration = (datetime.now() - self.silence_start).total_seconds() * 1000 return int(silence_duration) else: if self.silence_start is not None: log("๐Ÿ”Š Speech detected, silence broken") self.silence_start = None return 0 class BargeInHandler: """Handle barge-in (interruption) logic""" def __init__(self): self.interrupted_at_state: Optional[ConversationState] = None self.accumulated_text: str = "" self.pending_audio_chunks = [] def handle_interruption(self, current_state: ConversationState): """Handle user interruption""" self.interrupted_at_state = current_state log(f"๐Ÿ›‘ Barge-in detected at state: {current_state.value}") def should_preserve_context(self) -> bool: """Check if context should be preserved after interruption""" # Preserve context if interrupted during LLM or TTS processing return self.interrupted_at_state in [ ConversationState.PROCESSING_LLM, ConversationState.PROCESSING_TTS, ConversationState.PLAYING_AUDIO ] class ConversationManager: """Manage conversation state and flow""" def __init__(self, session: Session): self.session = session self.state = ConversationState.IDLE self.audio_buffer = AudioBuffer() self.silence_detector = SilenceDetector() self.barge_in_handler = BargeInHandler() self.stt_manager = None self.current_transcription = "" self.is_streaming = False async def initialize_stt(self): """Initialize STT provider""" try: self.stt_manager = STTFactory.create_provider() if self.stt_manager: config = ConfigProvider.get().global_config.stt_settings await self.stt_manager.start_streaming({ "language": config.get("language", "tr-TR"), "interim_results": config.get("interim_results", True), "single_utterance": False, # Important for continuous listening "enable_punctuation": config.get("enable_punctuation", True) }) log("โœ… STT manager initialized") return True except Exception as e: log(f"โŒ Failed to initialize STT: {e}") return False def change_state(self, new_state: ConversationState): """Change conversation state""" old_state = self.state self.state = new_state log(f"๐Ÿ“Š State change: {old_state.value} โ†’ {new_state.value}") def handle_barge_in(self): """Handle user interruption""" self.barge_in_handler.handle_interruption(self.state) self.change_state(ConversationState.LISTENING) def reset_audio_buffer(self): """Reset audio buffer for new utterance""" self.audio_buffer.clear() self.silence_detector.silence_start = None self.current_transcription = "" # ========================= WEBSOCKET HANDLER ========================= async def websocket_endpoint(websocket: WebSocket, session_id: str): """Main WebSocket endpoint for real-time conversation""" await websocket.accept() log(f"๐Ÿ”Œ WebSocket connected for session: {session_id}") # Get session session = session_store.get_session(session_id) if not session: await websocket.send_json({ "type": "error", "message": "Session not found" }) await websocket.close() return # Initialize conversation manager conversation = ConversationManager(session) # Initialize STT stt_initialized = await conversation.initialize_stt() if not stt_initialized: await websocket.send_json({ "type": "error", "message": "STT initialization failed" }) try: while True: # Receive message message = await websocket.receive_json() message_type = message.get("type") if message_type == "audio_chunk": await handle_audio_chunk(websocket, conversation, message) elif message_type == "control": await handle_control_message(websocket, conversation, message) elif message_type == "ping": # Keep-alive ping await websocket.send_json({"type": "pong"}) except WebSocketDisconnect: log(f"๐Ÿ”Œ WebSocket disconnected for session: {session_id}") await cleanup_conversation(conversation) except Exception as e: log(f"โŒ WebSocket error: {e}") await websocket.send_json({ "type": "error", "message": str(e) }) await cleanup_conversation(conversation) # ========================= MESSAGE HANDLERS ========================= async def handle_audio_chunk(websocket: WebSocket, conversation: ConversationManager, message: Dict[str, Any]): """Handle incoming audio chunk""" try: audio_data = message.get("data") if not audio_data: return # Check for barge-in if conversation.state in [ConversationState.PLAYING_AUDIO, ConversationState.PROCESSING_TTS]: conversation.handle_barge_in() await websocket.send_json({ "type": "control", "action": "stop_playback" }) # Change state to listening if idle if conversation.state == ConversationState.IDLE: conversation.change_state(ConversationState.LISTENING) await websocket.send_json({ "type": "state_change", "from": "idle", "to": "listening" }) # Add to buffer conversation.audio_buffer.add_chunk(audio_data) # Decode for processing decoded_audio = base64.b64decode(audio_data) # Check silence silence_duration = conversation.silence_detector.update(decoded_audio) # Stream to STT if available if conversation.stt_manager and conversation.state == ConversationState.LISTENING: async for result in conversation.stt_manager.stream_audio(decoded_audio): # Send interim results await websocket.send_json({ "type": "transcription", "text": result.text, "is_final": result.is_final, "confidence": result.confidence }) if result.is_final: conversation.current_transcription = result.text # Check if user stopped speaking (2 seconds of silence) if silence_duration > SILENCE_THRESHOLD_MS and conversation.current_transcription: log(f"๐Ÿ”‡ User stopped speaking after {silence_duration}ms of silence") await process_user_input(websocket, conversation) except Exception as e: log(f"โŒ Audio chunk handling error: {e}") await websocket.send_json({ "type": "error", "message": f"Audio processing error: {str(e)}" }) async def handle_control_message(websocket: WebSocket, conversation: ConversationManager, message: Dict[str, Any]): """Handle control messages""" action = message.get("action") if action == "start_session": # Session already started await websocket.send_json({ "type": "session_started", "session_id": conversation.session.session_id }) elif action == "end_session": # Clean up and close await cleanup_conversation(conversation) await websocket.close() elif action == "interrupt": # Handle explicit interrupt conversation.handle_barge_in() await websocket.send_json({ "type": "control", "action": "interrupt_acknowledged" }) elif action == "reset": # Reset conversation state conversation.reset_audio_buffer() conversation.change_state(ConversationState.IDLE) await websocket.send_json({ "type": "state_change", "from": conversation.state.value, "to": "idle" }) # ========================= PROCESSING FUNCTIONS ========================= async def process_user_input(websocket: WebSocket, conversation: ConversationManager): """Process complete user input""" try: user_text = conversation.current_transcription if not user_text: conversation.reset_audio_buffer() conversation.change_state(ConversationState.IDLE) return log(f"๐Ÿ’ฌ Processing user input: {user_text}") # Change state to processing conversation.change_state(ConversationState.PROCESSING_STT) await websocket.send_json({ "type": "state_change", "from": "listening", "to": "processing_stt" }) # Send final transcription await websocket.send_json({ "type": "transcription", "text": user_text, "is_final": True, "confidence": 0.95 }) # Process with LLM conversation.change_state(ConversationState.PROCESSING_LLM) await websocket.send_json({ "type": "state_change", "from": "processing_stt", "to": "processing_llm" }) # Add to session history conversation.session.add_turn("user", user_text) # Get response based on session state if conversation.session.state == "await_param": response_text = await handle_parameter_followup(conversation.session, user_text) else: response_text = await handle_new_message(conversation.session, user_text) # Add response to history conversation.session.add_turn("assistant", response_text) # Send text response await websocket.send_json({ "type": "assistant_response", "text": response_text }) # Generate TTS if enabled tts_provider = TTSFactory.create_provider() if tts_provider: conversation.change_state(ConversationState.PROCESSING_TTS) await websocket.send_json({ "type": "state_change", "from": "processing_llm", "to": "processing_tts" }) # Generate audio audio_data = await tts_provider.synthesize(response_text) # Send audio in chunks chunk_size = 4096 for i in range(0, len(audio_data), chunk_size): chunk = audio_data[i:i + chunk_size] await websocket.send_json({ "type": "tts_audio", "data": base64.b64encode(chunk).decode('utf-8'), "chunk_index": i // chunk_size, "is_last": i + chunk_size >= len(audio_data) }) conversation.change_state(ConversationState.PLAYING_AUDIO) await websocket.send_json({ "type": "state_change", "from": "processing_tts", "to": "playing_audio" }) else: # No TTS, go back to idle conversation.change_state(ConversationState.IDLE) await websocket.send_json({ "type": "state_change", "from": "processing_llm", "to": "idle" }) # Reset for next input conversation.reset_audio_buffer() except Exception as e: log(f"โŒ Error processing user input: {e}") await websocket.send_json({ "type": "error", "message": f"Processing error: {str(e)}" }) conversation.reset_audio_buffer() conversation.change_state(ConversationState.IDLE) # ========================= CLEANUP ========================= async def cleanup_conversation(conversation: ConversationManager): """Clean up conversation resources""" try: if conversation.stt_manager: await conversation.stt_manager.stop_streaming() log(f"๐Ÿงน Cleaned up conversation for session: {conversation.session.session_id}") except Exception as e: log(f"โš ๏ธ Cleanup error: {e}")