import asyncio import numpy as np import ffmpeg from time import time, sleep import math import logging import traceback from datetime import timedelta from typing import List, Dict, Any from timed_objects import ASRToken from whisper_streaming_custom.whisper_online import online_factory from core import WhisperLiveKit # Set up logging once logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) def format_time(seconds: float) -> str: """Format seconds as HH:MM:SS.""" return str(timedelta(seconds=int(seconds))) class AudioProcessor: """ Processes audio streams for transcription and diarization. Handles audio processing, state management, and result formatting. """ def __init__(self): """Initialize the audio processor with configuration, models, and state.""" models = WhisperLiveKit() # Audio processing settings self.args = models.args self.sample_rate = 16000 self.channels = 1 self.samples_per_sec = int(self.sample_rate * self.args.min_chunk_size) self.bytes_per_sample = 2 self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz # State management self.tokens = [] self.buffer_transcription = "" self.buffer_diarization = "" self.full_transcription = "" self.end_buffer = 0 self.end_attributed_speaker = 0 self.lock = asyncio.Lock() self.beg_loop = time() self.sep = " " # Default separator self.last_response_content = "" # Models and processing self.asr = models.asr self.tokenizer = models.tokenizer self.diarization = models.diarization self.ffmpeg_process = self.start_ffmpeg_decoder() self.transcription_queue = asyncio.Queue() if self.args.transcription else None self.diarization_queue = asyncio.Queue() if self.args.diarization else None self.pcm_buffer = bytearray() # Initialize transcription engine if enabled if self.args.transcription: self.online = online_factory(self.args, models.asr, models.tokenizer) def convert_pcm_to_float(self, pcm_buffer): """Convert PCM buffer in s16le format to normalized NumPy array.""" return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0 def start_ffmpeg_decoder(self): """Start FFmpeg process for WebM to PCM conversion.""" return (ffmpeg.input("pipe:0", format="webm") .output("pipe:1", format="s16le", acodec="pcm_s16le", ac=self.channels, ar=str(self.sample_rate)) .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True)) async def restart_ffmpeg(self): """Restart the FFmpeg process after failure.""" if self.ffmpeg_process: try: self.ffmpeg_process.kill() await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait) except Exception as e: logger.warning(f"Error killing FFmpeg process: {e}") self.ffmpeg_process = self.start_ffmpeg_decoder() self.pcm_buffer = bytearray() async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep): """Thread-safe update of transcription with new data.""" async with self.lock: self.tokens.extend(new_tokens) self.buffer_transcription = buffer self.end_buffer = end_buffer self.full_transcription = full_transcription self.sep = sep async def update_diarization(self, end_attributed_speaker, buffer_diarization=""): """Thread-safe update of diarization with new data.""" async with self.lock: self.end_attributed_speaker = end_attributed_speaker if buffer_diarization: self.buffer_diarization = buffer_diarization async def add_dummy_token(self): """Placeholder token when no transcription is available.""" async with self.lock: current_time = time() - self.beg_loop self.tokens.append(ASRToken( start=current_time, end=current_time + 1, text=".", speaker=-1, is_dummy=True )) async def get_current_state(self): """Get current state.""" async with self.lock: current_time = time() # Calculate remaining times remaining_transcription = 0 if self.end_buffer > 0: remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2)) remaining_diarization = 0 if self.tokens: latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0) remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2)) return { "tokens": self.tokens.copy(), "buffer_transcription": self.buffer_transcription, "buffer_diarization": self.buffer_diarization, "end_buffer": self.end_buffer, "end_attributed_speaker": self.end_attributed_speaker, "sep": self.sep, "remaining_time_transcription": remaining_transcription, "remaining_time_diarization": remaining_diarization } async def reset(self): """Reset all state variables to initial values.""" async with self.lock: self.tokens = [] self.buffer_transcription = self.buffer_diarization = "" self.end_buffer = self.end_attributed_speaker = 0 self.full_transcription = self.last_response_content = "" self.beg_loop = time() async def ffmpeg_stdout_reader(self): """Read audio data from FFmpeg stdout and process it.""" loop = asyncio.get_event_loop() beg = time() while True: try: # Calculate buffer size based on elapsed time elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec buffer_size = max(int(32000 * elapsed_time), 4096) beg = time() # Read chunk with timeout try: chunk = await asyncio.wait_for( loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size), timeout=15.0 ) except asyncio.TimeoutError: logger.warning("FFmpeg read timeout. Restarting...") await self.restart_ffmpeg() beg = time() continue if not chunk: logger.info("FFmpeg stdout closed.") break self.pcm_buffer.extend(chunk) # Send to diarization if enabled if self.args.diarization and self.diarization_queue: await self.diarization_queue.put( self.convert_pcm_to_float(self.pcm_buffer).copy() ) # Process when we have enough data if len(self.pcm_buffer) >= self.bytes_per_sec: if len(self.pcm_buffer) > self.max_bytes_per_sec: logger.warning( f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. " f"Consider using a smaller model." ) # Process audio chunk pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec]) self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:] # Send to transcription if enabled if self.args.transcription and self.transcription_queue: await self.transcription_queue.put(pcm_array.copy()) # Sleep if no processing is happening if not self.args.transcription and not self.args.diarization: await asyncio.sleep(0.1) except Exception as e: logger.warning(f"Exception in ffmpeg_stdout_reader: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") break async def transcription_processor(self): """Process audio chunks for transcription.""" self.full_transcription = "" self.sep = self.online.asr.sep while True: try: pcm_array = await self.transcription_queue.get() logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.") # Process transcription self.online.insert_audio_chunk(pcm_array) new_tokens = self.online.process_iter() if new_tokens: self.full_transcription += self.sep.join([t.text for t in new_tokens]) # Get buffer information _buffer = self.online.get_buffer() buffer = _buffer.text end_buffer = _buffer.end if _buffer.end else ( new_tokens[-1].end if new_tokens else 0 ) # Avoid duplicating content if buffer in self.full_transcription: buffer = "" await self.update_transcription( new_tokens, buffer, end_buffer, self.full_transcription, self.sep ) except Exception as e: logger.warning(f"Exception in transcription_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") finally: self.transcription_queue.task_done() async def diarization_processor(self, diarization_obj): """Process audio chunks for speaker diarization.""" buffer_diarization = "" while True: try: pcm_array = await self.diarization_queue.get() # Process diarization await diarization_obj.diarize(pcm_array) # Get current state and update speakers state = await self.get_current_state() new_end = diarization_obj.assign_speakers_to_tokens( state["end_attributed_speaker"], state["tokens"] ) await self.update_diarization(new_end, buffer_diarization) except Exception as e: logger.warning(f"Exception in diarization_processor: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") finally: self.diarization_queue.task_done() async def results_formatter(self): """Format processing results for output.""" while True: try: # Get current state state = await self.get_current_state() tokens = state["tokens"] buffer_transcription = state["buffer_transcription"] buffer_diarization = state["buffer_diarization"] end_attributed_speaker = state["end_attributed_speaker"] sep = state["sep"] # Add dummy tokens if needed if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization: await self.add_dummy_token() sleep(0.5) state = await self.get_current_state() tokens = state["tokens"] # Format output previous_speaker = -1 lines = [] last_end_diarized = 0 undiarized_text = [] # Process each token for token in tokens: speaker = token.speaker # Handle diarization if self.args.diarization: if (speaker in [-1, 0]) and token.end >= end_attributed_speaker: undiarized_text.append(token.text) continue elif (speaker in [-1, 0]) and token.end < end_attributed_speaker: speaker = previous_speaker if speaker not in [-1, 0]: last_end_diarized = max(token.end, last_end_diarized) # Group by speaker if speaker != previous_speaker or not lines: lines.append({ "speaker": speaker, "text": token.text, "beg": format_time(token.start), "end": format_time(token.end), "diff": round(token.end - last_end_diarized, 2) }) previous_speaker = speaker elif token.text: # Only append if text isn't empty lines[-1]["text"] += sep + token.text lines[-1]["end"] = format_time(token.end) lines[-1]["diff"] = round(token.end - last_end_diarized, 2) # Handle undiarized text if undiarized_text: combined = sep.join(undiarized_text) if buffer_transcription: combined += sep await self.update_diarization(end_attributed_speaker, combined) buffer_diarization = combined # Create response object if not lines: lines = [{ "speaker": 1, "text": "", "beg": format_time(0), "end": format_time(tokens[-1].end if tokens else 0), "diff": 0 }] response = { "lines": lines, "buffer_transcription": buffer_transcription, "buffer_diarization": buffer_diarization, "remaining_time_transcription": state["remaining_time_transcription"], "remaining_time_diarization": state["remaining_time_diarization"] } # Only yield if content has changed response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \ f" | {buffer_transcription} | {buffer_diarization}" if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization): yield response self.last_response_content = response_content await asyncio.sleep(0.1) # Avoid overwhelming the client except Exception as e: logger.warning(f"Exception in results_formatter: {e}") logger.warning(f"Traceback: {traceback.format_exc()}") await asyncio.sleep(0.5) # Back off on error async def create_tasks(self): """Create and start processing tasks.""" tasks = [] if self.args.transcription and self.online: tasks.append(asyncio.create_task(self.transcription_processor())) if self.args.diarization and self.diarization: tasks.append(asyncio.create_task(self.diarization_processor(self.diarization))) tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader())) self.tasks = tasks return self.results_formatter() async def cleanup(self): """Clean up resources when processing is complete.""" for task in self.tasks: task.cancel() try: await asyncio.gather(*self.tasks, return_exceptions=True) self.ffmpeg_process.stdin.close() self.ffmpeg_process.wait() except Exception as e: logger.warning(f"Error during cleanup: {e}") if self.args.diarization and hasattr(self, 'diarization'): self.diarization.close() async def process_audio(self, message): """Process incoming audio data.""" try: self.ffmpeg_process.stdin.write(message) self.ffmpeg_process.stdin.flush() except (BrokenPipeError, AttributeError) as e: logger.warning(f"Error writing to FFmpeg: {e}. Restarting...") await self.restart_ffmpeg() self.ffmpeg_process.stdin.write(message) self.ffmpeg_process.stdin.flush()