AnyaSchen commited on
Commit
72277b5
·
1 Parent(s): 9b29b01

Add application file

Browse files
Dockerfile ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.10-slim
5
+
6
+ # Install system dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ build-essential \
9
+ git \
10
+ curl \
11
+ ffmpeg \
12
+ && rm -rf /var/lib/apt/lists/* \
13
+ && apt-get clean
14
+
15
+ # Set working directory
16
+ WORKDIR /app
17
+
18
+ # Create static directory and set permissions
19
+ RUN mkdir -p /app/static && chmod 777 /app/static
20
+
21
+ # Create a non-root user
22
+ RUN useradd -m -u 1000 user
23
+ USER user
24
+ ENV PATH="/home/user/.local/bin:$PATH"
25
+
26
+ # Copy requirements first to leverage Docker cache
27
+ COPY --chown=user requirements.txt .
28
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
29
+
30
+ # Copy application files
31
+ COPY --chown=user *.py /app/
32
+ COPY --chown=user whisper_streaming_custom /app/whisper_streaming_custom/
33
+ COPY --chown=user diarization /app/diarization/
34
+ COPY --chown=user static /app/static/
35
+
36
+ # Set environment variables
37
+ ENV PYTHONPATH=/app
38
+ ENV PYTHONUNBUFFERED=1
39
+
40
+ # Expose the port the server runs on
41
+ EXPOSE 7860
42
+
43
+ # Run the server using main.py
44
+ CMD ["python", "main.py", "--host", "0.0.0.0", "--port", "7860", "--model", "tiny", "--backend", "faster-whisper", "--task", "transcribe"]
__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .core import WhisperLiveKit, parse_args
2
+ from .audio_processor import AudioProcessor
3
+
4
+ __all__ = ['WhisperLiveKit', 'AudioProcessor', 'parse_args']
audio_processor.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import numpy as np
3
+ import ffmpeg
4
+ from time import time, sleep
5
+ import math
6
+ import logging
7
+ import traceback
8
+ from datetime import timedelta
9
+ from typing import List, Dict, Any
10
+ from timed_objects import ASRToken
11
+ from whisper_streaming_custom.whisper_online import online_factory
12
+ from core import WhisperLiveKit
13
+
14
+ # Set up logging once
15
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
16
+ logger = logging.getLogger(__name__)
17
+ logger.setLevel(logging.DEBUG)
18
+
19
+ def format_time(seconds: float) -> str:
20
+ """Format seconds as HH:MM:SS."""
21
+ return str(timedelta(seconds=int(seconds)))
22
+
23
+ class AudioProcessor:
24
+ """
25
+ Processes audio streams for transcription and diarization.
26
+ Handles audio processing, state management, and result formatting.
27
+ """
28
+
29
+ def __init__(self):
30
+ """Initialize the audio processor with configuration, models, and state."""
31
+
32
+ models = WhisperLiveKit()
33
+
34
+ # Audio processing settings
35
+ self.args = models.args
36
+ self.sample_rate = 16000
37
+ self.channels = 1
38
+ self.samples_per_sec = int(self.sample_rate * self.args.min_chunk_size)
39
+ self.bytes_per_sample = 2
40
+ self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
41
+ self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
42
+
43
+ # State management
44
+ self.tokens = []
45
+ self.buffer_transcription = ""
46
+ self.buffer_diarization = ""
47
+ self.full_transcription = ""
48
+ self.end_buffer = 0
49
+ self.end_attributed_speaker = 0
50
+ self.lock = asyncio.Lock()
51
+ self.beg_loop = time()
52
+ self.sep = " " # Default separator
53
+ self.last_response_content = ""
54
+
55
+ # Models and processing
56
+ self.asr = models.asr
57
+ self.tokenizer = models.tokenizer
58
+ self.diarization = models.diarization
59
+ self.ffmpeg_process = self.start_ffmpeg_decoder()
60
+ self.transcription_queue = asyncio.Queue() if self.args.transcription else None
61
+ self.diarization_queue = asyncio.Queue() if self.args.diarization else None
62
+ self.pcm_buffer = bytearray()
63
+
64
+ # Initialize transcription engine if enabled
65
+ if self.args.transcription:
66
+ self.online = online_factory(self.args, models.asr, models.tokenizer)
67
+
68
+ def convert_pcm_to_float(self, pcm_buffer):
69
+ """Convert PCM buffer in s16le format to normalized NumPy array."""
70
+ return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
71
+
72
+ def start_ffmpeg_decoder(self):
73
+ """Start FFmpeg process for WebM to PCM conversion."""
74
+ return (ffmpeg.input("pipe:0", format="webm")
75
+ .output("pipe:1", format="s16le", acodec="pcm_s16le",
76
+ ac=self.channels, ar=str(self.sample_rate))
77
+ .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
78
+
79
+ async def restart_ffmpeg(self):
80
+ """Restart the FFmpeg process after failure."""
81
+ if self.ffmpeg_process:
82
+ try:
83
+ self.ffmpeg_process.kill()
84
+ await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
85
+ except Exception as e:
86
+ logger.warning(f"Error killing FFmpeg process: {e}")
87
+ self.ffmpeg_process = self.start_ffmpeg_decoder()
88
+ self.pcm_buffer = bytearray()
89
+
90
+ async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
91
+ """Thread-safe update of transcription with new data."""
92
+ async with self.lock:
93
+ self.tokens.extend(new_tokens)
94
+ self.buffer_transcription = buffer
95
+ self.end_buffer = end_buffer
96
+ self.full_transcription = full_transcription
97
+ self.sep = sep
98
+
99
+ async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
100
+ """Thread-safe update of diarization with new data."""
101
+ async with self.lock:
102
+ self.end_attributed_speaker = end_attributed_speaker
103
+ if buffer_diarization:
104
+ self.buffer_diarization = buffer_diarization
105
+
106
+ async def add_dummy_token(self):
107
+ """Placeholder token when no transcription is available."""
108
+ async with self.lock:
109
+ current_time = time() - self.beg_loop
110
+ self.tokens.append(ASRToken(
111
+ start=current_time, end=current_time + 1,
112
+ text=".", speaker=-1, is_dummy=True
113
+ ))
114
+
115
+ async def get_current_state(self):
116
+ """Get current state."""
117
+ async with self.lock:
118
+ current_time = time()
119
+
120
+ # Calculate remaining times
121
+ remaining_transcription = 0
122
+ if self.end_buffer > 0:
123
+ remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
124
+
125
+ remaining_diarization = 0
126
+ if self.tokens:
127
+ latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
128
+ remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
129
+
130
+ return {
131
+ "tokens": self.tokens.copy(),
132
+ "buffer_transcription": self.buffer_transcription,
133
+ "buffer_diarization": self.buffer_diarization,
134
+ "end_buffer": self.end_buffer,
135
+ "end_attributed_speaker": self.end_attributed_speaker,
136
+ "sep": self.sep,
137
+ "remaining_time_transcription": remaining_transcription,
138
+ "remaining_time_diarization": remaining_diarization
139
+ }
140
+
141
+ async def reset(self):
142
+ """Reset all state variables to initial values."""
143
+ async with self.lock:
144
+ self.tokens = []
145
+ self.buffer_transcription = self.buffer_diarization = ""
146
+ self.end_buffer = self.end_attributed_speaker = 0
147
+ self.full_transcription = self.last_response_content = ""
148
+ self.beg_loop = time()
149
+
150
+ async def ffmpeg_stdout_reader(self):
151
+ """Read audio data from FFmpeg stdout and process it."""
152
+ loop = asyncio.get_event_loop()
153
+ beg = time()
154
+
155
+ while True:
156
+ try:
157
+ # Calculate buffer size based on elapsed time
158
+ elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
159
+ buffer_size = max(int(32000 * elapsed_time), 4096)
160
+ beg = time()
161
+
162
+ # Read chunk with timeout
163
+ try:
164
+ chunk = await asyncio.wait_for(
165
+ loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size),
166
+ timeout=15.0
167
+ )
168
+ except asyncio.TimeoutError:
169
+ logger.warning("FFmpeg read timeout. Restarting...")
170
+ await self.restart_ffmpeg()
171
+ beg = time()
172
+ continue
173
+
174
+ if not chunk:
175
+ logger.info("FFmpeg stdout closed.")
176
+ break
177
+
178
+ self.pcm_buffer.extend(chunk)
179
+
180
+ # Send to diarization if enabled
181
+ if self.args.diarization and self.diarization_queue:
182
+ await self.diarization_queue.put(
183
+ self.convert_pcm_to_float(self.pcm_buffer).copy()
184
+ )
185
+
186
+ # Process when we have enough data
187
+ if len(self.pcm_buffer) >= self.bytes_per_sec:
188
+ if len(self.pcm_buffer) > self.max_bytes_per_sec:
189
+ logger.warning(
190
+ f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
191
+ f"Consider using a smaller model."
192
+ )
193
+
194
+ # Process audio chunk
195
+ pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
196
+ self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
197
+
198
+ # Send to transcription if enabled
199
+ if self.args.transcription and self.transcription_queue:
200
+ await self.transcription_queue.put(pcm_array.copy())
201
+
202
+ # Sleep if no processing is happening
203
+ if not self.args.transcription and not self.args.diarization:
204
+ await asyncio.sleep(0.1)
205
+
206
+ except Exception as e:
207
+ logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
208
+ logger.warning(f"Traceback: {traceback.format_exc()}")
209
+ break
210
+
211
+ async def transcription_processor(self):
212
+ """Process audio chunks for transcription."""
213
+ self.full_transcription = ""
214
+ self.sep = self.online.asr.sep
215
+
216
+ while True:
217
+ try:
218
+ pcm_array = await self.transcription_queue.get()
219
+
220
+ logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.")
221
+
222
+ # Process transcription
223
+ self.online.insert_audio_chunk(pcm_array)
224
+ new_tokens = self.online.process_iter()
225
+
226
+ if new_tokens:
227
+ self.full_transcription += self.sep.join([t.text for t in new_tokens])
228
+
229
+ # Get buffer information
230
+ _buffer = self.online.get_buffer()
231
+ buffer = _buffer.text
232
+ end_buffer = _buffer.end if _buffer.end else (
233
+ new_tokens[-1].end if new_tokens else 0
234
+ )
235
+
236
+ # Avoid duplicating content
237
+ if buffer in self.full_transcription:
238
+ buffer = ""
239
+
240
+ await self.update_transcription(
241
+ new_tokens, buffer, end_buffer, self.full_transcription, self.sep
242
+ )
243
+
244
+ except Exception as e:
245
+ logger.warning(f"Exception in transcription_processor: {e}")
246
+ logger.warning(f"Traceback: {traceback.format_exc()}")
247
+ finally:
248
+ self.transcription_queue.task_done()
249
+
250
+ async def diarization_processor(self, diarization_obj):
251
+ """Process audio chunks for speaker diarization."""
252
+ buffer_diarization = ""
253
+
254
+ while True:
255
+ try:
256
+ pcm_array = await self.diarization_queue.get()
257
+
258
+ # Process diarization
259
+ await diarization_obj.diarize(pcm_array)
260
+
261
+ # Get current state and update speakers
262
+ state = await self.get_current_state()
263
+ new_end = diarization_obj.assign_speakers_to_tokens(
264
+ state["end_attributed_speaker"], state["tokens"]
265
+ )
266
+
267
+ await self.update_diarization(new_end, buffer_diarization)
268
+
269
+ except Exception as e:
270
+ logger.warning(f"Exception in diarization_processor: {e}")
271
+ logger.warning(f"Traceback: {traceback.format_exc()}")
272
+ finally:
273
+ self.diarization_queue.task_done()
274
+
275
+ async def results_formatter(self):
276
+ """Format processing results for output."""
277
+ while True:
278
+ try:
279
+ # Get current state
280
+ state = await self.get_current_state()
281
+ tokens = state["tokens"]
282
+ buffer_transcription = state["buffer_transcription"]
283
+ buffer_diarization = state["buffer_diarization"]
284
+ end_attributed_speaker = state["end_attributed_speaker"]
285
+ sep = state["sep"]
286
+
287
+ # Add dummy tokens if needed
288
+ if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
289
+ await self.add_dummy_token()
290
+ sleep(0.5)
291
+ state = await self.get_current_state()
292
+ tokens = state["tokens"]
293
+
294
+ # Format output
295
+ previous_speaker = -1
296
+ lines = []
297
+ last_end_diarized = 0
298
+ undiarized_text = []
299
+
300
+ # Process each token
301
+ for token in tokens:
302
+ speaker = token.speaker
303
+
304
+ # Handle diarization
305
+ if self.args.diarization:
306
+ if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
307
+ undiarized_text.append(token.text)
308
+ continue
309
+ elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
310
+ speaker = previous_speaker
311
+ if speaker not in [-1, 0]:
312
+ last_end_diarized = max(token.end, last_end_diarized)
313
+
314
+ # Group by speaker
315
+ if speaker != previous_speaker or not lines:
316
+ lines.append({
317
+ "speaker": speaker,
318
+ "text": token.text,
319
+ "beg": format_time(token.start),
320
+ "end": format_time(token.end),
321
+ "diff": round(token.end - last_end_diarized, 2)
322
+ })
323
+ previous_speaker = speaker
324
+ elif token.text: # Only append if text isn't empty
325
+ lines[-1]["text"] += sep + token.text
326
+ lines[-1]["end"] = format_time(token.end)
327
+ lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
328
+
329
+ # Handle undiarized text
330
+ if undiarized_text:
331
+ combined = sep.join(undiarized_text)
332
+ if buffer_transcription:
333
+ combined += sep
334
+ await self.update_diarization(end_attributed_speaker, combined)
335
+ buffer_diarization = combined
336
+
337
+ # Create response object
338
+ if not lines:
339
+ lines = [{
340
+ "speaker": 1,
341
+ "text": "",
342
+ "beg": format_time(0),
343
+ "end": format_time(tokens[-1].end if tokens else 0),
344
+ "diff": 0
345
+ }]
346
+
347
+ response = {
348
+ "lines": lines,
349
+ "buffer_transcription": buffer_transcription,
350
+ "buffer_diarization": buffer_diarization,
351
+ "remaining_time_transcription": state["remaining_time_transcription"],
352
+ "remaining_time_diarization": state["remaining_time_diarization"]
353
+ }
354
+
355
+ # Only yield if content has changed
356
+ response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \
357
+ f" | {buffer_transcription} | {buffer_diarization}"
358
+
359
+ if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization):
360
+ yield response
361
+ self.last_response_content = response_content
362
+
363
+ await asyncio.sleep(0.1) # Avoid overwhelming the client
364
+
365
+ except Exception as e:
366
+ logger.warning(f"Exception in results_formatter: {e}")
367
+ logger.warning(f"Traceback: {traceback.format_exc()}")
368
+ await asyncio.sleep(0.5) # Back off on error
369
+
370
+ async def create_tasks(self):
371
+ """Create and start processing tasks."""
372
+
373
+ tasks = []
374
+ if self.args.transcription and self.online:
375
+ tasks.append(asyncio.create_task(self.transcription_processor()))
376
+
377
+ if self.args.diarization and self.diarization:
378
+ tasks.append(asyncio.create_task(self.diarization_processor(self.diarization)))
379
+
380
+ tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader()))
381
+ self.tasks = tasks
382
+
383
+ return self.results_formatter()
384
+
385
+ async def cleanup(self):
386
+ """Clean up resources when processing is complete."""
387
+ for task in self.tasks:
388
+ task.cancel()
389
+
390
+ try:
391
+ await asyncio.gather(*self.tasks, return_exceptions=True)
392
+ self.ffmpeg_process.stdin.close()
393
+ self.ffmpeg_process.wait()
394
+ except Exception as e:
395
+ logger.warning(f"Error during cleanup: {e}")
396
+
397
+ if self.args.diarization and hasattr(self, 'diarization'):
398
+ self.diarization.close()
399
+
400
+ async def process_audio(self, message):
401
+ """Process incoming audio data."""
402
+ try:
403
+ self.ffmpeg_process.stdin.write(message)
404
+ self.ffmpeg_process.stdin.flush()
405
+ except (BrokenPipeError, AttributeError) as e:
406
+ logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
407
+ await self.restart_ffmpeg()
408
+ self.ffmpeg_process.stdin.write(message)
409
+ self.ffmpeg_process.stdin.flush()
core.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
2
+ from argparse import Namespace, ArgumentParser
3
+
4
+ class WhisperLiveKit:
5
+ _instance = None
6
+ _initialized = False
7
+
8
+ def __new__(cls, *args, **kwargs):
9
+ if cls._instance is None:
10
+ cls._instance = super().__new__(cls)
11
+ return cls._instance
12
+
13
+ def __init__(self, args=None, **kwargs):
14
+ if WhisperLiveKit._initialized:
15
+ return
16
+
17
+ if args is None:
18
+ args = Namespace(
19
+ host="localhost",
20
+ port=8000,
21
+ warmup_file=None,
22
+ confidence_validation=False,
23
+ diarization=False,
24
+ transcription=True,
25
+ min_chunk_size=0.5,
26
+ model="base",
27
+ model_cache_dir=None,
28
+ model_dir=None,
29
+ lan="auto",
30
+ task="transcribe",
31
+ backend="faster-whisper",
32
+ vac=False,
33
+ vac_chunk_size=0.04,
34
+ vad=True,
35
+ buffer_trimming="sentence",
36
+ buffer_trimming_sec=1.0,
37
+ log_level="INFO"
38
+ )
39
+
40
+ self.args = args
41
+
42
+ self.asr = None
43
+ self.tokenizer = None
44
+ self.diarization = None
45
+
46
+ if self.args.transcription:
47
+ self.asr, self.tokenizer = backend_factory(self.args)
48
+ warmup_asr(self.asr, self.args.warmup_file)
49
+
50
+ if self.args.diarization:
51
+ from diarization.diarization_online import DiartDiarization
52
+ self.diarization = DiartDiarization()
53
+
54
+ WhisperLiveKit._initialized = True
55
+
56
+ def web_interface(self):
57
+ import pkg_resources
58
+ html_path = pkg_resources.resource_filename('whisperlivekit', 'web/live_transcription.html')
59
+ with open(html_path, "r", encoding="utf-8") as f:
60
+ html = f.read()
61
+ return html
diarization/__init__.py ADDED
File without changes
diarization/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (172 Bytes). View file
 
diarization/__pycache__/diarization_online.cpython-310.pyc ADDED
Binary file (6.63 kB). View file
 
diarization/diarization_online.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import re
3
+ import threading
4
+ import numpy as np
5
+ import logging
6
+
7
+
8
+ from diart import SpeakerDiarization, SpeakerDiarizationConfig
9
+ from diart.inference import StreamingInference
10
+ from diart.sources import AudioSource
11
+ from timed_objects import SpeakerSegment
12
+ from diart.sources import MicrophoneAudioSource
13
+ from rx.core import Observer
14
+ from typing import Tuple, Any, List
15
+ from pyannote.core import Annotation
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ def extract_number(s: str) -> int:
20
+ m = re.search(r'\d+', s)
21
+ return int(m.group()) if m else None
22
+
23
+ class DiarizationObserver(Observer):
24
+ """Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
25
+
26
+ def __init__(self):
27
+ self.speaker_segments = []
28
+ self.processed_time = 0
29
+ self.segment_lock = threading.Lock()
30
+
31
+ def on_next(self, value: Tuple[Annotation, Any]):
32
+ annotation, audio = value
33
+
34
+ logger.debug("\n--- New Diarization Result ---")
35
+
36
+ duration = audio.extent.end - audio.extent.start
37
+ logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
38
+ logger.debug(f"Audio shape: {audio.data.shape}")
39
+
40
+ with self.segment_lock:
41
+ if audio.extent.end > self.processed_time:
42
+ self.processed_time = audio.extent.end
43
+ if annotation and len(annotation._labels) > 0:
44
+ logger.debug("\nSpeaker segments:")
45
+ for speaker, label in annotation._labels.items():
46
+ for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
47
+ print(f" {speaker}: {start:.2f}s-{end:.2f}s")
48
+ self.speaker_segments.append(SpeakerSegment(
49
+ speaker=speaker,
50
+ start=start,
51
+ end=end
52
+ ))
53
+ else:
54
+ logger.debug("\nNo speakers detected in this segment")
55
+
56
+ def get_segments(self) -> List[SpeakerSegment]:
57
+ """Get a copy of the current speaker segments."""
58
+ with self.segment_lock:
59
+ return self.speaker_segments.copy()
60
+
61
+ def clear_old_segments(self, older_than: float = 30.0):
62
+ """Clear segments older than the specified time."""
63
+ with self.segment_lock:
64
+ current_time = self.processed_time
65
+ self.speaker_segments = [
66
+ segment for segment in self.speaker_segments
67
+ if current_time - segment.end < older_than
68
+ ]
69
+
70
+ def on_error(self, error):
71
+ """Handle an error in the stream."""
72
+ logger.debug(f"Error in diarization stream: {error}")
73
+
74
+ def on_completed(self):
75
+ """Handle the completion of the stream."""
76
+ logger.debug("Diarization stream completed")
77
+
78
+
79
+ class WebSocketAudioSource(AudioSource):
80
+ """
81
+ Custom AudioSource that blocks in read() until close() is called.
82
+ Use push_audio() to inject PCM chunks.
83
+ """
84
+ def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
85
+ super().__init__(uri, sample_rate)
86
+ self._closed = False
87
+ self._close_event = threading.Event()
88
+
89
+ def read(self):
90
+ self._close_event.wait()
91
+
92
+ def close(self):
93
+ if not self._closed:
94
+ self._closed = True
95
+ self.stream.on_completed()
96
+ self._close_event.set()
97
+
98
+ def push_audio(self, chunk: np.ndarray):
99
+ if not self._closed:
100
+ new_audio = np.expand_dims(chunk, axis=0)
101
+ logger.debug('Add new chunk with shape:', new_audio.shape)
102
+ self.stream.on_next(new_audio)
103
+
104
+
105
+ class DiartDiarization:
106
+ def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
107
+ self.pipeline = SpeakerDiarization(config=config)
108
+ self.observer = DiarizationObserver()
109
+
110
+ if use_microphone:
111
+ self.source = MicrophoneAudioSource()
112
+ self.custom_source = None
113
+ else:
114
+ self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
115
+ self.source = self.custom_source
116
+
117
+ self.inference = StreamingInference(
118
+ pipeline=self.pipeline,
119
+ source=self.source,
120
+ do_plot=False,
121
+ show_progress=False,
122
+ )
123
+ self.inference.attach_observers(self.observer)
124
+ asyncio.get_event_loop().run_in_executor(None, self.inference)
125
+
126
+ async def diarize(self, pcm_array: np.ndarray):
127
+ """
128
+ Process audio data for diarization.
129
+ Only used when working with WebSocketAudioSource.
130
+ """
131
+ if self.custom_source:
132
+ self.custom_source.push_audio(pcm_array)
133
+ self.observer.clear_old_segments()
134
+ return self.observer.get_segments()
135
+
136
+ def close(self):
137
+ """Close the audio source."""
138
+ if self.custom_source:
139
+ self.custom_source.close()
140
+
141
+ def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
142
+ """
143
+ Assign speakers to tokens based on timing overlap with speaker segments.
144
+ Uses the segments collected by the observer.
145
+ """
146
+ segments = self.observer.get_segments()
147
+
148
+ for token in tokens:
149
+ for segment in segments:
150
+ if not (segment.end <= token.start or segment.start >= token.end):
151
+ token.speaker = extract_number(segment.speaker) + 1
152
+ end_attributed_speaker = max(token.end, end_attributed_speaker)
153
+ return end_attributed_speaker
docker-compose.yml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ websocket-server:
5
+ build:
6
+ context: .
7
+ dockerfile: Dockerfile
8
+ ports:
9
+ - "8000:8000"
10
+ environment:
11
+ - PYTHONUNBUFFERED=1
12
+ volumes:
13
+ - .:/app
14
+ networks:
15
+ - app-network
16
+ healthcheck:
17
+ test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
18
+ interval: 30s
19
+ timeout: 10s
20
+ retries: 3
21
+
22
+ frontend:
23
+ build:
24
+ context: ./frontend
25
+ dockerfile: Dockerfile
26
+ ports:
27
+ - "80:80"
28
+ depends_on:
29
+ websocket-server:
30
+ condition: service_healthy
31
+ networks:
32
+ - app-network
33
+
34
+ networks:
35
+ app-network:
36
+ driver: bridge
main.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse
5
+ from fastapi.staticfiles import StaticFiles
6
+ from fastapi.responses import FileResponse
7
+ import asyncio
8
+ import logging
9
+ import os
10
+ import traceback
11
+ import argparse
12
+ import uvicorn
13
+
14
+ from core import WhisperLiveKit
15
+ from audio_processor import AudioProcessor
16
+
17
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
18
+ logging.getLogger().setLevel(logging.WARNING)
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.DEBUG)
21
+
22
+ kit = None
23
+
24
+ @asynccontextmanager
25
+ async def lifespan(app: FastAPI):
26
+ global kit
27
+ kit = WhisperLiveKit()
28
+ yield
29
+
30
+ app = FastAPI(lifespan=lifespan)
31
+
32
+ app.add_middleware(
33
+ CORSMiddleware,
34
+ allow_origins=["*"], # Allows all origins
35
+ allow_credentials=True,
36
+ allow_methods=["*"], # Allows all methods
37
+ allow_headers=["*"], # Allows all headers
38
+ )
39
+
40
+ # Mount static files
41
+ app.mount("/static", StaticFiles(directory="static"), name="static")
42
+
43
+
44
+
45
+ @app.get("/")
46
+ async def read_root():
47
+ return FileResponse("static/index.html")
48
+
49
+ @app.get("/health")
50
+ async def health_check():
51
+ return JSONResponse({"status": "healthy"})
52
+
53
+ async def handle_websocket_results(websocket, results_generator):
54
+ """Consumes results from the audio processor and sends them via WebSocket."""
55
+ try:
56
+ async for response in results_generator:
57
+ try:
58
+ logger.debug(f"Sending response: {response}")
59
+ if isinstance(response, dict):
60
+ # Ensure the response has a consistent format
61
+ if 'buffer_transcription' in response:
62
+ await websocket.send_json({
63
+ 'buffer_transcription': response['buffer_transcription']
64
+ })
65
+ elif 'full_transcription' in response:
66
+ await websocket.send_json({
67
+ 'full_transcription': response['full_transcription']
68
+ })
69
+ else:
70
+ await websocket.send_json(response)
71
+ else:
72
+ # If response is not a dict, wrap it in a text field
73
+ await websocket.send_json({"text": str(response)})
74
+ except Exception as e:
75
+ logger.error(f"Error sending message: {e}")
76
+ logger.error(f"Traceback: {traceback.format_exc()}")
77
+ raise
78
+ except Exception as e:
79
+ logger.warning(f"Error in WebSocket results handler: {e}")
80
+ logger.warning(f"Traceback: {traceback.format_exc()}")
81
+
82
+ @app.websocket("/asr")
83
+ async def websocket_endpoint(websocket: WebSocket):
84
+ logger.info("New WebSocket connection request")
85
+ audio_processor = None
86
+ websocket_task = None
87
+
88
+ try:
89
+ await websocket.accept()
90
+ logger.info("WebSocket connection accepted")
91
+
92
+ audio_processor = AudioProcessor()
93
+ results_generator = await audio_processor.create_tasks()
94
+ websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
95
+
96
+ while True:
97
+ try:
98
+ message = await websocket.receive_bytes()
99
+ logger.debug(f"Received audio chunk of size: {len(message)}")
100
+ await audio_processor.process_audio(message)
101
+ except WebSocketDisconnect:
102
+ logger.warning("WebSocket disconnected.")
103
+ break
104
+ except Exception as e:
105
+ logger.error(f"Error processing audio chunk: {e}")
106
+ logger.error(f"Traceback: {traceback.format_exc()}")
107
+ break
108
+
109
+ except WebSocketDisconnect:
110
+ logger.warning("WebSocket disconnected during setup.")
111
+ except Exception as e:
112
+ logger.error(f"Error in WebSocket endpoint: {e}")
113
+ logger.error(f"Traceback: {traceback.format_exc()}")
114
+ finally:
115
+ if websocket_task:
116
+ websocket_task.cancel()
117
+ try:
118
+ await websocket_task
119
+ except asyncio.CancelledError:
120
+ pass
121
+ if audio_processor:
122
+ await audio_processor.cleanup()
123
+ logger.info("WebSocket endpoint cleaned up.")
124
+
125
+ def parse_args():
126
+ parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
127
+ parser.add_argument(
128
+ "--host",
129
+ type=str,
130
+ default="localhost",
131
+ help="The host address to bind the server to.",
132
+ )
133
+ parser.add_argument(
134
+ "--port", type=int, default=8000, help="The port number to bind the server to."
135
+ )
136
+ parser.add_argument(
137
+ "--warmup-file",
138
+ type=str,
139
+ default=None,
140
+ dest="warmup_file",
141
+ help="""
142
+ The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
143
+ If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
144
+ If False, no warmup is performed.
145
+ """,
146
+ )
147
+
148
+ parser.add_argument(
149
+ "--confidence-validation",
150
+ action="store_true",
151
+ help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
152
+ )
153
+
154
+ parser.add_argument(
155
+ "--diarization",
156
+ action="store_true",
157
+ default=False,
158
+ help="Enable speaker diarization.",
159
+ )
160
+
161
+ parser.add_argument(
162
+ "--no-transcription",
163
+ action="store_true",
164
+ help="Disable transcription to only see live diarization results.",
165
+ )
166
+
167
+ parser.add_argument(
168
+ "--min-chunk-size",
169
+ type=float,
170
+ default=0.5,
171
+ help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
172
+ )
173
+
174
+ parser.add_argument(
175
+ "--model",
176
+ type=str,
177
+ default="tiny",
178
+ help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
179
+ )
180
+
181
+ parser.add_argument(
182
+ "--model_cache_dir",
183
+ type=str,
184
+ default=None,
185
+ help="Overriding the default model cache dir where models downloaded from the hub are saved",
186
+ )
187
+ parser.add_argument(
188
+ "--model_dir",
189
+ type=str,
190
+ default=None,
191
+ help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
192
+ )
193
+ parser.add_argument(
194
+ "--lan",
195
+ "--language",
196
+ type=str,
197
+ default="en",
198
+ help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
199
+ )
200
+ parser.add_argument(
201
+ "--task",
202
+ type=str,
203
+ default="transcribe",
204
+ choices=["transcribe", "translate"],
205
+ help="Transcribe or translate.",
206
+ )
207
+ parser.add_argument(
208
+ "--backend",
209
+ type=str,
210
+ default="faster-whisper",
211
+ choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
212
+ help="Load only this backend for Whisper processing.",
213
+ )
214
+ parser.add_argument(
215
+ "--vac",
216
+ action="store_true",
217
+ default=False,
218
+ help="Use VAC = voice activity controller. Recommended. Requires torch.",
219
+ )
220
+ parser.add_argument(
221
+ "--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
222
+ )
223
+ parser.add_argument(
224
+ "--no-vad",
225
+ action="store_true",
226
+ default=False,
227
+ help="Disable VAD = voice activity detection. Not recommended.",
228
+ )
229
+ parser.add_argument(
230
+ "--buffer_trimming",
231
+ type=str,
232
+ default="sentence",
233
+ choices=["sentence", "segment"],
234
+ help="Buffer trimming strategy.",
235
+ )
236
+ parser.add_argument(
237
+ "--buffer_trimming_sec",
238
+ type=float,
239
+ default=1.0,
240
+ help="Buffer trimming length in seconds.",
241
+ )
242
+ parser.add_argument(
243
+ "-l",
244
+ "--log-level",
245
+ type=str,
246
+ default="INFO",
247
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
248
+ help="Set the logging level.",
249
+ )
250
+
251
+ args = parser.parse_args()
252
+
253
+ args.transcription = not args.no_transcription
254
+ args.vad = not args.no_vad
255
+ delattr(args, 'no_transcription')
256
+ delattr(args, 'no_vad')
257
+
258
+ return args
259
+
260
+ def main():
261
+ args = parse_args()
262
+
263
+ # Initialize WhisperLiveKit with parsed arguments
264
+ kit = WhisperLiveKit(args=args)
265
+
266
+ # Start the server
267
+ uvicorn.run(
268
+ "main:app",
269
+ host=args.host,
270
+ port=args.port,
271
+ log_level=args.log_level.lower()
272
+ )
273
+
274
+ if __name__ == "__main__":
275
+ main()
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.68.0
2
+ uvicorn[standart]
3
+ python-multipart>=0.0.5
4
+ numpy>=1.21.0
5
+ ffmpeg-python>=0.2.0
6
+ torch>=2.0.0
7
+ torchaudio>=2.0.0
8
+ faster-whisper>=0.9.0
9
+ websockets>=10.0
10
+ pydantic>=1.8.0
11
+ python-dotenv>=0.19.0
12
+ setuptools>=65.5.1
13
+ librosa>=0.10.0
14
+ mosestokenizer
15
+ hf_xet
setup.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="whisperlivekit",
5
+ version="0.1.0",
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ "fastapi>=0.68.0",
9
+ "uvicorn>=0.15.0",
10
+ "python-multipart>=0.0.5",
11
+ "numpy>=1.21.0",
12
+ "ffmpeg-python>=0.2.0",
13
+ "torch>=2.0.0",
14
+ "torchaudio>=2.0.0",
15
+ "faster-whisper>=0.9.0",
16
+ "websockets>=10.0",
17
+ "pydantic>=1.8.0",
18
+ "python-dotenv>=0.19.0",
19
+ "setuptools>=65.5.1",
20
+ ],
21
+ python_requires=">=3.10",
22
+ )
static/index.html ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Whisper Live Transcription</title>
7
+ <style>
8
+ body {
9
+ font-family: Arial, sans-serif;
10
+ max-width: 800px;
11
+ margin: 0 auto;
12
+ padding: 20px;
13
+ background-color: #f5f5f5;
14
+ }
15
+ .container {
16
+ background-color: white;
17
+ padding: 20px;
18
+ border-radius: 8px;
19
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
20
+ }
21
+ .controls {
22
+ margin: 20px 0;
23
+ }
24
+ button {
25
+ padding: 10px 20px;
26
+ margin: 5px;
27
+ border: none;
28
+ border-radius: 4px;
29
+ background-color: #007bff;
30
+ color: white;
31
+ cursor: pointer;
32
+ }
33
+ button:disabled {
34
+ background-color: #cccccc;
35
+ cursor: not-allowed;
36
+ }
37
+ #clearBtn {
38
+ background-color: #dc3545;
39
+ }
40
+ #clearBtn:hover {
41
+ background-color: #c82333;
42
+ }
43
+ #transcription {
44
+ margin-top: 20px;
45
+ padding: 15px;
46
+ border: 1px solid #ddd;
47
+ border-radius: 4px;
48
+ min-height: 100px;
49
+ white-space: pre-wrap;
50
+ font-family: monospace;
51
+ line-height: 1.5;
52
+ }
53
+ .status {
54
+ margin: 10px 0;
55
+ padding: 10px;
56
+ border-radius: 4px;
57
+ }
58
+ .config {
59
+ margin: 10px 0;
60
+ padding: 10px;
61
+ background-color: #f8f9fa;
62
+ border-radius: 4px;
63
+ }
64
+ .config input {
65
+ margin-left: 10px;
66
+ padding: 5px;
67
+ border: 1px solid #ddd;
68
+ border-radius: 4px;
69
+ }
70
+ .connected {
71
+ background-color: #d4edda;
72
+ color: #155724;
73
+ }
74
+ .disconnected {
75
+ background-color: #f8d7da;
76
+ color: #721c24;
77
+ }
78
+ #error-message {
79
+ color: #721c24;
80
+ margin: 10px 0;
81
+ padding: 10px;
82
+ background-color: #f8d7da;
83
+ border-radius: 4px;
84
+ display: none;
85
+ }
86
+ </style>
87
+ </head>
88
+ <body>
89
+ <div class="container">
90
+ <h1>Whisper Live Transcription</h1>
91
+ <div class="config">
92
+ <label for="wsUrl">WebSocket URL:</label>
93
+ <input type="text" id="wsUrl" value="ws://localhost:7860/asr" style="width: 300px;">
94
+ </div>
95
+ <div id="status" class="status disconnected">Disconnected</div>
96
+ <div id="error-message"></div>
97
+ <div class="controls">
98
+ <button id="startBtn">Start Recording</button>
99
+ <button id="stopBtn" disabled>Stop Recording</button>
100
+ <button id="reconnectBtn">Reconnect</button>
101
+ <button id="clearBtn">Clear Transcription</button>
102
+ </div>
103
+ <div id="transcription"></div>
104
+ </div>
105
+
106
+ <script>
107
+ let ws = null;
108
+ let mediaRecorder = null;
109
+ let audioChunks = [];
110
+ let reconnectAttempts = 0;
111
+ const maxReconnectAttempts = 5;
112
+ const reconnectDelay = 2000; // 2 seconds
113
+
114
+ const startBtn = document.getElementById('startBtn');
115
+ const stopBtn = document.getElementById('stopBtn');
116
+ const reconnectBtn = document.getElementById('reconnectBtn');
117
+ const clearBtn = document.getElementById('clearBtn');
118
+ const wsUrlInput = document.getElementById('wsUrl');
119
+ const statusDiv = document.getElementById('status');
120
+ const transcriptionDiv = document.getElementById('transcription');
121
+ const errorMessageDiv = document.getElementById('error-message');
122
+
123
+ function showError(message) {
124
+ errorMessageDiv.textContent = message;
125
+ errorMessageDiv.style.display = 'block';
126
+ }
127
+
128
+ function hideError() {
129
+ errorMessageDiv.style.display = 'none';
130
+ }
131
+
132
+ function updateStatus(connected) {
133
+ statusDiv.textContent = connected ? 'Connected' : 'Disconnected';
134
+ statusDiv.className = `status ${connected ? 'connected' : 'disconnected'}`;
135
+ reconnectBtn.disabled = connected;
136
+ }
137
+
138
+ function connectWebSocket() {
139
+ if (ws) {
140
+ ws.close();
141
+ }
142
+
143
+ const wsUrl = wsUrlInput.value;
144
+ console.log('Attempting to connect to:', wsUrl);
145
+ ws = new WebSocket(wsUrl);
146
+
147
+ ws.onopen = () => {
148
+ console.log('WebSocket connection established');
149
+ updateStatus(true);
150
+ startBtn.disabled = false;
151
+ stopBtn.disabled = true;
152
+ hideError();
153
+ reconnectAttempts = 0;
154
+ };
155
+
156
+ ws.onclose = (event) => {
157
+ console.log('WebSocket connection closed:', event.code, event.reason);
158
+ updateStatus(false);
159
+ startBtn.disabled = true;
160
+ stopBtn.disabled = true;
161
+
162
+ if (event.code === 1006) {
163
+ showError('Connection lost. Click "Reconnect" to try again.');
164
+ }
165
+ };
166
+
167
+ ws.onerror = (error) => {
168
+ console.error('WebSocket error:', error);
169
+ updateStatus(false);
170
+ showError('Connection error occurred. Click "Reconnect" to try again.');
171
+ };
172
+
173
+ ws.onmessage = (event) => {
174
+ try {
175
+ console.log('Received message:', event.data);
176
+ const response = JSON.parse(event.data);
177
+ console.log('Parsed response:', response);
178
+
179
+ if (response.text) {
180
+ console.log('Adding text:', response.text);
181
+ transcriptionDiv.textContent += response.text + '\n';
182
+ } else if (response.partial) {
183
+ console.log('Adding partial text:', response.partial);
184
+ transcriptionDiv.textContent = transcriptionDiv.textContent.replace(/[^.!?]+$/, '') + response.partial;
185
+ } else if (response.error) {
186
+ console.error('Server error:', response.error);
187
+ showError('Server error: ' + response.error);
188
+ } else if (response.buffer_transcription) {
189
+ console.log('Adding buffer transcription:', response.buffer_transcription);
190
+ transcriptionDiv.textContent += response.buffer_transcription + '\n';
191
+ } else if (response.full_transcription) {
192
+ console.log('Adding full transcription:', response.full_transcription);
193
+ transcriptionDiv.textContent += response.full_transcription + '\n';
194
+ } else if (typeof response === 'string') {
195
+ console.log('Adding raw text:', response);
196
+ transcriptionDiv.textContent += response + '\n';
197
+ }
198
+ } catch (error) {
199
+ console.error('Error parsing message:', error);
200
+ console.error('Raw message:', event.data);
201
+ }
202
+ };
203
+ }
204
+
205
+ async function startRecording() {
206
+ try {
207
+ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
208
+ mediaRecorder = new MediaRecorder(stream);
209
+ audioChunks = [];
210
+
211
+ mediaRecorder.ondataavailable = (event) => {
212
+ if (event.data.size > 0 && ws && ws.readyState === WebSocket.OPEN) {
213
+ ws.send(event.data);
214
+ }
215
+ };
216
+
217
+ mediaRecorder.start(100); // Send chunks every 100ms
218
+ startBtn.disabled = true;
219
+ stopBtn.disabled = false;
220
+ } catch (error) {
221
+ console.error('Error accessing microphone:', error);
222
+ showError('Error accessing microphone. Please ensure you have granted microphone permissions.');
223
+ }
224
+ }
225
+
226
+ function stopRecording() {
227
+ if (mediaRecorder && mediaRecorder.state !== 'inactive') {
228
+ mediaRecorder.stop();
229
+ mediaRecorder.stream.getTracks().forEach(track => track.stop());
230
+ startBtn.disabled = false;
231
+ stopBtn.disabled = true;
232
+ }
233
+ }
234
+
235
+ function clearTranscription() {
236
+ transcriptionDiv.textContent = '';
237
+ }
238
+
239
+ startBtn.addEventListener('click', startRecording);
240
+ stopBtn.addEventListener('click', stopRecording);
241
+ reconnectBtn.addEventListener('click', connectWebSocket);
242
+ clearBtn.addEventListener('click', clearTranscription);
243
+
244
+ // Connect to WebSocket when page loads
245
+ connectWebSocket();
246
+ </script>
247
+ </body>
248
+ </html>
timed_objects.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ @dataclass
5
+ class TimedText:
6
+ start: Optional[float]
7
+ end: Optional[float]
8
+ text: Optional[str] = ''
9
+ speaker: Optional[int] = -1
10
+ probability: Optional[float] = None
11
+ is_dummy: Optional[bool] = False
12
+
13
+ @dataclass
14
+ class ASRToken(TimedText):
15
+ def with_offset(self, offset: float) -> "ASRToken":
16
+ """Return a new token with the time offset added."""
17
+ return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
18
+
19
+ @dataclass
20
+ class Sentence(TimedText):
21
+ pass
22
+
23
+ @dataclass
24
+ class Transcript(TimedText):
25
+ pass
26
+
27
+ @dataclass
28
+ class SpeakerSegment(TimedText):
29
+ pass
whisper_streaming_custom/__init__.py ADDED
File without changes
whisper_streaming_custom/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (185 Bytes). View file
 
whisper_streaming_custom/__pycache__/backends.cpython-310.pyc ADDED
Binary file (11.5 kB). View file
 
whisper_streaming_custom/__pycache__/online_asr.cpython-310.pyc ADDED
Binary file (16 kB). View file
 
whisper_streaming_custom/__pycache__/whisper_online.cpython-310.pyc ADDED
Binary file (5.5 kB). View file
 
whisper_streaming_custom/backends.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import io
4
+ import soundfile as sf
5
+ import math
6
+ try:
7
+ import torch
8
+ except ImportError:
9
+ torch = None
10
+ from typing import List
11
+ import numpy as np
12
+ from timed_objects import ASRToken
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ class ASRBase:
17
+ sep = " " # join transcribe words with this character (" " for whisper_timestamped,
18
+ # "" for faster-whisper because it emits the spaces when needed)
19
+
20
+ def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
21
+ self.logfile = logfile
22
+ self.transcribe_kargs = {}
23
+ if lan == "auto":
24
+ self.original_language = None
25
+ else:
26
+ self.original_language = lan
27
+ self.model = self.load_model(modelsize, cache_dir, model_dir)
28
+
29
+ def with_offset(self, offset: float) -> ASRToken:
30
+ # This method is kept for compatibility (typically you will use ASRToken.with_offset)
31
+ return ASRToken(self.start + offset, self.end + offset, self.text)
32
+
33
+ def __repr__(self):
34
+ return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
35
+
36
+ def load_model(self, modelsize, cache_dir, model_dir):
37
+ raise NotImplementedError("must be implemented in the child class")
38
+
39
+ def transcribe(self, audio, init_prompt=""):
40
+ raise NotImplementedError("must be implemented in the child class")
41
+
42
+ def use_vad(self):
43
+ raise NotImplementedError("must be implemented in the child class")
44
+
45
+
46
+ class WhisperTimestampedASR(ASRBase):
47
+ """Uses whisper_timestamped as the backend."""
48
+ sep = " "
49
+
50
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
51
+ import whisper
52
+ import whisper_timestamped
53
+ from whisper_timestamped import transcribe_timestamped
54
+
55
+ self.transcribe_timestamped = transcribe_timestamped
56
+ if model_dir is not None:
57
+ logger.debug("ignoring model_dir, not implemented")
58
+ return whisper.load_model(modelsize, download_root=cache_dir)
59
+
60
+ def transcribe(self, audio, init_prompt=""):
61
+ result = self.transcribe_timestamped(
62
+ self.model,
63
+ audio,
64
+ language=self.original_language,
65
+ initial_prompt=init_prompt,
66
+ verbose=None,
67
+ condition_on_previous_text=True,
68
+ **self.transcribe_kargs,
69
+ )
70
+ return result
71
+
72
+ def ts_words(self, r) -> List[ASRToken]:
73
+ """
74
+ Converts the whisper_timestamped result to a list of ASRToken objects.
75
+ """
76
+ tokens = []
77
+ for segment in r["segments"]:
78
+ for word in segment["words"]:
79
+ token = ASRToken(word["start"], word["end"], word["text"])
80
+ tokens.append(token)
81
+ return tokens
82
+
83
+ def segments_end_ts(self, res) -> List[float]:
84
+ return [segment["end"] for segment in res["segments"]]
85
+
86
+ def use_vad(self):
87
+ self.transcribe_kargs["vad"] = True
88
+
89
+ def set_translate_task(self):
90
+ self.transcribe_kargs["task"] = "translate"
91
+
92
+
93
+ class FasterWhisperASR(ASRBase):
94
+ """Uses faster-whisper as the backend."""
95
+ sep = ""
96
+
97
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
98
+ from faster_whisper import WhisperModel
99
+
100
+ if model_dir is not None:
101
+ logger.debug(f"Loading whisper model from model_dir {model_dir}. "
102
+ f"modelsize and cache_dir parameters are not used.")
103
+ model_size_or_path = model_dir
104
+ elif modelsize is not None:
105
+ model_size_or_path = modelsize
106
+ else:
107
+ raise ValueError("Either modelsize or model_dir must be set")
108
+ device = "cuda" if torch and torch.cuda.is_available() else "cpu"
109
+ compute_type = "float16" if device == "cuda" else "float32"
110
+
111
+ model = WhisperModel(
112
+ model_size_or_path,
113
+ device=device,
114
+ compute_type=compute_type,
115
+ download_root=cache_dir,
116
+ )
117
+ return model
118
+
119
+ def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
120
+ segments, info = self.model.transcribe(
121
+ audio,
122
+ language=self.original_language,
123
+ initial_prompt=init_prompt,
124
+ beam_size=5,
125
+ word_timestamps=True,
126
+ condition_on_previous_text=True,
127
+ **self.transcribe_kargs,
128
+ )
129
+ return list(segments)
130
+
131
+ def ts_words(self, segments) -> List[ASRToken]:
132
+ tokens = []
133
+ for segment in segments:
134
+ if segment.no_speech_prob > 0.9:
135
+ continue
136
+ for word in segment.words:
137
+ token = ASRToken(word.start, word.end, word.word, probability=word.probability)
138
+ tokens.append(token)
139
+ return tokens
140
+
141
+ def segments_end_ts(self, segments) -> List[float]:
142
+ return [segment.end for segment in segments]
143
+
144
+ def use_vad(self):
145
+ self.transcribe_kargs["vad_filter"] = True
146
+
147
+ def set_translate_task(self):
148
+ self.transcribe_kargs["task"] = "translate"
149
+
150
+
151
+ class MLXWhisper(ASRBase):
152
+ """
153
+ Uses MLX Whisper optimized for Apple Silicon.
154
+ """
155
+ sep = ""
156
+
157
+ def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
158
+ from mlx_whisper.transcribe import ModelHolder, transcribe
159
+ import mlx.core as mx
160
+
161
+ if model_dir is not None:
162
+ logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
163
+ model_size_or_path = model_dir
164
+ elif modelsize is not None:
165
+ model_size_or_path = self.translate_model_name(modelsize)
166
+ logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
167
+ else:
168
+ raise ValueError("Either modelsize or model_dir must be set")
169
+
170
+ self.model_size_or_path = model_size_or_path
171
+ dtype = mx.float16
172
+ ModelHolder.get_model(model_size_or_path, dtype)
173
+ return transcribe
174
+
175
+ def translate_model_name(self, model_name):
176
+ model_mapping = {
177
+ "tiny.en": "mlx-community/whisper-tiny.en-mlx",
178
+ "tiny": "mlx-community/whisper-tiny-mlx",
179
+ "base.en": "mlx-community/whisper-base.en-mlx",
180
+ "base": "mlx-community/whisper-base-mlx",
181
+ "small.en": "mlx-community/whisper-small.en-mlx",
182
+ "small": "mlx-community/whisper-small-mlx",
183
+ "medium.en": "mlx-community/whisper-medium.en-mlx",
184
+ "medium": "mlx-community/whisper-medium-mlx",
185
+ "large-v1": "mlx-community/whisper-large-v1-mlx",
186
+ "large-v2": "mlx-community/whisper-large-v2-mlx",
187
+ "large-v3": "mlx-community/whisper-large-v3-mlx",
188
+ "large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
189
+ "large": "mlx-community/whisper-large-mlx",
190
+ }
191
+ mlx_model_path = model_mapping.get(model_name)
192
+ if mlx_model_path:
193
+ return mlx_model_path
194
+ else:
195
+ raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
196
+
197
+ def transcribe(self, audio, init_prompt=""):
198
+ if self.transcribe_kargs:
199
+ logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
200
+ segments = self.model(
201
+ audio,
202
+ language=self.original_language,
203
+ initial_prompt=init_prompt,
204
+ word_timestamps=True,
205
+ condition_on_previous_text=True,
206
+ path_or_hf_repo=self.model_size_or_path,
207
+ )
208
+ return segments.get("segments", [])
209
+
210
+ def ts_words(self, segments) -> List[ASRToken]:
211
+ tokens = []
212
+ for segment in segments:
213
+ if segment.get("no_speech_prob", 0) > 0.9:
214
+ continue
215
+ for word in segment.get("words", []):
216
+ token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"])
217
+ tokens.append(token)
218
+ return tokens
219
+
220
+ def segments_end_ts(self, res) -> List[float]:
221
+ return [s["end"] for s in res]
222
+
223
+ def use_vad(self):
224
+ self.transcribe_kargs["vad_filter"] = True
225
+
226
+ def set_translate_task(self):
227
+ self.transcribe_kargs["task"] = "translate"
228
+
229
+
230
+ class OpenaiApiASR(ASRBase):
231
+ """Uses OpenAI's Whisper API for transcription."""
232
+ def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
233
+ self.logfile = logfile
234
+ self.modelname = "whisper-1"
235
+ self.original_language = None if lan == "auto" else lan
236
+ self.response_format = "verbose_json"
237
+ self.temperature = temperature
238
+ self.load_model()
239
+ self.use_vad_opt = False
240
+ self.task = "transcribe"
241
+
242
+ def load_model(self, *args, **kwargs):
243
+ from openai import OpenAI
244
+ self.client = OpenAI()
245
+ self.transcribed_seconds = 0
246
+
247
+ def ts_words(self, segments) -> List[ASRToken]:
248
+ """
249
+ Converts OpenAI API response words into ASRToken objects while
250
+ optionally skipping words that fall into no-speech segments.
251
+ """
252
+ no_speech_segments = []
253
+ if self.use_vad_opt:
254
+ for segment in segments.segments:
255
+ if segment.no_speech_prob > 0.8:
256
+ no_speech_segments.append((segment.start, segment.end))
257
+ tokens = []
258
+ for word in segments.words:
259
+ start = word.start
260
+ end = word.end
261
+ if any(s[0] <= start <= s[1] for s in no_speech_segments):
262
+ continue
263
+ tokens.append(ASRToken(start, end, word.word))
264
+ return tokens
265
+
266
+ def segments_end_ts(self, res) -> List[float]:
267
+ return [s.end for s in res.words]
268
+
269
+ def transcribe(self, audio_data, prompt=None, *args, **kwargs):
270
+ buffer = io.BytesIO()
271
+ buffer.name = "temp.wav"
272
+ sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
273
+ buffer.seek(0)
274
+ self.transcribed_seconds += math.ceil(len(audio_data) / 16000)
275
+ params = {
276
+ "model": self.modelname,
277
+ "file": buffer,
278
+ "response_format": self.response_format,
279
+ "temperature": self.temperature,
280
+ "timestamp_granularities": ["word", "segment"],
281
+ }
282
+ if self.task != "translate" and self.original_language:
283
+ params["language"] = self.original_language
284
+ if prompt:
285
+ params["prompt"] = prompt
286
+ proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
287
+ transcript = proc.create(**params)
288
+ logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
289
+ return transcript
290
+
291
+ def use_vad(self):
292
+ self.use_vad_opt = True
293
+
294
+ def set_translate_task(self):
295
+ self.task = "translate"
whisper_streaming_custom/online_asr.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import logging
4
+ from typing import List, Tuple, Optional
5
+ from timed_objects import ASRToken, Sentence, Transcript
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class HypothesisBuffer:
11
+ """
12
+ Buffer to store and process ASR hypothesis tokens.
13
+
14
+ It holds:
15
+ - committed_in_buffer: tokens that have been confirmed (committed)
16
+ - buffer: the last hypothesis that is not yet committed
17
+ - new: new tokens coming from the recognizer
18
+ """
19
+ def __init__(self, logfile=sys.stderr, confidence_validation=False):
20
+ self.confidence_validation = confidence_validation
21
+ self.committed_in_buffer: List[ASRToken] = []
22
+ self.buffer: List[ASRToken] = []
23
+ self.new: List[ASRToken] = []
24
+ self.last_committed_time = 0.0
25
+ self.last_committed_word: Optional[str] = None
26
+ self.logfile = logfile
27
+
28
+ def insert(self, new_tokens: List[ASRToken], offset: float):
29
+ """
30
+ Insert new tokens (after applying a time offset) and compare them with the
31
+ already committed tokens. Only tokens that extend the committed hypothesis
32
+ are added.
33
+ """
34
+ # Apply the offset to each token.
35
+ new_tokens = [token.with_offset(offset) for token in new_tokens]
36
+ # Only keep tokens that are roughly "new"
37
+ self.new = [token for token in new_tokens if token.start > self.last_committed_time - 0.1]
38
+
39
+ if self.new:
40
+ first_token = self.new[0]
41
+ if abs(first_token.start - self.last_committed_time) < 1:
42
+ if self.committed_in_buffer:
43
+ committed_len = len(self.committed_in_buffer)
44
+ new_len = len(self.new)
45
+ # Try to match 1 to 5 consecutive tokens
46
+ max_ngram = min(min(committed_len, new_len), 5)
47
+ for i in range(1, max_ngram + 1):
48
+ committed_ngram = " ".join(token.text for token in self.committed_in_buffer[-i:])
49
+ new_ngram = " ".join(token.text for token in self.new[:i])
50
+ if committed_ngram == new_ngram:
51
+ removed = []
52
+ for _ in range(i):
53
+ removed_token = self.new.pop(0)
54
+ removed.append(repr(removed_token))
55
+ logger.debug(f"Removing last {i} words: {' '.join(removed)}")
56
+ break
57
+
58
+ def flush(self) -> List[ASRToken]:
59
+ """
60
+ Returns the committed chunk, defined as the longest common prefix
61
+ between the previous hypothesis and the new tokens.
62
+ """
63
+ committed: List[ASRToken] = []
64
+ while self.new:
65
+ current_new = self.new[0]
66
+ if self.confidence_validation and current_new.probability and current_new.probability > 0.95:
67
+ committed.append(current_new)
68
+ self.last_committed_word = current_new.text
69
+ self.last_committed_time = current_new.end
70
+ self.new.pop(0)
71
+ self.buffer.pop(0) if self.buffer else None
72
+ elif not self.buffer:
73
+ break
74
+ elif current_new.text == self.buffer[0].text:
75
+ committed.append(current_new)
76
+ self.last_committed_word = current_new.text
77
+ self.last_committed_time = current_new.end
78
+ self.buffer.pop(0)
79
+ self.new.pop(0)
80
+ else:
81
+ break
82
+ self.buffer = self.new
83
+ self.new = []
84
+ self.committed_in_buffer.extend(committed)
85
+ return committed
86
+
87
+ def pop_committed(self, time: float):
88
+ """
89
+ Remove tokens (from the beginning) that have ended before `time`.
90
+ """
91
+ while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
92
+ self.committed_in_buffer.pop(0)
93
+
94
+
95
+
96
+ class OnlineASRProcessor:
97
+ """
98
+ Processes incoming audio in a streaming fashion, calling the ASR system
99
+ periodically, and uses a hypothesis buffer to commit and trim recognized text.
100
+
101
+ The processor supports two types of buffer trimming:
102
+ - "sentence": trims at sentence boundaries (using a sentence tokenizer)
103
+ - "segment": trims at fixed segment durations.
104
+ """
105
+ SAMPLING_RATE = 16000
106
+
107
+ def __init__(
108
+ self,
109
+ asr,
110
+ tokenize_method: Optional[callable] = None,
111
+ buffer_trimming: Tuple[str, float] = ("segment", 15),
112
+ confidence_validation = False,
113
+ logfile=sys.stderr,
114
+ ):
115
+ """
116
+ asr: An ASR system object (for example, a WhisperASR instance) that
117
+ provides a `transcribe` method, a `ts_words` method (to extract tokens),
118
+ a `segments_end_ts` method, and a separator attribute `sep`.
119
+ tokenize_method: A function that receives text and returns a list of sentence strings.
120
+ buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
121
+ """
122
+ self.asr = asr
123
+ self.tokenize = tokenize_method
124
+ self.logfile = logfile
125
+ self.confidence_validation = confidence_validation
126
+ self.init()
127
+
128
+ self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
129
+
130
+ if self.buffer_trimming_way not in ["sentence", "segment"]:
131
+ raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
132
+ if self.buffer_trimming_sec <= 0:
133
+ raise ValueError("buffer_trimming_sec must be positive")
134
+ elif self.buffer_trimming_sec > 30:
135
+ logger.warning(
136
+ f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
137
+ )
138
+
139
+ def init(self, offset: Optional[float] = None):
140
+ """Initialize or reset the processing buffers."""
141
+ self.audio_buffer = np.array([], dtype=np.float32)
142
+ self.transcript_buffer = HypothesisBuffer(logfile=self.logfile, confidence_validation=self.confidence_validation)
143
+ self.buffer_time_offset = offset if offset is not None else 0.0
144
+ self.transcript_buffer.last_committed_time = self.buffer_time_offset
145
+ self.committed: List[ASRToken] = []
146
+
147
+ def insert_audio_chunk(self, audio: np.ndarray):
148
+ """Append an audio chunk (a numpy array) to the current audio buffer."""
149
+ self.audio_buffer = np.append(self.audio_buffer, audio)
150
+
151
+ def prompt(self) -> Tuple[str, str]:
152
+ """
153
+ Returns a tuple: (prompt, context), where:
154
+ - prompt is a 200-character suffix of committed text that falls
155
+ outside the current audio buffer.
156
+ - context is the committed text within the current audio buffer.
157
+ """
158
+ k = len(self.committed)
159
+ while k > 0 and self.committed[k - 1].end > self.buffer_time_offset:
160
+ k -= 1
161
+
162
+ prompt_tokens = self.committed[:k]
163
+ prompt_words = [token.text for token in prompt_tokens]
164
+ prompt_list = []
165
+ length_count = 0
166
+ # Use the last words until reaching 200 characters.
167
+ while prompt_words and length_count < 200:
168
+ word = prompt_words.pop(-1)
169
+ length_count += len(word) + 1
170
+ prompt_list.append(word)
171
+ non_prompt_tokens = self.committed[k:]
172
+ context_text = self.asr.sep.join(token.text for token in non_prompt_tokens)
173
+ return self.asr.sep.join(prompt_list[::-1]), context_text
174
+
175
+ def get_buffer(self):
176
+ """
177
+ Get the unvalidated buffer in string format.
178
+ """
179
+ return self.concatenate_tokens(self.transcript_buffer.buffer)
180
+
181
+
182
+ def process_iter(self) -> Transcript:
183
+ """
184
+ Processes the current audio buffer.
185
+
186
+ Returns a Transcript object representing the committed transcript.
187
+ """
188
+ prompt_text, _ = self.prompt()
189
+ logger.debug(
190
+ f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
191
+ )
192
+ res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
193
+ tokens = self.asr.ts_words(res) # Expecting List[ASRToken]
194
+ self.transcript_buffer.insert(tokens, self.buffer_time_offset)
195
+ committed_tokens = self.transcript_buffer.flush()
196
+ self.committed.extend(committed_tokens)
197
+ completed = self.concatenate_tokens(committed_tokens)
198
+ logger.debug(f">>>> COMPLETE NOW: {completed.text}")
199
+ incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
200
+ logger.debug(f"INCOMPLETE: {incomp.text}")
201
+
202
+ if committed_tokens and self.buffer_trimming_way == "sentence":
203
+ if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec:
204
+ self.chunk_completed_sentence()
205
+
206
+ s = self.buffer_trimming_sec if self.buffer_trimming_way == "segment" else 30
207
+ if len(self.audio_buffer) / self.SAMPLING_RATE > s:
208
+ self.chunk_completed_segment(res)
209
+ logger.debug("Chunking segment")
210
+ logger.debug(
211
+ f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
212
+ )
213
+ return committed_tokens
214
+
215
+ def chunk_completed_sentence(self):
216
+ """
217
+ If the committed tokens form at least two sentences, chunk the audio
218
+ buffer at the end time of the penultimate sentence.
219
+ """
220
+ if not self.committed:
221
+ return
222
+ logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
223
+ sentences = self.words_to_sentences(self.committed)
224
+ for sentence in sentences:
225
+ logger.debug(f"\tSentence: {sentence.text}")
226
+ if len(sentences) < 2:
227
+ return
228
+ # Keep the last two sentences.
229
+ while len(sentences) > 2:
230
+ sentences.pop(0)
231
+ chunk_time = sentences[-2].end
232
+ logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
233
+ self.chunk_at(chunk_time)
234
+
235
+ def chunk_completed_segment(self, res):
236
+ """
237
+ Chunk the audio buffer based on segment-end timestamps reported by the ASR.
238
+ """
239
+ if not self.committed:
240
+ return
241
+ ends = self.asr.segments_end_ts(res)
242
+ last_committed_time = self.committed[-1].end
243
+ if len(ends) > 1:
244
+ e = ends[-2] + self.buffer_time_offset
245
+ while len(ends) > 2 and e > last_committed_time:
246
+ ends.pop(-1)
247
+ e = ends[-2] + self.buffer_time_offset
248
+ if e <= last_committed_time:
249
+ logger.debug(f"--- Segment chunked at {e:.2f}")
250
+ self.chunk_at(e)
251
+ else:
252
+ logger.debug("--- Last segment not within committed area")
253
+ else:
254
+ logger.debug("--- Not enough segments to chunk")
255
+
256
+ def chunk_at(self, time: float):
257
+ """
258
+ Trim both the hypothesis and audio buffer at the given time.
259
+ """
260
+ logger.debug(f"Chunking at {time:.2f}s")
261
+ logger.debug(
262
+ f"Audio buffer length before chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
263
+ )
264
+ self.transcript_buffer.pop_committed(time)
265
+ cut_seconds = time - self.buffer_time_offset
266
+ self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE):]
267
+ self.buffer_time_offset = time
268
+ logger.debug(
269
+ f"Audio buffer length after chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
270
+ )
271
+
272
+ def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
273
+ """
274
+ Converts a list of tokens to a list of Sentence objects using the provided
275
+ sentence tokenizer.
276
+ """
277
+ if not tokens:
278
+ return []
279
+
280
+ full_text = " ".join(token.text for token in tokens)
281
+
282
+ if self.tokenize:
283
+ try:
284
+ sentence_texts = self.tokenize(full_text)
285
+ except Exception as e:
286
+ # Some tokenizers (e.g., MosesSentenceSplitter) expect a list input.
287
+ try:
288
+ sentence_texts = self.tokenize([full_text])
289
+ except Exception as e2:
290
+ raise ValueError("Tokenization failed") from e2
291
+ else:
292
+ sentence_texts = [full_text]
293
+
294
+ sentences: List[Sentence] = []
295
+ token_index = 0
296
+ for sent_text in sentence_texts:
297
+ sent_text = sent_text.strip()
298
+ if not sent_text:
299
+ continue
300
+ sent_tokens = []
301
+ accumulated = ""
302
+ # Accumulate tokens until roughly matching the length of the sentence text.
303
+ while token_index < len(tokens) and len(accumulated) < len(sent_text):
304
+ token = tokens[token_index]
305
+ accumulated = (accumulated + " " + token.text).strip() if accumulated else token.text
306
+ sent_tokens.append(token)
307
+ token_index += 1
308
+ if sent_tokens:
309
+ sentence = Sentence(
310
+ start=sent_tokens[0].start,
311
+ end=sent_tokens[-1].end,
312
+ text=" ".join(t.text for t in sent_tokens),
313
+ )
314
+ sentences.append(sentence)
315
+ return sentences
316
+ def finish(self) -> Transcript:
317
+ """
318
+ Flush the remaining transcript when processing ends.
319
+ """
320
+ remaining_tokens = self.transcript_buffer.buffer
321
+ final_transcript = self.concatenate_tokens(remaining_tokens)
322
+ logger.debug(f"Final non-committed transcript: {final_transcript}")
323
+ self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
324
+ return final_transcript
325
+
326
+ def concatenate_tokens(
327
+ self,
328
+ tokens: List[ASRToken],
329
+ sep: Optional[str] = None,
330
+ offset: float = 0
331
+ ) -> Transcript:
332
+ sep = sep if sep is not None else self.asr.sep
333
+ text = sep.join(token.text for token in tokens)
334
+ probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
335
+ if tokens:
336
+ start = offset + tokens[0].start
337
+ end = offset + tokens[-1].end
338
+ else:
339
+ start = None
340
+ end = None
341
+ return Transcript(start, end, text, probability=probability)
342
+
343
+
344
+ class VACOnlineASRProcessor:
345
+ """
346
+ Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
347
+
348
+ It receives small chunks of audio, applies VAD (e.g. with Silero),
349
+ and when the system detects a pause in speech (or end of an utterance)
350
+ it finalizes the utterance immediately.
351
+ """
352
+ SAMPLING_RATE = 16000
353
+
354
+ def __init__(self, online_chunk_size: float, *args, **kwargs):
355
+ self.online_chunk_size = online_chunk_size
356
+ self.online = OnlineASRProcessor(*args, **kwargs)
357
+
358
+ # Load a VAD model (e.g. Silero VAD)
359
+ import torch
360
+ model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
361
+ from silero_vad_iterator import FixedVADIterator
362
+
363
+ self.vac = FixedVADIterator(model)
364
+ self.logfile = self.online.logfile
365
+ self.init()
366
+
367
+ def init(self):
368
+ self.online.init()
369
+ self.vac.reset_states()
370
+ self.current_online_chunk_buffer_size = 0
371
+ self.is_currently_final = False
372
+ self.status: Optional[str] = None # "voice" or "nonvoice"
373
+ self.audio_buffer = np.array([], dtype=np.float32)
374
+ self.buffer_offset = 0 # in frames
375
+
376
+ def clear_buffer(self):
377
+ self.buffer_offset += len(self.audio_buffer)
378
+ self.audio_buffer = np.array([], dtype=np.float32)
379
+
380
+ def insert_audio_chunk(self, audio: np.ndarray):
381
+ """
382
+ Process an incoming small audio chunk:
383
+ - run VAD on the chunk,
384
+ - decide whether to send the audio to the online ASR processor immediately,
385
+ - and/or to mark the current utterance as finished.
386
+ """
387
+ res = self.vac(audio)
388
+ self.audio_buffer = np.append(self.audio_buffer, audio)
389
+
390
+ if res is not None:
391
+ # VAD returned a result; adjust the frame number
392
+ frame = list(res.values())[0] - self.buffer_offset
393
+ if "start" in res and "end" not in res:
394
+ self.status = "voice"
395
+ send_audio = self.audio_buffer[frame:]
396
+ self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
397
+ self.online.insert_audio_chunk(send_audio)
398
+ self.current_online_chunk_buffer_size += len(send_audio)
399
+ self.clear_buffer()
400
+ elif "end" in res and "start" not in res:
401
+ self.status = "nonvoice"
402
+ send_audio = self.audio_buffer[:frame]
403
+ self.online.insert_audio_chunk(send_audio)
404
+ self.current_online_chunk_buffer_size += len(send_audio)
405
+ self.is_currently_final = True
406
+ self.clear_buffer()
407
+ else:
408
+ beg = res["start"] - self.buffer_offset
409
+ end = res["end"] - self.buffer_offset
410
+ self.status = "nonvoice"
411
+ send_audio = self.audio_buffer[beg:end]
412
+ self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
413
+ self.online.insert_audio_chunk(send_audio)
414
+ self.current_online_chunk_buffer_size += len(send_audio)
415
+ self.is_currently_final = True
416
+ self.clear_buffer()
417
+ else:
418
+ if self.status == "voice":
419
+ self.online.insert_audio_chunk(self.audio_buffer)
420
+ self.current_online_chunk_buffer_size += len(self.audio_buffer)
421
+ self.clear_buffer()
422
+ else:
423
+ # Keep 1 second worth of audio in case VAD later detects voice,
424
+ # but trim to avoid unbounded memory usage.
425
+ self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
426
+ self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
427
+
428
+ def process_iter(self) -> Transcript:
429
+ """
430
+ Depending on the VAD status and the amount of accumulated audio,
431
+ process the current audio chunk.
432
+ """
433
+ if self.is_currently_final:
434
+ return self.finish()
435
+ elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
436
+ self.current_online_chunk_buffer_size = 0
437
+ return self.online.process_iter()
438
+ else:
439
+ logger.debug("No online update, only VAD")
440
+ return Transcript(None, None, "")
441
+
442
+ def finish(self) -> Transcript:
443
+ """Finish processing by flushing any remaining text."""
444
+ result = self.online.finish()
445
+ self.current_online_chunk_buffer_size = 0
446
+ self.is_currently_final = False
447
+ return result
448
+
449
+ def get_buffer(self):
450
+ """
451
+ Get the unvalidated buffer in string format.
452
+ """
453
+ return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text
whisper_streaming_custom/whisper_online.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import numpy as np
4
+ import librosa
5
+ from functools import lru_cache
6
+ import time
7
+ import logging
8
+ from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
9
+ from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+
15
+ WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
16
+ ","
17
+ )
18
+
19
+
20
+ def create_tokenizer(lan):
21
+ """returns an object that has split function that works like the one of MosesTokenizer"""
22
+
23
+ assert (
24
+ lan in WHISPER_LANG_CODES
25
+ ), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
26
+
27
+ if lan == "uk":
28
+ import tokenize_uk
29
+
30
+ class UkrainianTokenizer:
31
+ def split(self, text):
32
+ return tokenize_uk.tokenize_sents(text)
33
+
34
+ return UkrainianTokenizer()
35
+
36
+ # supported by fast-mosestokenizer
37
+ if (
38
+ lan
39
+ in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
40
+ ):
41
+ from mosestokenizer import MosesSentenceSplitter
42
+
43
+ return MosesSentenceSplitter(lan)
44
+
45
+ # the following languages are in Whisper, but not in wtpsplit:
46
+ if (
47
+ lan
48
+ in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
49
+ ):
50
+ logger.debug(
51
+ f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
52
+ )
53
+ lan = None
54
+
55
+ from wtpsplit import WtP
56
+
57
+ # downloads the model from huggingface on the first use
58
+ wtp = WtP("wtp-canine-s-12l-no-adapters")
59
+
60
+ class WtPtok:
61
+ def split(self, sent):
62
+ return wtp.split(sent, lang_code=lan)
63
+
64
+ return WtPtok()
65
+
66
+
67
+ def backend_factory(args):
68
+ backend = args.backend
69
+ if backend == "openai-api":
70
+ logger.debug("Using OpenAI API.")
71
+ asr = OpenaiApiASR(lan=args.lan)
72
+ else:
73
+ if backend == "faster-whisper":
74
+ asr_cls = FasterWhisperASR
75
+ elif backend == "mlx-whisper":
76
+ asr_cls = MLXWhisper
77
+ else:
78
+ asr_cls = WhisperTimestampedASR
79
+
80
+ # Only for FasterWhisperASR and WhisperTimestampedASR
81
+ size = args.model
82
+ t = time.time()
83
+ logger.info(f"Loading Whisper {size} model for language {args.lan}...")
84
+ asr = asr_cls(
85
+ modelsize=size,
86
+ lan=args.lan,
87
+ cache_dir=args.model_cache_dir,
88
+ model_dir=args.model_dir,
89
+ )
90
+ e = time.time()
91
+ logger.info(f"done. It took {round(e-t,2)} seconds.")
92
+
93
+ # Apply common configurations
94
+ if getattr(args, "vad", False): # Checks if VAD argument is present and True
95
+ logger.info("Setting VAD filter")
96
+ asr.use_vad()
97
+
98
+ language = args.lan
99
+ if args.task == "translate":
100
+ asr.set_translate_task()
101
+ tgt_language = "en" # Whisper translates into English
102
+ else:
103
+ tgt_language = language # Whisper transcribes in this language
104
+
105
+ # Create the tokenizer
106
+ if args.buffer_trimming == "sentence":
107
+
108
+ tokenizer = create_tokenizer(tgt_language)
109
+ else:
110
+ tokenizer = None
111
+ return asr, tokenizer
112
+
113
+ def online_factory(args, asr, tokenizer, logfile=sys.stderr):
114
+ if args.vac:
115
+ online = VACOnlineASRProcessor(
116
+ args.min_chunk_size,
117
+ asr,
118
+ tokenizer,
119
+ logfile=logfile,
120
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
121
+ confidence_validation = args.confidence_validation
122
+ )
123
+ else:
124
+ online = OnlineASRProcessor(
125
+ asr,
126
+ tokenizer,
127
+ logfile=logfile,
128
+ buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
129
+ confidence_validation = args.confidence_validation
130
+ )
131
+ return online
132
+
133
+ def asr_factory(args, logfile=sys.stderr):
134
+ """
135
+ Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
136
+ """
137
+ asr, tokenizer = backend_factory(args)
138
+ online = online_factory(args, asr, tokenizer, logfile=logfile)
139
+ return asr, online
140
+
141
+ def warmup_asr(asr, warmup_file=None, timeout=5):
142
+ """
143
+ Warmup the ASR model by transcribing a short audio file.
144
+ """
145
+ import os
146
+ import tempfile
147
+
148
+
149
+ if warmup_file is None:
150
+ # Download JFK sample if not already present
151
+ jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
152
+ temp_dir = tempfile.gettempdir()
153
+ warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
154
+
155
+ if not os.path.exists(warmup_file):
156
+ logger.debug(f"Downloading warmup file from {jfk_url}")
157
+ print(f"Downloading warmup file from {jfk_url}")
158
+ import time
159
+ import urllib.request
160
+ import urllib.error
161
+ import socket
162
+
163
+ original_timeout = socket.getdefaulttimeout()
164
+ socket.setdefaulttimeout(timeout)
165
+
166
+ start_time = time.time()
167
+ try:
168
+ urllib.request.urlretrieve(jfk_url, warmup_file)
169
+ logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
170
+ except (urllib.error.URLError, socket.timeout) as e:
171
+ logger.warning(f"Download failed: {e}. Proceeding without warmup.")
172
+ return False
173
+ finally:
174
+ socket.setdefaulttimeout(original_timeout)
175
+ elif not warmup_file:
176
+ return False
177
+
178
+ if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
179
+ logger.warning(f"Warmup file {warmup_file} invalid or missing.")
180
+ return False
181
+
182
+ print(f"Warmping up Whisper with {warmup_file}")
183
+ try:
184
+ import librosa
185
+ audio, sr = librosa.load(warmup_file, sr=16000)
186
+ except Exception as e:
187
+ logger.warning(f"Failed to load audio file: {e}")
188
+ return False
189
+
190
+ # Process the audio
191
+ asr.transcribe(audio)
192
+
193
+ logger.info("Whisper is warmed up")
194
+