Spaces:
Paused
Paused
Add application file
Browse files- Dockerfile +44 -0
- __init__.py +4 -0
- audio_processor.py +409 -0
- core.py +61 -0
- diarization/__init__.py +0 -0
- diarization/__pycache__/__init__.cpython-310.pyc +0 -0
- diarization/__pycache__/diarization_online.cpython-310.pyc +0 -0
- diarization/diarization_online.py +153 -0
- docker-compose.yml +36 -0
- main.py +275 -0
- requirements.txt +15 -0
- setup.py +22 -0
- static/index.html +248 -0
- timed_objects.py +29 -0
- whisper_streaming_custom/__init__.py +0 -0
- whisper_streaming_custom/__pycache__/__init__.cpython-310.pyc +0 -0
- whisper_streaming_custom/__pycache__/backends.cpython-310.pyc +0 -0
- whisper_streaming_custom/__pycache__/online_asr.cpython-310.pyc +0 -0
- whisper_streaming_custom/__pycache__/whisper_online.cpython-310.pyc +0 -0
- whisper_streaming_custom/backends.py +295 -0
- whisper_streaming_custom/online_asr.py +453 -0
- whisper_streaming_custom/whisper_online.py +194 -0
Dockerfile
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
|
2 |
+
# you will also find guides on how best to write your Dockerfile
|
3 |
+
|
4 |
+
FROM python:3.10-slim
|
5 |
+
|
6 |
+
# Install system dependencies
|
7 |
+
RUN apt-get update && apt-get install -y \
|
8 |
+
build-essential \
|
9 |
+
git \
|
10 |
+
curl \
|
11 |
+
ffmpeg \
|
12 |
+
&& rm -rf /var/lib/apt/lists/* \
|
13 |
+
&& apt-get clean
|
14 |
+
|
15 |
+
# Set working directory
|
16 |
+
WORKDIR /app
|
17 |
+
|
18 |
+
# Create static directory and set permissions
|
19 |
+
RUN mkdir -p /app/static && chmod 777 /app/static
|
20 |
+
|
21 |
+
# Create a non-root user
|
22 |
+
RUN useradd -m -u 1000 user
|
23 |
+
USER user
|
24 |
+
ENV PATH="/home/user/.local/bin:$PATH"
|
25 |
+
|
26 |
+
# Copy requirements first to leverage Docker cache
|
27 |
+
COPY --chown=user requirements.txt .
|
28 |
+
RUN pip install --no-cache-dir --upgrade -r requirements.txt
|
29 |
+
|
30 |
+
# Copy application files
|
31 |
+
COPY --chown=user *.py /app/
|
32 |
+
COPY --chown=user whisper_streaming_custom /app/whisper_streaming_custom/
|
33 |
+
COPY --chown=user diarization /app/diarization/
|
34 |
+
COPY --chown=user static /app/static/
|
35 |
+
|
36 |
+
# Set environment variables
|
37 |
+
ENV PYTHONPATH=/app
|
38 |
+
ENV PYTHONUNBUFFERED=1
|
39 |
+
|
40 |
+
# Expose the port the server runs on
|
41 |
+
EXPOSE 7860
|
42 |
+
|
43 |
+
# Run the server using main.py
|
44 |
+
CMD ["python", "main.py", "--host", "0.0.0.0", "--port", "7860", "--model", "tiny", "--backend", "faster-whisper", "--task", "transcribe"]
|
__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .core import WhisperLiveKit, parse_args
|
2 |
+
from .audio_processor import AudioProcessor
|
3 |
+
|
4 |
+
__all__ = ['WhisperLiveKit', 'AudioProcessor', 'parse_args']
|
audio_processor.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import numpy as np
|
3 |
+
import ffmpeg
|
4 |
+
from time import time, sleep
|
5 |
+
import math
|
6 |
+
import logging
|
7 |
+
import traceback
|
8 |
+
from datetime import timedelta
|
9 |
+
from typing import List, Dict, Any
|
10 |
+
from timed_objects import ASRToken
|
11 |
+
from whisper_streaming_custom.whisper_online import online_factory
|
12 |
+
from core import WhisperLiveKit
|
13 |
+
|
14 |
+
# Set up logging once
|
15 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
logger.setLevel(logging.DEBUG)
|
18 |
+
|
19 |
+
def format_time(seconds: float) -> str:
|
20 |
+
"""Format seconds as HH:MM:SS."""
|
21 |
+
return str(timedelta(seconds=int(seconds)))
|
22 |
+
|
23 |
+
class AudioProcessor:
|
24 |
+
"""
|
25 |
+
Processes audio streams for transcription and diarization.
|
26 |
+
Handles audio processing, state management, and result formatting.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self):
|
30 |
+
"""Initialize the audio processor with configuration, models, and state."""
|
31 |
+
|
32 |
+
models = WhisperLiveKit()
|
33 |
+
|
34 |
+
# Audio processing settings
|
35 |
+
self.args = models.args
|
36 |
+
self.sample_rate = 16000
|
37 |
+
self.channels = 1
|
38 |
+
self.samples_per_sec = int(self.sample_rate * self.args.min_chunk_size)
|
39 |
+
self.bytes_per_sample = 2
|
40 |
+
self.bytes_per_sec = self.samples_per_sec * self.bytes_per_sample
|
41 |
+
self.max_bytes_per_sec = 32000 * 5 # 5 seconds of audio at 32 kHz
|
42 |
+
|
43 |
+
# State management
|
44 |
+
self.tokens = []
|
45 |
+
self.buffer_transcription = ""
|
46 |
+
self.buffer_diarization = ""
|
47 |
+
self.full_transcription = ""
|
48 |
+
self.end_buffer = 0
|
49 |
+
self.end_attributed_speaker = 0
|
50 |
+
self.lock = asyncio.Lock()
|
51 |
+
self.beg_loop = time()
|
52 |
+
self.sep = " " # Default separator
|
53 |
+
self.last_response_content = ""
|
54 |
+
|
55 |
+
# Models and processing
|
56 |
+
self.asr = models.asr
|
57 |
+
self.tokenizer = models.tokenizer
|
58 |
+
self.diarization = models.diarization
|
59 |
+
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
60 |
+
self.transcription_queue = asyncio.Queue() if self.args.transcription else None
|
61 |
+
self.diarization_queue = asyncio.Queue() if self.args.diarization else None
|
62 |
+
self.pcm_buffer = bytearray()
|
63 |
+
|
64 |
+
# Initialize transcription engine if enabled
|
65 |
+
if self.args.transcription:
|
66 |
+
self.online = online_factory(self.args, models.asr, models.tokenizer)
|
67 |
+
|
68 |
+
def convert_pcm_to_float(self, pcm_buffer):
|
69 |
+
"""Convert PCM buffer in s16le format to normalized NumPy array."""
|
70 |
+
return np.frombuffer(pcm_buffer, dtype=np.int16).astype(np.float32) / 32768.0
|
71 |
+
|
72 |
+
def start_ffmpeg_decoder(self):
|
73 |
+
"""Start FFmpeg process for WebM to PCM conversion."""
|
74 |
+
return (ffmpeg.input("pipe:0", format="webm")
|
75 |
+
.output("pipe:1", format="s16le", acodec="pcm_s16le",
|
76 |
+
ac=self.channels, ar=str(self.sample_rate))
|
77 |
+
.run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True))
|
78 |
+
|
79 |
+
async def restart_ffmpeg(self):
|
80 |
+
"""Restart the FFmpeg process after failure."""
|
81 |
+
if self.ffmpeg_process:
|
82 |
+
try:
|
83 |
+
self.ffmpeg_process.kill()
|
84 |
+
await asyncio.get_event_loop().run_in_executor(None, self.ffmpeg_process.wait)
|
85 |
+
except Exception as e:
|
86 |
+
logger.warning(f"Error killing FFmpeg process: {e}")
|
87 |
+
self.ffmpeg_process = self.start_ffmpeg_decoder()
|
88 |
+
self.pcm_buffer = bytearray()
|
89 |
+
|
90 |
+
async def update_transcription(self, new_tokens, buffer, end_buffer, full_transcription, sep):
|
91 |
+
"""Thread-safe update of transcription with new data."""
|
92 |
+
async with self.lock:
|
93 |
+
self.tokens.extend(new_tokens)
|
94 |
+
self.buffer_transcription = buffer
|
95 |
+
self.end_buffer = end_buffer
|
96 |
+
self.full_transcription = full_transcription
|
97 |
+
self.sep = sep
|
98 |
+
|
99 |
+
async def update_diarization(self, end_attributed_speaker, buffer_diarization=""):
|
100 |
+
"""Thread-safe update of diarization with new data."""
|
101 |
+
async with self.lock:
|
102 |
+
self.end_attributed_speaker = end_attributed_speaker
|
103 |
+
if buffer_diarization:
|
104 |
+
self.buffer_diarization = buffer_diarization
|
105 |
+
|
106 |
+
async def add_dummy_token(self):
|
107 |
+
"""Placeholder token when no transcription is available."""
|
108 |
+
async with self.lock:
|
109 |
+
current_time = time() - self.beg_loop
|
110 |
+
self.tokens.append(ASRToken(
|
111 |
+
start=current_time, end=current_time + 1,
|
112 |
+
text=".", speaker=-1, is_dummy=True
|
113 |
+
))
|
114 |
+
|
115 |
+
async def get_current_state(self):
|
116 |
+
"""Get current state."""
|
117 |
+
async with self.lock:
|
118 |
+
current_time = time()
|
119 |
+
|
120 |
+
# Calculate remaining times
|
121 |
+
remaining_transcription = 0
|
122 |
+
if self.end_buffer > 0:
|
123 |
+
remaining_transcription = max(0, round(current_time - self.beg_loop - self.end_buffer, 2))
|
124 |
+
|
125 |
+
remaining_diarization = 0
|
126 |
+
if self.tokens:
|
127 |
+
latest_end = max(self.end_buffer, self.tokens[-1].end if self.tokens else 0)
|
128 |
+
remaining_diarization = max(0, round(latest_end - self.end_attributed_speaker, 2))
|
129 |
+
|
130 |
+
return {
|
131 |
+
"tokens": self.tokens.copy(),
|
132 |
+
"buffer_transcription": self.buffer_transcription,
|
133 |
+
"buffer_diarization": self.buffer_diarization,
|
134 |
+
"end_buffer": self.end_buffer,
|
135 |
+
"end_attributed_speaker": self.end_attributed_speaker,
|
136 |
+
"sep": self.sep,
|
137 |
+
"remaining_time_transcription": remaining_transcription,
|
138 |
+
"remaining_time_diarization": remaining_diarization
|
139 |
+
}
|
140 |
+
|
141 |
+
async def reset(self):
|
142 |
+
"""Reset all state variables to initial values."""
|
143 |
+
async with self.lock:
|
144 |
+
self.tokens = []
|
145 |
+
self.buffer_transcription = self.buffer_diarization = ""
|
146 |
+
self.end_buffer = self.end_attributed_speaker = 0
|
147 |
+
self.full_transcription = self.last_response_content = ""
|
148 |
+
self.beg_loop = time()
|
149 |
+
|
150 |
+
async def ffmpeg_stdout_reader(self):
|
151 |
+
"""Read audio data from FFmpeg stdout and process it."""
|
152 |
+
loop = asyncio.get_event_loop()
|
153 |
+
beg = time()
|
154 |
+
|
155 |
+
while True:
|
156 |
+
try:
|
157 |
+
# Calculate buffer size based on elapsed time
|
158 |
+
elapsed_time = math.floor((time() - beg) * 10) / 10 # Round to 0.1 sec
|
159 |
+
buffer_size = max(int(32000 * elapsed_time), 4096)
|
160 |
+
beg = time()
|
161 |
+
|
162 |
+
# Read chunk with timeout
|
163 |
+
try:
|
164 |
+
chunk = await asyncio.wait_for(
|
165 |
+
loop.run_in_executor(None, self.ffmpeg_process.stdout.read, buffer_size),
|
166 |
+
timeout=15.0
|
167 |
+
)
|
168 |
+
except asyncio.TimeoutError:
|
169 |
+
logger.warning("FFmpeg read timeout. Restarting...")
|
170 |
+
await self.restart_ffmpeg()
|
171 |
+
beg = time()
|
172 |
+
continue
|
173 |
+
|
174 |
+
if not chunk:
|
175 |
+
logger.info("FFmpeg stdout closed.")
|
176 |
+
break
|
177 |
+
|
178 |
+
self.pcm_buffer.extend(chunk)
|
179 |
+
|
180 |
+
# Send to diarization if enabled
|
181 |
+
if self.args.diarization and self.diarization_queue:
|
182 |
+
await self.diarization_queue.put(
|
183 |
+
self.convert_pcm_to_float(self.pcm_buffer).copy()
|
184 |
+
)
|
185 |
+
|
186 |
+
# Process when we have enough data
|
187 |
+
if len(self.pcm_buffer) >= self.bytes_per_sec:
|
188 |
+
if len(self.pcm_buffer) > self.max_bytes_per_sec:
|
189 |
+
logger.warning(
|
190 |
+
f"Audio buffer too large: {len(self.pcm_buffer) / self.bytes_per_sec:.2f}s. "
|
191 |
+
f"Consider using a smaller model."
|
192 |
+
)
|
193 |
+
|
194 |
+
# Process audio chunk
|
195 |
+
pcm_array = self.convert_pcm_to_float(self.pcm_buffer[:self.max_bytes_per_sec])
|
196 |
+
self.pcm_buffer = self.pcm_buffer[self.max_bytes_per_sec:]
|
197 |
+
|
198 |
+
# Send to transcription if enabled
|
199 |
+
if self.args.transcription and self.transcription_queue:
|
200 |
+
await self.transcription_queue.put(pcm_array.copy())
|
201 |
+
|
202 |
+
# Sleep if no processing is happening
|
203 |
+
if not self.args.transcription and not self.args.diarization:
|
204 |
+
await asyncio.sleep(0.1)
|
205 |
+
|
206 |
+
except Exception as e:
|
207 |
+
logger.warning(f"Exception in ffmpeg_stdout_reader: {e}")
|
208 |
+
logger.warning(f"Traceback: {traceback.format_exc()}")
|
209 |
+
break
|
210 |
+
|
211 |
+
async def transcription_processor(self):
|
212 |
+
"""Process audio chunks for transcription."""
|
213 |
+
self.full_transcription = ""
|
214 |
+
self.sep = self.online.asr.sep
|
215 |
+
|
216 |
+
while True:
|
217 |
+
try:
|
218 |
+
pcm_array = await self.transcription_queue.get()
|
219 |
+
|
220 |
+
logger.info(f"{len(self.online.audio_buffer) / self.online.SAMPLING_RATE} seconds of audio to process.")
|
221 |
+
|
222 |
+
# Process transcription
|
223 |
+
self.online.insert_audio_chunk(pcm_array)
|
224 |
+
new_tokens = self.online.process_iter()
|
225 |
+
|
226 |
+
if new_tokens:
|
227 |
+
self.full_transcription += self.sep.join([t.text for t in new_tokens])
|
228 |
+
|
229 |
+
# Get buffer information
|
230 |
+
_buffer = self.online.get_buffer()
|
231 |
+
buffer = _buffer.text
|
232 |
+
end_buffer = _buffer.end if _buffer.end else (
|
233 |
+
new_tokens[-1].end if new_tokens else 0
|
234 |
+
)
|
235 |
+
|
236 |
+
# Avoid duplicating content
|
237 |
+
if buffer in self.full_transcription:
|
238 |
+
buffer = ""
|
239 |
+
|
240 |
+
await self.update_transcription(
|
241 |
+
new_tokens, buffer, end_buffer, self.full_transcription, self.sep
|
242 |
+
)
|
243 |
+
|
244 |
+
except Exception as e:
|
245 |
+
logger.warning(f"Exception in transcription_processor: {e}")
|
246 |
+
logger.warning(f"Traceback: {traceback.format_exc()}")
|
247 |
+
finally:
|
248 |
+
self.transcription_queue.task_done()
|
249 |
+
|
250 |
+
async def diarization_processor(self, diarization_obj):
|
251 |
+
"""Process audio chunks for speaker diarization."""
|
252 |
+
buffer_diarization = ""
|
253 |
+
|
254 |
+
while True:
|
255 |
+
try:
|
256 |
+
pcm_array = await self.diarization_queue.get()
|
257 |
+
|
258 |
+
# Process diarization
|
259 |
+
await diarization_obj.diarize(pcm_array)
|
260 |
+
|
261 |
+
# Get current state and update speakers
|
262 |
+
state = await self.get_current_state()
|
263 |
+
new_end = diarization_obj.assign_speakers_to_tokens(
|
264 |
+
state["end_attributed_speaker"], state["tokens"]
|
265 |
+
)
|
266 |
+
|
267 |
+
await self.update_diarization(new_end, buffer_diarization)
|
268 |
+
|
269 |
+
except Exception as e:
|
270 |
+
logger.warning(f"Exception in diarization_processor: {e}")
|
271 |
+
logger.warning(f"Traceback: {traceback.format_exc()}")
|
272 |
+
finally:
|
273 |
+
self.diarization_queue.task_done()
|
274 |
+
|
275 |
+
async def results_formatter(self):
|
276 |
+
"""Format processing results for output."""
|
277 |
+
while True:
|
278 |
+
try:
|
279 |
+
# Get current state
|
280 |
+
state = await self.get_current_state()
|
281 |
+
tokens = state["tokens"]
|
282 |
+
buffer_transcription = state["buffer_transcription"]
|
283 |
+
buffer_diarization = state["buffer_diarization"]
|
284 |
+
end_attributed_speaker = state["end_attributed_speaker"]
|
285 |
+
sep = state["sep"]
|
286 |
+
|
287 |
+
# Add dummy tokens if needed
|
288 |
+
if (not tokens or tokens[-1].is_dummy) and not self.args.transcription and self.args.diarization:
|
289 |
+
await self.add_dummy_token()
|
290 |
+
sleep(0.5)
|
291 |
+
state = await self.get_current_state()
|
292 |
+
tokens = state["tokens"]
|
293 |
+
|
294 |
+
# Format output
|
295 |
+
previous_speaker = -1
|
296 |
+
lines = []
|
297 |
+
last_end_diarized = 0
|
298 |
+
undiarized_text = []
|
299 |
+
|
300 |
+
# Process each token
|
301 |
+
for token in tokens:
|
302 |
+
speaker = token.speaker
|
303 |
+
|
304 |
+
# Handle diarization
|
305 |
+
if self.args.diarization:
|
306 |
+
if (speaker in [-1, 0]) and token.end >= end_attributed_speaker:
|
307 |
+
undiarized_text.append(token.text)
|
308 |
+
continue
|
309 |
+
elif (speaker in [-1, 0]) and token.end < end_attributed_speaker:
|
310 |
+
speaker = previous_speaker
|
311 |
+
if speaker not in [-1, 0]:
|
312 |
+
last_end_diarized = max(token.end, last_end_diarized)
|
313 |
+
|
314 |
+
# Group by speaker
|
315 |
+
if speaker != previous_speaker or not lines:
|
316 |
+
lines.append({
|
317 |
+
"speaker": speaker,
|
318 |
+
"text": token.text,
|
319 |
+
"beg": format_time(token.start),
|
320 |
+
"end": format_time(token.end),
|
321 |
+
"diff": round(token.end - last_end_diarized, 2)
|
322 |
+
})
|
323 |
+
previous_speaker = speaker
|
324 |
+
elif token.text: # Only append if text isn't empty
|
325 |
+
lines[-1]["text"] += sep + token.text
|
326 |
+
lines[-1]["end"] = format_time(token.end)
|
327 |
+
lines[-1]["diff"] = round(token.end - last_end_diarized, 2)
|
328 |
+
|
329 |
+
# Handle undiarized text
|
330 |
+
if undiarized_text:
|
331 |
+
combined = sep.join(undiarized_text)
|
332 |
+
if buffer_transcription:
|
333 |
+
combined += sep
|
334 |
+
await self.update_diarization(end_attributed_speaker, combined)
|
335 |
+
buffer_diarization = combined
|
336 |
+
|
337 |
+
# Create response object
|
338 |
+
if not lines:
|
339 |
+
lines = [{
|
340 |
+
"speaker": 1,
|
341 |
+
"text": "",
|
342 |
+
"beg": format_time(0),
|
343 |
+
"end": format_time(tokens[-1].end if tokens else 0),
|
344 |
+
"diff": 0
|
345 |
+
}]
|
346 |
+
|
347 |
+
response = {
|
348 |
+
"lines": lines,
|
349 |
+
"buffer_transcription": buffer_transcription,
|
350 |
+
"buffer_diarization": buffer_diarization,
|
351 |
+
"remaining_time_transcription": state["remaining_time_transcription"],
|
352 |
+
"remaining_time_diarization": state["remaining_time_diarization"]
|
353 |
+
}
|
354 |
+
|
355 |
+
# Only yield if content has changed
|
356 |
+
response_content = ' '.join([f"{line['speaker']} {line['text']}" for line in lines]) + \
|
357 |
+
f" | {buffer_transcription} | {buffer_diarization}"
|
358 |
+
|
359 |
+
if response_content != self.last_response_content and (lines or buffer_transcription or buffer_diarization):
|
360 |
+
yield response
|
361 |
+
self.last_response_content = response_content
|
362 |
+
|
363 |
+
await asyncio.sleep(0.1) # Avoid overwhelming the client
|
364 |
+
|
365 |
+
except Exception as e:
|
366 |
+
logger.warning(f"Exception in results_formatter: {e}")
|
367 |
+
logger.warning(f"Traceback: {traceback.format_exc()}")
|
368 |
+
await asyncio.sleep(0.5) # Back off on error
|
369 |
+
|
370 |
+
async def create_tasks(self):
|
371 |
+
"""Create and start processing tasks."""
|
372 |
+
|
373 |
+
tasks = []
|
374 |
+
if self.args.transcription and self.online:
|
375 |
+
tasks.append(asyncio.create_task(self.transcription_processor()))
|
376 |
+
|
377 |
+
if self.args.diarization and self.diarization:
|
378 |
+
tasks.append(asyncio.create_task(self.diarization_processor(self.diarization)))
|
379 |
+
|
380 |
+
tasks.append(asyncio.create_task(self.ffmpeg_stdout_reader()))
|
381 |
+
self.tasks = tasks
|
382 |
+
|
383 |
+
return self.results_formatter()
|
384 |
+
|
385 |
+
async def cleanup(self):
|
386 |
+
"""Clean up resources when processing is complete."""
|
387 |
+
for task in self.tasks:
|
388 |
+
task.cancel()
|
389 |
+
|
390 |
+
try:
|
391 |
+
await asyncio.gather(*self.tasks, return_exceptions=True)
|
392 |
+
self.ffmpeg_process.stdin.close()
|
393 |
+
self.ffmpeg_process.wait()
|
394 |
+
except Exception as e:
|
395 |
+
logger.warning(f"Error during cleanup: {e}")
|
396 |
+
|
397 |
+
if self.args.diarization and hasattr(self, 'diarization'):
|
398 |
+
self.diarization.close()
|
399 |
+
|
400 |
+
async def process_audio(self, message):
|
401 |
+
"""Process incoming audio data."""
|
402 |
+
try:
|
403 |
+
self.ffmpeg_process.stdin.write(message)
|
404 |
+
self.ffmpeg_process.stdin.flush()
|
405 |
+
except (BrokenPipeError, AttributeError) as e:
|
406 |
+
logger.warning(f"Error writing to FFmpeg: {e}. Restarting...")
|
407 |
+
await self.restart_ffmpeg()
|
408 |
+
self.ffmpeg_process.stdin.write(message)
|
409 |
+
self.ffmpeg_process.stdin.flush()
|
core.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from whisper_streaming_custom.whisper_online import backend_factory, warmup_asr
|
2 |
+
from argparse import Namespace, ArgumentParser
|
3 |
+
|
4 |
+
class WhisperLiveKit:
|
5 |
+
_instance = None
|
6 |
+
_initialized = False
|
7 |
+
|
8 |
+
def __new__(cls, *args, **kwargs):
|
9 |
+
if cls._instance is None:
|
10 |
+
cls._instance = super().__new__(cls)
|
11 |
+
return cls._instance
|
12 |
+
|
13 |
+
def __init__(self, args=None, **kwargs):
|
14 |
+
if WhisperLiveKit._initialized:
|
15 |
+
return
|
16 |
+
|
17 |
+
if args is None:
|
18 |
+
args = Namespace(
|
19 |
+
host="localhost",
|
20 |
+
port=8000,
|
21 |
+
warmup_file=None,
|
22 |
+
confidence_validation=False,
|
23 |
+
diarization=False,
|
24 |
+
transcription=True,
|
25 |
+
min_chunk_size=0.5,
|
26 |
+
model="base",
|
27 |
+
model_cache_dir=None,
|
28 |
+
model_dir=None,
|
29 |
+
lan="auto",
|
30 |
+
task="transcribe",
|
31 |
+
backend="faster-whisper",
|
32 |
+
vac=False,
|
33 |
+
vac_chunk_size=0.04,
|
34 |
+
vad=True,
|
35 |
+
buffer_trimming="sentence",
|
36 |
+
buffer_trimming_sec=1.0,
|
37 |
+
log_level="INFO"
|
38 |
+
)
|
39 |
+
|
40 |
+
self.args = args
|
41 |
+
|
42 |
+
self.asr = None
|
43 |
+
self.tokenizer = None
|
44 |
+
self.diarization = None
|
45 |
+
|
46 |
+
if self.args.transcription:
|
47 |
+
self.asr, self.tokenizer = backend_factory(self.args)
|
48 |
+
warmup_asr(self.asr, self.args.warmup_file)
|
49 |
+
|
50 |
+
if self.args.diarization:
|
51 |
+
from diarization.diarization_online import DiartDiarization
|
52 |
+
self.diarization = DiartDiarization()
|
53 |
+
|
54 |
+
WhisperLiveKit._initialized = True
|
55 |
+
|
56 |
+
def web_interface(self):
|
57 |
+
import pkg_resources
|
58 |
+
html_path = pkg_resources.resource_filename('whisperlivekit', 'web/live_transcription.html')
|
59 |
+
with open(html_path, "r", encoding="utf-8") as f:
|
60 |
+
html = f.read()
|
61 |
+
return html
|
diarization/__init__.py
ADDED
File without changes
|
diarization/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (172 Bytes). View file
|
|
diarization/__pycache__/diarization_online.cpython-310.pyc
ADDED
Binary file (6.63 kB). View file
|
|
diarization/diarization_online.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import re
|
3 |
+
import threading
|
4 |
+
import numpy as np
|
5 |
+
import logging
|
6 |
+
|
7 |
+
|
8 |
+
from diart import SpeakerDiarization, SpeakerDiarizationConfig
|
9 |
+
from diart.inference import StreamingInference
|
10 |
+
from diart.sources import AudioSource
|
11 |
+
from timed_objects import SpeakerSegment
|
12 |
+
from diart.sources import MicrophoneAudioSource
|
13 |
+
from rx.core import Observer
|
14 |
+
from typing import Tuple, Any, List
|
15 |
+
from pyannote.core import Annotation
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
def extract_number(s: str) -> int:
|
20 |
+
m = re.search(r'\d+', s)
|
21 |
+
return int(m.group()) if m else None
|
22 |
+
|
23 |
+
class DiarizationObserver(Observer):
|
24 |
+
"""Observer that logs all data emitted by the diarization pipeline and stores speaker segments."""
|
25 |
+
|
26 |
+
def __init__(self):
|
27 |
+
self.speaker_segments = []
|
28 |
+
self.processed_time = 0
|
29 |
+
self.segment_lock = threading.Lock()
|
30 |
+
|
31 |
+
def on_next(self, value: Tuple[Annotation, Any]):
|
32 |
+
annotation, audio = value
|
33 |
+
|
34 |
+
logger.debug("\n--- New Diarization Result ---")
|
35 |
+
|
36 |
+
duration = audio.extent.end - audio.extent.start
|
37 |
+
logger.debug(f"Audio segment: {audio.extent.start:.2f}s - {audio.extent.end:.2f}s (duration: {duration:.2f}s)")
|
38 |
+
logger.debug(f"Audio shape: {audio.data.shape}")
|
39 |
+
|
40 |
+
with self.segment_lock:
|
41 |
+
if audio.extent.end > self.processed_time:
|
42 |
+
self.processed_time = audio.extent.end
|
43 |
+
if annotation and len(annotation._labels) > 0:
|
44 |
+
logger.debug("\nSpeaker segments:")
|
45 |
+
for speaker, label in annotation._labels.items():
|
46 |
+
for start, end in zip(label.segments_boundaries_[:-1], label.segments_boundaries_[1:]):
|
47 |
+
print(f" {speaker}: {start:.2f}s-{end:.2f}s")
|
48 |
+
self.speaker_segments.append(SpeakerSegment(
|
49 |
+
speaker=speaker,
|
50 |
+
start=start,
|
51 |
+
end=end
|
52 |
+
))
|
53 |
+
else:
|
54 |
+
logger.debug("\nNo speakers detected in this segment")
|
55 |
+
|
56 |
+
def get_segments(self) -> List[SpeakerSegment]:
|
57 |
+
"""Get a copy of the current speaker segments."""
|
58 |
+
with self.segment_lock:
|
59 |
+
return self.speaker_segments.copy()
|
60 |
+
|
61 |
+
def clear_old_segments(self, older_than: float = 30.0):
|
62 |
+
"""Clear segments older than the specified time."""
|
63 |
+
with self.segment_lock:
|
64 |
+
current_time = self.processed_time
|
65 |
+
self.speaker_segments = [
|
66 |
+
segment for segment in self.speaker_segments
|
67 |
+
if current_time - segment.end < older_than
|
68 |
+
]
|
69 |
+
|
70 |
+
def on_error(self, error):
|
71 |
+
"""Handle an error in the stream."""
|
72 |
+
logger.debug(f"Error in diarization stream: {error}")
|
73 |
+
|
74 |
+
def on_completed(self):
|
75 |
+
"""Handle the completion of the stream."""
|
76 |
+
logger.debug("Diarization stream completed")
|
77 |
+
|
78 |
+
|
79 |
+
class WebSocketAudioSource(AudioSource):
|
80 |
+
"""
|
81 |
+
Custom AudioSource that blocks in read() until close() is called.
|
82 |
+
Use push_audio() to inject PCM chunks.
|
83 |
+
"""
|
84 |
+
def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
|
85 |
+
super().__init__(uri, sample_rate)
|
86 |
+
self._closed = False
|
87 |
+
self._close_event = threading.Event()
|
88 |
+
|
89 |
+
def read(self):
|
90 |
+
self._close_event.wait()
|
91 |
+
|
92 |
+
def close(self):
|
93 |
+
if not self._closed:
|
94 |
+
self._closed = True
|
95 |
+
self.stream.on_completed()
|
96 |
+
self._close_event.set()
|
97 |
+
|
98 |
+
def push_audio(self, chunk: np.ndarray):
|
99 |
+
if not self._closed:
|
100 |
+
new_audio = np.expand_dims(chunk, axis=0)
|
101 |
+
logger.debug('Add new chunk with shape:', new_audio.shape)
|
102 |
+
self.stream.on_next(new_audio)
|
103 |
+
|
104 |
+
|
105 |
+
class DiartDiarization:
|
106 |
+
def __init__(self, sample_rate: int = 16000, config : SpeakerDiarizationConfig = None, use_microphone: bool = False):
|
107 |
+
self.pipeline = SpeakerDiarization(config=config)
|
108 |
+
self.observer = DiarizationObserver()
|
109 |
+
|
110 |
+
if use_microphone:
|
111 |
+
self.source = MicrophoneAudioSource()
|
112 |
+
self.custom_source = None
|
113 |
+
else:
|
114 |
+
self.custom_source = WebSocketAudioSource(uri="websocket_source", sample_rate=sample_rate)
|
115 |
+
self.source = self.custom_source
|
116 |
+
|
117 |
+
self.inference = StreamingInference(
|
118 |
+
pipeline=self.pipeline,
|
119 |
+
source=self.source,
|
120 |
+
do_plot=False,
|
121 |
+
show_progress=False,
|
122 |
+
)
|
123 |
+
self.inference.attach_observers(self.observer)
|
124 |
+
asyncio.get_event_loop().run_in_executor(None, self.inference)
|
125 |
+
|
126 |
+
async def diarize(self, pcm_array: np.ndarray):
|
127 |
+
"""
|
128 |
+
Process audio data for diarization.
|
129 |
+
Only used when working with WebSocketAudioSource.
|
130 |
+
"""
|
131 |
+
if self.custom_source:
|
132 |
+
self.custom_source.push_audio(pcm_array)
|
133 |
+
self.observer.clear_old_segments()
|
134 |
+
return self.observer.get_segments()
|
135 |
+
|
136 |
+
def close(self):
|
137 |
+
"""Close the audio source."""
|
138 |
+
if self.custom_source:
|
139 |
+
self.custom_source.close()
|
140 |
+
|
141 |
+
def assign_speakers_to_tokens(self, end_attributed_speaker, tokens: list) -> float:
|
142 |
+
"""
|
143 |
+
Assign speakers to tokens based on timing overlap with speaker segments.
|
144 |
+
Uses the segments collected by the observer.
|
145 |
+
"""
|
146 |
+
segments = self.observer.get_segments()
|
147 |
+
|
148 |
+
for token in tokens:
|
149 |
+
for segment in segments:
|
150 |
+
if not (segment.end <= token.start or segment.start >= token.end):
|
151 |
+
token.speaker = extract_number(segment.speaker) + 1
|
152 |
+
end_attributed_speaker = max(token.end, end_attributed_speaker)
|
153 |
+
return end_attributed_speaker
|
docker-compose.yml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.8'
|
2 |
+
|
3 |
+
services:
|
4 |
+
websocket-server:
|
5 |
+
build:
|
6 |
+
context: .
|
7 |
+
dockerfile: Dockerfile
|
8 |
+
ports:
|
9 |
+
- "8000:8000"
|
10 |
+
environment:
|
11 |
+
- PYTHONUNBUFFERED=1
|
12 |
+
volumes:
|
13 |
+
- .:/app
|
14 |
+
networks:
|
15 |
+
- app-network
|
16 |
+
healthcheck:
|
17 |
+
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
18 |
+
interval: 30s
|
19 |
+
timeout: 10s
|
20 |
+
retries: 3
|
21 |
+
|
22 |
+
frontend:
|
23 |
+
build:
|
24 |
+
context: ./frontend
|
25 |
+
dockerfile: Dockerfile
|
26 |
+
ports:
|
27 |
+
- "80:80"
|
28 |
+
depends_on:
|
29 |
+
websocket-server:
|
30 |
+
condition: service_healthy
|
31 |
+
networks:
|
32 |
+
- app-network
|
33 |
+
|
34 |
+
networks:
|
35 |
+
app-network:
|
36 |
+
driver: bridge
|
main.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from contextlib import asynccontextmanager
|
2 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
from fastapi.responses import JSONResponse
|
5 |
+
from fastapi.staticfiles import StaticFiles
|
6 |
+
from fastapi.responses import FileResponse
|
7 |
+
import asyncio
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import traceback
|
11 |
+
import argparse
|
12 |
+
import uvicorn
|
13 |
+
|
14 |
+
from core import WhisperLiveKit
|
15 |
+
from audio_processor import AudioProcessor
|
16 |
+
|
17 |
+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
18 |
+
logging.getLogger().setLevel(logging.WARNING)
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
logger.setLevel(logging.DEBUG)
|
21 |
+
|
22 |
+
kit = None
|
23 |
+
|
24 |
+
@asynccontextmanager
|
25 |
+
async def lifespan(app: FastAPI):
|
26 |
+
global kit
|
27 |
+
kit = WhisperLiveKit()
|
28 |
+
yield
|
29 |
+
|
30 |
+
app = FastAPI(lifespan=lifespan)
|
31 |
+
|
32 |
+
app.add_middleware(
|
33 |
+
CORSMiddleware,
|
34 |
+
allow_origins=["*"], # Allows all origins
|
35 |
+
allow_credentials=True,
|
36 |
+
allow_methods=["*"], # Allows all methods
|
37 |
+
allow_headers=["*"], # Allows all headers
|
38 |
+
)
|
39 |
+
|
40 |
+
# Mount static files
|
41 |
+
app.mount("/static", StaticFiles(directory="static"), name="static")
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
@app.get("/")
|
46 |
+
async def read_root():
|
47 |
+
return FileResponse("static/index.html")
|
48 |
+
|
49 |
+
@app.get("/health")
|
50 |
+
async def health_check():
|
51 |
+
return JSONResponse({"status": "healthy"})
|
52 |
+
|
53 |
+
async def handle_websocket_results(websocket, results_generator):
|
54 |
+
"""Consumes results from the audio processor and sends them via WebSocket."""
|
55 |
+
try:
|
56 |
+
async for response in results_generator:
|
57 |
+
try:
|
58 |
+
logger.debug(f"Sending response: {response}")
|
59 |
+
if isinstance(response, dict):
|
60 |
+
# Ensure the response has a consistent format
|
61 |
+
if 'buffer_transcription' in response:
|
62 |
+
await websocket.send_json({
|
63 |
+
'buffer_transcription': response['buffer_transcription']
|
64 |
+
})
|
65 |
+
elif 'full_transcription' in response:
|
66 |
+
await websocket.send_json({
|
67 |
+
'full_transcription': response['full_transcription']
|
68 |
+
})
|
69 |
+
else:
|
70 |
+
await websocket.send_json(response)
|
71 |
+
else:
|
72 |
+
# If response is not a dict, wrap it in a text field
|
73 |
+
await websocket.send_json({"text": str(response)})
|
74 |
+
except Exception as e:
|
75 |
+
logger.error(f"Error sending message: {e}")
|
76 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
77 |
+
raise
|
78 |
+
except Exception as e:
|
79 |
+
logger.warning(f"Error in WebSocket results handler: {e}")
|
80 |
+
logger.warning(f"Traceback: {traceback.format_exc()}")
|
81 |
+
|
82 |
+
@app.websocket("/asr")
|
83 |
+
async def websocket_endpoint(websocket: WebSocket):
|
84 |
+
logger.info("New WebSocket connection request")
|
85 |
+
audio_processor = None
|
86 |
+
websocket_task = None
|
87 |
+
|
88 |
+
try:
|
89 |
+
await websocket.accept()
|
90 |
+
logger.info("WebSocket connection accepted")
|
91 |
+
|
92 |
+
audio_processor = AudioProcessor()
|
93 |
+
results_generator = await audio_processor.create_tasks()
|
94 |
+
websocket_task = asyncio.create_task(handle_websocket_results(websocket, results_generator))
|
95 |
+
|
96 |
+
while True:
|
97 |
+
try:
|
98 |
+
message = await websocket.receive_bytes()
|
99 |
+
logger.debug(f"Received audio chunk of size: {len(message)}")
|
100 |
+
await audio_processor.process_audio(message)
|
101 |
+
except WebSocketDisconnect:
|
102 |
+
logger.warning("WebSocket disconnected.")
|
103 |
+
break
|
104 |
+
except Exception as e:
|
105 |
+
logger.error(f"Error processing audio chunk: {e}")
|
106 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
107 |
+
break
|
108 |
+
|
109 |
+
except WebSocketDisconnect:
|
110 |
+
logger.warning("WebSocket disconnected during setup.")
|
111 |
+
except Exception as e:
|
112 |
+
logger.error(f"Error in WebSocket endpoint: {e}")
|
113 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
114 |
+
finally:
|
115 |
+
if websocket_task:
|
116 |
+
websocket_task.cancel()
|
117 |
+
try:
|
118 |
+
await websocket_task
|
119 |
+
except asyncio.CancelledError:
|
120 |
+
pass
|
121 |
+
if audio_processor:
|
122 |
+
await audio_processor.cleanup()
|
123 |
+
logger.info("WebSocket endpoint cleaned up.")
|
124 |
+
|
125 |
+
def parse_args():
|
126 |
+
parser = argparse.ArgumentParser(description="Whisper FastAPI Online Server")
|
127 |
+
parser.add_argument(
|
128 |
+
"--host",
|
129 |
+
type=str,
|
130 |
+
default="localhost",
|
131 |
+
help="The host address to bind the server to.",
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--port", type=int, default=8000, help="The port number to bind the server to."
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--warmup-file",
|
138 |
+
type=str,
|
139 |
+
default=None,
|
140 |
+
dest="warmup_file",
|
141 |
+
help="""
|
142 |
+
The path to a speech audio wav file to warm up Whisper so that the very first chunk processing is fast.
|
143 |
+
If not set, uses https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav.
|
144 |
+
If False, no warmup is performed.
|
145 |
+
""",
|
146 |
+
)
|
147 |
+
|
148 |
+
parser.add_argument(
|
149 |
+
"--confidence-validation",
|
150 |
+
action="store_true",
|
151 |
+
help="Accelerates validation of tokens using confidence scores. Transcription will be faster but punctuation might be less accurate.",
|
152 |
+
)
|
153 |
+
|
154 |
+
parser.add_argument(
|
155 |
+
"--diarization",
|
156 |
+
action="store_true",
|
157 |
+
default=False,
|
158 |
+
help="Enable speaker diarization.",
|
159 |
+
)
|
160 |
+
|
161 |
+
parser.add_argument(
|
162 |
+
"--no-transcription",
|
163 |
+
action="store_true",
|
164 |
+
help="Disable transcription to only see live diarization results.",
|
165 |
+
)
|
166 |
+
|
167 |
+
parser.add_argument(
|
168 |
+
"--min-chunk-size",
|
169 |
+
type=float,
|
170 |
+
default=0.5,
|
171 |
+
help="Minimum audio chunk size in seconds. It waits up to this time to do processing. If the processing takes shorter time, it waits, otherwise it processes the whole segment that was received by this time.",
|
172 |
+
)
|
173 |
+
|
174 |
+
parser.add_argument(
|
175 |
+
"--model",
|
176 |
+
type=str,
|
177 |
+
default="tiny",
|
178 |
+
help="Name size of the Whisper model to use (default: tiny). Suggested values: tiny.en,tiny,base.en,base,small.en,small,medium.en,medium,large-v1,large-v2,large-v3,large,large-v3-turbo. The model is automatically downloaded from the model hub if not present in model cache dir.",
|
179 |
+
)
|
180 |
+
|
181 |
+
parser.add_argument(
|
182 |
+
"--model_cache_dir",
|
183 |
+
type=str,
|
184 |
+
default=None,
|
185 |
+
help="Overriding the default model cache dir where models downloaded from the hub are saved",
|
186 |
+
)
|
187 |
+
parser.add_argument(
|
188 |
+
"--model_dir",
|
189 |
+
type=str,
|
190 |
+
default=None,
|
191 |
+
help="Dir where Whisper model.bin and other files are saved. This option overrides --model and --model_cache_dir parameter.",
|
192 |
+
)
|
193 |
+
parser.add_argument(
|
194 |
+
"--lan",
|
195 |
+
"--language",
|
196 |
+
type=str,
|
197 |
+
default="en",
|
198 |
+
help="Source language code, e.g. en,de,cs, or 'auto' for language detection.",
|
199 |
+
)
|
200 |
+
parser.add_argument(
|
201 |
+
"--task",
|
202 |
+
type=str,
|
203 |
+
default="transcribe",
|
204 |
+
choices=["transcribe", "translate"],
|
205 |
+
help="Transcribe or translate.",
|
206 |
+
)
|
207 |
+
parser.add_argument(
|
208 |
+
"--backend",
|
209 |
+
type=str,
|
210 |
+
default="faster-whisper",
|
211 |
+
choices=["faster-whisper", "whisper_timestamped", "mlx-whisper", "openai-api"],
|
212 |
+
help="Load only this backend for Whisper processing.",
|
213 |
+
)
|
214 |
+
parser.add_argument(
|
215 |
+
"--vac",
|
216 |
+
action="store_true",
|
217 |
+
default=False,
|
218 |
+
help="Use VAC = voice activity controller. Recommended. Requires torch.",
|
219 |
+
)
|
220 |
+
parser.add_argument(
|
221 |
+
"--vac-chunk-size", type=float, default=0.04, help="VAC sample size in seconds."
|
222 |
+
)
|
223 |
+
parser.add_argument(
|
224 |
+
"--no-vad",
|
225 |
+
action="store_true",
|
226 |
+
default=False,
|
227 |
+
help="Disable VAD = voice activity detection. Not recommended.",
|
228 |
+
)
|
229 |
+
parser.add_argument(
|
230 |
+
"--buffer_trimming",
|
231 |
+
type=str,
|
232 |
+
default="sentence",
|
233 |
+
choices=["sentence", "segment"],
|
234 |
+
help="Buffer trimming strategy.",
|
235 |
+
)
|
236 |
+
parser.add_argument(
|
237 |
+
"--buffer_trimming_sec",
|
238 |
+
type=float,
|
239 |
+
default=1.0,
|
240 |
+
help="Buffer trimming length in seconds.",
|
241 |
+
)
|
242 |
+
parser.add_argument(
|
243 |
+
"-l",
|
244 |
+
"--log-level",
|
245 |
+
type=str,
|
246 |
+
default="INFO",
|
247 |
+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
248 |
+
help="Set the logging level.",
|
249 |
+
)
|
250 |
+
|
251 |
+
args = parser.parse_args()
|
252 |
+
|
253 |
+
args.transcription = not args.no_transcription
|
254 |
+
args.vad = not args.no_vad
|
255 |
+
delattr(args, 'no_transcription')
|
256 |
+
delattr(args, 'no_vad')
|
257 |
+
|
258 |
+
return args
|
259 |
+
|
260 |
+
def main():
|
261 |
+
args = parse_args()
|
262 |
+
|
263 |
+
# Initialize WhisperLiveKit with parsed arguments
|
264 |
+
kit = WhisperLiveKit(args=args)
|
265 |
+
|
266 |
+
# Start the server
|
267 |
+
uvicorn.run(
|
268 |
+
"main:app",
|
269 |
+
host=args.host,
|
270 |
+
port=args.port,
|
271 |
+
log_level=args.log_level.lower()
|
272 |
+
)
|
273 |
+
|
274 |
+
if __name__ == "__main__":
|
275 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
fastapi>=0.68.0
|
2 |
+
uvicorn[standart]
|
3 |
+
python-multipart>=0.0.5
|
4 |
+
numpy>=1.21.0
|
5 |
+
ffmpeg-python>=0.2.0
|
6 |
+
torch>=2.0.0
|
7 |
+
torchaudio>=2.0.0
|
8 |
+
faster-whisper>=0.9.0
|
9 |
+
websockets>=10.0
|
10 |
+
pydantic>=1.8.0
|
11 |
+
python-dotenv>=0.19.0
|
12 |
+
setuptools>=65.5.1
|
13 |
+
librosa>=0.10.0
|
14 |
+
mosestokenizer
|
15 |
+
hf_xet
|
setup.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
setup(
|
4 |
+
name="whisperlivekit",
|
5 |
+
version="0.1.0",
|
6 |
+
packages=find_packages(),
|
7 |
+
install_requires=[
|
8 |
+
"fastapi>=0.68.0",
|
9 |
+
"uvicorn>=0.15.0",
|
10 |
+
"python-multipart>=0.0.5",
|
11 |
+
"numpy>=1.21.0",
|
12 |
+
"ffmpeg-python>=0.2.0",
|
13 |
+
"torch>=2.0.0",
|
14 |
+
"torchaudio>=2.0.0",
|
15 |
+
"faster-whisper>=0.9.0",
|
16 |
+
"websockets>=10.0",
|
17 |
+
"pydantic>=1.8.0",
|
18 |
+
"python-dotenv>=0.19.0",
|
19 |
+
"setuptools>=65.5.1",
|
20 |
+
],
|
21 |
+
python_requires=">=3.10",
|
22 |
+
)
|
static/index.html
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Whisper Live Transcription</title>
|
7 |
+
<style>
|
8 |
+
body {
|
9 |
+
font-family: Arial, sans-serif;
|
10 |
+
max-width: 800px;
|
11 |
+
margin: 0 auto;
|
12 |
+
padding: 20px;
|
13 |
+
background-color: #f5f5f5;
|
14 |
+
}
|
15 |
+
.container {
|
16 |
+
background-color: white;
|
17 |
+
padding: 20px;
|
18 |
+
border-radius: 8px;
|
19 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
20 |
+
}
|
21 |
+
.controls {
|
22 |
+
margin: 20px 0;
|
23 |
+
}
|
24 |
+
button {
|
25 |
+
padding: 10px 20px;
|
26 |
+
margin: 5px;
|
27 |
+
border: none;
|
28 |
+
border-radius: 4px;
|
29 |
+
background-color: #007bff;
|
30 |
+
color: white;
|
31 |
+
cursor: pointer;
|
32 |
+
}
|
33 |
+
button:disabled {
|
34 |
+
background-color: #cccccc;
|
35 |
+
cursor: not-allowed;
|
36 |
+
}
|
37 |
+
#clearBtn {
|
38 |
+
background-color: #dc3545;
|
39 |
+
}
|
40 |
+
#clearBtn:hover {
|
41 |
+
background-color: #c82333;
|
42 |
+
}
|
43 |
+
#transcription {
|
44 |
+
margin-top: 20px;
|
45 |
+
padding: 15px;
|
46 |
+
border: 1px solid #ddd;
|
47 |
+
border-radius: 4px;
|
48 |
+
min-height: 100px;
|
49 |
+
white-space: pre-wrap;
|
50 |
+
font-family: monospace;
|
51 |
+
line-height: 1.5;
|
52 |
+
}
|
53 |
+
.status {
|
54 |
+
margin: 10px 0;
|
55 |
+
padding: 10px;
|
56 |
+
border-radius: 4px;
|
57 |
+
}
|
58 |
+
.config {
|
59 |
+
margin: 10px 0;
|
60 |
+
padding: 10px;
|
61 |
+
background-color: #f8f9fa;
|
62 |
+
border-radius: 4px;
|
63 |
+
}
|
64 |
+
.config input {
|
65 |
+
margin-left: 10px;
|
66 |
+
padding: 5px;
|
67 |
+
border: 1px solid #ddd;
|
68 |
+
border-radius: 4px;
|
69 |
+
}
|
70 |
+
.connected {
|
71 |
+
background-color: #d4edda;
|
72 |
+
color: #155724;
|
73 |
+
}
|
74 |
+
.disconnected {
|
75 |
+
background-color: #f8d7da;
|
76 |
+
color: #721c24;
|
77 |
+
}
|
78 |
+
#error-message {
|
79 |
+
color: #721c24;
|
80 |
+
margin: 10px 0;
|
81 |
+
padding: 10px;
|
82 |
+
background-color: #f8d7da;
|
83 |
+
border-radius: 4px;
|
84 |
+
display: none;
|
85 |
+
}
|
86 |
+
</style>
|
87 |
+
</head>
|
88 |
+
<body>
|
89 |
+
<div class="container">
|
90 |
+
<h1>Whisper Live Transcription</h1>
|
91 |
+
<div class="config">
|
92 |
+
<label for="wsUrl">WebSocket URL:</label>
|
93 |
+
<input type="text" id="wsUrl" value="ws://localhost:7860/asr" style="width: 300px;">
|
94 |
+
</div>
|
95 |
+
<div id="status" class="status disconnected">Disconnected</div>
|
96 |
+
<div id="error-message"></div>
|
97 |
+
<div class="controls">
|
98 |
+
<button id="startBtn">Start Recording</button>
|
99 |
+
<button id="stopBtn" disabled>Stop Recording</button>
|
100 |
+
<button id="reconnectBtn">Reconnect</button>
|
101 |
+
<button id="clearBtn">Clear Transcription</button>
|
102 |
+
</div>
|
103 |
+
<div id="transcription"></div>
|
104 |
+
</div>
|
105 |
+
|
106 |
+
<script>
|
107 |
+
let ws = null;
|
108 |
+
let mediaRecorder = null;
|
109 |
+
let audioChunks = [];
|
110 |
+
let reconnectAttempts = 0;
|
111 |
+
const maxReconnectAttempts = 5;
|
112 |
+
const reconnectDelay = 2000; // 2 seconds
|
113 |
+
|
114 |
+
const startBtn = document.getElementById('startBtn');
|
115 |
+
const stopBtn = document.getElementById('stopBtn');
|
116 |
+
const reconnectBtn = document.getElementById('reconnectBtn');
|
117 |
+
const clearBtn = document.getElementById('clearBtn');
|
118 |
+
const wsUrlInput = document.getElementById('wsUrl');
|
119 |
+
const statusDiv = document.getElementById('status');
|
120 |
+
const transcriptionDiv = document.getElementById('transcription');
|
121 |
+
const errorMessageDiv = document.getElementById('error-message');
|
122 |
+
|
123 |
+
function showError(message) {
|
124 |
+
errorMessageDiv.textContent = message;
|
125 |
+
errorMessageDiv.style.display = 'block';
|
126 |
+
}
|
127 |
+
|
128 |
+
function hideError() {
|
129 |
+
errorMessageDiv.style.display = 'none';
|
130 |
+
}
|
131 |
+
|
132 |
+
function updateStatus(connected) {
|
133 |
+
statusDiv.textContent = connected ? 'Connected' : 'Disconnected';
|
134 |
+
statusDiv.className = `status ${connected ? 'connected' : 'disconnected'}`;
|
135 |
+
reconnectBtn.disabled = connected;
|
136 |
+
}
|
137 |
+
|
138 |
+
function connectWebSocket() {
|
139 |
+
if (ws) {
|
140 |
+
ws.close();
|
141 |
+
}
|
142 |
+
|
143 |
+
const wsUrl = wsUrlInput.value;
|
144 |
+
console.log('Attempting to connect to:', wsUrl);
|
145 |
+
ws = new WebSocket(wsUrl);
|
146 |
+
|
147 |
+
ws.onopen = () => {
|
148 |
+
console.log('WebSocket connection established');
|
149 |
+
updateStatus(true);
|
150 |
+
startBtn.disabled = false;
|
151 |
+
stopBtn.disabled = true;
|
152 |
+
hideError();
|
153 |
+
reconnectAttempts = 0;
|
154 |
+
};
|
155 |
+
|
156 |
+
ws.onclose = (event) => {
|
157 |
+
console.log('WebSocket connection closed:', event.code, event.reason);
|
158 |
+
updateStatus(false);
|
159 |
+
startBtn.disabled = true;
|
160 |
+
stopBtn.disabled = true;
|
161 |
+
|
162 |
+
if (event.code === 1006) {
|
163 |
+
showError('Connection lost. Click "Reconnect" to try again.');
|
164 |
+
}
|
165 |
+
};
|
166 |
+
|
167 |
+
ws.onerror = (error) => {
|
168 |
+
console.error('WebSocket error:', error);
|
169 |
+
updateStatus(false);
|
170 |
+
showError('Connection error occurred. Click "Reconnect" to try again.');
|
171 |
+
};
|
172 |
+
|
173 |
+
ws.onmessage = (event) => {
|
174 |
+
try {
|
175 |
+
console.log('Received message:', event.data);
|
176 |
+
const response = JSON.parse(event.data);
|
177 |
+
console.log('Parsed response:', response);
|
178 |
+
|
179 |
+
if (response.text) {
|
180 |
+
console.log('Adding text:', response.text);
|
181 |
+
transcriptionDiv.textContent += response.text + '\n';
|
182 |
+
} else if (response.partial) {
|
183 |
+
console.log('Adding partial text:', response.partial);
|
184 |
+
transcriptionDiv.textContent = transcriptionDiv.textContent.replace(/[^.!?]+$/, '') + response.partial;
|
185 |
+
} else if (response.error) {
|
186 |
+
console.error('Server error:', response.error);
|
187 |
+
showError('Server error: ' + response.error);
|
188 |
+
} else if (response.buffer_transcription) {
|
189 |
+
console.log('Adding buffer transcription:', response.buffer_transcription);
|
190 |
+
transcriptionDiv.textContent += response.buffer_transcription + '\n';
|
191 |
+
} else if (response.full_transcription) {
|
192 |
+
console.log('Adding full transcription:', response.full_transcription);
|
193 |
+
transcriptionDiv.textContent += response.full_transcription + '\n';
|
194 |
+
} else if (typeof response === 'string') {
|
195 |
+
console.log('Adding raw text:', response);
|
196 |
+
transcriptionDiv.textContent += response + '\n';
|
197 |
+
}
|
198 |
+
} catch (error) {
|
199 |
+
console.error('Error parsing message:', error);
|
200 |
+
console.error('Raw message:', event.data);
|
201 |
+
}
|
202 |
+
};
|
203 |
+
}
|
204 |
+
|
205 |
+
async function startRecording() {
|
206 |
+
try {
|
207 |
+
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
|
208 |
+
mediaRecorder = new MediaRecorder(stream);
|
209 |
+
audioChunks = [];
|
210 |
+
|
211 |
+
mediaRecorder.ondataavailable = (event) => {
|
212 |
+
if (event.data.size > 0 && ws && ws.readyState === WebSocket.OPEN) {
|
213 |
+
ws.send(event.data);
|
214 |
+
}
|
215 |
+
};
|
216 |
+
|
217 |
+
mediaRecorder.start(100); // Send chunks every 100ms
|
218 |
+
startBtn.disabled = true;
|
219 |
+
stopBtn.disabled = false;
|
220 |
+
} catch (error) {
|
221 |
+
console.error('Error accessing microphone:', error);
|
222 |
+
showError('Error accessing microphone. Please ensure you have granted microphone permissions.');
|
223 |
+
}
|
224 |
+
}
|
225 |
+
|
226 |
+
function stopRecording() {
|
227 |
+
if (mediaRecorder && mediaRecorder.state !== 'inactive') {
|
228 |
+
mediaRecorder.stop();
|
229 |
+
mediaRecorder.stream.getTracks().forEach(track => track.stop());
|
230 |
+
startBtn.disabled = false;
|
231 |
+
stopBtn.disabled = true;
|
232 |
+
}
|
233 |
+
}
|
234 |
+
|
235 |
+
function clearTranscription() {
|
236 |
+
transcriptionDiv.textContent = '';
|
237 |
+
}
|
238 |
+
|
239 |
+
startBtn.addEventListener('click', startRecording);
|
240 |
+
stopBtn.addEventListener('click', stopRecording);
|
241 |
+
reconnectBtn.addEventListener('click', connectWebSocket);
|
242 |
+
clearBtn.addEventListener('click', clearTranscription);
|
243 |
+
|
244 |
+
// Connect to WebSocket when page loads
|
245 |
+
connectWebSocket();
|
246 |
+
</script>
|
247 |
+
</body>
|
248 |
+
</html>
|
timed_objects.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
@dataclass
|
5 |
+
class TimedText:
|
6 |
+
start: Optional[float]
|
7 |
+
end: Optional[float]
|
8 |
+
text: Optional[str] = ''
|
9 |
+
speaker: Optional[int] = -1
|
10 |
+
probability: Optional[float] = None
|
11 |
+
is_dummy: Optional[bool] = False
|
12 |
+
|
13 |
+
@dataclass
|
14 |
+
class ASRToken(TimedText):
|
15 |
+
def with_offset(self, offset: float) -> "ASRToken":
|
16 |
+
"""Return a new token with the time offset added."""
|
17 |
+
return ASRToken(self.start + offset, self.end + offset, self.text, self.speaker, self.probability)
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class Sentence(TimedText):
|
21 |
+
pass
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class Transcript(TimedText):
|
25 |
+
pass
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class SpeakerSegment(TimedText):
|
29 |
+
pass
|
whisper_streaming_custom/__init__.py
ADDED
File without changes
|
whisper_streaming_custom/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (185 Bytes). View file
|
|
whisper_streaming_custom/__pycache__/backends.cpython-310.pyc
ADDED
Binary file (11.5 kB). View file
|
|
whisper_streaming_custom/__pycache__/online_asr.cpython-310.pyc
ADDED
Binary file (16 kB). View file
|
|
whisper_streaming_custom/__pycache__/whisper_online.cpython-310.pyc
ADDED
Binary file (5.5 kB). View file
|
|
whisper_streaming_custom/backends.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import logging
|
3 |
+
import io
|
4 |
+
import soundfile as sf
|
5 |
+
import math
|
6 |
+
try:
|
7 |
+
import torch
|
8 |
+
except ImportError:
|
9 |
+
torch = None
|
10 |
+
from typing import List
|
11 |
+
import numpy as np
|
12 |
+
from timed_objects import ASRToken
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
class ASRBase:
|
17 |
+
sep = " " # join transcribe words with this character (" " for whisper_timestamped,
|
18 |
+
# "" for faster-whisper because it emits the spaces when needed)
|
19 |
+
|
20 |
+
def __init__(self, lan, modelsize=None, cache_dir=None, model_dir=None, logfile=sys.stderr):
|
21 |
+
self.logfile = logfile
|
22 |
+
self.transcribe_kargs = {}
|
23 |
+
if lan == "auto":
|
24 |
+
self.original_language = None
|
25 |
+
else:
|
26 |
+
self.original_language = lan
|
27 |
+
self.model = self.load_model(modelsize, cache_dir, model_dir)
|
28 |
+
|
29 |
+
def with_offset(self, offset: float) -> ASRToken:
|
30 |
+
# This method is kept for compatibility (typically you will use ASRToken.with_offset)
|
31 |
+
return ASRToken(self.start + offset, self.end + offset, self.text)
|
32 |
+
|
33 |
+
def __repr__(self):
|
34 |
+
return f"ASRToken(start={self.start:.2f}, end={self.end:.2f}, text={self.text!r})"
|
35 |
+
|
36 |
+
def load_model(self, modelsize, cache_dir, model_dir):
|
37 |
+
raise NotImplementedError("must be implemented in the child class")
|
38 |
+
|
39 |
+
def transcribe(self, audio, init_prompt=""):
|
40 |
+
raise NotImplementedError("must be implemented in the child class")
|
41 |
+
|
42 |
+
def use_vad(self):
|
43 |
+
raise NotImplementedError("must be implemented in the child class")
|
44 |
+
|
45 |
+
|
46 |
+
class WhisperTimestampedASR(ASRBase):
|
47 |
+
"""Uses whisper_timestamped as the backend."""
|
48 |
+
sep = " "
|
49 |
+
|
50 |
+
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
51 |
+
import whisper
|
52 |
+
import whisper_timestamped
|
53 |
+
from whisper_timestamped import transcribe_timestamped
|
54 |
+
|
55 |
+
self.transcribe_timestamped = transcribe_timestamped
|
56 |
+
if model_dir is not None:
|
57 |
+
logger.debug("ignoring model_dir, not implemented")
|
58 |
+
return whisper.load_model(modelsize, download_root=cache_dir)
|
59 |
+
|
60 |
+
def transcribe(self, audio, init_prompt=""):
|
61 |
+
result = self.transcribe_timestamped(
|
62 |
+
self.model,
|
63 |
+
audio,
|
64 |
+
language=self.original_language,
|
65 |
+
initial_prompt=init_prompt,
|
66 |
+
verbose=None,
|
67 |
+
condition_on_previous_text=True,
|
68 |
+
**self.transcribe_kargs,
|
69 |
+
)
|
70 |
+
return result
|
71 |
+
|
72 |
+
def ts_words(self, r) -> List[ASRToken]:
|
73 |
+
"""
|
74 |
+
Converts the whisper_timestamped result to a list of ASRToken objects.
|
75 |
+
"""
|
76 |
+
tokens = []
|
77 |
+
for segment in r["segments"]:
|
78 |
+
for word in segment["words"]:
|
79 |
+
token = ASRToken(word["start"], word["end"], word["text"])
|
80 |
+
tokens.append(token)
|
81 |
+
return tokens
|
82 |
+
|
83 |
+
def segments_end_ts(self, res) -> List[float]:
|
84 |
+
return [segment["end"] for segment in res["segments"]]
|
85 |
+
|
86 |
+
def use_vad(self):
|
87 |
+
self.transcribe_kargs["vad"] = True
|
88 |
+
|
89 |
+
def set_translate_task(self):
|
90 |
+
self.transcribe_kargs["task"] = "translate"
|
91 |
+
|
92 |
+
|
93 |
+
class FasterWhisperASR(ASRBase):
|
94 |
+
"""Uses faster-whisper as the backend."""
|
95 |
+
sep = ""
|
96 |
+
|
97 |
+
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
98 |
+
from faster_whisper import WhisperModel
|
99 |
+
|
100 |
+
if model_dir is not None:
|
101 |
+
logger.debug(f"Loading whisper model from model_dir {model_dir}. "
|
102 |
+
f"modelsize and cache_dir parameters are not used.")
|
103 |
+
model_size_or_path = model_dir
|
104 |
+
elif modelsize is not None:
|
105 |
+
model_size_or_path = modelsize
|
106 |
+
else:
|
107 |
+
raise ValueError("Either modelsize or model_dir must be set")
|
108 |
+
device = "cuda" if torch and torch.cuda.is_available() else "cpu"
|
109 |
+
compute_type = "float16" if device == "cuda" else "float32"
|
110 |
+
|
111 |
+
model = WhisperModel(
|
112 |
+
model_size_or_path,
|
113 |
+
device=device,
|
114 |
+
compute_type=compute_type,
|
115 |
+
download_root=cache_dir,
|
116 |
+
)
|
117 |
+
return model
|
118 |
+
|
119 |
+
def transcribe(self, audio: np.ndarray, init_prompt: str = "") -> list:
|
120 |
+
segments, info = self.model.transcribe(
|
121 |
+
audio,
|
122 |
+
language=self.original_language,
|
123 |
+
initial_prompt=init_prompt,
|
124 |
+
beam_size=5,
|
125 |
+
word_timestamps=True,
|
126 |
+
condition_on_previous_text=True,
|
127 |
+
**self.transcribe_kargs,
|
128 |
+
)
|
129 |
+
return list(segments)
|
130 |
+
|
131 |
+
def ts_words(self, segments) -> List[ASRToken]:
|
132 |
+
tokens = []
|
133 |
+
for segment in segments:
|
134 |
+
if segment.no_speech_prob > 0.9:
|
135 |
+
continue
|
136 |
+
for word in segment.words:
|
137 |
+
token = ASRToken(word.start, word.end, word.word, probability=word.probability)
|
138 |
+
tokens.append(token)
|
139 |
+
return tokens
|
140 |
+
|
141 |
+
def segments_end_ts(self, segments) -> List[float]:
|
142 |
+
return [segment.end for segment in segments]
|
143 |
+
|
144 |
+
def use_vad(self):
|
145 |
+
self.transcribe_kargs["vad_filter"] = True
|
146 |
+
|
147 |
+
def set_translate_task(self):
|
148 |
+
self.transcribe_kargs["task"] = "translate"
|
149 |
+
|
150 |
+
|
151 |
+
class MLXWhisper(ASRBase):
|
152 |
+
"""
|
153 |
+
Uses MLX Whisper optimized for Apple Silicon.
|
154 |
+
"""
|
155 |
+
sep = ""
|
156 |
+
|
157 |
+
def load_model(self, modelsize=None, cache_dir=None, model_dir=None):
|
158 |
+
from mlx_whisper.transcribe import ModelHolder, transcribe
|
159 |
+
import mlx.core as mx
|
160 |
+
|
161 |
+
if model_dir is not None:
|
162 |
+
logger.debug(f"Loading whisper model from model_dir {model_dir}. modelsize parameter is not used.")
|
163 |
+
model_size_or_path = model_dir
|
164 |
+
elif modelsize is not None:
|
165 |
+
model_size_or_path = self.translate_model_name(modelsize)
|
166 |
+
logger.debug(f"Loading whisper model {modelsize}. You use mlx whisper, so {model_size_or_path} will be used.")
|
167 |
+
else:
|
168 |
+
raise ValueError("Either modelsize or model_dir must be set")
|
169 |
+
|
170 |
+
self.model_size_or_path = model_size_or_path
|
171 |
+
dtype = mx.float16
|
172 |
+
ModelHolder.get_model(model_size_or_path, dtype)
|
173 |
+
return transcribe
|
174 |
+
|
175 |
+
def translate_model_name(self, model_name):
|
176 |
+
model_mapping = {
|
177 |
+
"tiny.en": "mlx-community/whisper-tiny.en-mlx",
|
178 |
+
"tiny": "mlx-community/whisper-tiny-mlx",
|
179 |
+
"base.en": "mlx-community/whisper-base.en-mlx",
|
180 |
+
"base": "mlx-community/whisper-base-mlx",
|
181 |
+
"small.en": "mlx-community/whisper-small.en-mlx",
|
182 |
+
"small": "mlx-community/whisper-small-mlx",
|
183 |
+
"medium.en": "mlx-community/whisper-medium.en-mlx",
|
184 |
+
"medium": "mlx-community/whisper-medium-mlx",
|
185 |
+
"large-v1": "mlx-community/whisper-large-v1-mlx",
|
186 |
+
"large-v2": "mlx-community/whisper-large-v2-mlx",
|
187 |
+
"large-v3": "mlx-community/whisper-large-v3-mlx",
|
188 |
+
"large-v3-turbo": "mlx-community/whisper-large-v3-turbo",
|
189 |
+
"large": "mlx-community/whisper-large-mlx",
|
190 |
+
}
|
191 |
+
mlx_model_path = model_mapping.get(model_name)
|
192 |
+
if mlx_model_path:
|
193 |
+
return mlx_model_path
|
194 |
+
else:
|
195 |
+
raise ValueError(f"Model name '{model_name}' is not recognized or not supported.")
|
196 |
+
|
197 |
+
def transcribe(self, audio, init_prompt=""):
|
198 |
+
if self.transcribe_kargs:
|
199 |
+
logger.warning("Transcribe kwargs (vad, task) are not compatible with MLX Whisper and will be ignored.")
|
200 |
+
segments = self.model(
|
201 |
+
audio,
|
202 |
+
language=self.original_language,
|
203 |
+
initial_prompt=init_prompt,
|
204 |
+
word_timestamps=True,
|
205 |
+
condition_on_previous_text=True,
|
206 |
+
path_or_hf_repo=self.model_size_or_path,
|
207 |
+
)
|
208 |
+
return segments.get("segments", [])
|
209 |
+
|
210 |
+
def ts_words(self, segments) -> List[ASRToken]:
|
211 |
+
tokens = []
|
212 |
+
for segment in segments:
|
213 |
+
if segment.get("no_speech_prob", 0) > 0.9:
|
214 |
+
continue
|
215 |
+
for word in segment.get("words", []):
|
216 |
+
token = ASRToken(word["start"], word["end"], word["word"], probability=word["probability"])
|
217 |
+
tokens.append(token)
|
218 |
+
return tokens
|
219 |
+
|
220 |
+
def segments_end_ts(self, res) -> List[float]:
|
221 |
+
return [s["end"] for s in res]
|
222 |
+
|
223 |
+
def use_vad(self):
|
224 |
+
self.transcribe_kargs["vad_filter"] = True
|
225 |
+
|
226 |
+
def set_translate_task(self):
|
227 |
+
self.transcribe_kargs["task"] = "translate"
|
228 |
+
|
229 |
+
|
230 |
+
class OpenaiApiASR(ASRBase):
|
231 |
+
"""Uses OpenAI's Whisper API for transcription."""
|
232 |
+
def __init__(self, lan=None, temperature=0, logfile=sys.stderr):
|
233 |
+
self.logfile = logfile
|
234 |
+
self.modelname = "whisper-1"
|
235 |
+
self.original_language = None if lan == "auto" else lan
|
236 |
+
self.response_format = "verbose_json"
|
237 |
+
self.temperature = temperature
|
238 |
+
self.load_model()
|
239 |
+
self.use_vad_opt = False
|
240 |
+
self.task = "transcribe"
|
241 |
+
|
242 |
+
def load_model(self, *args, **kwargs):
|
243 |
+
from openai import OpenAI
|
244 |
+
self.client = OpenAI()
|
245 |
+
self.transcribed_seconds = 0
|
246 |
+
|
247 |
+
def ts_words(self, segments) -> List[ASRToken]:
|
248 |
+
"""
|
249 |
+
Converts OpenAI API response words into ASRToken objects while
|
250 |
+
optionally skipping words that fall into no-speech segments.
|
251 |
+
"""
|
252 |
+
no_speech_segments = []
|
253 |
+
if self.use_vad_opt:
|
254 |
+
for segment in segments.segments:
|
255 |
+
if segment.no_speech_prob > 0.8:
|
256 |
+
no_speech_segments.append((segment.start, segment.end))
|
257 |
+
tokens = []
|
258 |
+
for word in segments.words:
|
259 |
+
start = word.start
|
260 |
+
end = word.end
|
261 |
+
if any(s[0] <= start <= s[1] for s in no_speech_segments):
|
262 |
+
continue
|
263 |
+
tokens.append(ASRToken(start, end, word.word))
|
264 |
+
return tokens
|
265 |
+
|
266 |
+
def segments_end_ts(self, res) -> List[float]:
|
267 |
+
return [s.end for s in res.words]
|
268 |
+
|
269 |
+
def transcribe(self, audio_data, prompt=None, *args, **kwargs):
|
270 |
+
buffer = io.BytesIO()
|
271 |
+
buffer.name = "temp.wav"
|
272 |
+
sf.write(buffer, audio_data, samplerate=16000, format="WAV", subtype="PCM_16")
|
273 |
+
buffer.seek(0)
|
274 |
+
self.transcribed_seconds += math.ceil(len(audio_data) / 16000)
|
275 |
+
params = {
|
276 |
+
"model": self.modelname,
|
277 |
+
"file": buffer,
|
278 |
+
"response_format": self.response_format,
|
279 |
+
"temperature": self.temperature,
|
280 |
+
"timestamp_granularities": ["word", "segment"],
|
281 |
+
}
|
282 |
+
if self.task != "translate" and self.original_language:
|
283 |
+
params["language"] = self.original_language
|
284 |
+
if prompt:
|
285 |
+
params["prompt"] = prompt
|
286 |
+
proc = self.client.audio.translations if self.task == "translate" else self.client.audio.transcriptions
|
287 |
+
transcript = proc.create(**params)
|
288 |
+
logger.debug(f"OpenAI API processed accumulated {self.transcribed_seconds} seconds")
|
289 |
+
return transcript
|
290 |
+
|
291 |
+
def use_vad(self):
|
292 |
+
self.use_vad_opt = True
|
293 |
+
|
294 |
+
def set_translate_task(self):
|
295 |
+
self.task = "translate"
|
whisper_streaming_custom/online_asr.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import logging
|
4 |
+
from typing import List, Tuple, Optional
|
5 |
+
from timed_objects import ASRToken, Sentence, Transcript
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
|
10 |
+
class HypothesisBuffer:
|
11 |
+
"""
|
12 |
+
Buffer to store and process ASR hypothesis tokens.
|
13 |
+
|
14 |
+
It holds:
|
15 |
+
- committed_in_buffer: tokens that have been confirmed (committed)
|
16 |
+
- buffer: the last hypothesis that is not yet committed
|
17 |
+
- new: new tokens coming from the recognizer
|
18 |
+
"""
|
19 |
+
def __init__(self, logfile=sys.stderr, confidence_validation=False):
|
20 |
+
self.confidence_validation = confidence_validation
|
21 |
+
self.committed_in_buffer: List[ASRToken] = []
|
22 |
+
self.buffer: List[ASRToken] = []
|
23 |
+
self.new: List[ASRToken] = []
|
24 |
+
self.last_committed_time = 0.0
|
25 |
+
self.last_committed_word: Optional[str] = None
|
26 |
+
self.logfile = logfile
|
27 |
+
|
28 |
+
def insert(self, new_tokens: List[ASRToken], offset: float):
|
29 |
+
"""
|
30 |
+
Insert new tokens (after applying a time offset) and compare them with the
|
31 |
+
already committed tokens. Only tokens that extend the committed hypothesis
|
32 |
+
are added.
|
33 |
+
"""
|
34 |
+
# Apply the offset to each token.
|
35 |
+
new_tokens = [token.with_offset(offset) for token in new_tokens]
|
36 |
+
# Only keep tokens that are roughly "new"
|
37 |
+
self.new = [token for token in new_tokens if token.start > self.last_committed_time - 0.1]
|
38 |
+
|
39 |
+
if self.new:
|
40 |
+
first_token = self.new[0]
|
41 |
+
if abs(first_token.start - self.last_committed_time) < 1:
|
42 |
+
if self.committed_in_buffer:
|
43 |
+
committed_len = len(self.committed_in_buffer)
|
44 |
+
new_len = len(self.new)
|
45 |
+
# Try to match 1 to 5 consecutive tokens
|
46 |
+
max_ngram = min(min(committed_len, new_len), 5)
|
47 |
+
for i in range(1, max_ngram + 1):
|
48 |
+
committed_ngram = " ".join(token.text for token in self.committed_in_buffer[-i:])
|
49 |
+
new_ngram = " ".join(token.text for token in self.new[:i])
|
50 |
+
if committed_ngram == new_ngram:
|
51 |
+
removed = []
|
52 |
+
for _ in range(i):
|
53 |
+
removed_token = self.new.pop(0)
|
54 |
+
removed.append(repr(removed_token))
|
55 |
+
logger.debug(f"Removing last {i} words: {' '.join(removed)}")
|
56 |
+
break
|
57 |
+
|
58 |
+
def flush(self) -> List[ASRToken]:
|
59 |
+
"""
|
60 |
+
Returns the committed chunk, defined as the longest common prefix
|
61 |
+
between the previous hypothesis and the new tokens.
|
62 |
+
"""
|
63 |
+
committed: List[ASRToken] = []
|
64 |
+
while self.new:
|
65 |
+
current_new = self.new[0]
|
66 |
+
if self.confidence_validation and current_new.probability and current_new.probability > 0.95:
|
67 |
+
committed.append(current_new)
|
68 |
+
self.last_committed_word = current_new.text
|
69 |
+
self.last_committed_time = current_new.end
|
70 |
+
self.new.pop(0)
|
71 |
+
self.buffer.pop(0) if self.buffer else None
|
72 |
+
elif not self.buffer:
|
73 |
+
break
|
74 |
+
elif current_new.text == self.buffer[0].text:
|
75 |
+
committed.append(current_new)
|
76 |
+
self.last_committed_word = current_new.text
|
77 |
+
self.last_committed_time = current_new.end
|
78 |
+
self.buffer.pop(0)
|
79 |
+
self.new.pop(0)
|
80 |
+
else:
|
81 |
+
break
|
82 |
+
self.buffer = self.new
|
83 |
+
self.new = []
|
84 |
+
self.committed_in_buffer.extend(committed)
|
85 |
+
return committed
|
86 |
+
|
87 |
+
def pop_committed(self, time: float):
|
88 |
+
"""
|
89 |
+
Remove tokens (from the beginning) that have ended before `time`.
|
90 |
+
"""
|
91 |
+
while self.committed_in_buffer and self.committed_in_buffer[0].end <= time:
|
92 |
+
self.committed_in_buffer.pop(0)
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
class OnlineASRProcessor:
|
97 |
+
"""
|
98 |
+
Processes incoming audio in a streaming fashion, calling the ASR system
|
99 |
+
periodically, and uses a hypothesis buffer to commit and trim recognized text.
|
100 |
+
|
101 |
+
The processor supports two types of buffer trimming:
|
102 |
+
- "sentence": trims at sentence boundaries (using a sentence tokenizer)
|
103 |
+
- "segment": trims at fixed segment durations.
|
104 |
+
"""
|
105 |
+
SAMPLING_RATE = 16000
|
106 |
+
|
107 |
+
def __init__(
|
108 |
+
self,
|
109 |
+
asr,
|
110 |
+
tokenize_method: Optional[callable] = None,
|
111 |
+
buffer_trimming: Tuple[str, float] = ("segment", 15),
|
112 |
+
confidence_validation = False,
|
113 |
+
logfile=sys.stderr,
|
114 |
+
):
|
115 |
+
"""
|
116 |
+
asr: An ASR system object (for example, a WhisperASR instance) that
|
117 |
+
provides a `transcribe` method, a `ts_words` method (to extract tokens),
|
118 |
+
a `segments_end_ts` method, and a separator attribute `sep`.
|
119 |
+
tokenize_method: A function that receives text and returns a list of sentence strings.
|
120 |
+
buffer_trimming: A tuple (option, seconds), where option is either "sentence" or "segment".
|
121 |
+
"""
|
122 |
+
self.asr = asr
|
123 |
+
self.tokenize = tokenize_method
|
124 |
+
self.logfile = logfile
|
125 |
+
self.confidence_validation = confidence_validation
|
126 |
+
self.init()
|
127 |
+
|
128 |
+
self.buffer_trimming_way, self.buffer_trimming_sec = buffer_trimming
|
129 |
+
|
130 |
+
if self.buffer_trimming_way not in ["sentence", "segment"]:
|
131 |
+
raise ValueError("buffer_trimming must be either 'sentence' or 'segment'")
|
132 |
+
if self.buffer_trimming_sec <= 0:
|
133 |
+
raise ValueError("buffer_trimming_sec must be positive")
|
134 |
+
elif self.buffer_trimming_sec > 30:
|
135 |
+
logger.warning(
|
136 |
+
f"buffer_trimming_sec is set to {self.buffer_trimming_sec}, which is very long. It may cause OOM."
|
137 |
+
)
|
138 |
+
|
139 |
+
def init(self, offset: Optional[float] = None):
|
140 |
+
"""Initialize or reset the processing buffers."""
|
141 |
+
self.audio_buffer = np.array([], dtype=np.float32)
|
142 |
+
self.transcript_buffer = HypothesisBuffer(logfile=self.logfile, confidence_validation=self.confidence_validation)
|
143 |
+
self.buffer_time_offset = offset if offset is not None else 0.0
|
144 |
+
self.transcript_buffer.last_committed_time = self.buffer_time_offset
|
145 |
+
self.committed: List[ASRToken] = []
|
146 |
+
|
147 |
+
def insert_audio_chunk(self, audio: np.ndarray):
|
148 |
+
"""Append an audio chunk (a numpy array) to the current audio buffer."""
|
149 |
+
self.audio_buffer = np.append(self.audio_buffer, audio)
|
150 |
+
|
151 |
+
def prompt(self) -> Tuple[str, str]:
|
152 |
+
"""
|
153 |
+
Returns a tuple: (prompt, context), where:
|
154 |
+
- prompt is a 200-character suffix of committed text that falls
|
155 |
+
outside the current audio buffer.
|
156 |
+
- context is the committed text within the current audio buffer.
|
157 |
+
"""
|
158 |
+
k = len(self.committed)
|
159 |
+
while k > 0 and self.committed[k - 1].end > self.buffer_time_offset:
|
160 |
+
k -= 1
|
161 |
+
|
162 |
+
prompt_tokens = self.committed[:k]
|
163 |
+
prompt_words = [token.text for token in prompt_tokens]
|
164 |
+
prompt_list = []
|
165 |
+
length_count = 0
|
166 |
+
# Use the last words until reaching 200 characters.
|
167 |
+
while prompt_words and length_count < 200:
|
168 |
+
word = prompt_words.pop(-1)
|
169 |
+
length_count += len(word) + 1
|
170 |
+
prompt_list.append(word)
|
171 |
+
non_prompt_tokens = self.committed[k:]
|
172 |
+
context_text = self.asr.sep.join(token.text for token in non_prompt_tokens)
|
173 |
+
return self.asr.sep.join(prompt_list[::-1]), context_text
|
174 |
+
|
175 |
+
def get_buffer(self):
|
176 |
+
"""
|
177 |
+
Get the unvalidated buffer in string format.
|
178 |
+
"""
|
179 |
+
return self.concatenate_tokens(self.transcript_buffer.buffer)
|
180 |
+
|
181 |
+
|
182 |
+
def process_iter(self) -> Transcript:
|
183 |
+
"""
|
184 |
+
Processes the current audio buffer.
|
185 |
+
|
186 |
+
Returns a Transcript object representing the committed transcript.
|
187 |
+
"""
|
188 |
+
prompt_text, _ = self.prompt()
|
189 |
+
logger.debug(
|
190 |
+
f"Transcribing {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds from {self.buffer_time_offset:.2f}"
|
191 |
+
)
|
192 |
+
res = self.asr.transcribe(self.audio_buffer, init_prompt=prompt_text)
|
193 |
+
tokens = self.asr.ts_words(res) # Expecting List[ASRToken]
|
194 |
+
self.transcript_buffer.insert(tokens, self.buffer_time_offset)
|
195 |
+
committed_tokens = self.transcript_buffer.flush()
|
196 |
+
self.committed.extend(committed_tokens)
|
197 |
+
completed = self.concatenate_tokens(committed_tokens)
|
198 |
+
logger.debug(f">>>> COMPLETE NOW: {completed.text}")
|
199 |
+
incomp = self.concatenate_tokens(self.transcript_buffer.buffer)
|
200 |
+
logger.debug(f"INCOMPLETE: {incomp.text}")
|
201 |
+
|
202 |
+
if committed_tokens and self.buffer_trimming_way == "sentence":
|
203 |
+
if len(self.audio_buffer) / self.SAMPLING_RATE > self.buffer_trimming_sec:
|
204 |
+
self.chunk_completed_sentence()
|
205 |
+
|
206 |
+
s = self.buffer_trimming_sec if self.buffer_trimming_way == "segment" else 30
|
207 |
+
if len(self.audio_buffer) / self.SAMPLING_RATE > s:
|
208 |
+
self.chunk_completed_segment(res)
|
209 |
+
logger.debug("Chunking segment")
|
210 |
+
logger.debug(
|
211 |
+
f"Length of audio buffer now: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f} seconds"
|
212 |
+
)
|
213 |
+
return committed_tokens
|
214 |
+
|
215 |
+
def chunk_completed_sentence(self):
|
216 |
+
"""
|
217 |
+
If the committed tokens form at least two sentences, chunk the audio
|
218 |
+
buffer at the end time of the penultimate sentence.
|
219 |
+
"""
|
220 |
+
if not self.committed:
|
221 |
+
return
|
222 |
+
logger.debug("COMPLETED SENTENCE: " + " ".join(token.text for token in self.committed))
|
223 |
+
sentences = self.words_to_sentences(self.committed)
|
224 |
+
for sentence in sentences:
|
225 |
+
logger.debug(f"\tSentence: {sentence.text}")
|
226 |
+
if len(sentences) < 2:
|
227 |
+
return
|
228 |
+
# Keep the last two sentences.
|
229 |
+
while len(sentences) > 2:
|
230 |
+
sentences.pop(0)
|
231 |
+
chunk_time = sentences[-2].end
|
232 |
+
logger.debug(f"--- Sentence chunked at {chunk_time:.2f}")
|
233 |
+
self.chunk_at(chunk_time)
|
234 |
+
|
235 |
+
def chunk_completed_segment(self, res):
|
236 |
+
"""
|
237 |
+
Chunk the audio buffer based on segment-end timestamps reported by the ASR.
|
238 |
+
"""
|
239 |
+
if not self.committed:
|
240 |
+
return
|
241 |
+
ends = self.asr.segments_end_ts(res)
|
242 |
+
last_committed_time = self.committed[-1].end
|
243 |
+
if len(ends) > 1:
|
244 |
+
e = ends[-2] + self.buffer_time_offset
|
245 |
+
while len(ends) > 2 and e > last_committed_time:
|
246 |
+
ends.pop(-1)
|
247 |
+
e = ends[-2] + self.buffer_time_offset
|
248 |
+
if e <= last_committed_time:
|
249 |
+
logger.debug(f"--- Segment chunked at {e:.2f}")
|
250 |
+
self.chunk_at(e)
|
251 |
+
else:
|
252 |
+
logger.debug("--- Last segment not within committed area")
|
253 |
+
else:
|
254 |
+
logger.debug("--- Not enough segments to chunk")
|
255 |
+
|
256 |
+
def chunk_at(self, time: float):
|
257 |
+
"""
|
258 |
+
Trim both the hypothesis and audio buffer at the given time.
|
259 |
+
"""
|
260 |
+
logger.debug(f"Chunking at {time:.2f}s")
|
261 |
+
logger.debug(
|
262 |
+
f"Audio buffer length before chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
|
263 |
+
)
|
264 |
+
self.transcript_buffer.pop_committed(time)
|
265 |
+
cut_seconds = time - self.buffer_time_offset
|
266 |
+
self.audio_buffer = self.audio_buffer[int(cut_seconds * self.SAMPLING_RATE):]
|
267 |
+
self.buffer_time_offset = time
|
268 |
+
logger.debug(
|
269 |
+
f"Audio buffer length after chunking: {len(self.audio_buffer)/self.SAMPLING_RATE:.2f}s"
|
270 |
+
)
|
271 |
+
|
272 |
+
def words_to_sentences(self, tokens: List[ASRToken]) -> List[Sentence]:
|
273 |
+
"""
|
274 |
+
Converts a list of tokens to a list of Sentence objects using the provided
|
275 |
+
sentence tokenizer.
|
276 |
+
"""
|
277 |
+
if not tokens:
|
278 |
+
return []
|
279 |
+
|
280 |
+
full_text = " ".join(token.text for token in tokens)
|
281 |
+
|
282 |
+
if self.tokenize:
|
283 |
+
try:
|
284 |
+
sentence_texts = self.tokenize(full_text)
|
285 |
+
except Exception as e:
|
286 |
+
# Some tokenizers (e.g., MosesSentenceSplitter) expect a list input.
|
287 |
+
try:
|
288 |
+
sentence_texts = self.tokenize([full_text])
|
289 |
+
except Exception as e2:
|
290 |
+
raise ValueError("Tokenization failed") from e2
|
291 |
+
else:
|
292 |
+
sentence_texts = [full_text]
|
293 |
+
|
294 |
+
sentences: List[Sentence] = []
|
295 |
+
token_index = 0
|
296 |
+
for sent_text in sentence_texts:
|
297 |
+
sent_text = sent_text.strip()
|
298 |
+
if not sent_text:
|
299 |
+
continue
|
300 |
+
sent_tokens = []
|
301 |
+
accumulated = ""
|
302 |
+
# Accumulate tokens until roughly matching the length of the sentence text.
|
303 |
+
while token_index < len(tokens) and len(accumulated) < len(sent_text):
|
304 |
+
token = tokens[token_index]
|
305 |
+
accumulated = (accumulated + " " + token.text).strip() if accumulated else token.text
|
306 |
+
sent_tokens.append(token)
|
307 |
+
token_index += 1
|
308 |
+
if sent_tokens:
|
309 |
+
sentence = Sentence(
|
310 |
+
start=sent_tokens[0].start,
|
311 |
+
end=sent_tokens[-1].end,
|
312 |
+
text=" ".join(t.text for t in sent_tokens),
|
313 |
+
)
|
314 |
+
sentences.append(sentence)
|
315 |
+
return sentences
|
316 |
+
def finish(self) -> Transcript:
|
317 |
+
"""
|
318 |
+
Flush the remaining transcript when processing ends.
|
319 |
+
"""
|
320 |
+
remaining_tokens = self.transcript_buffer.buffer
|
321 |
+
final_transcript = self.concatenate_tokens(remaining_tokens)
|
322 |
+
logger.debug(f"Final non-committed transcript: {final_transcript}")
|
323 |
+
self.buffer_time_offset += len(self.audio_buffer) / self.SAMPLING_RATE
|
324 |
+
return final_transcript
|
325 |
+
|
326 |
+
def concatenate_tokens(
|
327 |
+
self,
|
328 |
+
tokens: List[ASRToken],
|
329 |
+
sep: Optional[str] = None,
|
330 |
+
offset: float = 0
|
331 |
+
) -> Transcript:
|
332 |
+
sep = sep if sep is not None else self.asr.sep
|
333 |
+
text = sep.join(token.text for token in tokens)
|
334 |
+
probability = sum(token.probability for token in tokens if token.probability) / len(tokens) if tokens else None
|
335 |
+
if tokens:
|
336 |
+
start = offset + tokens[0].start
|
337 |
+
end = offset + tokens[-1].end
|
338 |
+
else:
|
339 |
+
start = None
|
340 |
+
end = None
|
341 |
+
return Transcript(start, end, text, probability=probability)
|
342 |
+
|
343 |
+
|
344 |
+
class VACOnlineASRProcessor:
|
345 |
+
"""
|
346 |
+
Wraps an OnlineASRProcessor with a Voice Activity Controller (VAC).
|
347 |
+
|
348 |
+
It receives small chunks of audio, applies VAD (e.g. with Silero),
|
349 |
+
and when the system detects a pause in speech (or end of an utterance)
|
350 |
+
it finalizes the utterance immediately.
|
351 |
+
"""
|
352 |
+
SAMPLING_RATE = 16000
|
353 |
+
|
354 |
+
def __init__(self, online_chunk_size: float, *args, **kwargs):
|
355 |
+
self.online_chunk_size = online_chunk_size
|
356 |
+
self.online = OnlineASRProcessor(*args, **kwargs)
|
357 |
+
|
358 |
+
# Load a VAD model (e.g. Silero VAD)
|
359 |
+
import torch
|
360 |
+
model, _ = torch.hub.load(repo_or_dir="snakers4/silero-vad", model="silero_vad")
|
361 |
+
from silero_vad_iterator import FixedVADIterator
|
362 |
+
|
363 |
+
self.vac = FixedVADIterator(model)
|
364 |
+
self.logfile = self.online.logfile
|
365 |
+
self.init()
|
366 |
+
|
367 |
+
def init(self):
|
368 |
+
self.online.init()
|
369 |
+
self.vac.reset_states()
|
370 |
+
self.current_online_chunk_buffer_size = 0
|
371 |
+
self.is_currently_final = False
|
372 |
+
self.status: Optional[str] = None # "voice" or "nonvoice"
|
373 |
+
self.audio_buffer = np.array([], dtype=np.float32)
|
374 |
+
self.buffer_offset = 0 # in frames
|
375 |
+
|
376 |
+
def clear_buffer(self):
|
377 |
+
self.buffer_offset += len(self.audio_buffer)
|
378 |
+
self.audio_buffer = np.array([], dtype=np.float32)
|
379 |
+
|
380 |
+
def insert_audio_chunk(self, audio: np.ndarray):
|
381 |
+
"""
|
382 |
+
Process an incoming small audio chunk:
|
383 |
+
- run VAD on the chunk,
|
384 |
+
- decide whether to send the audio to the online ASR processor immediately,
|
385 |
+
- and/or to mark the current utterance as finished.
|
386 |
+
"""
|
387 |
+
res = self.vac(audio)
|
388 |
+
self.audio_buffer = np.append(self.audio_buffer, audio)
|
389 |
+
|
390 |
+
if res is not None:
|
391 |
+
# VAD returned a result; adjust the frame number
|
392 |
+
frame = list(res.values())[0] - self.buffer_offset
|
393 |
+
if "start" in res and "end" not in res:
|
394 |
+
self.status = "voice"
|
395 |
+
send_audio = self.audio_buffer[frame:]
|
396 |
+
self.online.init(offset=(frame + self.buffer_offset) / self.SAMPLING_RATE)
|
397 |
+
self.online.insert_audio_chunk(send_audio)
|
398 |
+
self.current_online_chunk_buffer_size += len(send_audio)
|
399 |
+
self.clear_buffer()
|
400 |
+
elif "end" in res and "start" not in res:
|
401 |
+
self.status = "nonvoice"
|
402 |
+
send_audio = self.audio_buffer[:frame]
|
403 |
+
self.online.insert_audio_chunk(send_audio)
|
404 |
+
self.current_online_chunk_buffer_size += len(send_audio)
|
405 |
+
self.is_currently_final = True
|
406 |
+
self.clear_buffer()
|
407 |
+
else:
|
408 |
+
beg = res["start"] - self.buffer_offset
|
409 |
+
end = res["end"] - self.buffer_offset
|
410 |
+
self.status = "nonvoice"
|
411 |
+
send_audio = self.audio_buffer[beg:end]
|
412 |
+
self.online.init(offset=(beg + self.buffer_offset) / self.SAMPLING_RATE)
|
413 |
+
self.online.insert_audio_chunk(send_audio)
|
414 |
+
self.current_online_chunk_buffer_size += len(send_audio)
|
415 |
+
self.is_currently_final = True
|
416 |
+
self.clear_buffer()
|
417 |
+
else:
|
418 |
+
if self.status == "voice":
|
419 |
+
self.online.insert_audio_chunk(self.audio_buffer)
|
420 |
+
self.current_online_chunk_buffer_size += len(self.audio_buffer)
|
421 |
+
self.clear_buffer()
|
422 |
+
else:
|
423 |
+
# Keep 1 second worth of audio in case VAD later detects voice,
|
424 |
+
# but trim to avoid unbounded memory usage.
|
425 |
+
self.buffer_offset += max(0, len(self.audio_buffer) - self.SAMPLING_RATE)
|
426 |
+
self.audio_buffer = self.audio_buffer[-self.SAMPLING_RATE:]
|
427 |
+
|
428 |
+
def process_iter(self) -> Transcript:
|
429 |
+
"""
|
430 |
+
Depending on the VAD status and the amount of accumulated audio,
|
431 |
+
process the current audio chunk.
|
432 |
+
"""
|
433 |
+
if self.is_currently_final:
|
434 |
+
return self.finish()
|
435 |
+
elif self.current_online_chunk_buffer_size > self.SAMPLING_RATE * self.online_chunk_size:
|
436 |
+
self.current_online_chunk_buffer_size = 0
|
437 |
+
return self.online.process_iter()
|
438 |
+
else:
|
439 |
+
logger.debug("No online update, only VAD")
|
440 |
+
return Transcript(None, None, "")
|
441 |
+
|
442 |
+
def finish(self) -> Transcript:
|
443 |
+
"""Finish processing by flushing any remaining text."""
|
444 |
+
result = self.online.finish()
|
445 |
+
self.current_online_chunk_buffer_size = 0
|
446 |
+
self.is_currently_final = False
|
447 |
+
return result
|
448 |
+
|
449 |
+
def get_buffer(self):
|
450 |
+
"""
|
451 |
+
Get the unvalidated buffer in string format.
|
452 |
+
"""
|
453 |
+
return self.online.concatenate_tokens(self.online.transcript_buffer.buffer).text
|
whisper_streaming_custom/whisper_online.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
import librosa
|
5 |
+
from functools import lru_cache
|
6 |
+
import time
|
7 |
+
import logging
|
8 |
+
from .backends import FasterWhisperASR, MLXWhisper, WhisperTimestampedASR, OpenaiApiASR
|
9 |
+
from .online_asr import OnlineASRProcessor, VACOnlineASRProcessor
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
WHISPER_LANG_CODES = "af,am,ar,as,az,ba,be,bg,bn,bo,br,bs,ca,cs,cy,da,de,el,en,es,et,eu,fa,fi,fo,fr,gl,gu,ha,haw,he,hi,hr,ht,hu,hy,id,is,it,ja,jw,ka,kk,km,kn,ko,la,lb,ln,lo,lt,lv,mg,mi,mk,ml,mn,mr,ms,mt,my,ne,nl,nn,no,oc,pa,pl,ps,pt,ro,ru,sa,sd,si,sk,sl,sn,so,sq,sr,su,sv,sw,ta,te,tg,th,tk,tl,tr,tt,uk,ur,uz,vi,yi,yo,zh".split(
|
16 |
+
","
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def create_tokenizer(lan):
|
21 |
+
"""returns an object that has split function that works like the one of MosesTokenizer"""
|
22 |
+
|
23 |
+
assert (
|
24 |
+
lan in WHISPER_LANG_CODES
|
25 |
+
), "language must be Whisper's supported lang code: " + " ".join(WHISPER_LANG_CODES)
|
26 |
+
|
27 |
+
if lan == "uk":
|
28 |
+
import tokenize_uk
|
29 |
+
|
30 |
+
class UkrainianTokenizer:
|
31 |
+
def split(self, text):
|
32 |
+
return tokenize_uk.tokenize_sents(text)
|
33 |
+
|
34 |
+
return UkrainianTokenizer()
|
35 |
+
|
36 |
+
# supported by fast-mosestokenizer
|
37 |
+
if (
|
38 |
+
lan
|
39 |
+
in "as bn ca cs de el en es et fi fr ga gu hi hu is it kn lt lv ml mni mr nl or pa pl pt ro ru sk sl sv ta te yue zh".split()
|
40 |
+
):
|
41 |
+
from mosestokenizer import MosesSentenceSplitter
|
42 |
+
|
43 |
+
return MosesSentenceSplitter(lan)
|
44 |
+
|
45 |
+
# the following languages are in Whisper, but not in wtpsplit:
|
46 |
+
if (
|
47 |
+
lan
|
48 |
+
in "as ba bo br bs fo haw hr ht jw lb ln lo mi nn oc sa sd sn so su sw tk tl tt".split()
|
49 |
+
):
|
50 |
+
logger.debug(
|
51 |
+
f"{lan} code is not supported by wtpsplit. Going to use None lang_code option."
|
52 |
+
)
|
53 |
+
lan = None
|
54 |
+
|
55 |
+
from wtpsplit import WtP
|
56 |
+
|
57 |
+
# downloads the model from huggingface on the first use
|
58 |
+
wtp = WtP("wtp-canine-s-12l-no-adapters")
|
59 |
+
|
60 |
+
class WtPtok:
|
61 |
+
def split(self, sent):
|
62 |
+
return wtp.split(sent, lang_code=lan)
|
63 |
+
|
64 |
+
return WtPtok()
|
65 |
+
|
66 |
+
|
67 |
+
def backend_factory(args):
|
68 |
+
backend = args.backend
|
69 |
+
if backend == "openai-api":
|
70 |
+
logger.debug("Using OpenAI API.")
|
71 |
+
asr = OpenaiApiASR(lan=args.lan)
|
72 |
+
else:
|
73 |
+
if backend == "faster-whisper":
|
74 |
+
asr_cls = FasterWhisperASR
|
75 |
+
elif backend == "mlx-whisper":
|
76 |
+
asr_cls = MLXWhisper
|
77 |
+
else:
|
78 |
+
asr_cls = WhisperTimestampedASR
|
79 |
+
|
80 |
+
# Only for FasterWhisperASR and WhisperTimestampedASR
|
81 |
+
size = args.model
|
82 |
+
t = time.time()
|
83 |
+
logger.info(f"Loading Whisper {size} model for language {args.lan}...")
|
84 |
+
asr = asr_cls(
|
85 |
+
modelsize=size,
|
86 |
+
lan=args.lan,
|
87 |
+
cache_dir=args.model_cache_dir,
|
88 |
+
model_dir=args.model_dir,
|
89 |
+
)
|
90 |
+
e = time.time()
|
91 |
+
logger.info(f"done. It took {round(e-t,2)} seconds.")
|
92 |
+
|
93 |
+
# Apply common configurations
|
94 |
+
if getattr(args, "vad", False): # Checks if VAD argument is present and True
|
95 |
+
logger.info("Setting VAD filter")
|
96 |
+
asr.use_vad()
|
97 |
+
|
98 |
+
language = args.lan
|
99 |
+
if args.task == "translate":
|
100 |
+
asr.set_translate_task()
|
101 |
+
tgt_language = "en" # Whisper translates into English
|
102 |
+
else:
|
103 |
+
tgt_language = language # Whisper transcribes in this language
|
104 |
+
|
105 |
+
# Create the tokenizer
|
106 |
+
if args.buffer_trimming == "sentence":
|
107 |
+
|
108 |
+
tokenizer = create_tokenizer(tgt_language)
|
109 |
+
else:
|
110 |
+
tokenizer = None
|
111 |
+
return asr, tokenizer
|
112 |
+
|
113 |
+
def online_factory(args, asr, tokenizer, logfile=sys.stderr):
|
114 |
+
if args.vac:
|
115 |
+
online = VACOnlineASRProcessor(
|
116 |
+
args.min_chunk_size,
|
117 |
+
asr,
|
118 |
+
tokenizer,
|
119 |
+
logfile=logfile,
|
120 |
+
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
121 |
+
confidence_validation = args.confidence_validation
|
122 |
+
)
|
123 |
+
else:
|
124 |
+
online = OnlineASRProcessor(
|
125 |
+
asr,
|
126 |
+
tokenizer,
|
127 |
+
logfile=logfile,
|
128 |
+
buffer_trimming=(args.buffer_trimming, args.buffer_trimming_sec),
|
129 |
+
confidence_validation = args.confidence_validation
|
130 |
+
)
|
131 |
+
return online
|
132 |
+
|
133 |
+
def asr_factory(args, logfile=sys.stderr):
|
134 |
+
"""
|
135 |
+
Creates and configures an ASR and ASR Online instance based on the specified backend and arguments.
|
136 |
+
"""
|
137 |
+
asr, tokenizer = backend_factory(args)
|
138 |
+
online = online_factory(args, asr, tokenizer, logfile=logfile)
|
139 |
+
return asr, online
|
140 |
+
|
141 |
+
def warmup_asr(asr, warmup_file=None, timeout=5):
|
142 |
+
"""
|
143 |
+
Warmup the ASR model by transcribing a short audio file.
|
144 |
+
"""
|
145 |
+
import os
|
146 |
+
import tempfile
|
147 |
+
|
148 |
+
|
149 |
+
if warmup_file is None:
|
150 |
+
# Download JFK sample if not already present
|
151 |
+
jfk_url = "https://github.com/ggerganov/whisper.cpp/raw/master/samples/jfk.wav"
|
152 |
+
temp_dir = tempfile.gettempdir()
|
153 |
+
warmup_file = os.path.join(temp_dir, "whisper_warmup_jfk.wav")
|
154 |
+
|
155 |
+
if not os.path.exists(warmup_file):
|
156 |
+
logger.debug(f"Downloading warmup file from {jfk_url}")
|
157 |
+
print(f"Downloading warmup file from {jfk_url}")
|
158 |
+
import time
|
159 |
+
import urllib.request
|
160 |
+
import urllib.error
|
161 |
+
import socket
|
162 |
+
|
163 |
+
original_timeout = socket.getdefaulttimeout()
|
164 |
+
socket.setdefaulttimeout(timeout)
|
165 |
+
|
166 |
+
start_time = time.time()
|
167 |
+
try:
|
168 |
+
urllib.request.urlretrieve(jfk_url, warmup_file)
|
169 |
+
logger.debug(f"Download successful in {time.time() - start_time:.2f}s")
|
170 |
+
except (urllib.error.URLError, socket.timeout) as e:
|
171 |
+
logger.warning(f"Download failed: {e}. Proceeding without warmup.")
|
172 |
+
return False
|
173 |
+
finally:
|
174 |
+
socket.setdefaulttimeout(original_timeout)
|
175 |
+
elif not warmup_file:
|
176 |
+
return False
|
177 |
+
|
178 |
+
if not warmup_file or not os.path.exists(warmup_file) or os.path.getsize(warmup_file) == 0:
|
179 |
+
logger.warning(f"Warmup file {warmup_file} invalid or missing.")
|
180 |
+
return False
|
181 |
+
|
182 |
+
print(f"Warmping up Whisper with {warmup_file}")
|
183 |
+
try:
|
184 |
+
import librosa
|
185 |
+
audio, sr = librosa.load(warmup_file, sr=16000)
|
186 |
+
except Exception as e:
|
187 |
+
logger.warning(f"Failed to load audio file: {e}")
|
188 |
+
return False
|
189 |
+
|
190 |
+
# Process the audio
|
191 |
+
asr.transcribe(audio)
|
192 |
+
|
193 |
+
logger.info("Whisper is warmed up")
|
194 |
+
|