tomer9080 commited on
Commit
f787576
·
1 Parent(s): 784aada

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ *.ipynb
2
+ old_demo_code/
3
+ htokenf.txt
4
+ __pycache__/
dockerfile ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use NVIDIA CUDA base image
2
+ FROM nvidia/cuda:11.8-devel-ubuntu20.04
3
+
4
+ # Set environment variables
5
+ ENV DEBIAN_FRONTEND=noninteractive
6
+ ENV PYTHONUNBUFFERED=1
7
+
8
+ # Install system dependencies
9
+ RUN apt-get update && apt-get install -y \
10
+ python3.9 \
11
+ python3.9-dev \
12
+ python3-pip \
13
+ ffmpeg \
14
+ git \
15
+ wget \
16
+ curl \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Create symlinks for python
20
+ RUN ln -s /usr/bin/python3.9 /usr/bin/python
21
+
22
+ # Set up a new user named "user" with user ID 1000
23
+ RUN useradd -m -u 1000 user
24
+
25
+ # Switch to the "user" user
26
+ USER user
27
+
28
+ # Set home to the user's home directory
29
+ ENV HOME=/home/user \
30
+ PATH=/home/user/.local/bin:$PATH
31
+
32
+ # Set the working directory to the user's home directory
33
+ WORKDIR $HOME/app
34
+
35
+ # Upgrade pip
36
+ RUN python -m pip install --no-cache-dir --upgrade pip
37
+
38
+ # Install PyTorch with CUDA support first
39
+ RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
40
+
41
+ # Copy requirements file
42
+ COPY --chown=user requirements.txt .
43
+
44
+ # Install other Python dependencies
45
+ RUN pip install --no-cache-dir -r requirements.txt
46
+
47
+ # Copy the application code
48
+ COPY --chown=user . .
49
+
50
+ # Expose the port
51
+ EXPOSE 7860
52
+
53
+ # Run the application
54
+ CMD ["python", "unified_socket_server.py", "--host", "0.0.0.0", "--port", "7860"]
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ evaluate==0.4.3
2
+ librosa==0.10.2.post1
3
+ more_itertools==10.7.0
4
+ numba==0.60.0
5
+ numpy==1.26.0
6
+ openai_whisper==20240930
7
+ pandas==2.3.1
8
+ praatio==6.2.0
9
+ pyaudio==0.2.11
10
+ pytorch_lightning==2.5.0.post0
11
+ regex==2024.11.6
12
+ soundfile==0.12.1
13
+ tiktoken==0.8.0
14
+ tqdm==4.66.5
15
+ triton==3.2.0
16
+ websockets==15.0.1
static/client.html ADDED
@@ -0,0 +1,567 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Real-Time Whisper Transcription</title>
7
+ <style>
8
+ body {
9
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
10
+ max-width: 1200px;
11
+ margin: 0 auto;
12
+ padding: 20px;
13
+ background-color: #f5f5f5;
14
+ }
15
+
16
+ .container {
17
+ background: white;
18
+ border-radius: 10px;
19
+ padding: 30px;
20
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
21
+ }
22
+
23
+ h1 {
24
+ color: #333;
25
+ text-align: center;
26
+ margin-bottom: 30px;
27
+ }
28
+
29
+ .config-section {
30
+ display: grid;
31
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
32
+ gap: 15px;
33
+ margin-bottom: 30px;
34
+ padding: 20px;
35
+ background: #f8f9fa;
36
+ border-radius: 8px;
37
+ }
38
+
39
+ .config-group {
40
+ display: flex;
41
+ flex-direction: column;
42
+ }
43
+
44
+ label {
45
+ font-weight: 600;
46
+ margin-bottom: 5px;
47
+ color: #555;
48
+ }
49
+
50
+ input, select {
51
+ padding: 10px;
52
+ border: 2px solid #ddd;
53
+ border-radius: 5px;
54
+ font-size: 14px;
55
+ transition: border-color 0.3s;
56
+ }
57
+
58
+ input:focus, select:focus {
59
+ outline: none;
60
+ border-color: #007bff;
61
+ }
62
+
63
+ .controls {
64
+ display: flex;
65
+ gap: 10px;
66
+ justify-content: center;
67
+ margin-bottom: 30px;
68
+ }
69
+
70
+ button {
71
+ padding: 12px 24px;
72
+ border: none;
73
+ border-radius: 5px;
74
+ font-size: 16px;
75
+ font-weight: 600;
76
+ cursor: pointer;
77
+ transition: all 0.3s;
78
+ }
79
+
80
+ .start-btn {
81
+ background: #28a745;
82
+ color: white;
83
+ }
84
+
85
+ .start-btn:hover:not(:disabled) {
86
+ background: #218838;
87
+ }
88
+
89
+ .stop-btn {
90
+ background: #dc3545;
91
+ color: white;
92
+ }
93
+
94
+ .stop-btn:hover:not(:disabled) {
95
+ background: #c82333;
96
+ }
97
+
98
+ .clear-btn {
99
+ background: #6c757d;
100
+ color: white;
101
+ }
102
+
103
+ .clear-btn:hover:not(:disabled) {
104
+ background: #5a6268;
105
+ }
106
+
107
+ button:disabled {
108
+ opacity: 0.6;
109
+ cursor: not-allowed;
110
+ }
111
+
112
+ .status {
113
+ display: flex;
114
+ align-items: center;
115
+ justify-content: center;
116
+ gap: 10px;
117
+ margin-bottom: 20px;
118
+ font-weight: 600;
119
+ }
120
+
121
+ .status-indicator {
122
+ width: 12px;
123
+ height: 12px;
124
+ border-radius: 50%;
125
+ background: #dc3545;
126
+ animation: pulse 2s infinite;
127
+ }
128
+
129
+ .status-indicator.connected {
130
+ background: #28a745;
131
+ }
132
+
133
+ .status-indicator.streaming {
134
+ background: #ffc107;
135
+ }
136
+
137
+ @keyframes pulse {
138
+ 0% { opacity: 1; }
139
+ 50% { opacity: 0.5; }
140
+ 100% { opacity: 1; }
141
+ }
142
+
143
+ .transcription-section {
144
+ display: grid;
145
+ grid-template-columns: 1fr 1fr;
146
+ gap: 20px;
147
+ }
148
+
149
+ .transcription-panel {
150
+ background: #f8f9fa;
151
+ border-radius: 8px;
152
+ padding: 20px;
153
+ }
154
+
155
+ .transcription-panel h3 {
156
+ margin-top: 0;
157
+ color: #333;
158
+ }
159
+
160
+ .log-area, .transcription-area {
161
+ background: #fff;
162
+ border: 1px solid #ddd;
163
+ border-radius: 5px;
164
+ padding: 15px;
165
+ height: 300px;
166
+ overflow-y: auto;
167
+ font-family: 'Courier New', monospace;
168
+ font-size: 14px;
169
+ line-height: 1.4;
170
+ white-space: pre-wrap;
171
+ }
172
+
173
+ .transcription-area {
174
+ font-family: inherit;
175
+ font-size: 16px;
176
+ line-height: 1.6;
177
+ }
178
+
179
+ .stats {
180
+ display: grid;
181
+ grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
182
+ gap: 15px;
183
+ margin-top: 20px;
184
+ }
185
+
186
+ .stat-item {
187
+ text-align: center;
188
+ padding: 10px;
189
+ background: #e9ecef;
190
+ border-radius: 5px;
191
+ }
192
+
193
+ .stat-value {
194
+ font-size: 18px;
195
+ font-weight: bold;
196
+ color: #007bff;
197
+ }
198
+
199
+ .stat-label {
200
+ font-size: 12px;
201
+ color: #666;
202
+ margin-top: 5px;
203
+ }
204
+
205
+ @media (max-width: 768px) {
206
+ .config-section {
207
+ grid-template-columns: 1fr;
208
+ }
209
+
210
+ .transcription-section {
211
+ grid-template-columns: 1fr;
212
+ }
213
+
214
+ .controls {
215
+ flex-direction: column;
216
+ align-items: center;
217
+ }
218
+
219
+ button {
220
+ width: 200px;
221
+ }
222
+ }
223
+ </style>
224
+ </head>
225
+ <body>
226
+ <div class="container">
227
+ <h1>🎤 Real-Time Whisper Transcription</h1>
228
+
229
+ <div class="config-section">
230
+ <div class="config-group">
231
+ <label for="modelSize">Model Size:</label>
232
+ <select id="modelSize">
233
+ <option value="base">Base (74 MB)</option>
234
+ <option value="small" selected>Small (244 MB)</option>
235
+ <option value="large-v2">Large-v2 (1550 MB)</option>
236
+ </select>
237
+ </div>
238
+
239
+ <div class="config-group">
240
+ <label for="chunkSize">Chunk Size (ms):</label>
241
+ <input type="number" id="chunkSize" value="300" min="100" max="2000" step="100">
242
+ </div>
243
+
244
+ <div class="config-group">
245
+ <label for="beamSize">Beam Size:</label>
246
+ <input type="number" id="beamSize" value="0" min="0" max="10">
247
+ </div>
248
+
249
+ <div class="config-group">
250
+ <label for="temperature">Temperature:</label>
251
+ <input type="number" id="temperature" value="0.0" min="0.0" max="1.0" step="0.1">
252
+ </div>
253
+ </div>
254
+
255
+ <div class="controls">
256
+ <button id="startBtn" class="start-btn" onclick="startStreaming()">🎤 Start Recording</button>
257
+ <button id="stopBtn" class="stop-btn" onclick="stopStreaming()" disabled>⏹️ Stop Recording</button>
258
+ <button id="clearBtn" class="clear-btn" onclick="clearAll()">🗑️ Clear All</button>
259
+ </div>
260
+
261
+ <div class="status">
262
+ <div class="status-indicator" id="statusIndicator"></div>
263
+ <span id="statusText">Disconnected</span>
264
+ </div>
265
+
266
+ <div class="transcription-section">
267
+ <div class="transcription-panel">
268
+ <h3>📝 Transcription</h3>
269
+ <div id="transcriptionArea" class="transcription-area"></div>
270
+ </div>
271
+
272
+ <div class="transcription-panel">
273
+ <h3>📋 System Log</h3>
274
+ <div id="logArea" class="log-area"></div>
275
+ </div>
276
+ </div>
277
+
278
+ <div class="stats">
279
+ <div class="stat-item">
280
+ <div class="stat-value" id="durationStat">0.0s</div>
281
+ <div class="stat-label">Duration</div>
282
+ </div>
283
+ <div class="stat-item">
284
+ <div class="stat-value" id="chunksStat">0</div>
285
+ <div class="stat-label">Chunks Sent</div>
286
+ </div>
287
+ <div class="stat-item">
288
+ <div class="stat-value" id="transcriptionsStat">0</div>
289
+ <div class="stat-label">Transcriptions</div>
290
+ </div>
291
+ <div class="stat-item">
292
+ <div class="stat-value" id="errorsStat">0</div>
293
+ <div class="stat-label">Errors</div>
294
+ </div>
295
+ </div>
296
+ </div>
297
+
298
+ <script>
299
+ let socket;
300
+ let audioContext, processor, micStream;
301
+ let isStreaming = false;
302
+ let startTime = 0;
303
+ let stats = {
304
+ chunks: 0,
305
+ transcriptions: 0,
306
+ errors: 0
307
+ };
308
+
309
+ // UI Elements
310
+ const startBtn = document.getElementById('startBtn');
311
+ const stopBtn = document.getElementById('stopBtn');
312
+ const statusIndicator = document.getElementById('statusIndicator');
313
+ const statusText = document.getElementById('statusText');
314
+ const logArea = document.getElementById('logArea');
315
+ const transcriptionArea = document.getElementById('transcriptionArea');
316
+
317
+ function log(message, isError = false) {
318
+ const timestamp = new Date().toLocaleTimeString();
319
+ const logMessage = `[${timestamp}] ${message}`;
320
+ logArea.textContent += logMessage + '\n';
321
+ logArea.scrollTop = logArea.scrollHeight;
322
+
323
+ if (isError) {
324
+ console.error(logMessage);
325
+ stats.errors++;
326
+ updateStats();
327
+ } else {
328
+ console.log(logMessage);
329
+ }
330
+ }
331
+
332
+ function updateStatus(status, color) {
333
+ statusText.textContent = status;
334
+ statusIndicator.className = `status-indicator ${color}`;
335
+ }
336
+
337
+ function updateStats() {
338
+ document.getElementById('durationStat').textContent =
339
+ isStreaming ? ((Date.now() - startTime) / 1000).toFixed(1) + 's' : '0.0s';
340
+ document.getElementById('chunksStat').textContent = stats.chunks;
341
+ document.getElementById('transcriptionsStat').textContent = stats.transcriptions;
342
+ document.getElementById('errorsStat').textContent = stats.errors;
343
+ }
344
+
345
+ function addTranscription(text, timestamp) {
346
+ const transcriptionText = `[${timestamp.toFixed(1)}s] ${text}\n`;
347
+ transcriptionArea.textContent += transcriptionText;
348
+ transcriptionArea.scrollTop = transcriptionArea.scrollHeight;
349
+ stats.transcriptions++;
350
+ updateStats();
351
+ }
352
+
353
+ async function startStreaming() {
354
+ try {
355
+ // Get configuration
356
+ const config = {
357
+ model_size: document.getElementById('modelSize').value,
358
+ chunk_size: parseInt(document.getElementById('chunkSize').value),
359
+ beam_size: parseInt(document.getElementById('beamSize').value),
360
+ temperature: parseFloat(document.getElementById('temperature').value)
361
+ };
362
+
363
+ log('Starting transcription session...');
364
+ log(`Config: ${JSON.stringify(config, null, 2)}`);
365
+
366
+ // Create WebSocket URL (relative to current page)
367
+ const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
368
+ const wsUrl = `${wsProtocol}//${window.location.host}/ws`;
369
+
370
+ // Initialize WebSocket
371
+ socket = new WebSocket(wsUrl);
372
+ socket.binaryType = 'arraybuffer';
373
+
374
+ socket.onopen = async () => {
375
+ log('WebSocket connected');
376
+ updateStatus('Connected', 'connected');
377
+
378
+ // Send configuration
379
+ log('Sending configuration...');
380
+ socket.send(JSON.stringify(config));
381
+ };
382
+
383
+ socket.onmessage = (event) => {
384
+ try {
385
+ const data = JSON.parse(event.data);
386
+
387
+ if (data.error) {
388
+ log(`Server error: ${data.error}`, true);
389
+ return;
390
+ }
391
+
392
+ if (data.status === 'CONFIG_RECEIVED') {
393
+ log(`Configuration received. GPU: ${data.gpu ? 'Yes' : 'No'}`);
394
+ if (data.fallback) {
395
+ log('Using fallback Whisper model');
396
+ }
397
+ // Wait a bit for WebSocket to be fully ready
398
+ setTimeout(() => {
399
+ startAudioCapture();
400
+ }, 100);
401
+ } else if (data.text) {
402
+ addTranscription(data.text, data.timestamp || 0);
403
+ }
404
+ } catch (e) {
405
+ log(`Server message: ${event.data}`);
406
+ }
407
+ };
408
+
409
+ socket.onerror = (error) => {
410
+ log(`WebSocket error: ${error}`, true);
411
+ updateStatus('Error', 'error');
412
+ };
413
+
414
+ socket.onclose = () => {
415
+ log('WebSocket disconnected');
416
+ updateStatus('Disconnected', '');
417
+ stopStreaming();
418
+ };
419
+
420
+ } catch (error) {
421
+ log(`Failed to start: ${error.message}`, true);
422
+ updateStatus('Error', 'error');
423
+ }
424
+ }
425
+
426
+ async function startAudioCapture() {
427
+ try {
428
+ log('Requesting microphone access...');
429
+
430
+ const stream = await navigator.mediaDevices.getUserMedia({
431
+ audio: {
432
+ sampleRate: 16000,
433
+ channelCount: 1,
434
+ echoCancellation: true,
435
+ noiseSuppression: true,
436
+ autoGainControl: true
437
+ }
438
+ });
439
+
440
+ log('Microphone access granted');
441
+
442
+ // Create audio context
443
+ audioContext = new (window.AudioContext || window.webkitAudioContext)({
444
+ sampleRate: 16000
445
+ });
446
+
447
+ // Resume audio context if suspended
448
+ if (audioContext.state === 'suspended') {
449
+ await audioContext.resume();
450
+ log('Audio context resumed');
451
+ }
452
+
453
+ // Create media stream source
454
+ micStream = audioContext.createMediaStreamSource(stream);
455
+
456
+ // Create script processor
457
+ const bufferSize = 4096;
458
+ processor = audioContext.createScriptProcessor(bufferSize, 1, 1);
459
+
460
+ processor.onaudioprocess = (event) => {
461
+ if (!isStreaming || !socket || socket.readyState !== WebSocket.OPEN) {
462
+ return;
463
+ }
464
+
465
+ const inputData = event.inputBuffer.getChannelData(0);
466
+
467
+ // Convert to 16-bit PCM
468
+ const int16Array = new Int16Array(inputData.length);
469
+ for (let i = 0; i < inputData.length; i++) {
470
+ const sample = Math.max(-1, Math.min(1, inputData[i]));
471
+ int16Array[i] = sample * 32767;
472
+ }
473
+
474
+ // Send to server
475
+ socket.send(int16Array.buffer);
476
+ stats.chunks++;
477
+ };
478
+
479
+ // Connect audio nodes
480
+ micStream.connect(processor);
481
+ processor.connect(audioContext.destination);
482
+
483
+ // Update UI
484
+ isStreaming = true;
485
+ startTime = Date.now();
486
+ startBtn.disabled = true;
487
+ stopBtn.disabled = false;
488
+ updateStatus('Streaming', 'streaming');
489
+
490
+ log('Audio streaming started');
491
+
492
+ // Start stats update timer
493
+ const statsTimer = setInterval(() => {
494
+ if (isStreaming) {
495
+ updateStats();
496
+ } else {
497
+ clearInterval(statsTimer);
498
+ }
499
+ }, 100);
500
+
501
+ } catch (error) {
502
+ log(`Audio capture failed: ${error.message}`, true);
503
+ updateStatus('Error', 'error');
504
+ startBtn.disabled = false;
505
+ stopBtn.disabled = true;
506
+ }
507
+ }
508
+
509
+ function stopStreaming() {
510
+ isStreaming = false;
511
+
512
+ // Stop all tracks
513
+ if (micStream && micStream.mediaStream) {
514
+ micStream.mediaStream.getTracks().forEach(track => track.stop());
515
+ }
516
+
517
+ // Disconnect audio nodes
518
+ if (processor) {
519
+ processor.disconnect();
520
+ processor = null;
521
+ }
522
+
523
+ if (micStream) {
524
+ micStream.disconnect();
525
+ micStream = null;
526
+ }
527
+
528
+ if (audioContext) {
529
+ audioContext.close();
530
+ audioContext = null;
531
+ }
532
+
533
+ // Close WebSocket
534
+ if (socket && socket.readyState === WebSocket.OPEN) {
535
+ socket.close();
536
+ }
537
+
538
+ // Update UI
539
+ startBtn.disabled = false;
540
+ stopBtn.disabled = true;
541
+ updateStatus('Disconnected', '');
542
+
543
+ log('Streaming stopped');
544
+ updateStats();
545
+ }
546
+
547
+ function clearAll() {
548
+ logArea.textContent = '';
549
+ transcriptionArea.textContent = '';
550
+ stats = { chunks: 0, transcriptions: 0, errors: 0 };
551
+ updateStats();
552
+ log('All content cleared');
553
+ }
554
+
555
+ // Initialize
556
+ updateStats();
557
+ log('Real-time transcription client ready');
558
+
559
+ // Handle page unload
560
+ window.addEventListener('beforeunload', () => {
561
+ if (isStreaming) {
562
+ stopStreaming();
563
+ }
564
+ });
565
+ </script>
566
+ </body>
567
+ </html>
static/whisper_client.html ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Real-Time Whisper Transcription</title>
7
+ <style>
8
+ body {
9
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
10
+ max-width: 1200px;
11
+ margin: 0 auto;
12
+ padding: 20px;
13
+ background-color: #f5f5f5;
14
+ }
15
+
16
+ .container {
17
+ background: white;
18
+ border-radius: 10px;
19
+ padding: 30px;
20
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
21
+ }
22
+
23
+ h1 {
24
+ color: #333;
25
+ text-align: center;
26
+ margin-bottom: 30px;
27
+ }
28
+
29
+ .config-section {
30
+ display: grid;
31
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
32
+ gap: 15px;
33
+ margin-bottom: 30px;
34
+ padding: 20px;
35
+ background: #f8f9fa;
36
+ border-radius: 8px;
37
+ }
38
+
39
+ .config-group {
40
+ display: flex;
41
+ flex-direction: column;
42
+ }
43
+
44
+ label {
45
+ font-weight: 600;
46
+ margin-bottom: 5px;
47
+ color: #555;
48
+ }
49
+
50
+ input, select {
51
+ padding: 10px;
52
+ border: 2px solid #ddd;
53
+ border-radius: 5px;
54
+ font-size: 14px;
55
+ transition: border-color 0.3s;
56
+ }
57
+
58
+ input:focus, select:focus {
59
+ outline: none;
60
+ border-color: #007bff;
61
+ }
62
+
63
+ .controls {
64
+ display: flex;
65
+ gap: 10px;
66
+ justify-content: center;
67
+ margin-bottom: 30px;
68
+ }
69
+
70
+ button {
71
+ padding: 12px 24px;
72
+ border: none;
73
+ border-radius: 5px;
74
+ font-size: 16px;
75
+ font-weight: 600;
76
+ cursor: pointer;
77
+ transition: all 0.3s;
78
+ }
79
+
80
+ .start-btn {
81
+ background: #28a745;
82
+ color: white;
83
+ }
84
+
85
+ .start-btn:hover:not(:disabled) {
86
+ background: #218838;
87
+ }
88
+
89
+ .stop-btn {
90
+ background: #dc3545;
91
+ color: white;
92
+ }
93
+
94
+ .stop-btn:hover:not(:disabled) {
95
+ background: #c82333;
96
+ }
97
+
98
+ .clear-btn {
99
+ background: #6c757d;
100
+ color: white;
101
+ }
102
+
103
+ .clear-btn:hover:not(:disabled) {
104
+ background: #5a6268;
105
+ }
106
+
107
+ button:disabled {
108
+ opacity: 0.6;
109
+ cursor: not-allowed;
110
+ }
111
+
112
+ .status {
113
+ display: flex;
114
+ align-items: center;
115
+ justify-content: center;
116
+ gap: 10px;
117
+ margin-bottom: 20px;
118
+ font-weight: 600;
119
+ }
120
+
121
+ .status-indicator {
122
+ width: 12px;
123
+ height: 12px;
124
+ border-radius: 50%;
125
+ background: #dc3545;
126
+ animation: pulse 2s infinite;
127
+ }
128
+
129
+ .status-indicator.connected {
130
+ background: #28a745;
131
+ }
132
+
133
+ .status-indicator.streaming {
134
+ background: #ffc107;
135
+ }
136
+
137
+ @keyframes pulse {
138
+ 0% { opacity: 1; }
139
+ 50% { opacity: 0.5; }
140
+ 100% { opacity: 1; }
141
+ }
142
+
143
+ .transcription-section {
144
+ display: grid;
145
+ grid-template-columns: 1fr 1fr;
146
+ gap: 20px;
147
+ }
148
+
149
+ .transcription-panel {
150
+ background: #f8f9fa;
151
+ border-radius: 8px;
152
+ padding: 20px;
153
+ }
154
+
155
+ .transcription-panel h3 {
156
+ margin-top: 0;
157
+ color: #333;
158
+ }
159
+
160
+ .log-area, .transcription-area {
161
+ background: #fff;
162
+ border: 1px solid #ddd;
163
+ border-radius: 5px;
164
+ padding: 15px;
165
+ height: 300px;
166
+ overflow-y: auto;
167
+ font-family: 'Courier New', monospace;
168
+ font-size: 14px;
169
+ line-height: 1.4;
170
+ white-space: pre-wrap;
171
+ }
172
+
173
+ .transcription-area {
174
+ font-family: inherit;
175
+ font-size: 16px;
176
+ line-height: 1.6;
177
+ }
178
+
179
+ .stats {
180
+ display: grid;
181
+ grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
182
+ gap: 15px;
183
+ margin-top: 20px;
184
+ }
185
+
186
+ .stat-item {
187
+ text-align: center;
188
+ padding: 10px;
189
+ background: #e9ecef;
190
+ border-radius: 5px;
191
+ }
192
+
193
+ .stat-value {
194
+ font-size: 18px;
195
+ font-weight: bold;
196
+ color: #007bff;
197
+ }
198
+
199
+ .stat-label {
200
+ font-size: 12px;
201
+ color: #666;
202
+ margin-top: 5px;
203
+ }
204
+
205
+ @media (max-width: 768px) {
206
+ .config-section {
207
+ grid-template-columns: 1fr;
208
+ }
209
+
210
+ .transcription-section {
211
+ grid-template-columns: 1fr;
212
+ }
213
+
214
+ .controls {
215
+ flex-direction: column;
216
+ align-items: center;
217
+ }
218
+
219
+ button {
220
+ width: 200px;
221
+ }
222
+ }
223
+ </style>
224
+ </head>
225
+ <body>
226
+ <div class="container">
227
+ <h1>🎤 Real-Time Whisper Transcription</h1>
228
+
229
+ <div class="config-section">
230
+ <div class="config-group">
231
+ <label for="serverUrl">Server URL:</label>
232
+ <input type="text" id="serverUrl" value="ws://localhost:8000">
233
+ </div>
234
+
235
+ <div class="config-group">
236
+ <label for="modelSize">Model Size:</label>
237
+ <select id="modelSize">
238
+ <option value="base">Base (74 MB)</option>
239
+ <option value="small" selected>Small (244 MB)</option>
240
+ <option value="large-v2">Large-v2 (1550 MB)</option>
241
+ </select>
242
+ </div>
243
+
244
+ <div class="config-group">
245
+ <label for="chunkSize">Chunk Size (ms):</label>
246
+ <input type="number" id="chunkSize" value="300" min="100" max="2000" step="100">
247
+ </div>
248
+
249
+ <div class="config-group">
250
+ <label for="beamSize">Beam Size:</label>
251
+ <input type="number" id="beamSize" value="0" min="0" max="10">
252
+ </div>
253
+
254
+ <div class="config-group">
255
+ <label for="temperature">Temperature:</label>
256
+ <input type="number" id="temperature" value="0.0" min="0.0" max="1.0" step="0.1">
257
+ </div>
258
+ </div>
259
+
260
+ <div class="controls">
261
+ <button id="startBtn" class="start-btn" onclick="startStreaming()">🎤 Start Recording</button>
262
+ <button id="stopBtn" class="stop-btn" onclick="stopStreaming()" disabled>⏹️ Stop Recording</button>
263
+ <button id="clearBtn" class="clear-btn" onclick="clearAll()">🗑️ Clear All</button>
264
+ </div>
265
+
266
+ <div class="status">
267
+ <div class="status-indicator" id="statusIndicator"></div>
268
+ <span id="statusText">Disconnected</span>
269
+ </div>
270
+
271
+ <div class="transcription-section">
272
+ <div class="transcription-panel">
273
+ <h3>📝 Transcription</h3>
274
+ <div id="transcriptionArea" class="transcription-area"></div>
275
+ </div>
276
+
277
+ <div class="transcription-panel">
278
+ <h3>📋 System Log</h3>
279
+ <div id="logArea" class="log-area"></div>
280
+ </div>
281
+ </div>
282
+
283
+ <div class="stats">
284
+ <div class="stat-item">
285
+ <div class="stat-value" id="durationStat">0.0s</div>
286
+ <div class="stat-label">Duration</div>
287
+ </div>
288
+ <div class="stat-item">
289
+ <div class="stat-value" id="chunksStat">0</div>
290
+ <div class="stat-label">Chunks Sent</div>
291
+ </div>
292
+ <div class="stat-item">
293
+ <div class="stat-value" id="transcriptionsStat">0</div>
294
+ <div class="stat-label">Transcriptions</div>
295
+ </div>
296
+ <div class="stat-item">
297
+ <div class="stat-value" id="errorsStat">0</div>
298
+ <div class="stat-label">Errors</div>
299
+ </div>
300
+ </div>
301
+ </div>
302
+
303
+ <script>
304
+ let socket, audioContext, processor, micStream;
305
+ let isStreaming = false;
306
+ let startTime = 0;
307
+ let stats = {
308
+ chunks: 0,
309
+ transcriptions: 0,
310
+ errors: 0
311
+ };
312
+
313
+ // UI Elements
314
+ const startBtn = document.getElementById('startBtn');
315
+ const stopBtn = document.getElementById('stopBtn');
316
+ const statusIndicator = document.getElementById('statusIndicator');
317
+ const statusText = document.getElementById('statusText');
318
+ const logArea = document.getElementById('logArea');
319
+ const transcriptionArea = document.getElementById('transcriptionArea');
320
+
321
+ function log(message, isError = false) {
322
+ const timestamp = new Date().toLocaleTimeString();
323
+ const logMessage = `[${timestamp}] ${message}`;
324
+ logArea.textContent += logMessage + '\n';
325
+ logArea.scrollTop = logArea.scrollHeight;
326
+
327
+ if (isError) {
328
+ console.error(logMessage);
329
+ stats.errors++;
330
+ updateStats();
331
+ } else {
332
+ console.log(logMessage);
333
+ }
334
+ }
335
+
336
+ function updateStatus(status, color) {
337
+ statusText.textContent = status;
338
+ statusIndicator.className = `status-indicator ${color}`;
339
+ }
340
+
341
+ function updateStats() {
342
+ document.getElementById('durationStat').textContent =
343
+ isStreaming ? ((Date.now() - startTime) / 1000).toFixed(1) + 's' : '0.0s';
344
+ document.getElementById('chunksStat').textContent = stats.chunks;
345
+ document.getElementById('transcriptionsStat').textContent = stats.transcriptions;
346
+ document.getElementById('errorsStat').textContent = stats.errors;
347
+ }
348
+
349
+ function addTranscription(text, timestamp) {
350
+ const transcriptionText = `[${timestamp.toFixed(1)}s] ${text}\n`;
351
+ transcriptionArea.textContent += transcriptionText;
352
+ transcriptionArea.scrollTop = transcriptionArea.scrollHeight;
353
+ stats.transcriptions++;
354
+ updateStats();
355
+ }
356
+
357
+ async function startStreaming() {
358
+ try {
359
+ // Get configuration
360
+ const config = {
361
+ serverUrl: document.getElementById('serverUrl').value,
362
+ model_size: document.getElementById('modelSize').value,
363
+ chunk_size: parseInt(document.getElementById('chunkSize').value),
364
+ beam_size: parseInt(document.getElementById('beamSize').value),
365
+ temperature: parseFloat(document.getElementById('temperature').value)
366
+ };
367
+
368
+ log('Starting transcription session...');
369
+ log(`Config: ${JSON.stringify(config, null, 2)}`);
370
+
371
+ // Initialize WebSocket
372
+ socket = new WebSocket(config.serverUrl);
373
+ socket.binaryType = 'arraybuffer';
374
+
375
+ socket.onopen = async () => {
376
+ log('WebSocket connected');
377
+ updateStatus('Connected', 'connected');
378
+
379
+ // Send configuration
380
+ log('Sending configuration...');
381
+ socket.send(JSON.stringify(config));
382
+ };
383
+
384
+ socket.onmessage = (event) => {
385
+ try {
386
+ const data = JSON.parse(event.data);
387
+
388
+ if (data.error) {
389
+ log(`Server error: ${data.error}`, true);
390
+ return;
391
+ }
392
+
393
+ if (data.status === 'CONFIG_RECEIVED') {
394
+ log(`Configuration received. GPU: ${data.gpu ? 'Yes' : 'No'}`);
395
+ if (data.fallback) {
396
+ log('Using fallback Whisper model');
397
+ }
398
+ // Wait a bit for WebSocket to be fully ready
399
+ setTimeout(() => {
400
+ startAudioCapture();
401
+ }, 100);
402
+ } else if (data.text) {
403
+ addTranscription(data.text, data.timestamp || 0);
404
+ }
405
+ } catch (e) {
406
+ log(`Server message: ${event.data}`);
407
+ }
408
+ };
409
+
410
+ socket.onerror = (error) => {
411
+ log(`WebSocket error: ${error}`, true);
412
+ updateStatus('Error', 'error');
413
+ };
414
+
415
+ socket.onclose = () => {
416
+ log('WebSocket disconnected');
417
+ updateStatus('Disconnected', '');
418
+ stopStreaming();
419
+ };
420
+
421
+ } catch (error) {
422
+ log(`Failed to start: ${error.message}`, true);
423
+ updateStatus('Error', 'error');
424
+ }
425
+ }
426
+
427
+ async function startAudioCapture() {
428
+ try {
429
+ // Request microphone access with explicit user gesture
430
+ log('Requesting microphone access...');
431
+
432
+ // Check if we have permission first
433
+ if (navigator.permissions) {
434
+ const permission = await navigator.permissions.query({name: 'microphone'});
435
+ log(`Microphone permission state: ${permission.state}`);
436
+ }
437
+
438
+ const stream = await navigator.mediaDevices.getUserMedia({
439
+ audio: {
440
+ sampleRate: 16000,
441
+ channelCount: 1,
442
+ echoCancellation: true,
443
+ noiseSuppression: true,
444
+ autoGainControl: true
445
+ }
446
+ });
447
+
448
+ log('Microphone access granted');
449
+
450
+ // Create audio context
451
+ audioContext = new (window.AudioContext || window.webkitAudioContext)({
452
+ sampleRate: 16000
453
+ });
454
+
455
+ // Resume audio context if suspended (required by some browsers)
456
+ if (audioContext.state === 'suspended') {
457
+ await audioContext.resume();
458
+ log('Audio context resumed');
459
+ }
460
+
461
+ // Create media stream source
462
+ micStream = audioContext.createMediaStreamSource(stream);
463
+
464
+ // Create script processor
465
+ const bufferSize = 4096;
466
+ processor = audioContext.createScriptProcessor(bufferSize, 1, 1);
467
+
468
+ processor.onaudioprocess = (event) => {
469
+ if (!isStreaming || !socket || socket.readyState !== WebSocket.OPEN) {
470
+ return;
471
+ }
472
+
473
+ const inputData = event.inputBuffer.getChannelData(0);
474
+
475
+ // Convert to 16-bit PCM
476
+ const int16Array = new Int16Array(inputData.length);
477
+ for (let i = 0; i < inputData.length; i++) {
478
+ const sample = Math.max(-1, Math.min(1, inputData[i]));
479
+ int16Array[i] = sample * 32767;
480
+ }
481
+
482
+ // Send to server
483
+ socket.send(int16Array.buffer);
484
+ stats.chunks++;
485
+ };
486
+
487
+ // Connect audio nodes
488
+ micStream.connect(processor);
489
+ processor.connect(audioContext.destination);
490
+
491
+ // Update UI
492
+ isStreaming = true;
493
+ startTime = Date.now();
494
+ startBtn.disabled = true;
495
+ stopBtn.disabled = false;
496
+ updateStatus('Streaming', 'streaming');
497
+
498
+ log('Audio streaming started');
499
+
500
+ // Start stats update timer
501
+ const statsTimer = setInterval(() => {
502
+ if (isStreaming) {
503
+ updateStats();
504
+ } else {
505
+ clearInterval(statsTimer);
506
+ }
507
+ }, 100);
508
+
509
+ } catch (error) {
510
+ log(`Audio capture failed: ${error.message}`, true);
511
+ log(`Error details: ${error.name} - ${error.message}`);
512
+
513
+ // Handle specific permission errors
514
+ if (error.name === 'NotAllowedError') {
515
+ log('Microphone access denied by user. Please allow microphone access and try again.', true);
516
+ } else if (error.name === 'NotFoundError') {
517
+ log('No microphone found. Please check your audio devices.', true);
518
+ } else if (error.name === 'NotReadableError') {
519
+ log('Microphone is already in use by another application.', true);
520
+ }
521
+
522
+ updateStatus('Error', 'error');
523
+
524
+ // Reset UI state
525
+ startBtn.disabled = false;
526
+ stopBtn.disabled = true;
527
+ }
528
+ }
529
+
530
+ function stopStreaming() {
531
+ isStreaming = false;
532
+
533
+ // Stop all tracks to properly release the microphone
534
+ if (micStream && micStream.mediaStream) {
535
+ micStream.mediaStream.getTracks().forEach(track => track.stop());
536
+ }
537
+
538
+ // Disconnect audio nodes
539
+ if (processor) {
540
+ processor.disconnect();
541
+ processor = null;
542
+ }
543
+
544
+ if (micStream) {
545
+ micStream.disconnect();
546
+ micStream = null;
547
+ }
548
+
549
+ if (audioContext) {
550
+ audioContext.close();
551
+ audioContext = null;
552
+ }
553
+
554
+ // Close WebSocket
555
+ if (socket && socket.readyState === WebSocket.OPEN) {
556
+ socket.close();
557
+ }
558
+
559
+ // Update UI
560
+ startBtn.disabled = false;
561
+ stopBtn.disabled = true;
562
+ updateStatus('Disconnected', '');
563
+
564
+ log('Streaming stopped');
565
+ updateStats();
566
+ }
567
+
568
+ function clearAll() {
569
+ logArea.textContent = '';
570
+ transcriptionArea.textContent = '';
571
+ stats = { chunks: 0, transcriptions: 0, errors: 0 };
572
+ updateStats();
573
+ log('All content cleared');
574
+ }
575
+
576
+ // Initialize stats display
577
+ updateStats();
578
+
579
+ // Handle page unload
580
+ window.addEventListener('beforeunload', () => {
581
+ if (isStreaming) {
582
+ stopStreaming();
583
+ }
584
+ });
585
+
586
+ // Handle errors
587
+ window.addEventListener('error', (event) => {
588
+ log(`JavaScript error: ${event.message}`, true);
589
+ });
590
+
591
+ // Show initial log message
592
+ log('Real-time transcription client ready');
593
+ </script>
594
+ </body>
595
+ </html>
unified_socket_server.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Unified WebSocket/HTTP Whisper Transcription Server
4
+ Handles real-time audio streaming, transcription using Whisper, and HTTP serving
5
+ """
6
+
7
+ import asyncio
8
+ import websockets
9
+ import json
10
+ import numpy as np
11
+ import torch
12
+ import logging
13
+ import traceback
14
+ import os
15
+ from typing import Dict, Any
16
+ from aiohttp import web, WSMsgType
17
+ from aiohttp.web_ws import WebSocketResponse
18
+
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ try:
25
+ from whisper_stream import load_streaming_model_correct
26
+ from whisper_stream.streaming_decoding import DecodingOptions
27
+ except ImportError:
28
+ logger.error("whisper_stream not found. Please install it or use regular whisper")
29
+ # Fallback to regular whisper if whisper_stream is not available
30
+ import whisper
31
+
32
+ class UnifiedTranscriptionServer:
33
+ def __init__(self, host: str = "0.0.0.0", port: int = 8000):
34
+ self.host = host
35
+ self.port = port
36
+ self.clients: Dict[str, Dict[str, Any]] = {}
37
+ self.app = web.Application()
38
+ self.setup_routes()
39
+
40
+ def setup_routes(self):
41
+ """Setup HTTP routes and WebSocket endpoint"""
42
+ # HTTP routes
43
+ self.app.router.add_get('/', self.serve_index)
44
+ self.app.router.add_get('/health', self.health_check)
45
+
46
+ # WebSocket endpoint
47
+ self.app.router.add_get('/ws', self.websocket_handler)
48
+
49
+ # Static file serving (if needed)
50
+ if os.path.exists('static'):
51
+ self.app.router.add_static('/static/', 'static')
52
+
53
+ async def serve_index(self, request):
54
+ """Serve the main HTML page"""
55
+ try:
56
+ with open("./static/client.html", "r", encoding='utf-8') as f:
57
+ html_content = f.read()
58
+ return web.Response(text=html_content, content_type='text/html')
59
+ except FileNotFoundError:
60
+ return web.Response(text="client.html not found!", status=404)
61
+ except Exception as e:
62
+ logger.error(f"Error serving client.html! {e}")
63
+ return web.Response(text="Error loading page...", status=500)
64
+
65
+ async def health_check(self, request):
66
+ """Health check endpoint"""
67
+ return web.json_response({"status": "healthy", "cuda": torch.cuda.is_available()})
68
+
69
+ async def websocket_handler(self, request):
70
+ """Handle WebSocket connections"""
71
+ ws = WebSocketResponse()
72
+ await ws.prepare(request)
73
+
74
+ # Generate client ID
75
+ client_id = f"{request.remote}:{id(ws)}"
76
+ logger.info(f"New WebSocket client connected: {client_id}")
77
+
78
+ # Initialize client state
79
+ self.clients[client_id] = {
80
+ 'websocket': ws,
81
+ 'model': None,
82
+ 'config': None,
83
+ 'buffer': bytearray(),
84
+ 'total_samples': 0,
85
+ 'is_first_chunk': True
86
+ }
87
+
88
+ try:
89
+ await self.process_websocket_messages(client_id)
90
+ except Exception as e:
91
+ logger.error(f"Error handling WebSocket client {client_id}: {e}")
92
+ logger.error(traceback.format_exc())
93
+ finally:
94
+ # Cleanup
95
+ if client_id in self.clients:
96
+ del self.clients[client_id]
97
+
98
+ if not ws.closed:
99
+ await ws.close()
100
+
101
+ return ws
102
+
103
+ async def process_websocket_messages(self, client_id: str):
104
+ """Process messages from a WebSocket client"""
105
+ client = self.clients[client_id]
106
+ ws = client['websocket']
107
+
108
+ async for msg in ws:
109
+ if msg.type == WSMsgType.TEXT:
110
+ # Handle configuration message
111
+ await self.handle_config_message(client_id, msg.data)
112
+ elif msg.type == WSMsgType.BINARY:
113
+ # Handle audio data
114
+ await self.handle_audio_data(client_id, msg.data)
115
+ elif msg.type == WSMsgType.ERROR:
116
+ logger.error(f'WebSocket error for client {client_id}: {ws.exception()}')
117
+ break
118
+
119
+ async def handle_config_message(self, client_id: str, message: str):
120
+ """Handle configuration message from client"""
121
+ client = self.clients[client_id]
122
+ ws = client['websocket']
123
+
124
+ try:
125
+ config = json.loads(message)
126
+ logger.info(f"Received config from {client_id}: {config}")
127
+
128
+ # Validate config
129
+ required_fields = ['model_size', 'chunk_size', 'beam_size', 'temperature']
130
+ for field in required_fields:
131
+ if field not in config:
132
+ await ws.send_str(json.dumps({"error": f"Missing required field: {field}"}))
133
+ return
134
+
135
+ # Load model
136
+ model_size = config['model_size']
137
+ chunk_size = config['chunk_size']
138
+
139
+ logger.info(f"Loading model {model_size} for client {client_id}")
140
+
141
+ # Try to use whisper_stream, fallback to regular whisper
142
+ try:
143
+ model = load_streaming_model_correct(model_size, chunk_size)
144
+ client['first_chunk'] = True
145
+ if torch.cuda.is_available():
146
+ model = model.to("cuda")
147
+ logger.info(f"Model loaded on GPU for client {client_id}")
148
+ else:
149
+ logger.info(f"Model loaded on CPU for client {client_id}")
150
+
151
+ model.reset(use_stream=True)
152
+ model.eval()
153
+
154
+ client['model'] = model
155
+ client['config'] = config
156
+
157
+ await ws.send_str(json.dumps({"status": "CONFIG_RECEIVED", "gpu": torch.cuda.is_available()}))
158
+
159
+ except Exception as e:
160
+ logger.error(f"Error loading streaming model: {e}")
161
+ # Fallback to regular whisper
162
+ try:
163
+ model = whisper.load_model(model_size)
164
+ if torch.cuda.is_available():
165
+ model = model.to("cuda")
166
+
167
+ client['model'] = model
168
+ client['config'] = config
169
+ client['use_streaming'] = False
170
+
171
+ await ws.send_str(json.dumps({"status": "CONFIG_RECEIVED", "gpu": torch.cuda.is_available(), "fallback": True}))
172
+ except Exception as e2:
173
+ logger.error(f"Error loading fallback model: {e2}")
174
+ await ws.send_str(json.dumps({"error": f"Failed to load model: {e2}"}))
175
+
176
+ except json.JSONDecodeError as e:
177
+ await ws.send_str(json.dumps({"error": f"Invalid JSON: {e}"}))
178
+ except Exception as e:
179
+ logger.error(f"Error handling config for client {client_id}: {e}")
180
+ await ws.send_str(json.dumps({"error": str(e)}))
181
+
182
+ async def handle_audio_data(self, client_id: str, audio_data: bytes):
183
+ """Handle audio data from client"""
184
+ client = self.clients[client_id]
185
+ ws = client['websocket']
186
+
187
+ if client['config'] is None:
188
+ await ws.send_str(json.dumps({"error": "Config not set"}))
189
+ return
190
+
191
+ if client['model'] is None:
192
+ await ws.send_str(json.dumps({"error": "Model not loaded"}))
193
+ return
194
+
195
+ # Add audio data to buffer
196
+ client['buffer'].extend(audio_data)
197
+
198
+ # Calculate chunk size in bytes
199
+ chunk_size_ms = client['config']['chunk_size']
200
+ sample_rate = 16000
201
+ chunk_samples = int(sample_rate * (chunk_size_ms / 1000))
202
+ chunk_bytes = chunk_samples * 2 # 16-bit audio = 2 bytes per sample
203
+ if client.get('first_chunk', True):
204
+ chunk_bytes += 720
205
+
206
+ # Process complete chunks
207
+ while len(client['buffer']) >= chunk_bytes:
208
+ chunk = client['buffer'][:chunk_bytes]
209
+ client['buffer'] = client['buffer'][chunk_bytes:]
210
+
211
+ try:
212
+ if client.get('first_chunk', True):
213
+ client['first_chunk'] = False
214
+ await self.transcribe_chunk(client_id, chunk)
215
+ except Exception as e:
216
+ logger.error(f"Error transcribing chunk for client {client_id}: {e}")
217
+ await ws.send_str(json.dumps({"error": f"Transcription error: {str(e)}"}))
218
+
219
+ async def transcribe_chunk(self, client_id: str, chunk: bytes):
220
+ """Transcribe audio chunk"""
221
+ client = self.clients[client_id]
222
+ ws = client['websocket']
223
+ model = client['model']
224
+ config = client['config']
225
+
226
+ try:
227
+ # Convert bytes to numpy array
228
+ pcm = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32768.0
229
+
230
+ # Convert to torch tensor
231
+ audio = torch.tensor(pcm)
232
+ if torch.cuda.is_available() and next(model.parameters()).is_cuda:
233
+ audio = audio.to("cuda")
234
+
235
+ # Transcribe based on model type
236
+ if hasattr(model, 'decode') and 'use_streaming' not in client:
237
+ # Using whisper_stream
238
+ decoding_options = DecodingOptions(
239
+ language="en",
240
+ gran=(config['chunk_size'] // 20),
241
+ single_frame_mel=True,
242
+ without_timestamps=True,
243
+ beam_size=config['beam_size'],
244
+ temperature=config['temperature'],
245
+ stream_decode=True,
246
+ use_ca_kv_cache=True,
247
+ look_ahead_blocks=model.extra_gran_blocks
248
+ )
249
+ result = model.decode(audio, decoding_options, use_frames=True)
250
+ text = result.text
251
+ else:
252
+ # Using regular whisper
253
+ # Pad audio to minimum length if needed
254
+ min_length = 16000 # 1 second at 16kHz
255
+ if len(audio) < min_length:
256
+ audio = torch.nn.functional.pad(audio, (0, min_length - len(audio)))
257
+
258
+ result = model.transcribe(audio.cpu().numpy(),
259
+ language="en",
260
+ beam_size=config['beam_size'],
261
+ temperature=config['temperature'])
262
+ text = result['text']
263
+
264
+ # Send transcription result
265
+ if text.strip():
266
+ client['total_samples'] += len(pcm)
267
+ duration = client['total_samples'] / 16000 # seconds
268
+
269
+ await ws.send_str(json.dumps({
270
+ "text": text.strip(),
271
+ "timestamp": duration,
272
+ "chunk_duration": len(pcm) / 16000
273
+ }))
274
+
275
+ except Exception as e:
276
+ logger.error(f"Error in transcription for client {client_id}: {e}")
277
+ logger.exception("Exception occurred")
278
+ raise
279
+
280
+ async def start_server(self):
281
+ """Start the unified server"""
282
+ logger.info(f"Starting unified server on {self.host}:{self.port}")
283
+ logger.info(f"CUDA available: {torch.cuda.is_available()}")
284
+
285
+ runner = web.AppRunner(self.app)
286
+ await runner.setup()
287
+ site = web.TCPSite(runner, self.host, self.port)
288
+ await site.start()
289
+
290
+ logger.info(f"Server running on http://{self.host}:{self.port}")
291
+ logger.info(f"WebSocket endpoint: ws://{self.host}:{self.port}/ws")
292
+
293
+ # Keep the server running
294
+ try:
295
+ await asyncio.Future() # Run forever
296
+ except KeyboardInterrupt:
297
+ logger.info("Server stopped by user")
298
+ finally:
299
+ await runner.cleanup()
300
+
301
+ def main():
302
+ import argparse
303
+
304
+ parser = argparse.ArgumentParser(description='Unified WebSocket/HTTP Whisper Transcription Server')
305
+ parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
306
+ parser.add_argument('--port', type=int, default=8000, help='Port to bind to')
307
+ parser.add_argument('--log-level', default='INFO', help='Log level')
308
+
309
+ args = parser.parse_args()
310
+
311
+ # Set log level
312
+ logging.getLogger().setLevel(getattr(logging, args.log_level.upper()))
313
+
314
+ server = UnifiedTranscriptionServer(args.host, args.port)
315
+
316
+ try:
317
+ asyncio.run(server.start_server())
318
+ except KeyboardInterrupt:
319
+ logger.info("Server stopped by user")
320
+ except Exception as e:
321
+ logger.error(f"Server error: {e}")
322
+ logger.error(traceback.format_exc())
323
+
324
+ if __name__ == '__main__':
325
+ main()
whisper_stream/__init__.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import io
3
+ import os
4
+ import urllib
5
+ import warnings
6
+ from typing import List, Optional, Union
7
+
8
+ import torch
9
+ from tqdm import tqdm
10
+
11
+ from .audio import load_audio, pad_or_trim, log_mel_spectrogram
12
+ from .model import ModelDimensions, Whisper
13
+ from .streaming_model import StreamingWhisper
14
+ from .version import __version__
15
+
16
+ _MODELS = {
17
+ "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
18
+ "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
19
+ "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
20
+ "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
21
+ "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
22
+ "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
23
+ "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
24
+ "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
25
+ "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
26
+ "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
27
+ "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
28
+ "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
29
+ "large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
30
+ "turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
31
+ }
32
+
33
+ _STREAMING_MODELS = {
34
+ "base": {
35
+ "300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.25/checkpoint/checkpoint-epoch=0009.pt",
36
+ "200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0009.pt",
37
+ "100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0009.pt",
38
+ "40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.02/checkpoint/checkpoint-epoch=0006.pt",
39
+ },
40
+ "small": {
41
+ "1000": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g50_eg0_top5_full-streamTrue_random-orderFalse_fraction0.4/checkpoint/checkpoint-epoch=0009.pt",
42
+ "300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.25/checkpoint/checkpoint-epoch=0009.pt",
43
+ "200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0009.pt",
44
+ "100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0009.pt",
45
+ "40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.02/checkpoint/checkpoint-epoch=0009.pt",
46
+ },
47
+ "large-v2": {
48
+ "1000": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g50_eg0_top5_full-streamTrue_random-orderFalse_fraction0.3/checkpoint/checkpoint-epoch=0002.pt",
49
+ "300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0002.pt",
50
+ "200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.07/checkpoint/checkpoint-epoch=0002.pt",
51
+ "100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.03/checkpoint/checkpoint-epoch=0002.pt",
52
+ "40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.01/checkpoint/checkpoint-epoch=0002.pt",
53
+ "300-multi": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-BLEND-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0001.pt",
54
+ }
55
+ }
56
+
57
+ _STREAMING_MODELS_HF = {
58
+ "base": {
59
+ "300": "base_300.pt",
60
+ "200": "base_200.pt",
61
+ "100": "base_100.pt",
62
+ "40": "base_40.pt",
63
+ },
64
+ "small": {
65
+ "1000": "small_1000.pt",
66
+ "300": "small_300.pt",
67
+ "200": "small_200.pt",
68
+ "100": "small_100.pt",
69
+ "40": "small_40.pt",
70
+ },
71
+ "large-v2": {
72
+ "1000": "large-v2_1000.pt",
73
+ "300": "large-v2_300.pt",
74
+ "200": "large-v2_200.pt",
75
+ "100": "large-v2_100.pt",
76
+ "40": "large-v2_40.pt",
77
+ "300-multi": "large-v2_300_multi.pt",
78
+ }
79
+ }
80
+
81
+ # base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
82
+ # highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
83
+ _ALIGNMENT_HEADS = {
84
+ "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
85
+ "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
86
+ "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
87
+ "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
88
+ "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
89
+ "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
90
+ "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
91
+ "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
92
+ "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
93
+ "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
94
+ "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
95
+ "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
96
+ "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
97
+ "turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
98
+ }
99
+
100
+
101
+ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
102
+ os.makedirs(root, exist_ok=True)
103
+
104
+ expected_sha256 = url.split("/")[-2]
105
+ download_target = os.path.join(root, os.path.basename(url))
106
+
107
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
108
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
109
+
110
+ if os.path.isfile(download_target):
111
+ with open(download_target, "rb") as f:
112
+ model_bytes = f.read()
113
+ if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
114
+ return model_bytes if in_memory else download_target
115
+ else:
116
+ warnings.warn(
117
+ f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
118
+ )
119
+
120
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
121
+ with tqdm(
122
+ total=int(source.info().get("Content-Length")),
123
+ ncols=80,
124
+ unit="iB",
125
+ unit_scale=True,
126
+ unit_divisor=1024,
127
+ ) as loop:
128
+ while True:
129
+ buffer = source.read(8192)
130
+ if not buffer:
131
+ break
132
+
133
+ output.write(buffer)
134
+ loop.update(len(buffer))
135
+
136
+ model_bytes = open(download_target, "rb").read()
137
+ if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
138
+ raise RuntimeError(
139
+ "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
140
+ )
141
+
142
+ return model_bytes if in_memory else download_target
143
+
144
+
145
+ def available_models() -> List[str]:
146
+ """Returns the names of available models"""
147
+ return list(_MODELS.keys())
148
+
149
+
150
+ def load_model(
151
+ name: str,
152
+ device: Optional[Union[str, torch.device]] = None,
153
+ download_root: str = None,
154
+ in_memory: bool = False,
155
+ ) -> Whisper:
156
+ """
157
+ Load a Whisper ASR model
158
+
159
+ Parameters
160
+ ----------
161
+ name : str
162
+ one of the official model names listed by `whisper.available_models()`, or
163
+ path to a model checkpoint containing the model dimensions and the model state_dict.
164
+ device : Union[str, torch.device]
165
+ the PyTorch device to put the model into
166
+ download_root: str
167
+ path to download the model files; by default, it uses "~/.cache/whisper"
168
+ in_memory: bool
169
+ whether to preload the model weights into host memory
170
+
171
+ Returns
172
+ -------
173
+ model : Whisper
174
+ The Whisper ASR model instance
175
+ """
176
+
177
+ if device is None:
178
+ device = "cuda" if torch.cuda.is_available() else "cpu"
179
+ if download_root is None:
180
+ default = os.path.join(os.path.expanduser("~"), ".cache")
181
+ download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
182
+
183
+ if name in _MODELS:
184
+ checkpoint_file = _download(_MODELS[name], download_root, in_memory)
185
+ alignment_heads = _ALIGNMENT_HEADS[name]
186
+ elif os.path.isfile(name):
187
+ checkpoint_file = open(name, "rb").read() if in_memory else name
188
+ alignment_heads = None
189
+ else:
190
+ raise RuntimeError(
191
+ f"Model {name} not found; available models = {available_models()}"
192
+ )
193
+
194
+ with (
195
+ io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
196
+ ) as fp:
197
+ checkpoint = torch.load(fp, map_location=device)
198
+ del checkpoint_file
199
+
200
+ dims = ModelDimensions(**checkpoint["dims"])
201
+ model = Whisper(dims)
202
+ model.load_state_dict(checkpoint["model_state_dict"])
203
+
204
+ if alignment_heads is not None:
205
+ model.set_alignment_heads(alignment_heads)
206
+
207
+ return model.to(device)
208
+
209
+
210
+ def load_streaming_model(
211
+ name: str,
212
+ advisor_ckpt_path: str = None,
213
+ ft_model_ckpt_path: str = None,
214
+ device: Optional[Union[str, torch.device]] = None,
215
+ download_root: str = None,
216
+ in_memory: bool = False,
217
+ cache_gran: bool = True,
218
+ gran: int = 15,
219
+ rank: int = 8,
220
+ extra_gran_blocks: int = 0,
221
+ n_advisor_class: int = 4,
222
+ **kwargs: any
223
+ ) -> StreamingWhisper:
224
+ """
225
+ Load a StreamingWhisper ASR model
226
+
227
+ Parameters
228
+ ----------
229
+ name : str
230
+ one of the official model names listed by `whisper.available_models()`, or
231
+ path to a model checkpoint containing the model dimensions and the model state_dict.
232
+ device : Union[str, torch.device]
233
+ the PyTorch device to put the model into
234
+ download_root: str
235
+ path to download the model files; by default, it uses "~/.cache/whisper"
236
+ in_memory: bool
237
+ whether to preload the model weights into host memory
238
+
239
+ Returns
240
+ -------
241
+ model : Whisper
242
+ The Whisper ASR model instance
243
+ """
244
+ if ft_model_ckpt_path is None:
245
+ if device is None:
246
+ device = "cuda" if torch.cuda.is_available() else "cpu"
247
+ if download_root is None:
248
+ default = os.path.join(os.path.expanduser("~"), ".cache")
249
+ download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
250
+
251
+ if name in _MODELS:
252
+ checkpoint_file = _download(_MODELS[name], download_root, in_memory)
253
+ alignment_heads = _ALIGNMENT_HEADS[name]
254
+ elif os.path.isfile(name):
255
+ checkpoint_file = open(name, "rb").read() if in_memory else name
256
+ alignment_heads = None
257
+ else:
258
+ raise RuntimeError(
259
+ f"Model {name} not found; available models = {available_models()}"
260
+ )
261
+
262
+ with (
263
+ io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
264
+ ) as fp:
265
+ checkpoint = torch.load(fp, map_location=device)
266
+ del checkpoint_file
267
+ else:
268
+ checkpoint = torch.load(ft_model_ckpt_path, weights_only=False)
269
+
270
+ decoder_advisor_chkpt = torch.load(advisor_ckpt_path, weights_only=False) if advisor_ckpt_path is not None else {"state_dict": {}}
271
+ advisor_state_dict = {k: v for k, v in decoder_advisor_chkpt["state_dict"].items() if "decoder_advisor" in k}
272
+
273
+ whisper_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint.keys() else checkpoint["state_dict"]
274
+
275
+ whisper_dict = {k.replace("weight", "base_layer.weight") if "attn." in k and "weight" in k
276
+ else k.replace("bias", "base_layer.bias") if "attn." in k and "bias" in k
277
+ else k: v for k, v in whisper_dict.items()}
278
+
279
+ streaming_whisper_state_dict = {**advisor_state_dict, **whisper_dict}
280
+
281
+ dims = ModelDimensions(**checkpoint["dims"])
282
+
283
+ model = StreamingWhisper(dims,
284
+ cache_gran=cache_gran,
285
+ gran=gran,
286
+ rank=rank,
287
+ extra_gran_blocks=extra_gran_blocks,
288
+ n_advisor_class=n_advisor_class,
289
+ **kwargs)
290
+
291
+ model.load_state_dict(streaming_whisper_state_dict, strict=False)
292
+
293
+ # for n, p in model.named_parameters():
294
+ # print(n, p)
295
+
296
+ if ft_model_ckpt_path is None and alignment_heads is not None:
297
+ model.set_alignment_heads(alignment_heads)
298
+
299
+ return model.to(device)
300
+
301
+
302
+ def load_streaming_model_correct(
303
+ name: str,
304
+ gran: int = 300,
305
+ multilingual: bool = False,
306
+ device: Optional[Union[str, torch.device]] = None,
307
+ download_root: str = None,
308
+ in_memory: bool = False,
309
+ ) -> StreamingWhisper:
310
+
311
+ subname = (str(gran) + '-multi') if multilingual else str(gran)
312
+
313
+ from huggingface_hub import hf_hub_download
314
+
315
+ try:
316
+ ckpt_path = hf_hub_download(repo_id="MLSpeech/causal-whisper", filename=_STREAMING_MODELS_HF[name][subname], repo_type="model", token=True)
317
+ except KeyError as e:
318
+ print(f"Streaming model with the next configs: size {name}, multilingual: {multilingual} and chunk size: {gran} is not available.")
319
+
320
+ checkpoint = torch.load(ckpt_path, weights_only=False)
321
+
322
+ dims = ModelDimensions(**checkpoint["dims"])
323
+
324
+ model = StreamingWhisper(dims,
325
+ gran=checkpoint['cfg']['gran'],
326
+ rank=checkpoint['cfg']['rank'],
327
+ extra_gran_blocks=checkpoint['cfg']['extra_gran_blocks'])
328
+
329
+ model.load_state_dict(checkpoint['state_dict'], strict=False)
330
+
331
+ return model.to(device)
whisper_stream/__main__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .transcribe import cli
2
+
3
+ cli()
whisper_stream/assets/gpt2.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
whisper_stream/assets/mel_filters.npz ADDED
Binary file (4.27 kB). View file
 
whisper_stream/assets/multilingual.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
whisper_stream/audio.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from functools import lru_cache
4
+ from subprocess import CalledProcessError, run
5
+ from typing import Optional, Union
6
+
7
+ import wave
8
+ import torch
9
+ import pyaudio
10
+ import numpy as np
11
+ import soundfile as sf
12
+ import torch.nn.functional as F
13
+
14
+ from .utils import exact_div
15
+
16
+ # hard-coded audio hyperparameters
17
+ SAMPLE_RATE = 16000
18
+ N_FFT = 400
19
+ HOP_LENGTH = 160
20
+ CHUNK_LENGTH = 30
21
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
22
+ N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
23
+
24
+ N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
25
+ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
26
+ TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
27
+
28
+
29
+ class MyStream:
30
+ def __init__(self,
31
+ ms_gran: int = 200,
32
+ sample_rate: int = 16000,
33
+ channels: int = 2,
34
+ filename: str = None,
35
+ inp_dtype: any = pyaudio.paInt16,
36
+ simulate_stream: bool = False,
37
+ wav_file: str = None,
38
+ relay: bool = False,
39
+ use_latency: bool = False,
40
+ pad_trim: bool = True,
41
+ use_remote_machine: bool = False):
42
+
43
+ assert ms_gran % 20 == 0, "ms_gran must be a multiple of 20"
44
+
45
+ self.ms_gran = ms_gran
46
+ self.sample_rate = sample_rate
47
+ self.channels = channels
48
+ self.inp_dtype = inp_dtype
49
+ self.relay = relay
50
+ self.use_latency = use_latency
51
+ self.use_remote_machine = use_remote_machine
52
+
53
+ rate_fraction = ms_gran / 1000
54
+ self.chunk_size = int(rate_fraction * sample_rate)
55
+ self.filename = filename
56
+ self.streamed_wav_file = wav_file
57
+
58
+ self.simulate_stream = simulate_stream
59
+ if self.simulate_stream:
60
+ assert wav_file is not None, "when simulating stream a wav file must be provided."
61
+ if pad_trim:
62
+ self.wav_array = pad_or_trim(load_audio(wav_file, sample_rate), length=N_SAMPLES+180) # wav array
63
+ else:
64
+ audio = load_audio(wav_file, sample_rate)
65
+ self.wav_array = pad_or_trim(audio, length=audio.shape[-1]+180)
66
+ print(f"{self.wav_array.shape=}")
67
+
68
+ def _simulate_stream_using_wav(self):
69
+ print("Streaming simulation of a wav started...")
70
+
71
+ for i in range(self.wav_array.shape[-1] // self.chunk_size):
72
+ if i == 0:
73
+ yield self.wav_array[..., :(((i + 1) * self.chunk_size) + 40 + 320)] # 320 is extra 20 msec buffer we need!
74
+ else:
75
+ yield self.wav_array[..., ((i * self.chunk_size) + 40 + 320):(((i + 1) * self.chunk_size) + 40 + 320)]
76
+
77
+ if self.use_latency: time.sleep(self.ms_gran / 1000) # simulating the latency between audio chunks
78
+
79
+ def open_stream(self):
80
+ if self.simulate_stream or self.relay or self.use_remote_machine: return
81
+
82
+ self.audio = pyaudio.PyAudio()
83
+ self.stream = self.audio.open(input=True, format=self.inp_dtype, channels=self.channels, rate=self.sample_rate, frames_per_buffer=self.chunk_size)
84
+
85
+ def _read_from_stream(self):
86
+ print("Streaming instance recording started...")
87
+
88
+ while True:
89
+ yield self.stream.read(self.chunk_size)
90
+
91
+ def _follow_growing_wav(self):
92
+ while not os.path.exists(self.streamed_wav_file):
93
+ time.sleep(0.1)
94
+
95
+ with sf.SoundFile(self.streamed_wav_file, mode='r') as f:
96
+ while True:
97
+ block = f.read(self.chunk_size)
98
+ if len(block) == 0:
99
+ time.sleep(self.ms_gran / 1000) # Wait for more data
100
+ continue
101
+ yield block
102
+
103
+ def _read_raw_pcm(self):
104
+ samples_per_chunk = int(self.sample_rate * (self.ms_gran / 1000))
105
+ bytes_per_sample = 2 # s16le = 16 bits = 2 bytes
106
+ chunk_size = samples_per_chunk * bytes_per_sample
107
+
108
+ while not os.path.exists(self.streamed_wav_file):
109
+ time.sleep(0.1)
110
+
111
+ with open(self.streamed_wav_file, 'rb') as f:
112
+ while True:
113
+ chunk = f.read(chunk_size)
114
+ if not chunk:
115
+ time.sleep((self.ms_gran / 1000))
116
+ continue
117
+ yield np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32768.0
118
+
119
+ def read(self):
120
+ if self.simulate_stream:
121
+ return self._simulate_stream_using_wav()
122
+
123
+ if self.use_remote_machine:
124
+ return self._read_raw_pcm()
125
+
126
+ return self._read_from_stream()
127
+
128
+ def _save_recording_file(self, frames: list):
129
+ print(f"Saving recorded audio file on path {self.filename}")
130
+
131
+ waveFile = wave.open(self.filename, 'wb')
132
+ waveFile.setnchannels(self.channels)
133
+ waveFile.setsampwidth(self.audio.get_sample_size(self.inp_dtype))
134
+ waveFile.setframerate(self.sample_rate)
135
+ waveFile.writeframes(b''.join(frames))
136
+ waveFile.close()
137
+
138
+ def close_stream(self, frames: list):
139
+ if self.simulate_stream: return
140
+
141
+ # Stop Recording
142
+ self.stream.stop_stream()
143
+ self.stream.close()
144
+ self.audio.terminate()
145
+
146
+ print("Finished recording, stream and audio terminated.")
147
+
148
+ if self.filename: self._save_recording_file(frames)
149
+
150
+
151
+ def load_audio(file: str, sr: int = SAMPLE_RATE):
152
+ """
153
+ Open an audio file and read as mono waveform, resampling as necessary
154
+
155
+ Parameters
156
+ ----------
157
+ file: str
158
+ The audio file to open
159
+
160
+ sr: int
161
+ The sample rate to resample the audio if necessary
162
+
163
+ Returns
164
+ -------
165
+ A NumPy array containing the audio waveform, in float32 dtype.
166
+ """
167
+
168
+ # This launches a subprocess to decode audio while down-mixing
169
+ # and resampling as necessary. Requires the ffmpeg CLI in PATH.
170
+ # fmt: off
171
+ cmd = [
172
+ "ffmpeg",
173
+ "-nostdin",
174
+ "-threads", "0",
175
+ "-i", file,
176
+ "-f", "s16le",
177
+ "-ac", "1",
178
+ "-acodec", "pcm_s16le",
179
+ "-ar", str(sr),
180
+ "-"
181
+ ]
182
+ # fmt: on
183
+ try:
184
+ out = run(cmd, capture_output=True, check=True).stdout
185
+ except CalledProcessError as e:
186
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
187
+
188
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
189
+
190
+
191
+ def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
192
+ """
193
+ Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
194
+ """
195
+ if torch.is_tensor(array):
196
+ if array.shape[axis] > length:
197
+ array = array.index_select(
198
+ dim=axis, index=torch.arange(length, device=array.device)
199
+ )
200
+
201
+ if array.shape[axis] < length:
202
+ pad_widths = [(0, 0)] * array.ndim
203
+ pad_widths[axis] = (0, length - array.shape[axis])
204
+ array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
205
+ else:
206
+ if array.shape[axis] > length:
207
+ array = array.take(indices=range(length), axis=axis)
208
+
209
+ if array.shape[axis] < length:
210
+ pad_widths = [(0, 0)] * array.ndim
211
+ pad_widths[axis] = (0, length - array.shape[axis])
212
+ array = np.pad(array, pad_widths)
213
+
214
+ return array
215
+
216
+
217
+ @lru_cache(maxsize=None)
218
+ def mel_filters(device, n_mels: int) -> torch.Tensor:
219
+ """
220
+ load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
221
+ Allows decoupling librosa dependency; saved using:
222
+
223
+ np.savez_compressed(
224
+ "mel_filters.npz",
225
+ mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
226
+ mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
227
+ )
228
+ """
229
+ assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
230
+
231
+ filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
232
+ with np.load(filters_path, allow_pickle=False) as f:
233
+ return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
234
+
235
+
236
+ def log_mel_spectrogram(
237
+ audio: Union[str, np.ndarray, torch.Tensor],
238
+ n_mels: int = 80,
239
+ padding: int = 0,
240
+ device: Optional[Union[str, torch.device]] = None,
241
+ ):
242
+ """
243
+ Compute the log-Mel spectrogram of
244
+
245
+ Parameters
246
+ ----------
247
+ audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
248
+ The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
249
+
250
+ n_mels: int
251
+ The number of Mel-frequency filters, only 80 is supported
252
+
253
+ padding: int
254
+ Number of zero samples to pad to the right
255
+
256
+ device: Optional[Union[str, torch.device]]
257
+ If given, the audio tensor is moved to this device before STFT
258
+
259
+ Returns
260
+ -------
261
+ torch.Tensor, shape = (80, n_frames)
262
+ A Tensor that contains the Mel spectrogram
263
+ """
264
+ if not torch.is_tensor(audio):
265
+ if isinstance(audio, str):
266
+ audio = load_audio(audio)
267
+ audio = torch.from_numpy(audio)
268
+
269
+ if device is not None:
270
+ audio = audio.to(device)
271
+ if padding > 0:
272
+ audio = F.pad(audio, (0, padding))
273
+
274
+ window = torch.hann_window(N_FFT).to(audio.device)
275
+ stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
276
+ magnitudes = stft[..., :-1].abs() ** 2
277
+
278
+ filters = mel_filters(audio.device, n_mels)
279
+ mel_spec = filters @ magnitudes
280
+
281
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
282
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
283
+ log_spec = (log_spec + 4.0) / 4.0
284
+ return log_spec
285
+
286
+
287
+ class SpectrogramStream:
288
+ def __init__(self, n_fft: int = N_FFT, hop_length: int = HOP_LENGTH, n_mels: int = 80, window: Optional[str] = "hann", pad_mode: str = "reflect"):
289
+
290
+ self.n_fft = n_fft
291
+ self.hop_length = hop_length
292
+ self.pad_mode = pad_mode
293
+ self.n_mels = n_mels
294
+
295
+ self.window = torch.hann_window(n_fft)
296
+ self.window_type = window
297
+
298
+ self.ctx_samples = self.n_fft - self.hop_length
299
+
300
+ self.reset()
301
+
302
+ def reset(self):
303
+ self.is_first = True
304
+ self.audio_ctx = torch.tensor([])
305
+ self.log_spec_max = -torch.inf
306
+
307
+ def calc_mel_with_new_frame(self, audio_frame: torch.Tensor, is_last: bool = False):
308
+
309
+ self.window = self.window.to(audio_frame.device)
310
+
311
+ if len(audio_frame.shape) == 1:
312
+ audio_frame = audio_frame.unsqueeze(0)
313
+
314
+ n_batch = audio_frame.shape[0]
315
+
316
+ if isinstance(self.log_spec_max, float):
317
+ self.log_spec_max = torch.ones((n_batch)).to(audio_frame.device) * -torch.inf
318
+
319
+ # check if we are on first frame, if so, pad using reflection
320
+ if self.is_first:
321
+ pad = int(self.n_fft // 2) + 1
322
+ audio_input = F.pad(audio_frame, [pad, 0], self.pad_mode)
323
+ self.is_first = False
324
+ else: # pad with previous context
325
+ audio_input = torch.cat([self.audio_ctx[..., -self.ctx_samples:], audio_frame], dim=-1)
326
+
327
+ if is_last: # pad reflect last frame
328
+ pad = int(self.n_fft // 4) + 1
329
+ audio_input = F.pad(audio_input, [pad, 0], self.pad_mode)
330
+
331
+ self.audio_ctx = audio_frame # now audio ctx is the last frame
332
+
333
+ stft = torch.stft(audio_input, self.n_fft, self.hop_length, window=self.window, return_complex=True, center=False)
334
+ magnitudes = stft.abs() ** 2
335
+ filters = mel_filters(audio_frame.device, self.n_mels)
336
+ mel_spec = filters @ magnitudes
337
+
338
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10() # from shape (b, n_mels, audio_frames)
339
+ self.log_spec_max = torch.maximum(log_spec.view(n_batch, -1).max(dim=-1).values, self.log_spec_max).to(log_spec.device)
340
+
341
+ log_spec = torch.maximum(log_spec.view(n_batch, -1).permute(1, 0), self.log_spec_max - 8.0).permute(1, 0).view(n_batch, self.n_mels, -1)
342
+ log_spec = (log_spec + 4.0) / 4.0
343
+ return log_spec
344
+
345
+ def _simulate_streaming_log_spec(self, audio: torch.Tensor, ms_gran: int = 300, total_frames: int = 3000, get_gt: bool = False):
346
+ self.reset()
347
+
348
+ samples_gran = HOP_LENGTH * (ms_gran // 10)
349
+ sub_mel_frames = int(total_frames / ms_gran) * 10
350
+ # print(samples_gran, sub_mel_frames)
351
+ pred_mel = torch.cat([self.calc_mel_with_new_frame(audio[..., (i * samples_gran) + (40 * int(i != 0)): ((i + 1) * samples_gran) + 40], is_last=(i == sub_mel_frames - 1)) for i in range(sub_mel_frames)], dim=-1)
352
+
353
+ if get_gt:
354
+ gt_mel = log_mel_spectrogram(audio)
355
+ return pred_mel, gt_mel
356
+
357
+ return pred_mel
whisper_stream/decoding.py ADDED
@@ -0,0 +1,826 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field, replace
2
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.distributions import Categorical
9
+
10
+ from .audio import CHUNK_LENGTH
11
+ from .tokenizer import Tokenizer, get_tokenizer
12
+ from .utils import compression_ratio
13
+
14
+ if TYPE_CHECKING:
15
+ from .model import Whisper
16
+
17
+
18
+ @torch.no_grad()
19
+ def detect_language(
20
+ model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
21
+ ) -> Tuple[Tensor, List[dict]]:
22
+ """
23
+ Detect the spoken language in the audio, and return them as list of strings, along with the ids
24
+ of the most probable language tokens and the probability distribution over all language tokens.
25
+ This is performed outside the main decode loop in order to not interfere with kv-caching.
26
+
27
+ Returns
28
+ -------
29
+ language_tokens : Tensor, shape = (n_audio,)
30
+ ids of the most probable language tokens, which appears after the startoftranscript token.
31
+ language_probs : List[Dict[str, float]], length = n_audio
32
+ list of dictionaries containing the probability distribution over all languages.
33
+ """
34
+ if tokenizer is None:
35
+ tokenizer = get_tokenizer(
36
+ model.is_multilingual, num_languages=model.num_languages
37
+ )
38
+ if (
39
+ tokenizer.language is None
40
+ or tokenizer.language_token not in tokenizer.sot_sequence
41
+ ):
42
+ raise ValueError(
43
+ "This model doesn't have language tokens so it can't perform lang id"
44
+ )
45
+
46
+ single = mel.ndim == 2
47
+ if single:
48
+ mel = mel.unsqueeze(0)
49
+
50
+ # skip encoder forward pass if already-encoded audio features were given
51
+ if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
52
+ mel = model.encoder(mel)
53
+
54
+ # forward pass using a single token, startoftranscript
55
+ n_audio = mel.shape[0]
56
+ x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
57
+ logits = model.logits(x, mel)[:, 0]
58
+
59
+ # collect detected languages; suppress all non-language tokens
60
+ mask = torch.ones(logits.shape[-1], dtype=torch.bool)
61
+ mask[list(tokenizer.all_language_tokens)] = False
62
+ logits[:, mask] = -np.inf
63
+ language_tokens = logits.argmax(dim=-1)
64
+ language_token_probs = logits.softmax(dim=-1).cpu()
65
+ language_probs = [
66
+ {
67
+ c: language_token_probs[i, j].item()
68
+ for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
69
+ }
70
+ for i in range(n_audio)
71
+ ]
72
+
73
+ if single:
74
+ language_tokens = language_tokens[0]
75
+ language_probs = language_probs[0]
76
+
77
+ return language_tokens, language_probs
78
+
79
+
80
+ @dataclass(frozen=True)
81
+ class DecodingOptions:
82
+ # whether to perform X->X "transcribe" or X->English "translate"
83
+ task: str = "transcribe"
84
+
85
+ # language that the audio is in; uses detected language if None
86
+ language: Optional[str] = None
87
+
88
+ # sampling-related options
89
+ temperature: float = 0.0
90
+ sample_len: Optional[int] = None # maximum number of tokens to sample
91
+ best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
92
+ beam_size: Optional[int] = None # number of beams in beam search, if t == 0
93
+ patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
94
+
95
+ # "alpha" in Google NMT, or None for length norm, when ranking generations
96
+ # to select which to return among the beams or best-of-N samples
97
+ length_penalty: Optional[float] = None
98
+
99
+ # text or tokens to feed as the prompt or the prefix; for more info:
100
+ # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
101
+ prompt: Optional[Union[str, List[int]]] = None # for the previous context
102
+ prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
103
+
104
+ # list of tokens ids (or comma-separated token ids) to suppress
105
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
106
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
107
+ suppress_blank: bool = True # this will suppress blank outputs
108
+
109
+ # timestamp sampling options
110
+ without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
111
+ max_initial_timestamp: Optional[float] = 1.0
112
+
113
+ # implementation details
114
+ fp16: bool = True # use fp16 for most of the calculation
115
+
116
+
117
+ @dataclass(frozen=True)
118
+ class DecodingResult:
119
+ audio_features: Tensor
120
+ language: str
121
+ language_probs: Optional[Dict[str, float]] = None
122
+ tokens: List[int] = field(default_factory=list)
123
+ text: str = ""
124
+ avg_logprob: float = np.nan
125
+ no_speech_prob: float = np.nan
126
+ temperature: float = np.nan
127
+ compression_ratio: float = np.nan
128
+
129
+
130
+ class Inference:
131
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
132
+ """Perform a forward pass on the decoder and return per-token logits"""
133
+ raise NotImplementedError
134
+
135
+ def rearrange_kv_cache(self, source_indices) -> None:
136
+ """Update the key-value cache according to the updated beams"""
137
+ raise NotImplementedError
138
+
139
+ def cleanup_caching(self) -> None:
140
+ """Clean up any resources or hooks after decoding is finished"""
141
+ pass
142
+
143
+
144
+ class PyTorchInference(Inference):
145
+ def __init__(self, model: "Whisper", initial_token_length: int):
146
+ self.model: "Whisper" = model
147
+ self.initial_token_length = initial_token_length
148
+ self.kv_cache = {}
149
+ self.hooks = []
150
+
151
+ key_modules = [block.attn.key for block in self.model.decoder.blocks]
152
+ value_modules = [block.attn.value for block in self.model.decoder.blocks]
153
+ self.kv_modules = key_modules + value_modules
154
+
155
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
156
+ if not self.kv_cache:
157
+ self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
158
+
159
+ if tokens.shape[-1] > self.initial_token_length:
160
+ # only need to use the last token except in the first forward pass
161
+ tokens = tokens[:, -1:]
162
+ return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
163
+
164
+ def cleanup_caching(self):
165
+ for hook in self.hooks:
166
+ hook.remove()
167
+
168
+ self.kv_cache = {}
169
+ self.hooks = []
170
+
171
+ def rearrange_kv_cache(self, source_indices):
172
+ if source_indices != list(range(len(source_indices))):
173
+ for module in self.kv_modules:
174
+ # update the key/value cache to contain the selected sequences
175
+ self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
176
+
177
+
178
+ class SequenceRanker:
179
+ def rank(
180
+ self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
181
+ ) -> List[int]:
182
+ """
183
+ Given a list of groups of samples and their cumulative log probabilities,
184
+ return the indices of the samples in each group to select as the final result
185
+ """
186
+ raise NotImplementedError
187
+
188
+
189
+ class MaximumLikelihoodRanker(SequenceRanker):
190
+ """
191
+ Select the sample with the highest log probabilities, penalized using either
192
+ a simple length normalization or Google NMT paper's length penalty
193
+ """
194
+
195
+ def __init__(self, length_penalty: Optional[float]):
196
+ self.length_penalty = length_penalty
197
+
198
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
199
+ def scores(logprobs, lengths):
200
+ result = []
201
+ for logprob, length in zip(logprobs, lengths):
202
+ if self.length_penalty is None:
203
+ penalty = length
204
+ else:
205
+ # from the Google NMT paper
206
+ penalty = ((5 + length) / 6) ** self.length_penalty
207
+ result.append(logprob / penalty)
208
+ return result
209
+
210
+ # get the sequence with the highest score
211
+ lengths = [[len(t) for t in s] for s in tokens]
212
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
213
+
214
+
215
+ class TokenDecoder:
216
+ def reset(self):
217
+ """Initialize any stateful variables for decoding a new sequence"""
218
+
219
+ def update(
220
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
221
+ ) -> Tuple[Tensor, bool]:
222
+ """Specify how to select the next token, based on the current trace and logits
223
+
224
+ Parameters
225
+ ----------
226
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
227
+ all tokens in the context so far, including the prefix and sot_sequence tokens
228
+
229
+ logits : Tensor, shape = (n_batch, vocab_size)
230
+ per-token logits of the probability distribution at the current step
231
+
232
+ sum_logprobs : Tensor, shape = (n_batch)
233
+ cumulative log probabilities for each sequence
234
+
235
+ Returns
236
+ -------
237
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
238
+ the tokens, appended with the selected next token
239
+
240
+ completed : bool
241
+ True if all sequences has reached the end of text
242
+
243
+ """
244
+ raise NotImplementedError
245
+
246
+ def finalize(
247
+ self, tokens: Tensor, sum_logprobs: Tensor
248
+ ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
249
+ """Finalize search and return the final candidate sequences
250
+
251
+ Parameters
252
+ ----------
253
+ tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
254
+ all tokens in the context so far, including the prefix and sot_sequence
255
+
256
+ sum_logprobs : Tensor, shape = (n_audio, n_group)
257
+ cumulative log probabilities for each sequence
258
+
259
+ Returns
260
+ -------
261
+ tokens : Sequence[Sequence[Tensor]], length = n_audio
262
+ sequence of Tensors containing candidate token sequences, for each audio input
263
+
264
+ sum_logprobs : List[List[float]], length = n_audio
265
+ sequence of cumulative log probabilities corresponding to the above
266
+
267
+ """
268
+ raise NotImplementedError
269
+
270
+
271
+ class GreedyDecoder(TokenDecoder):
272
+ def __init__(self, temperature: float, eot: int):
273
+ self.temperature = temperature
274
+ self.eot = eot
275
+
276
+ def update(
277
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
278
+ ) -> Tuple[Tensor, bool]:
279
+ if self.temperature == 0:
280
+ next_tokens = logits.argmax(dim=-1)
281
+ else:
282
+ next_tokens = Categorical(logits=logits / self.temperature).sample()
283
+
284
+ logprobs = F.log_softmax(logits.float(), dim=-1)
285
+
286
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
287
+ sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
288
+
289
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
290
+ tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
291
+
292
+ completed = (tokens[:, -1] == self.eot).all()
293
+ return tokens, completed
294
+
295
+ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
296
+ # make sure each sequence has at least one EOT token at the end
297
+ tokens = F.pad(tokens, (0, 1), value=self.eot)
298
+ return tokens, sum_logprobs.tolist()
299
+
300
+
301
+ class BeamSearchDecoder(TokenDecoder):
302
+ def __init__(
303
+ self,
304
+ beam_size: int,
305
+ eot: int,
306
+ inference: Inference,
307
+ patience: Optional[float] = None,
308
+ ):
309
+ self.beam_size = beam_size
310
+ self.eot = eot
311
+ self.inference = inference
312
+ self.patience = patience or 1.0
313
+ self.max_candidates: int = round(beam_size * self.patience)
314
+ self.finished_sequences = None
315
+
316
+ assert (
317
+ self.max_candidates > 0
318
+ ), f"Invalid beam size ({beam_size}) or patience ({patience})"
319
+
320
+ def reset(self):
321
+ self.finished_sequences = None
322
+
323
+ def update(
324
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
325
+ ) -> Tuple[Tensor, bool]:
326
+ if tokens.shape[0] % self.beam_size != 0:
327
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
328
+
329
+ n_audio = tokens.shape[0] // self.beam_size
330
+ if self.finished_sequences is None: # for the first update
331
+ self.finished_sequences = [{} for _ in range(n_audio)]
332
+
333
+ logprobs = F.log_softmax(logits.float(), dim=-1)
334
+ next_tokens, source_indices, finished_sequences = [], [], []
335
+ for i in range(n_audio):
336
+ scores, sources, finished = {}, {}, {}
337
+
338
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
339
+ for j in range(self.beam_size):
340
+ idx = i * self.beam_size + j
341
+ prefix = tokens[idx].tolist()
342
+ for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
343
+ new_logprob = (sum_logprobs[idx] + logprob).item()
344
+ sequence = tuple(prefix + [token.item()])
345
+ scores[sequence] = new_logprob
346
+ sources[sequence] = idx
347
+
348
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
349
+ saved = 0
350
+ for sequence in sorted(scores, key=scores.get, reverse=True):
351
+ if sequence[-1] == self.eot:
352
+ finished[sequence] = scores[sequence]
353
+ else:
354
+ sum_logprobs[len(next_tokens)] = scores[sequence]
355
+ next_tokens.append(sequence)
356
+ source_indices.append(sources[sequence])
357
+
358
+ saved += 1
359
+ if saved == self.beam_size:
360
+ break
361
+
362
+ finished_sequences.append(finished)
363
+
364
+ tokens = torch.tensor(next_tokens, device=tokens.device)
365
+ self.inference.rearrange_kv_cache(source_indices)
366
+
367
+ # add newly finished sequences to self.finished_sequences
368
+ assert len(self.finished_sequences) == len(finished_sequences)
369
+ for previously_finished, newly_finished in zip(
370
+ self.finished_sequences, finished_sequences
371
+ ):
372
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
373
+ if len(previously_finished) >= self.max_candidates:
374
+ break # the candidate list is full
375
+ previously_finished[seq] = newly_finished[seq]
376
+
377
+ # mark as completed if all audio has enough number of samples
378
+ completed = all(
379
+ len(sequences) >= self.max_candidates
380
+ for sequences in self.finished_sequences
381
+ )
382
+ return tokens, completed
383
+
384
+ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
385
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
386
+ sum_logprobs = sum_logprobs.cpu()
387
+ for i, sequences in enumerate(self.finished_sequences):
388
+ if (
389
+ len(sequences) < self.beam_size
390
+ ): # when not enough sequences are finished
391
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
392
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
393
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
394
+ if len(sequences) >= self.beam_size:
395
+ break
396
+
397
+ tokens: List[List[Tensor]] = [
398
+ [torch.tensor(seq) for seq in sequences.keys()]
399
+ for sequences in self.finished_sequences
400
+ ]
401
+ sum_logprobs: List[List[float]] = [
402
+ list(sequences.values()) for sequences in self.finished_sequences
403
+ ]
404
+ return tokens, sum_logprobs
405
+
406
+
407
+ class LogitFilter:
408
+ def apply(self, logits: Tensor, tokens: Tensor) -> None:
409
+ """Apply any filtering or masking to logits in-place
410
+
411
+ Parameters
412
+ ----------
413
+ logits : Tensor, shape = (n_batch, vocab_size)
414
+ per-token logits of the probability distribution at the current step
415
+
416
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
417
+ all tokens in the context so far, including the prefix and sot_sequence tokens
418
+
419
+ """
420
+ raise NotImplementedError
421
+
422
+
423
+ class SuppressBlank(LogitFilter):
424
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
425
+ self.tokenizer = tokenizer
426
+ self.sample_begin = sample_begin
427
+
428
+ def apply(self, logits: Tensor, tokens: Tensor):
429
+ if tokens.shape[1] == self.sample_begin:
430
+ logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
431
+
432
+
433
+ class SuppressTokens(LogitFilter):
434
+ def __init__(self, suppress_tokens: Sequence[int]):
435
+ self.suppress_tokens = list(suppress_tokens)
436
+
437
+ def apply(self, logits: Tensor, tokens: Tensor):
438
+ logits[:, self.suppress_tokens] = -np.inf
439
+
440
+
441
+ class ApplyTimestampRules(LogitFilter):
442
+ def __init__(
443
+ self,
444
+ tokenizer: Tokenizer,
445
+ sample_begin: int,
446
+ max_initial_timestamp_index: Optional[int],
447
+ ):
448
+ self.tokenizer = tokenizer
449
+ self.sample_begin = sample_begin
450
+ self.max_initial_timestamp_index = max_initial_timestamp_index
451
+
452
+ def apply(self, logits: Tensor, tokens: Tensor):
453
+ # suppress <|notimestamps|> which is handled by without_timestamps
454
+ if self.tokenizer.no_timestamps is not None:
455
+ logits[:, self.tokenizer.no_timestamps] = -np.inf
456
+
457
+ # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
458
+ for k in range(tokens.shape[0]):
459
+ sampled_tokens = tokens[k, self.sample_begin :]
460
+ seq = [t for t in sampled_tokens.tolist()]
461
+ last_was_timestamp = (
462
+ len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
463
+ )
464
+ penultimate_was_timestamp = (
465
+ len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
466
+ )
467
+
468
+ if last_was_timestamp:
469
+ if penultimate_was_timestamp: # has to be non-timestamp
470
+ logits[k, self.tokenizer.timestamp_begin :] = -np.inf
471
+ else: # cannot be normal text tokens
472
+ logits[k, : self.tokenizer.eot] = -np.inf
473
+
474
+ timestamps = sampled_tokens[
475
+ sampled_tokens.ge(self.tokenizer.timestamp_begin)
476
+ ]
477
+ if timestamps.numel() > 0:
478
+ # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
479
+ # also force each segment to have a nonzero length, to prevent infinite looping
480
+ if last_was_timestamp and not penultimate_was_timestamp:
481
+ timestamp_last = timestamps[-1]
482
+ else:
483
+ timestamp_last = timestamps[-1] + 1
484
+ logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
485
+
486
+ if tokens.shape[1] == self.sample_begin:
487
+ # suppress generating non-timestamp tokens at the beginning
488
+ logits[:, : self.tokenizer.timestamp_begin] = -np.inf
489
+
490
+ # apply the `max_initial_timestamp` option
491
+ if self.max_initial_timestamp_index is not None:
492
+ last_allowed = (
493
+ self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
494
+ )
495
+ logits[:, last_allowed + 1 :] = -np.inf
496
+
497
+ # if sum of probability over timestamps is above any other token, sample timestamp
498
+ logprobs = F.log_softmax(logits.float(), dim=-1)
499
+ for k in range(tokens.shape[0]):
500
+ timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
501
+ dim=-1
502
+ )
503
+ max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
504
+ if timestamp_logprob > max_text_token_logprob:
505
+ logits[k, : self.tokenizer.timestamp_begin] = -np.inf
506
+
507
+
508
+ class DecodingTask:
509
+ inference: Inference
510
+ sequence_ranker: SequenceRanker
511
+ decoder: TokenDecoder
512
+ logit_filters: List[LogitFilter]
513
+
514
+ def __init__(self, model: "Whisper", options: DecodingOptions):
515
+ self.model = model
516
+
517
+ language = options.language or "en"
518
+ tokenizer = get_tokenizer(
519
+ model.is_multilingual,
520
+ num_languages=model.num_languages,
521
+ language=language,
522
+ task=options.task,
523
+ )
524
+ self.tokenizer: Tokenizer = tokenizer
525
+ self.options: DecodingOptions = self._verify_options(options)
526
+
527
+ self.n_group: int = options.beam_size or options.best_of or 1
528
+ self.n_ctx: int = model.dims.n_text_ctx
529
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
530
+
531
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
532
+ if self.options.without_timestamps:
533
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
534
+
535
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
536
+ self.sample_begin: int = len(self.initial_tokens)
537
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
538
+
539
+ # inference: implements the forward pass through the decoder, including kv caching
540
+ self.inference = PyTorchInference(model, len(self.initial_tokens))
541
+
542
+ # sequence ranker: implements how to rank a group of sampled sequences
543
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
544
+
545
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
546
+ if options.beam_size is not None:
547
+ self.decoder = BeamSearchDecoder(
548
+ options.beam_size, tokenizer.eot, self.inference, options.patience
549
+ )
550
+ else:
551
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
552
+
553
+ # logit filters: applies various rules to suppress or penalize certain tokens
554
+ self.logit_filters = []
555
+ if self.options.suppress_blank:
556
+ self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
557
+ if self.options.suppress_tokens:
558
+ self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
559
+ if not options.without_timestamps:
560
+ precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
561
+ max_initial_timestamp_index = None
562
+ if options.max_initial_timestamp:
563
+ max_initial_timestamp_index = round(
564
+ self.options.max_initial_timestamp / precision
565
+ )
566
+ self.logit_filters.append(
567
+ ApplyTimestampRules(
568
+ tokenizer, self.sample_begin, max_initial_timestamp_index
569
+ )
570
+ )
571
+
572
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
573
+ if options.beam_size is not None and options.best_of is not None:
574
+ raise ValueError("beam_size and best_of can't be given together")
575
+ if options.temperature == 0:
576
+ if options.best_of is not None:
577
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
578
+ if options.patience is not None and options.beam_size is None:
579
+ raise ValueError("patience requires beam_size to be given")
580
+ if options.length_penalty is not None and not (
581
+ 0 <= options.length_penalty <= 1
582
+ ):
583
+ raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
584
+
585
+ return options
586
+
587
+ def _get_initial_tokens(self) -> Tuple[int]:
588
+ tokens = list(self.sot_sequence)
589
+
590
+ if prefix := self.options.prefix:
591
+ prefix_tokens = (
592
+ self.tokenizer.encode(" " + prefix.strip())
593
+ if isinstance(prefix, str)
594
+ else prefix
595
+ )
596
+ if self.sample_len is not None:
597
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
598
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
599
+ tokens = tokens + prefix_tokens
600
+
601
+ if prompt := self.options.prompt:
602
+ prompt_tokens = (
603
+ self.tokenizer.encode(" " + prompt.strip())
604
+ if isinstance(prompt, str)
605
+ else prompt
606
+ )
607
+ tokens = (
608
+ [self.tokenizer.sot_prev]
609
+ + prompt_tokens[-(self.n_ctx // 2 - 1) :]
610
+ + tokens
611
+ )
612
+
613
+ return tuple(tokens)
614
+
615
+ def _get_suppress_tokens(self) -> Tuple[int]:
616
+ suppress_tokens = self.options.suppress_tokens
617
+
618
+ if isinstance(suppress_tokens, str):
619
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
620
+
621
+ if -1 in suppress_tokens:
622
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
623
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
624
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
625
+ suppress_tokens = [] # interpret empty string as an empty list
626
+ else:
627
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
628
+
629
+ suppress_tokens.extend(
630
+ [
631
+ self.tokenizer.transcribe,
632
+ self.tokenizer.translate,
633
+ self.tokenizer.sot,
634
+ self.tokenizer.sot_prev,
635
+ self.tokenizer.sot_lm,
636
+ ]
637
+ )
638
+ if self.tokenizer.no_speech is not None:
639
+ # no-speech probability is collected separately
640
+ suppress_tokens.append(self.tokenizer.no_speech)
641
+
642
+ return tuple(sorted(set(suppress_tokens)))
643
+
644
+ def _get_audio_features(self, mel: Tensor):
645
+ if self.options.fp16:
646
+ mel = mel.half()
647
+
648
+ if mel.shape[-2:] == (
649
+ self.model.dims.n_audio_ctx,
650
+ self.model.dims.n_audio_state,
651
+ ):
652
+ # encoded audio features are given; skip audio encoding
653
+ audio_features = mel
654
+ else:
655
+ audio_features = self.model.encoder(mel)
656
+
657
+ if audio_features.dtype != (
658
+ torch.float16 if self.options.fp16 else torch.float32
659
+ ):
660
+ return TypeError(
661
+ f"audio_features has an incorrect dtype: {audio_features.dtype}"
662
+ )
663
+
664
+ return audio_features
665
+
666
+ def _detect_language(self, audio_features: Tensor, tokens: Tensor):
667
+ languages = [self.options.language] * audio_features.shape[0]
668
+ lang_probs = None
669
+
670
+ if self.options.language is None or self.options.task == "lang_id":
671
+ lang_tokens, lang_probs = self.model.detect_language(
672
+ audio_features, self.tokenizer
673
+ )
674
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
675
+ if self.options.language is None:
676
+ tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
677
+
678
+ return languages, lang_probs
679
+
680
+ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
681
+ n_batch = tokens.shape[0]
682
+ sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
683
+ no_speech_probs = [np.nan] * n_batch
684
+
685
+ try:
686
+ for i in range(self.sample_len):
687
+
688
+ logits = self.inference.logits(tokens, audio_features)
689
+
690
+ if (
691
+ i == 0 and self.tokenizer.no_speech is not None
692
+ ): # save no_speech_probs
693
+ probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
694
+ no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
695
+
696
+ # now we need to consider the logits at the last token only
697
+ logits = logits[:, -1]
698
+
699
+ # apply the logit filters, e.g. for suppressing or applying penalty to
700
+ for logit_filter in self.logit_filters:
701
+ logit_filter.apply(logits, tokens)
702
+
703
+ # expand the tokens tensor with the selected next tokens
704
+ tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
705
+
706
+ if completed or tokens.shape[-1] > self.n_ctx:
707
+ break
708
+ finally:
709
+ self.inference.cleanup_caching()
710
+
711
+ return tokens, sum_logprobs, no_speech_probs
712
+
713
+ @torch.no_grad()
714
+ def run(self, mel: Tensor) -> List[DecodingResult]:
715
+ self.decoder.reset()
716
+ tokenizer: Tokenizer = self.tokenizer
717
+ n_audio: int = mel.shape[0]
718
+
719
+ audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
720
+ tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
721
+
722
+ # detect language if requested, overwriting the language token
723
+ languages, language_probs = self._detect_language(audio_features, tokens)
724
+ if self.options.task == "lang_id":
725
+ return [
726
+ DecodingResult(
727
+ audio_features=features, language=language, language_probs=probs
728
+ )
729
+ for features, language, probs in zip(
730
+ audio_features, languages, language_probs
731
+ )
732
+ ]
733
+
734
+ # repeat text tensors by the group size, for beam search or best-of-n sampling
735
+ tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
736
+
737
+ # call the main sampling loop
738
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
739
+
740
+ # reshape the tensors to have (n_audio, n_group) as the first two dimensions
741
+ audio_features = audio_features[:: self.n_group]
742
+ no_speech_probs = no_speech_probs[:: self.n_group]
743
+ assert audio_features.shape[0] == len(no_speech_probs) == n_audio
744
+
745
+ tokens = tokens.reshape(n_audio, self.n_group, -1)
746
+ sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
747
+
748
+ # get the final candidates for each group, and slice between the first sampled token and EOT
749
+ tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
750
+ tokens: List[List[Tensor]] = [
751
+ [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
752
+ for s in tokens
753
+ ]
754
+
755
+ # select the top-ranked sample in each group
756
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
757
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
758
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
759
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
760
+ avg_logprobs: List[float] = [
761
+ lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
762
+ ]
763
+
764
+ fields = (
765
+ texts,
766
+ languages,
767
+ tokens,
768
+ audio_features,
769
+ avg_logprobs,
770
+ no_speech_probs,
771
+ )
772
+ if len(set(map(len, fields))) != 1:
773
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
774
+
775
+ return [
776
+ DecodingResult(
777
+ audio_features=features,
778
+ language=language,
779
+ tokens=tokens,
780
+ text=text,
781
+ avg_logprob=avg_logprob,
782
+ no_speech_prob=no_speech_prob,
783
+ temperature=self.options.temperature,
784
+ compression_ratio=compression_ratio(text),
785
+ )
786
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
787
+ *fields
788
+ )
789
+ ]
790
+
791
+
792
+ @torch.no_grad()
793
+ def decode(
794
+ model: "Whisper",
795
+ mel: Tensor,
796
+ options: DecodingOptions = DecodingOptions(),
797
+ **kwargs,
798
+ ) -> Union[DecodingResult, List[DecodingResult]]:
799
+ """
800
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
801
+
802
+ Parameters
803
+ ----------
804
+ model: Whisper
805
+ the Whisper model instance
806
+
807
+ mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
808
+ A tensor containing the Mel spectrogram(s)
809
+
810
+ options: DecodingOptions
811
+ A dataclass that contains all necessary options for decoding 30-second segments
812
+
813
+ Returns
814
+ -------
815
+ result: Union[DecodingResult, List[DecodingResult]]
816
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
817
+ """
818
+ if single := mel.ndim == 2:
819
+ mel = mel.unsqueeze(0)
820
+
821
+ if kwargs:
822
+ options = replace(options, **kwargs)
823
+
824
+ result = DecodingTask(model, options).run(mel)
825
+
826
+ return result[0] if single else result
whisper_stream/model.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import gzip
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Iterable, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import Tensor, nn
10
+
11
+ from .decoding import decode as decode_function
12
+ from .decoding import detect_language as detect_language_function
13
+ from .transcribe import transcribe as transcribe_function
14
+
15
+
16
+ @dataclass
17
+ class ModelDimensions:
18
+ n_mels: int
19
+ n_audio_ctx: int
20
+ n_audio_state: int
21
+ n_audio_head: int
22
+ n_audio_layer: int
23
+ n_vocab: int
24
+ n_text_ctx: int
25
+ n_text_state: int
26
+ n_text_head: int
27
+ n_text_layer: int
28
+
29
+
30
+ class LayerNorm(nn.LayerNorm):
31
+ def forward(self, x: Tensor) -> Tensor:
32
+ return super().forward(x.float()).type(x.dtype)
33
+
34
+
35
+ class Linear(nn.Linear):
36
+ def forward(self, x: Tensor) -> Tensor:
37
+ return F.linear(
38
+ x,
39
+ self.weight.to(x.dtype),
40
+ None if self.bias is None else self.bias.to(x.dtype),
41
+ )
42
+
43
+
44
+ class Conv1d(nn.Conv1d):
45
+ def _conv_forward(
46
+ self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
47
+ ) -> Tensor:
48
+ return super()._conv_forward(
49
+ x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
50
+ )
51
+
52
+
53
+ def sinusoids(length, channels, max_timescale=10000):
54
+ """Returns sinusoids for positional embedding"""
55
+ assert channels % 2 == 0
56
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
57
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
58
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
59
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
60
+
61
+
62
+ class MultiHeadAttention(nn.Module):
63
+ def __init__(self, n_state: int, n_head: int):
64
+ super().__init__()
65
+ self.n_head = n_head
66
+ self.query = Linear(n_state, n_state)
67
+ self.key = Linear(n_state, n_state, bias=False)
68
+ self.value = Linear(n_state, n_state)
69
+ self.out = Linear(n_state, n_state)
70
+
71
+ def forward(
72
+ self,
73
+ x: Tensor,
74
+ xa: Optional[Tensor] = None,
75
+ mask: Optional[Tensor] = None,
76
+ kv_cache: Optional[dict[any, Tensor]] = None,
77
+ ):
78
+ q = self.query(x)
79
+
80
+ if kv_cache is None or xa is None or self.key not in kv_cache:
81
+ # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
82
+ # otherwise, perform key/value projections for self- or cross-attention as usual.
83
+ k = self.key(x if xa is None else xa)
84
+ v = self.value(x if xa is None else xa)
85
+ else:
86
+ # for cross-attention, calculate keys and values once and reuse in subsequent calls.
87
+ k = kv_cache[self.key]
88
+ v = kv_cache[self.value]
89
+
90
+ wv, qk = self.qkv_attention(q, k, v, mask)
91
+
92
+ return self.out(wv), qk
93
+
94
+ def qkv_attention(
95
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
96
+ ):
97
+ # print(f"q shape: {q.shape}")
98
+ n_batch, n_ctx, n_state = q.shape
99
+ scale = (n_state // self.n_head) ** -0.25
100
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
101
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
102
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
103
+
104
+ qk = q @ k
105
+ if mask is not None:
106
+ qk = qk + mask[:n_ctx, :n_ctx]
107
+ qk = qk.float()
108
+
109
+ w = F.softmax(qk, dim=-1).to(q.dtype)
110
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
111
+
112
+
113
+ class ResidualAttentionBlock(nn.Module):
114
+ def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
115
+ super().__init__()
116
+
117
+ self.attn = MultiHeadAttention(n_state, n_head)
118
+ self.attn_ln = LayerNorm(n_state)
119
+
120
+ self.cross_attn = (
121
+ MultiHeadAttention(n_state, n_head) if cross_attention else None
122
+ )
123
+ self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
124
+
125
+ n_mlp = n_state * 4
126
+ self.mlp = nn.Sequential(
127
+ Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
128
+ )
129
+ self.mlp_ln = LayerNorm(n_state)
130
+
131
+ def forward(
132
+ self,
133
+ x: Tensor,
134
+ xa: Optional[Tensor] = None,
135
+ mask: Optional[Tensor] = None,
136
+ kv_cache: Optional[dict] = None,
137
+ ):
138
+ # SA
139
+ x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
140
+
141
+ # CA
142
+ if self.cross_attn:
143
+ x = x + self.cross_attn(self.cross_attn_ln(x), xa)[0]
144
+
145
+ # MLP
146
+ x = x + self.mlp(self.mlp_ln(x))
147
+
148
+ return x
149
+
150
+
151
+ class AudioEncoder(nn.Module):
152
+ def __init__(
153
+ self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
154
+ ):
155
+ super().__init__()
156
+ self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
157
+ self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
158
+ self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
159
+
160
+ self.n_head = n_head
161
+ self.n_layer = n_layer
162
+ self.n_state = n_state
163
+
164
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
165
+ [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
166
+ )
167
+ self.ln_post = LayerNorm(n_state)
168
+
169
+ def forward(self, x: Tensor):
170
+ """
171
+ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
172
+ the mel spectrogram of the audio
173
+ """
174
+ x = F.gelu(self.conv1(x))
175
+ x = F.gelu(self.conv2(x))
176
+ x = x.permute(0, 2, 1)
177
+
178
+ assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
179
+ x = (x + self.positional_embedding).to(x.dtype)
180
+
181
+ for block in self.blocks:
182
+ x = block(x)
183
+
184
+ x = self.ln_post(x)
185
+ return x
186
+
187
+
188
+ class TextDecoder(nn.Module):
189
+ def __init__(
190
+ self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
191
+ ):
192
+ super().__init__()
193
+
194
+ self.token_embedding = nn.Embedding(n_vocab, n_state)
195
+ self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
196
+
197
+ self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
198
+ [
199
+ ResidualAttentionBlock(n_state, n_head, cross_attention=True)
200
+ for _ in range(n_layer)
201
+ ]
202
+ )
203
+ self.ln = LayerNorm(n_state)
204
+
205
+ mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
206
+ self.register_buffer("mask", mask, persistent=False)
207
+
208
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
209
+ """
210
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
211
+ the text tokens
212
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
213
+ the encoded audio features to be attended on
214
+ dump_type: str - specifies which dump to return (MLP, pre_MLP, ATT)
215
+ """
216
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
217
+ x = (
218
+ self.token_embedding(x)
219
+ + self.positional_embedding[offset : offset + x.shape[-1]]
220
+ )
221
+ x = x.to(xa.dtype)
222
+
223
+ for block in self.blocks:
224
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
225
+
226
+ x = self.ln(x)
227
+ logits = (
228
+ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
229
+ ).float()
230
+
231
+ return logits
232
+
233
+
234
+ class Whisper(nn.Module):
235
+ def __init__(self, dims: ModelDimensions):
236
+ super().__init__()
237
+ self.dims = dims
238
+ self.encoder = AudioEncoder(
239
+ self.dims.n_mels,
240
+ self.dims.n_audio_ctx,
241
+ self.dims.n_audio_state,
242
+ self.dims.n_audio_head,
243
+ self.dims.n_audio_layer,
244
+ )
245
+ self.decoder = TextDecoder(
246
+ self.dims.n_vocab,
247
+ self.dims.n_text_ctx,
248
+ self.dims.n_text_state,
249
+ self.dims.n_text_head,
250
+ self.dims.n_text_layer,
251
+ )
252
+ # use the last half among the decoder layers for time alignment by default;
253
+ # to use a specific set of heads, see `set_alignment_heads()` below.
254
+ all_heads = torch.zeros(
255
+ self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
256
+ )
257
+ all_heads[self.dims.n_text_layer // 2 :] = True
258
+ # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
259
+ self.register_buffer("alignment_heads", all_heads, persistent=False) # To use lightning can't use sparse weights
260
+
261
+ def set_alignment_heads(self, dump: bytes):
262
+ array = np.frombuffer(
263
+ gzip.decompress(base64.b85decode(dump)), dtype=bool
264
+ ).copy()
265
+ mask = torch.from_numpy(array).reshape(
266
+ self.dims.n_text_layer, self.dims.n_text_head
267
+ )
268
+ # self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
269
+ self.register_buffer("alignment_heads", mask, persistent=False) # To use lightning can't use sparse weights
270
+
271
+ def embed_audio(self, mel: torch.Tensor):
272
+ return self.encoder(mel)
273
+
274
+ def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
275
+ return self.decoder(tokens, audio_features)
276
+
277
+ def forward(
278
+ self, mel: torch.Tensor, tokens: torch.Tensor
279
+ ) -> Dict[str, torch.Tensor]:
280
+ return self.decoder(tokens, self.encoder(mel))
281
+
282
+ @property
283
+ def device(self):
284
+ return next(self.parameters()).device
285
+
286
+ @property
287
+ def is_multilingual(self):
288
+ return self.dims.n_vocab >= 51865
289
+
290
+ @property
291
+ def num_languages(self):
292
+ return self.dims.n_vocab - 51765 - int(self.is_multilingual)
293
+
294
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
295
+ """
296
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
297
+ tensors calculated for the previous positions. This method returns a dictionary that stores
298
+ all caches, and the necessary hooks for the key and value projection modules that save the
299
+ intermediate tensors to be reused during later calculations.
300
+
301
+ Returns
302
+ -------
303
+ cache : Dict[nn.Module, torch.Tensor]
304
+ A dictionary object mapping the key/value projection modules to its cache
305
+ hooks : List[RemovableHandle]
306
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
307
+ """
308
+ cache = {**cache} if cache is not None else {}
309
+ hooks = []
310
+
311
+ def save_to_cache(module, _, output):
312
+ if module not in cache or output.shape[1] > self.dims.n_text_ctx:
313
+ # save as-is, for the first token or cross attention
314
+ cache[module] = output
315
+ else:
316
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
317
+ return cache[module]
318
+
319
+ def install_hooks(layer: nn.Module):
320
+ if isinstance(layer, MultiHeadAttention):
321
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
322
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
323
+
324
+ self.decoder.apply(install_hooks)
325
+ return cache, hooks
326
+
327
+ detect_language = detect_language_function
328
+ transcribe = transcribe_function
329
+ decode = decode_function
whisper_stream/normalizers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .basic import BasicTextNormalizer as BasicTextNormalizer
2
+ from .english import EnglishTextNormalizer as EnglishTextNormalizer
whisper_stream/normalizers/basic.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import unicodedata
3
+
4
+ import regex
5
+
6
+ # non-ASCII letters that are not separated by "NFKD" normalization
7
+ ADDITIONAL_DIACRITICS = {
8
+ "œ": "oe",
9
+ "Œ": "OE",
10
+ "ø": "o",
11
+ "Ø": "O",
12
+ "æ": "ae",
13
+ "Æ": "AE",
14
+ "ß": "ss",
15
+ "ẞ": "SS",
16
+ "đ": "d",
17
+ "Đ": "D",
18
+ "ð": "d",
19
+ "Ð": "D",
20
+ "þ": "th",
21
+ "Þ": "th",
22
+ "ł": "l",
23
+ "Ł": "L",
24
+ }
25
+
26
+
27
+ def remove_symbols_and_diacritics(s: str, keep=""):
28
+ """
29
+ Replace any other markers, symbols, and punctuations with a space,
30
+ and drop any diacritics (category 'Mn' and some manual mappings)
31
+ """
32
+ return "".join(
33
+ c
34
+ if c in keep
35
+ else ADDITIONAL_DIACRITICS[c]
36
+ if c in ADDITIONAL_DIACRITICS
37
+ else ""
38
+ if unicodedata.category(c) == "Mn"
39
+ else " "
40
+ if unicodedata.category(c)[0] in "MSP"
41
+ else c
42
+ for c in unicodedata.normalize("NFKD", s)
43
+ )
44
+
45
+
46
+ def remove_symbols(s: str):
47
+ """
48
+ Replace any other markers, symbols, punctuations with a space, keeping diacritics
49
+ """
50
+ return "".join(
51
+ " " if unicodedata.category(c)[0] in "MSP" else c
52
+ for c in unicodedata.normalize("NFKC", s)
53
+ )
54
+
55
+
56
+ class BasicTextNormalizer:
57
+ def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
58
+ self.clean = (
59
+ remove_symbols_and_diacritics if remove_diacritics else remove_symbols
60
+ )
61
+ self.split_letters = split_letters
62
+
63
+ def __call__(self, s: str):
64
+ s = s.lower()
65
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
66
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
67
+ s = self.clean(s).lower()
68
+
69
+ if self.split_letters:
70
+ s = " ".join(regex.findall(r"\X", s, regex.U))
71
+
72
+ s = re.sub(
73
+ r"\s+", " ", s
74
+ ) # replace any successive whitespace characters with a space
75
+
76
+ return s
whisper_stream/normalizers/english.json ADDED
@@ -0,0 +1,1741 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "accessorise": "accessorize",
3
+ "accessorised": "accessorized",
4
+ "accessorises": "accessorizes",
5
+ "accessorising": "accessorizing",
6
+ "acclimatisation": "acclimatization",
7
+ "acclimatise": "acclimatize",
8
+ "acclimatised": "acclimatized",
9
+ "acclimatises": "acclimatizes",
10
+ "acclimatising": "acclimatizing",
11
+ "accoutrements": "accouterments",
12
+ "aeon": "eon",
13
+ "aeons": "eons",
14
+ "aerogramme": "aerogram",
15
+ "aerogrammes": "aerograms",
16
+ "aeroplane": "airplane",
17
+ "aeroplanes": "airplanes",
18
+ "aesthete": "esthete",
19
+ "aesthetes": "esthetes",
20
+ "aesthetic": "esthetic",
21
+ "aesthetically": "esthetically",
22
+ "aesthetics": "esthetics",
23
+ "aetiology": "etiology",
24
+ "ageing": "aging",
25
+ "aggrandisement": "aggrandizement",
26
+ "agonise": "agonize",
27
+ "agonised": "agonized",
28
+ "agonises": "agonizes",
29
+ "agonising": "agonizing",
30
+ "agonisingly": "agonizingly",
31
+ "almanack": "almanac",
32
+ "almanacks": "almanacs",
33
+ "aluminium": "aluminum",
34
+ "amortisable": "amortizable",
35
+ "amortisation": "amortization",
36
+ "amortisations": "amortizations",
37
+ "amortise": "amortize",
38
+ "amortised": "amortized",
39
+ "amortises": "amortizes",
40
+ "amortising": "amortizing",
41
+ "amphitheatre": "amphitheater",
42
+ "amphitheatres": "amphitheaters",
43
+ "anaemia": "anemia",
44
+ "anaemic": "anemic",
45
+ "anaesthesia": "anesthesia",
46
+ "anaesthetic": "anesthetic",
47
+ "anaesthetics": "anesthetics",
48
+ "anaesthetise": "anesthetize",
49
+ "anaesthetised": "anesthetized",
50
+ "anaesthetises": "anesthetizes",
51
+ "anaesthetising": "anesthetizing",
52
+ "anaesthetist": "anesthetist",
53
+ "anaesthetists": "anesthetists",
54
+ "anaesthetize": "anesthetize",
55
+ "anaesthetized": "anesthetized",
56
+ "anaesthetizes": "anesthetizes",
57
+ "anaesthetizing": "anesthetizing",
58
+ "analogue": "analog",
59
+ "analogues": "analogs",
60
+ "analyse": "analyze",
61
+ "analysed": "analyzed",
62
+ "analyses": "analyzes",
63
+ "analysing": "analyzing",
64
+ "anglicise": "anglicize",
65
+ "anglicised": "anglicized",
66
+ "anglicises": "anglicizes",
67
+ "anglicising": "anglicizing",
68
+ "annualised": "annualized",
69
+ "antagonise": "antagonize",
70
+ "antagonised": "antagonized",
71
+ "antagonises": "antagonizes",
72
+ "antagonising": "antagonizing",
73
+ "apologise": "apologize",
74
+ "apologised": "apologized",
75
+ "apologises": "apologizes",
76
+ "apologising": "apologizing",
77
+ "appal": "appall",
78
+ "appals": "appalls",
79
+ "appetiser": "appetizer",
80
+ "appetisers": "appetizers",
81
+ "appetising": "appetizing",
82
+ "appetisingly": "appetizingly",
83
+ "arbour": "arbor",
84
+ "arbours": "arbors",
85
+ "archeological": "archaeological",
86
+ "archaeologically": "archeologically",
87
+ "archaeologist": "archeologist",
88
+ "archaeologists": "archeologists",
89
+ "archaeology": "archeology</span>",
90
+ "ardour": "ardor",
91
+ "armour": "armor",
92
+ "armoured": "armored",
93
+ "armourer": "armorer",
94
+ "armourers": "armorers",
95
+ "armouries": "armories",
96
+ "armoury": "armory",
97
+ "artefact": "artifact",
98
+ "artefacts": "artifacts",
99
+ "authorise": "authorize",
100
+ "authorised": "authorized",
101
+ "authorises": "authorizes",
102
+ "authorising": "authorizing",
103
+ "axe": "ax",
104
+ "backpedalled": "backpedaled",
105
+ "backpedalling": "backpedaling",
106
+ "bannister": "banister",
107
+ "bannisters": "banisters",
108
+ "baptise": "baptize",
109
+ "baptised": "baptized",
110
+ "baptises": "baptizes",
111
+ "baptising": "baptizing",
112
+ "bastardise": "bastardize",
113
+ "bastardised": "bastardized",
114
+ "bastardises": "bastardizes",
115
+ "bastardising": "bastardizing",
116
+ "battleax": "battleaxe",
117
+ "baulk": "balk",
118
+ "baulked": "balked",
119
+ "baulking": "balking",
120
+ "baulks": "balks",
121
+ "bedevilled": "bedeviled",
122
+ "bedevilling": "bedeviling",
123
+ "behaviour": "behavior",
124
+ "behavioural": "behavioral",
125
+ "behaviourism": "behaviorism",
126
+ "behaviourist": "behaviorist",
127
+ "behaviourists": "behaviorists",
128
+ "behaviours": "behaviors",
129
+ "behove": "behoove",
130
+ "behoved": "behooved",
131
+ "behoves": "behooves",
132
+ "bejewelled": "bejeweled",
133
+ "belabour": "belabor",
134
+ "belaboured": "belabored",
135
+ "belabouring": "belaboring",
136
+ "belabours": "belabors",
137
+ "bevelled": "beveled",
138
+ "bevvies": "bevies",
139
+ "bevvy": "bevy",
140
+ "biassed": "biased",
141
+ "biassing": "biasing",
142
+ "bingeing": "binging",
143
+ "bougainvillaea": "bougainvillea",
144
+ "bougainvillaeas": "bougainvilleas",
145
+ "bowdlerise": "bowdlerize",
146
+ "bowdlerised": "bowdlerized",
147
+ "bowdlerises": "bowdlerizes",
148
+ "bowdlerising": "bowdlerizing",
149
+ "breathalyse": "breathalyze",
150
+ "breathalysed": "breathalyzed",
151
+ "breathalyser": "breathalyzer",
152
+ "breathalysers": "breathalyzers",
153
+ "breathalyses": "breathalyzes",
154
+ "breathalysing": "breathalyzing",
155
+ "brutalise": "brutalize",
156
+ "brutalised": "brutalized",
157
+ "brutalises": "brutalizes",
158
+ "brutalising": "brutalizing",
159
+ "busses": "buses",
160
+ "bussing": "busing",
161
+ "caesarean": "cesarean",
162
+ "caesareans": "cesareans",
163
+ "calibre": "caliber",
164
+ "calibres": "calibers",
165
+ "calliper": "caliper",
166
+ "callipers": "calipers",
167
+ "callisthenics": "calisthenics",
168
+ "canalise": "canalize",
169
+ "canalised": "canalized",
170
+ "canalises": "canalizes",
171
+ "canalising": "canalizing",
172
+ "cancelation": "cancellation",
173
+ "cancelations": "cancellations",
174
+ "cancelled": "canceled",
175
+ "cancelling": "canceling",
176
+ "candour": "candor",
177
+ "cannibalise": "cannibalize",
178
+ "cannibalised": "cannibalized",
179
+ "cannibalises": "cannibalizes",
180
+ "cannibalising": "cannibalizing",
181
+ "canonise": "canonize",
182
+ "canonised": "canonized",
183
+ "canonises": "canonizes",
184
+ "canonising": "canonizing",
185
+ "capitalise": "capitalize",
186
+ "capitalised": "capitalized",
187
+ "capitalises": "capitalizes",
188
+ "capitalising": "capitalizing",
189
+ "caramelise": "caramelize",
190
+ "caramelised": "caramelized",
191
+ "caramelises": "caramelizes",
192
+ "caramelising": "caramelizing",
193
+ "carbonise": "carbonize",
194
+ "carbonised": "carbonized",
195
+ "carbonises": "carbonizes",
196
+ "carbonising": "carbonizing",
197
+ "carolled": "caroled",
198
+ "carolling": "caroling",
199
+ "catalogue": "catalog",
200
+ "catalogued": "cataloged",
201
+ "catalogues": "catalogs",
202
+ "cataloguing": "cataloging",
203
+ "catalyse": "catalyze",
204
+ "catalysed": "catalyzed",
205
+ "catalyses": "catalyzes",
206
+ "catalysing": "catalyzing",
207
+ "categorise": "categorize",
208
+ "categorised": "categorized",
209
+ "categorises": "categorizes",
210
+ "categorising": "categorizing",
211
+ "cauterise": "cauterize",
212
+ "cauterised": "cauterized",
213
+ "cauterises": "cauterizes",
214
+ "cauterising": "cauterizing",
215
+ "cavilled": "caviled",
216
+ "cavilling": "caviling",
217
+ "centigramme": "centigram",
218
+ "centigrammes": "centigrams",
219
+ "centilitre": "centiliter",
220
+ "centilitres": "centiliters",
221
+ "centimetre": "centimeter",
222
+ "centimetres": "centimeters",
223
+ "centralise": "centralize",
224
+ "centralised": "centralized",
225
+ "centralises": "centralizes",
226
+ "centralising": "centralizing",
227
+ "centre": "center",
228
+ "centred": "centered",
229
+ "centrefold": "centerfold",
230
+ "centrefolds": "centerfolds",
231
+ "centrepiece": "centerpiece",
232
+ "centrepieces": "centerpieces",
233
+ "centres": "centers",
234
+ "channelled": "channeled",
235
+ "channelling": "channeling",
236
+ "characterise": "characterize",
237
+ "characterised": "characterized",
238
+ "characterises": "characterizes",
239
+ "characterising": "characterizing",
240
+ "cheque": "check",
241
+ "chequebook": "checkbook",
242
+ "chequebooks": "checkbooks",
243
+ "chequered": "checkered",
244
+ "cheques": "checks",
245
+ "chilli": "chili",
246
+ "chimaera": "chimera",
247
+ "chimaeras": "chimeras",
248
+ "chiselled": "chiseled",
249
+ "chiselling": "chiseling",
250
+ "circularise": "circularize",
251
+ "circularised": "circularized",
252
+ "circularises": "circularizes",
253
+ "circularising": "circularizing",
254
+ "civilise": "civilize",
255
+ "civilised": "civilized",
256
+ "civilises": "civilizes",
257
+ "civilising": "civilizing",
258
+ "clamour": "clamor",
259
+ "clamoured": "clamored",
260
+ "clamouring": "clamoring",
261
+ "clamours": "clamors",
262
+ "clangour": "clangor",
263
+ "clarinettist": "clarinetist",
264
+ "clarinettists": "clarinetists",
265
+ "collectivise": "collectivize",
266
+ "collectivised": "collectivized",
267
+ "collectivises": "collectivizes",
268
+ "collectivising": "collectivizing",
269
+ "colonisation": "colonization",
270
+ "colonise": "colonize",
271
+ "colonised": "colonized",
272
+ "coloniser": "colonizer",
273
+ "colonisers": "colonizers",
274
+ "colonises": "colonizes",
275
+ "colonising": "colonizing",
276
+ "colour": "color",
277
+ "colourant": "colorant",
278
+ "colourants": "colorants",
279
+ "coloured": "colored",
280
+ "coloureds": "coloreds",
281
+ "colourful": "colorful",
282
+ "colourfully": "colorfully",
283
+ "colouring": "coloring",
284
+ "colourize": "colorize",
285
+ "colourized": "colorized",
286
+ "colourizes": "colorizes",
287
+ "colourizing": "colorizing",
288
+ "colourless": "colorless",
289
+ "colours": "colors",
290
+ "commercialise": "commercialize",
291
+ "commercialised": "commercialized",
292
+ "commercialises": "commercializes",
293
+ "commercialising": "commercializing",
294
+ "compartmentalise": "compartmentalize",
295
+ "compartmentalised": "compartmentalized",
296
+ "compartmentalises": "compartmentalizes",
297
+ "compartmentalising": "compartmentalizing",
298
+ "computerise": "computerize",
299
+ "computerised": "computerized",
300
+ "computerises": "computerizes",
301
+ "computerising": "computerizing",
302
+ "conceptualise": "conceptualize",
303
+ "conceptualised": "conceptualized",
304
+ "conceptualises": "conceptualizes",
305
+ "conceptualising": "conceptualizing",
306
+ "connexion": "connection",
307
+ "connexions": "connections",
308
+ "contextualise": "contextualize",
309
+ "contextualised": "contextualized",
310
+ "contextualises": "contextualizes",
311
+ "contextualising": "contextualizing",
312
+ "cosier": "cozier",
313
+ "cosies": "cozies",
314
+ "cosiest": "coziest",
315
+ "cosily": "cozily",
316
+ "cosiness": "coziness",
317
+ "cosy": "cozy",
318
+ "councillor": "councilor",
319
+ "councillors": "councilors",
320
+ "counselled": "counseled",
321
+ "counselling": "counseling",
322
+ "counsellor": "counselor",
323
+ "counsellors": "counselors",
324
+ "crenelated": "crenellated",
325
+ "criminalise": "criminalize",
326
+ "criminalised": "criminalized",
327
+ "criminalises": "criminalizes",
328
+ "criminalising": "criminalizing",
329
+ "criticise": "criticize",
330
+ "criticised": "criticized",
331
+ "criticises": "criticizes",
332
+ "criticising": "criticizing",
333
+ "crueller": "crueler",
334
+ "cruellest": "cruelest",
335
+ "crystallisation": "crystallization",
336
+ "crystallise": "crystallize",
337
+ "crystallised": "crystallized",
338
+ "crystallises": "crystallizes",
339
+ "crystallising": "crystallizing",
340
+ "cudgelled": "cudgeled",
341
+ "cudgelling": "cudgeling",
342
+ "customise": "customize",
343
+ "customised": "customized",
344
+ "customises": "customizes",
345
+ "customising": "customizing",
346
+ "cypher": "cipher",
347
+ "cyphers": "ciphers",
348
+ "decentralisation": "decentralization",
349
+ "decentralise": "decentralize",
350
+ "decentralised": "decentralized",
351
+ "decentralises": "decentralizes",
352
+ "decentralising": "decentralizing",
353
+ "decriminalisation": "decriminalization",
354
+ "decriminalise": "decriminalize",
355
+ "decriminalised": "decriminalized",
356
+ "decriminalises": "decriminalizes",
357
+ "decriminalising": "decriminalizing",
358
+ "defence": "defense",
359
+ "defenceless": "defenseless",
360
+ "defences": "defenses",
361
+ "dehumanisation": "dehumanization",
362
+ "dehumanise": "dehumanize",
363
+ "dehumanised": "dehumanized",
364
+ "dehumanises": "dehumanizes",
365
+ "dehumanising": "dehumanizing",
366
+ "demeanour": "demeanor",
367
+ "demilitarisation": "demilitarization",
368
+ "demilitarise": "demilitarize",
369
+ "demilitarised": "demilitarized",
370
+ "demilitarises": "demilitarizes",
371
+ "demilitarising": "demilitarizing",
372
+ "demobilisation": "demobilization",
373
+ "demobilise": "demobilize",
374
+ "demobilised": "demobilized",
375
+ "demobilises": "demobilizes",
376
+ "demobilising": "demobilizing",
377
+ "democratisation": "democratization",
378
+ "democratise": "democratize",
379
+ "democratised": "democratized",
380
+ "democratises": "democratizes",
381
+ "democratising": "democratizing",
382
+ "demonise": "demonize",
383
+ "demonised": "demonized",
384
+ "demonises": "demonizes",
385
+ "demonising": "demonizing",
386
+ "demoralisation": "demoralization",
387
+ "demoralise": "demoralize",
388
+ "demoralised": "demoralized",
389
+ "demoralises": "demoralizes",
390
+ "demoralising": "demoralizing",
391
+ "denationalisation": "denationalization",
392
+ "denationalise": "denationalize",
393
+ "denationalised": "denationalized",
394
+ "denationalises": "denationalizes",
395
+ "denationalising": "denationalizing",
396
+ "deodorise": "deodorize",
397
+ "deodorised": "deodorized",
398
+ "deodorises": "deodorizes",
399
+ "deodorising": "deodorizing",
400
+ "depersonalise": "depersonalize",
401
+ "depersonalised": "depersonalized",
402
+ "depersonalises": "depersonalizes",
403
+ "depersonalising": "depersonalizing",
404
+ "deputise": "deputize",
405
+ "deputised": "deputized",
406
+ "deputises": "deputizes",
407
+ "deputising": "deputizing",
408
+ "desensitisation": "desensitization",
409
+ "desensitise": "desensitize",
410
+ "desensitised": "desensitized",
411
+ "desensitises": "desensitizes",
412
+ "desensitising": "desensitizing",
413
+ "destabilisation": "destabilization",
414
+ "destabilise": "destabilize",
415
+ "destabilised": "destabilized",
416
+ "destabilises": "destabilizes",
417
+ "destabilising": "destabilizing",
418
+ "dialled": "dialed",
419
+ "dialling": "dialing",
420
+ "dialogue": "dialog",
421
+ "dialogues": "dialogs",
422
+ "diarrhoea": "diarrhea",
423
+ "digitise": "digitize",
424
+ "digitised": "digitized",
425
+ "digitises": "digitizes",
426
+ "digitising": "digitizing",
427
+ "disc": "disk",
428
+ "discolour": "discolor",
429
+ "discoloured": "discolored",
430
+ "discolouring": "discoloring",
431
+ "discolours": "discolors",
432
+ "discs": "disks",
433
+ "disembowelled": "disemboweled",
434
+ "disembowelling": "disemboweling",
435
+ "disfavour": "disfavor",
436
+ "dishevelled": "disheveled",
437
+ "dishonour": "dishonor",
438
+ "dishonourable": "dishonorable",
439
+ "dishonourably": "dishonorably",
440
+ "dishonoured": "dishonored",
441
+ "dishonouring": "dishonoring",
442
+ "dishonours": "dishonors",
443
+ "disorganisation": "disorganization",
444
+ "disorganised": "disorganized",
445
+ "distil": "distill",
446
+ "distils": "distills",
447
+ "dramatisation": "dramatization",
448
+ "dramatisations": "dramatizations",
449
+ "dramatise": "dramatize",
450
+ "dramatised": "dramatized",
451
+ "dramatises": "dramatizes",
452
+ "dramatising": "dramatizing",
453
+ "draught": "draft",
454
+ "draughtboard": "draftboard",
455
+ "draughtboards": "draftboards",
456
+ "draughtier": "draftier",
457
+ "draughtiest": "draftiest",
458
+ "draughts": "drafts",
459
+ "draughtsman": "draftsman",
460
+ "draughtsmanship": "draftsmanship",
461
+ "draughtsmen": "draftsmen",
462
+ "draughtswoman": "draftswoman",
463
+ "draughtswomen": "draftswomen",
464
+ "draughty": "drafty",
465
+ "drivelled": "driveled",
466
+ "drivelling": "driveling",
467
+ "duelled": "dueled",
468
+ "duelling": "dueling",
469
+ "economise": "economize",
470
+ "economised": "economized",
471
+ "economises": "economizes",
472
+ "economising": "economizing",
473
+ "edoema": "edema",
474
+ "editorialise": "editorialize",
475
+ "editorialised": "editorialized",
476
+ "editorialises": "editorializes",
477
+ "editorialising": "editorializing",
478
+ "empathise": "empathize",
479
+ "empathised": "empathized",
480
+ "empathises": "empathizes",
481
+ "empathising": "empathizing",
482
+ "emphasise": "emphasize",
483
+ "emphasised": "emphasized",
484
+ "emphasises": "emphasizes",
485
+ "emphasising": "emphasizing",
486
+ "enamelled": "enameled",
487
+ "enamelling": "enameling",
488
+ "enamoured": "enamored",
489
+ "encyclopaedia": "encyclopedia",
490
+ "encyclopaedias": "encyclopedias",
491
+ "encyclopaedic": "encyclopedic",
492
+ "endeavour": "endeavor",
493
+ "endeavoured": "endeavored",
494
+ "endeavouring": "endeavoring",
495
+ "endeavours": "endeavors",
496
+ "energise": "energize",
497
+ "energised": "energized",
498
+ "energises": "energizes",
499
+ "energising": "energizing",
500
+ "enrol": "enroll",
501
+ "enrols": "enrolls",
502
+ "enthral": "enthrall",
503
+ "enthrals": "enthralls",
504
+ "epaulette": "epaulet",
505
+ "epaulettes": "epaulets",
506
+ "epicentre": "epicenter",
507
+ "epicentres": "epicenters",
508
+ "epilogue": "epilog",
509
+ "epilogues": "epilogs",
510
+ "epitomise": "epitomize",
511
+ "epitomised": "epitomized",
512
+ "epitomises": "epitomizes",
513
+ "epitomising": "epitomizing",
514
+ "equalisation": "equalization",
515
+ "equalise": "equalize",
516
+ "equalised": "equalized",
517
+ "equaliser": "equalizer",
518
+ "equalisers": "equalizers",
519
+ "equalises": "equalizes",
520
+ "equalising": "equalizing",
521
+ "eulogise": "eulogize",
522
+ "eulogised": "eulogized",
523
+ "eulogises": "eulogizes",
524
+ "eulogising": "eulogizing",
525
+ "evangelise": "evangelize",
526
+ "evangelised": "evangelized",
527
+ "evangelises": "evangelizes",
528
+ "evangelising": "evangelizing",
529
+ "exorcise": "exorcize",
530
+ "exorcised": "exorcized",
531
+ "exorcises": "exorcizes",
532
+ "exorcising": "exorcizing",
533
+ "extemporisation": "extemporization",
534
+ "extemporise": "extemporize",
535
+ "extemporised": "extemporized",
536
+ "extemporises": "extemporizes",
537
+ "extemporising": "extemporizing",
538
+ "externalisation": "externalization",
539
+ "externalisations": "externalizations",
540
+ "externalise": "externalize",
541
+ "externalised": "externalized",
542
+ "externalises": "externalizes",
543
+ "externalising": "externalizing",
544
+ "factorise": "factorize",
545
+ "factorised": "factorized",
546
+ "factorises": "factorizes",
547
+ "factorising": "factorizing",
548
+ "faecal": "fecal",
549
+ "faeces": "feces",
550
+ "familiarisation": "familiarization",
551
+ "familiarise": "familiarize",
552
+ "familiarised": "familiarized",
553
+ "familiarises": "familiarizes",
554
+ "familiarising": "familiarizing",
555
+ "fantasise": "fantasize",
556
+ "fantasised": "fantasized",
557
+ "fantasises": "fantasizes",
558
+ "fantasising": "fantasizing",
559
+ "favour": "favor",
560
+ "favourable": "favorable",
561
+ "favourably": "favorably",
562
+ "favoured": "favored",
563
+ "favouring": "favoring",
564
+ "favourite": "favorite",
565
+ "favourites": "favorites",
566
+ "favouritism": "favoritism",
567
+ "favours": "favors",
568
+ "feminise": "feminize",
569
+ "feminised": "feminized",
570
+ "feminises": "feminizes",
571
+ "feminising": "feminizing",
572
+ "fertilisation": "fertilization",
573
+ "fertilise": "fertilize",
574
+ "fertilised": "fertilized",
575
+ "fertiliser": "fertilizer",
576
+ "fertilisers": "fertilizers",
577
+ "fertilises": "fertilizes",
578
+ "fertilising": "fertilizing",
579
+ "fervour": "fervor",
580
+ "fibre": "fiber",
581
+ "fibreglass": "fiberglass",
582
+ "fibres": "fibers",
583
+ "fictionalisation": "fictionalization",
584
+ "fictionalisations": "fictionalizations",
585
+ "fictionalise": "fictionalize",
586
+ "fictionalised": "fictionalized",
587
+ "fictionalises": "fictionalizes",
588
+ "fictionalising": "fictionalizing",
589
+ "fillet": "filet",
590
+ "filleted": "fileted",
591
+ "filleting": "fileting",
592
+ "fillets": "filets",
593
+ "finalisation": "finalization",
594
+ "finalise": "finalize",
595
+ "finalised": "finalized",
596
+ "finalises": "finalizes",
597
+ "finalising": "finalizing",
598
+ "flautist": "flutist",
599
+ "flautists": "flutists",
600
+ "flavour": "flavor",
601
+ "flavoured": "flavored",
602
+ "flavouring": "flavoring",
603
+ "flavourings": "flavorings",
604
+ "flavourless": "flavorless",
605
+ "flavours": "flavors",
606
+ "flavoursome": "flavorsome",
607
+ "flyer / flier": "flier / flyer",
608
+ "foetal": "fetal",
609
+ "foetid": "fetid",
610
+ "foetus": "fetus",
611
+ "foetuses": "fetuses",
612
+ "formalisation": "formalization",
613
+ "formalise": "formalize",
614
+ "formalised": "formalized",
615
+ "formalises": "formalizes",
616
+ "formalising": "formalizing",
617
+ "fossilisation": "fossilization",
618
+ "fossilise": "fossilize",
619
+ "fossilised": "fossilized",
620
+ "fossilises": "fossilizes",
621
+ "fossilising": "fossilizing",
622
+ "fraternisation": "fraternization",
623
+ "fraternise": "fraternize",
624
+ "fraternised": "fraternized",
625
+ "fraternises": "fraternizes",
626
+ "fraternising": "fraternizing",
627
+ "fulfil": "fulfill",
628
+ "fulfilment": "fulfillment",
629
+ "fulfils": "fulfills",
630
+ "funnelled": "funneled",
631
+ "funnelling": "funneling",
632
+ "galvanise": "galvanize",
633
+ "galvanised": "galvanized",
634
+ "galvanises": "galvanizes",
635
+ "galvanising": "galvanizing",
636
+ "gambolled": "gamboled",
637
+ "gambolling": "gamboling",
638
+ "gaol": "jail",
639
+ "gaolbird": "jailbird",
640
+ "gaolbirds": "jailbirds",
641
+ "gaolbreak": "jailbreak",
642
+ "gaolbreaks": "jailbreaks",
643
+ "gaoled": "jailed",
644
+ "gaoler": "jailer",
645
+ "gaolers": "jailers",
646
+ "gaoling": "jailing",
647
+ "gaols": "jails",
648
+ "gasses": "gases",
649
+ "gage": "gauge",
650
+ "gaged": "gauged",
651
+ "gages": "gauges",
652
+ "gaging": "gauging",
653
+ "generalisation": "generalization",
654
+ "generalisations": "generalizations",
655
+ "generalise": "generalize",
656
+ "generalised": "generalized",
657
+ "generalises": "generalizes",
658
+ "generalising": "generalizing",
659
+ "ghettoise": "ghettoize",
660
+ "ghettoised": "ghettoized",
661
+ "ghettoises": "ghettoizes",
662
+ "ghettoising": "ghettoizing",
663
+ "gipsies": "gypsies",
664
+ "glamorise": "glamorize",
665
+ "glamorised": "glamorized",
666
+ "glamorises": "glamorizes",
667
+ "glamorising": "glamorizing",
668
+ "glamor": "glamour",
669
+ "globalisation": "globalization",
670
+ "globalise": "globalize",
671
+ "globalised": "globalized",
672
+ "globalises": "globalizes",
673
+ "globalising": "globalizing",
674
+ "glueing": "gluing",
675
+ "goitre": "goiter",
676
+ "goitres": "goiters",
677
+ "gonorrhoea": "gonorrhea",
678
+ "gramme": "gram",
679
+ "grammes": "grams",
680
+ "gravelled": "graveled",
681
+ "grey": "gray",
682
+ "greyed": "grayed",
683
+ "greying": "graying",
684
+ "greyish": "grayish",
685
+ "greyness": "grayness",
686
+ "greys": "grays",
687
+ "grovelled": "groveled",
688
+ "grovelling": "groveling",
689
+ "groyne": "groin",
690
+ "groynes": "groins",
691
+ "gruelling": "grueling",
692
+ "gruellingly": "gruelingly",
693
+ "gryphon": "griffin",
694
+ "gryphons": "griffins",
695
+ "gynaecological": "gynecological",
696
+ "gynaecologist": "gynecologist",
697
+ "gynaecologists": "gynecologists",
698
+ "gynaecology": "gynecology",
699
+ "haematological": "hematological",
700
+ "haematologist": "hematologist",
701
+ "haematologists": "hematologists",
702
+ "haematology": "hematology",
703
+ "haemoglobin": "hemoglobin",
704
+ "haemophilia": "hemophilia",
705
+ "haemophiliac": "hemophiliac",
706
+ "haemophiliacs": "hemophiliacs",
707
+ "haemorrhage": "hemorrhage",
708
+ "haemorrhaged": "hemorrhaged",
709
+ "haemorrhages": "hemorrhages",
710
+ "haemorrhaging": "hemorrhaging",
711
+ "haemorrhoids": "hemorrhoids",
712
+ "harbour": "harbor",
713
+ "harboured": "harbored",
714
+ "harbouring": "harboring",
715
+ "harbours": "harbors",
716
+ "harmonisation": "harmonization",
717
+ "harmonise": "harmonize",
718
+ "harmonised": "harmonized",
719
+ "harmonises": "harmonizes",
720
+ "harmonising": "harmonizing",
721
+ "homoeopath": "homeopath",
722
+ "homoeopathic": "homeopathic",
723
+ "homoeopaths": "homeopaths",
724
+ "homoeopathy": "homeopathy",
725
+ "homogenise": "homogenize",
726
+ "homogenised": "homogenized",
727
+ "homogenises": "homogenizes",
728
+ "homogenising": "homogenizing",
729
+ "honour": "honor",
730
+ "honourable": "honorable",
731
+ "honourably": "honorably",
732
+ "honoured": "honored",
733
+ "honouring": "honoring",
734
+ "honours": "honors",
735
+ "hospitalisation": "hospitalization",
736
+ "hospitalise": "hospitalize",
737
+ "hospitalised": "hospitalized",
738
+ "hospitalises": "hospitalizes",
739
+ "hospitalising": "hospitalizing",
740
+ "humanise": "humanize",
741
+ "humanised": "humanized",
742
+ "humanises": "humanizes",
743
+ "humanising": "humanizing",
744
+ "humour": "humor",
745
+ "humoured": "humored",
746
+ "humouring": "humoring",
747
+ "humourless": "humorless",
748
+ "humours": "humors",
749
+ "hybridise": "hybridize",
750
+ "hybridised": "hybridized",
751
+ "hybridises": "hybridizes",
752
+ "hybridising": "hybridizing",
753
+ "hypnotise": "hypnotize",
754
+ "hypnotised": "hypnotized",
755
+ "hypnotises": "hypnotizes",
756
+ "hypnotising": "hypnotizing",
757
+ "hypothesise": "hypothesize",
758
+ "hypothesised": "hypothesized",
759
+ "hypothesises": "hypothesizes",
760
+ "hypothesising": "hypothesizing",
761
+ "idealisation": "idealization",
762
+ "idealise": "idealize",
763
+ "idealised": "idealized",
764
+ "idealises": "idealizes",
765
+ "idealising": "idealizing",
766
+ "idolise": "idolize",
767
+ "idolised": "idolized",
768
+ "idolises": "idolizes",
769
+ "idolising": "idolizing",
770
+ "immobilisation": "immobilization",
771
+ "immobilise": "immobilize",
772
+ "immobilised": "immobilized",
773
+ "immobiliser": "immobilizer",
774
+ "immobilisers": "immobilizers",
775
+ "immobilises": "immobilizes",
776
+ "immobilising": "immobilizing",
777
+ "immortalise": "immortalize",
778
+ "immortalised": "immortalized",
779
+ "immortalises": "immortalizes",
780
+ "immortalising": "immortalizing",
781
+ "immunisation": "immunization",
782
+ "immunise": "immunize",
783
+ "immunised": "immunized",
784
+ "immunises": "immunizes",
785
+ "immunising": "immunizing",
786
+ "impanelled": "impaneled",
787
+ "impanelling": "impaneling",
788
+ "imperilled": "imperiled",
789
+ "imperilling": "imperiling",
790
+ "individualise": "individualize",
791
+ "individualised": "individualized",
792
+ "individualises": "individualizes",
793
+ "individualising": "individualizing",
794
+ "industrialise": "industrialize",
795
+ "industrialised": "industrialized",
796
+ "industrialises": "industrializes",
797
+ "industrialising": "industrializing",
798
+ "inflexion": "inflection",
799
+ "inflexions": "inflections",
800
+ "initialise": "initialize",
801
+ "initialised": "initialized",
802
+ "initialises": "initializes",
803
+ "initialising": "initializing",
804
+ "initialled": "initialed",
805
+ "initialling": "initialing",
806
+ "instal": "install",
807
+ "instalment": "installment",
808
+ "instalments": "installments",
809
+ "instals": "installs",
810
+ "instil": "instill",
811
+ "instils": "instills",
812
+ "institutionalisation": "institutionalization",
813
+ "institutionalise": "institutionalize",
814
+ "institutionalised": "institutionalized",
815
+ "institutionalises": "institutionalizes",
816
+ "institutionalising": "institutionalizing",
817
+ "intellectualise": "intellectualize",
818
+ "intellectualised": "intellectualized",
819
+ "intellectualises": "intellectualizes",
820
+ "intellectualising": "intellectualizing",
821
+ "internalisation": "internalization",
822
+ "internalise": "internalize",
823
+ "internalised": "internalized",
824
+ "internalises": "internalizes",
825
+ "internalising": "internalizing",
826
+ "internationalisation": "internationalization",
827
+ "internationalise": "internationalize",
828
+ "internationalised": "internationalized",
829
+ "internationalises": "internationalizes",
830
+ "internationalising": "internationalizing",
831
+ "ionisation": "ionization",
832
+ "ionise": "ionize",
833
+ "ionised": "ionized",
834
+ "ioniser": "ionizer",
835
+ "ionisers": "ionizers",
836
+ "ionises": "ionizes",
837
+ "ionising": "ionizing",
838
+ "italicise": "italicize",
839
+ "italicised": "italicized",
840
+ "italicises": "italicizes",
841
+ "italicising": "italicizing",
842
+ "itemise": "itemize",
843
+ "itemised": "itemized",
844
+ "itemises": "itemizes",
845
+ "itemising": "itemizing",
846
+ "jeopardise": "jeopardize",
847
+ "jeopardised": "jeopardized",
848
+ "jeopardises": "jeopardizes",
849
+ "jeopardising": "jeopardizing",
850
+ "jewelled": "jeweled",
851
+ "jeweller": "jeweler",
852
+ "jewellers": "jewelers",
853
+ "jewellery": "jewelry",
854
+ "judgement": "judgment",
855
+ "kilogramme": "kilogram",
856
+ "kilogrammes": "kilograms",
857
+ "kilometre": "kilometer",
858
+ "kilometres": "kilometers",
859
+ "labelled": "labeled",
860
+ "labelling": "labeling",
861
+ "labour": "labor",
862
+ "laboured": "labored",
863
+ "labourer": "laborer",
864
+ "labourers": "laborers",
865
+ "labouring": "laboring",
866
+ "labours": "labors",
867
+ "lacklustre": "lackluster",
868
+ "legalisation": "legalization",
869
+ "legalise": "legalize",
870
+ "legalised": "legalized",
871
+ "legalises": "legalizes",
872
+ "legalising": "legalizing",
873
+ "legitimise": "legitimize",
874
+ "legitimised": "legitimized",
875
+ "legitimises": "legitimizes",
876
+ "legitimising": "legitimizing",
877
+ "leukaemia": "leukemia",
878
+ "levelled": "leveled",
879
+ "leveller": "leveler",
880
+ "levellers": "levelers",
881
+ "levelling": "leveling",
882
+ "libelled": "libeled",
883
+ "libelling": "libeling",
884
+ "libellous": "libelous",
885
+ "liberalisation": "liberalization",
886
+ "liberalise": "liberalize",
887
+ "liberalised": "liberalized",
888
+ "liberalises": "liberalizes",
889
+ "liberalising": "liberalizing",
890
+ "licence": "license",
891
+ "licenced": "licensed",
892
+ "licences": "licenses",
893
+ "licencing": "licensing",
894
+ "likeable": "likable",
895
+ "lionisation": "lionization",
896
+ "lionise": "lionize",
897
+ "lionised": "lionized",
898
+ "lionises": "lionizes",
899
+ "lionising": "lionizing",
900
+ "liquidise": "liquidize",
901
+ "liquidised": "liquidized",
902
+ "liquidiser": "liquidizer",
903
+ "liquidisers": "liquidizers",
904
+ "liquidises": "liquidizes",
905
+ "liquidising": "liquidizing",
906
+ "litre": "liter",
907
+ "litres": "liters",
908
+ "localise": "localize",
909
+ "localised": "localized",
910
+ "localises": "localizes",
911
+ "localising": "localizing",
912
+ "louvre": "louver",
913
+ "louvred": "louvered",
914
+ "louvres": "louvers",
915
+ "lustre": "luster",
916
+ "magnetise": "magnetize",
917
+ "magnetised": "magnetized",
918
+ "magnetises": "magnetizes",
919
+ "magnetising": "magnetizing",
920
+ "manoeuvrability": "maneuverability",
921
+ "manoeuvrable": "maneuverable",
922
+ "manoeuvre": "maneuver",
923
+ "manoeuvred": "maneuvered",
924
+ "manoeuvres": "maneuvers",
925
+ "manoeuvring": "maneuvering",
926
+ "manoeuvrings": "maneuverings",
927
+ "marginalisation": "marginalization",
928
+ "marginalise": "marginalize",
929
+ "marginalised": "marginalized",
930
+ "marginalises": "marginalizes",
931
+ "marginalising": "marginalizing",
932
+ "marshalled": "marshaled",
933
+ "marshalling": "marshaling",
934
+ "marvelled": "marveled",
935
+ "marvelling": "marveling",
936
+ "marvellous": "marvelous",
937
+ "marvellously": "marvelously",
938
+ "materialisation": "materialization",
939
+ "materialise": "materialize",
940
+ "materialised": "materialized",
941
+ "materialises": "materializes",
942
+ "materialising": "materializing",
943
+ "maximisation": "maximization",
944
+ "maximise": "maximize",
945
+ "maximised": "maximized",
946
+ "maximises": "maximizes",
947
+ "maximising": "maximizing",
948
+ "meagre": "meager",
949
+ "mechanisation": "mechanization",
950
+ "mechanise": "mechanize",
951
+ "mechanised": "mechanized",
952
+ "mechanises": "mechanizes",
953
+ "mechanising": "mechanizing",
954
+ "mediaeval": "medieval",
955
+ "memorialise": "memorialize",
956
+ "memorialised": "memorialized",
957
+ "memorialises": "memorializes",
958
+ "memorialising": "memorializing",
959
+ "memorise": "memorize",
960
+ "memorised": "memorized",
961
+ "memorises": "memorizes",
962
+ "memorising": "memorizing",
963
+ "mesmerise": "mesmerize",
964
+ "mesmerised": "mesmerized",
965
+ "mesmerises": "mesmerizes",
966
+ "mesmerising": "mesmerizing",
967
+ "metabolise": "metabolize",
968
+ "metabolised": "metabolized",
969
+ "metabolises": "metabolizes",
970
+ "metabolising": "metabolizing",
971
+ "metre": "meter",
972
+ "metres": "meters",
973
+ "micrometre": "micrometer",
974
+ "micrometres": "micrometers",
975
+ "militarise": "militarize",
976
+ "militarised": "militarized",
977
+ "militarises": "militarizes",
978
+ "militarising": "militarizing",
979
+ "milligramme": "milligram",
980
+ "milligrammes": "milligrams",
981
+ "millilitre": "milliliter",
982
+ "millilitres": "milliliters",
983
+ "millimetre": "millimeter",
984
+ "millimetres": "millimeters",
985
+ "miniaturisation": "miniaturization",
986
+ "miniaturise": "miniaturize",
987
+ "miniaturised": "miniaturized",
988
+ "miniaturises": "miniaturizes",
989
+ "miniaturising": "miniaturizing",
990
+ "minibusses": "minibuses",
991
+ "minimise": "minimize",
992
+ "minimised": "minimized",
993
+ "minimises": "minimizes",
994
+ "minimising": "minimizing",
995
+ "misbehaviour": "misbehavior",
996
+ "misdemeanour": "misdemeanor",
997
+ "misdemeanours": "misdemeanors",
998
+ "misspelt": "misspelled",
999
+ "mitre": "miter",
1000
+ "mitres": "miters",
1001
+ "mobilisation": "mobilization",
1002
+ "mobilise": "mobilize",
1003
+ "mobilised": "mobilized",
1004
+ "mobilises": "mobilizes",
1005
+ "mobilising": "mobilizing",
1006
+ "modelled": "modeled",
1007
+ "modeller": "modeler",
1008
+ "modellers": "modelers",
1009
+ "modelling": "modeling",
1010
+ "modernise": "modernize",
1011
+ "modernised": "modernized",
1012
+ "modernises": "modernizes",
1013
+ "modernising": "modernizing",
1014
+ "moisturise": "moisturize",
1015
+ "moisturised": "moisturized",
1016
+ "moisturiser": "moisturizer",
1017
+ "moisturisers": "moisturizers",
1018
+ "moisturises": "moisturizes",
1019
+ "moisturising": "moisturizing",
1020
+ "monologue": "monolog",
1021
+ "monologues": "monologs",
1022
+ "monopolisation": "monopolization",
1023
+ "monopolise": "monopolize",
1024
+ "monopolised": "monopolized",
1025
+ "monopolises": "monopolizes",
1026
+ "monopolising": "monopolizing",
1027
+ "moralise": "moralize",
1028
+ "moralised": "moralized",
1029
+ "moralises": "moralizes",
1030
+ "moralising": "moralizing",
1031
+ "motorised": "motorized",
1032
+ "mould": "mold",
1033
+ "moulded": "molded",
1034
+ "moulder": "molder",
1035
+ "mouldered": "moldered",
1036
+ "mouldering": "moldering",
1037
+ "moulders": "molders",
1038
+ "mouldier": "moldier",
1039
+ "mouldiest": "moldiest",
1040
+ "moulding": "molding",
1041
+ "mouldings": "moldings",
1042
+ "moulds": "molds",
1043
+ "mouldy": "moldy",
1044
+ "moult": "molt",
1045
+ "moulted": "molted",
1046
+ "moulting": "molting",
1047
+ "moults": "molts",
1048
+ "moustache": "mustache",
1049
+ "moustached": "mustached",
1050
+ "moustaches": "mustaches",
1051
+ "moustachioed": "mustachioed",
1052
+ "multicoloured": "multicolored",
1053
+ "nationalisation": "nationalization",
1054
+ "nationalisations": "nationalizations",
1055
+ "nationalise": "nationalize",
1056
+ "nationalised": "nationalized",
1057
+ "nationalises": "nationalizes",
1058
+ "nationalising": "nationalizing",
1059
+ "naturalisation": "naturalization",
1060
+ "naturalise": "naturalize",
1061
+ "naturalised": "naturalized",
1062
+ "naturalises": "naturalizes",
1063
+ "naturalising": "naturalizing",
1064
+ "neighbour": "neighbor",
1065
+ "neighbourhood": "neighborhood",
1066
+ "neighbourhoods": "neighborhoods",
1067
+ "neighbouring": "neighboring",
1068
+ "neighbourliness": "neighborliness",
1069
+ "neighbourly": "neighborly",
1070
+ "neighbours": "neighbors",
1071
+ "neutralisation": "neutralization",
1072
+ "neutralise": "neutralize",
1073
+ "neutralised": "neutralized",
1074
+ "neutralises": "neutralizes",
1075
+ "neutralising": "neutralizing",
1076
+ "normalisation": "normalization",
1077
+ "normalise": "normalize",
1078
+ "normalised": "normalized",
1079
+ "normalises": "normalizes",
1080
+ "normalising": "normalizing",
1081
+ "odour": "odor",
1082
+ "odourless": "odorless",
1083
+ "odours": "odors",
1084
+ "oesophagus": "esophagus",
1085
+ "oesophaguses": "esophaguses",
1086
+ "oestrogen": "estrogen",
1087
+ "offence": "offense",
1088
+ "offences": "offenses",
1089
+ "omelette": "omelet",
1090
+ "omelettes": "omelets",
1091
+ "optimise": "optimize",
1092
+ "optimised": "optimized",
1093
+ "optimises": "optimizes",
1094
+ "optimising": "optimizing",
1095
+ "organisation": "organization",
1096
+ "organisational": "organizational",
1097
+ "organisations": "organizations",
1098
+ "organise": "organize",
1099
+ "organised": "organized",
1100
+ "organiser": "organizer",
1101
+ "organisers": "organizers",
1102
+ "organises": "organizes",
1103
+ "organising": "organizing",
1104
+ "orthopaedic": "orthopedic",
1105
+ "orthopaedics": "orthopedics",
1106
+ "ostracise": "ostracize",
1107
+ "ostracised": "ostracized",
1108
+ "ostracises": "ostracizes",
1109
+ "ostracising": "ostracizing",
1110
+ "outmanoeuvre": "outmaneuver",
1111
+ "outmanoeuvred": "outmaneuvered",
1112
+ "outmanoeuvres": "outmaneuvers",
1113
+ "outmanoeuvring": "outmaneuvering",
1114
+ "overemphasise": "overemphasize",
1115
+ "overemphasised": "overemphasized",
1116
+ "overemphasises": "overemphasizes",
1117
+ "overemphasising": "overemphasizing",
1118
+ "oxidisation": "oxidization",
1119
+ "oxidise": "oxidize",
1120
+ "oxidised": "oxidized",
1121
+ "oxidises": "oxidizes",
1122
+ "oxidising": "oxidizing",
1123
+ "paederast": "pederast",
1124
+ "paederasts": "pederasts",
1125
+ "paediatric": "pediatric",
1126
+ "paediatrician": "pediatrician",
1127
+ "paediatricians": "pediatricians",
1128
+ "paediatrics": "pediatrics",
1129
+ "paedophile": "pedophile",
1130
+ "paedophiles": "pedophiles",
1131
+ "paedophilia": "pedophilia",
1132
+ "palaeolithic": "paleolithic",
1133
+ "palaeontologist": "paleontologist",
1134
+ "palaeontologists": "paleontologists",
1135
+ "palaeontology": "paleontology",
1136
+ "panelled": "paneled",
1137
+ "panelling": "paneling",
1138
+ "panellist": "panelist",
1139
+ "panellists": "panelists",
1140
+ "paralyse": "paralyze",
1141
+ "paralysed": "paralyzed",
1142
+ "paralyses": "paralyzes",
1143
+ "paralysing": "paralyzing",
1144
+ "parcelled": "parceled",
1145
+ "parcelling": "parceling",
1146
+ "parlour": "parlor",
1147
+ "parlours": "parlors",
1148
+ "particularise": "particularize",
1149
+ "particularised": "particularized",
1150
+ "particularises": "particularizes",
1151
+ "particularising": "particularizing",
1152
+ "passivisation": "passivization",
1153
+ "passivise": "passivize",
1154
+ "passivised": "passivized",
1155
+ "passivises": "passivizes",
1156
+ "passivising": "passivizing",
1157
+ "pasteurisation": "pasteurization",
1158
+ "pasteurise": "pasteurize",
1159
+ "pasteurised": "pasteurized",
1160
+ "pasteurises": "pasteurizes",
1161
+ "pasteurising": "pasteurizing",
1162
+ "patronise": "patronize",
1163
+ "patronised": "patronized",
1164
+ "patronises": "patronizes",
1165
+ "patronising": "patronizing",
1166
+ "patronisingly": "patronizingly",
1167
+ "pedalled": "pedaled",
1168
+ "pedalling": "pedaling",
1169
+ "pedestrianisation": "pedestrianization",
1170
+ "pedestrianise": "pedestrianize",
1171
+ "pedestrianised": "pedestrianized",
1172
+ "pedestrianises": "pedestrianizes",
1173
+ "pedestrianising": "pedestrianizing",
1174
+ "penalise": "penalize",
1175
+ "penalised": "penalized",
1176
+ "penalises": "penalizes",
1177
+ "penalising": "penalizing",
1178
+ "pencilled": "penciled",
1179
+ "pencilling": "penciling",
1180
+ "personalise": "personalize",
1181
+ "personalised": "personalized",
1182
+ "personalises": "personalizes",
1183
+ "personalising": "personalizing",
1184
+ "pharmacopoeia": "pharmacopeia",
1185
+ "pharmacopoeias": "pharmacopeias",
1186
+ "philosophise": "philosophize",
1187
+ "philosophised": "philosophized",
1188
+ "philosophises": "philosophizes",
1189
+ "philosophising": "philosophizing",
1190
+ "philtre": "filter",
1191
+ "philtres": "filters",
1192
+ "phoney": "phony",
1193
+ "plagiarise": "plagiarize",
1194
+ "plagiarised": "plagiarized",
1195
+ "plagiarises": "plagiarizes",
1196
+ "plagiarising": "plagiarizing",
1197
+ "plough": "plow",
1198
+ "ploughed": "plowed",
1199
+ "ploughing": "plowing",
1200
+ "ploughman": "plowman",
1201
+ "ploughmen": "plowmen",
1202
+ "ploughs": "plows",
1203
+ "ploughshare": "plowshare",
1204
+ "ploughshares": "plowshares",
1205
+ "polarisation": "polarization",
1206
+ "polarise": "polarize",
1207
+ "polarised": "polarized",
1208
+ "polarises": "polarizes",
1209
+ "polarising": "polarizing",
1210
+ "politicisation": "politicization",
1211
+ "politicise": "politicize",
1212
+ "politicised": "politicized",
1213
+ "politicises": "politicizes",
1214
+ "politicising": "politicizing",
1215
+ "popularisation": "popularization",
1216
+ "popularise": "popularize",
1217
+ "popularised": "popularized",
1218
+ "popularises": "popularizes",
1219
+ "popularising": "popularizing",
1220
+ "pouffe": "pouf",
1221
+ "pouffes": "poufs",
1222
+ "practise": "practice",
1223
+ "practised": "practiced",
1224
+ "practises": "practices",
1225
+ "practising": "practicing",
1226
+ "praesidium": "presidium",
1227
+ "praesidiums": "presidiums",
1228
+ "pressurisation": "pressurization",
1229
+ "pressurise": "pressurize",
1230
+ "pressurised": "pressurized",
1231
+ "pressurises": "pressurizes",
1232
+ "pressurising": "pressurizing",
1233
+ "pretence": "pretense",
1234
+ "pretences": "pretenses",
1235
+ "primaeval": "primeval",
1236
+ "prioritisation": "prioritization",
1237
+ "prioritise": "prioritize",
1238
+ "prioritised": "prioritized",
1239
+ "prioritises": "prioritizes",
1240
+ "prioritising": "prioritizing",
1241
+ "privatisation": "privatization",
1242
+ "privatisations": "privatizations",
1243
+ "privatise": "privatize",
1244
+ "privatised": "privatized",
1245
+ "privatises": "privatizes",
1246
+ "privatising": "privatizing",
1247
+ "professionalisation": "professionalization",
1248
+ "professionalise": "professionalize",
1249
+ "professionalised": "professionalized",
1250
+ "professionalises": "professionalizes",
1251
+ "professionalising": "professionalizing",
1252
+ "programme": "program",
1253
+ "programmes": "programs",
1254
+ "prologue": "prolog",
1255
+ "prologues": "prologs",
1256
+ "propagandise": "propagandize",
1257
+ "propagandised": "propagandized",
1258
+ "propagandises": "propagandizes",
1259
+ "propagandising": "propagandizing",
1260
+ "proselytise": "proselytize",
1261
+ "proselytised": "proselytized",
1262
+ "proselytiser": "proselytizer",
1263
+ "proselytisers": "proselytizers",
1264
+ "proselytises": "proselytizes",
1265
+ "proselytising": "proselytizing",
1266
+ "psychoanalyse": "psychoanalyze",
1267
+ "psychoanalysed": "psychoanalyzed",
1268
+ "psychoanalyses": "psychoanalyzes",
1269
+ "psychoanalysing": "psychoanalyzing",
1270
+ "publicise": "publicize",
1271
+ "publicised": "publicized",
1272
+ "publicises": "publicizes",
1273
+ "publicising": "publicizing",
1274
+ "pulverisation": "pulverization",
1275
+ "pulverise": "pulverize",
1276
+ "pulverised": "pulverized",
1277
+ "pulverises": "pulverizes",
1278
+ "pulverising": "pulverizing",
1279
+ "pummelled": "pummel",
1280
+ "pummelling": "pummeled",
1281
+ "pyjama": "pajama",
1282
+ "pyjamas": "pajamas",
1283
+ "pzazz": "pizzazz",
1284
+ "quarrelled": "quarreled",
1285
+ "quarrelling": "quarreling",
1286
+ "radicalise": "radicalize",
1287
+ "radicalised": "radicalized",
1288
+ "radicalises": "radicalizes",
1289
+ "radicalising": "radicalizing",
1290
+ "rancour": "rancor",
1291
+ "randomise": "randomize",
1292
+ "randomised": "randomized",
1293
+ "randomises": "randomizes",
1294
+ "randomising": "randomizing",
1295
+ "rationalisation": "rationalization",
1296
+ "rationalisations": "rationalizations",
1297
+ "rationalise": "rationalize",
1298
+ "rationalised": "rationalized",
1299
+ "rationalises": "rationalizes",
1300
+ "rationalising": "rationalizing",
1301
+ "ravelled": "raveled",
1302
+ "ravelling": "raveling",
1303
+ "realisable": "realizable",
1304
+ "realisation": "realization",
1305
+ "realisations": "realizations",
1306
+ "realise": "realize",
1307
+ "realised": "realized",
1308
+ "realises": "realizes",
1309
+ "realising": "realizing",
1310
+ "recognisable": "recognizable",
1311
+ "recognisably": "recognizably",
1312
+ "recognisance": "recognizance",
1313
+ "recognise": "recognize",
1314
+ "recognised": "recognized",
1315
+ "recognises": "recognizes",
1316
+ "recognising": "recognizing",
1317
+ "reconnoitre": "reconnoiter",
1318
+ "reconnoitred": "reconnoitered",
1319
+ "reconnoitres": "reconnoiters",
1320
+ "reconnoitring": "reconnoitering",
1321
+ "refuelled": "refueled",
1322
+ "refuelling": "refueling",
1323
+ "regularisation": "regularization",
1324
+ "regularise": "regularize",
1325
+ "regularised": "regularized",
1326
+ "regularises": "regularizes",
1327
+ "regularising": "regularizing",
1328
+ "remodelled": "remodeled",
1329
+ "remodelling": "remodeling",
1330
+ "remould": "remold",
1331
+ "remoulded": "remolded",
1332
+ "remoulding": "remolding",
1333
+ "remoulds": "remolds",
1334
+ "reorganisation": "reorganization",
1335
+ "reorganisations": "reorganizations",
1336
+ "reorganise": "reorganize",
1337
+ "reorganised": "reorganized",
1338
+ "reorganises": "reorganizes",
1339
+ "reorganising": "reorganizing",
1340
+ "revelled": "reveled",
1341
+ "reveller": "reveler",
1342
+ "revellers": "revelers",
1343
+ "revelling": "reveling",
1344
+ "revitalise": "revitalize",
1345
+ "revitalised": "revitalized",
1346
+ "revitalises": "revitalizes",
1347
+ "revitalising": "revitalizing",
1348
+ "revolutionise": "revolutionize",
1349
+ "revolutionised": "revolutionized",
1350
+ "revolutionises": "revolutionizes",
1351
+ "revolutionising": "revolutionizing",
1352
+ "rhapsodise": "rhapsodize",
1353
+ "rhapsodised": "rhapsodized",
1354
+ "rhapsodises": "rhapsodizes",
1355
+ "rhapsodising": "rhapsodizing",
1356
+ "rigour": "rigor",
1357
+ "rigours": "rigors",
1358
+ "ritualised": "ritualized",
1359
+ "rivalled": "rivaled",
1360
+ "rivalling": "rivaling",
1361
+ "romanticise": "romanticize",
1362
+ "romanticised": "romanticized",
1363
+ "romanticises": "romanticizes",
1364
+ "romanticising": "romanticizing",
1365
+ "rumour": "rumor",
1366
+ "rumoured": "rumored",
1367
+ "rumours": "rumors",
1368
+ "sabre": "saber",
1369
+ "sabres": "sabers",
1370
+ "saltpetre": "saltpeter",
1371
+ "sanitise": "sanitize",
1372
+ "sanitised": "sanitized",
1373
+ "sanitises": "sanitizes",
1374
+ "sanitising": "sanitizing",
1375
+ "satirise": "satirize",
1376
+ "satirised": "satirized",
1377
+ "satirises": "satirizes",
1378
+ "satirising": "satirizing",
1379
+ "saviour": "savior",
1380
+ "saviours": "saviors",
1381
+ "savour": "savor",
1382
+ "savoured": "savored",
1383
+ "savouries": "savories",
1384
+ "savouring": "savoring",
1385
+ "savours": "savors",
1386
+ "savoury": "savory",
1387
+ "scandalise": "scandalize",
1388
+ "scandalised": "scandalized",
1389
+ "scandalises": "scandalizes",
1390
+ "scandalising": "scandalizing",
1391
+ "sceptic": "skeptic",
1392
+ "sceptical": "skeptical",
1393
+ "sceptically": "skeptically",
1394
+ "scepticism": "skepticism",
1395
+ "sceptics": "skeptics",
1396
+ "sceptre": "scepter",
1397
+ "sceptres": "scepters",
1398
+ "scrutinise": "scrutinize",
1399
+ "scrutinised": "scrutinized",
1400
+ "scrutinises": "scrutinizes",
1401
+ "scrutinising": "scrutinizing",
1402
+ "secularisation": "secularization",
1403
+ "secularise": "secularize",
1404
+ "secularised": "secularized",
1405
+ "secularises": "secularizes",
1406
+ "secularising": "secularizing",
1407
+ "sensationalise": "sensationalize",
1408
+ "sensationalised": "sensationalized",
1409
+ "sensationalises": "sensationalizes",
1410
+ "sensationalising": "sensationalizing",
1411
+ "sensitise": "sensitize",
1412
+ "sensitised": "sensitized",
1413
+ "sensitises": "sensitizes",
1414
+ "sensitising": "sensitizing",
1415
+ "sentimentalise": "sentimentalize",
1416
+ "sentimentalised": "sentimentalized",
1417
+ "sentimentalises": "sentimentalizes",
1418
+ "sentimentalising": "sentimentalizing",
1419
+ "sepulchre": "sepulcher",
1420
+ "sepulchres": "sepulchers",
1421
+ "serialisation": "serialization",
1422
+ "serialisations": "serializations",
1423
+ "serialise": "serialize",
1424
+ "serialised": "serialized",
1425
+ "serialises": "serializes",
1426
+ "serialising": "serializing",
1427
+ "sermonise": "sermonize",
1428
+ "sermonised": "sermonized",
1429
+ "sermonises": "sermonizes",
1430
+ "sermonising": "sermonizing",
1431
+ "sheikh": "sheik",
1432
+ "shovelled": "shoveled",
1433
+ "shovelling": "shoveling",
1434
+ "shrivelled": "shriveled",
1435
+ "shrivelling": "shriveling",
1436
+ "signalise": "signalize",
1437
+ "signalised": "signalized",
1438
+ "signalises": "signalizes",
1439
+ "signalising": "signalizing",
1440
+ "signalled": "signaled",
1441
+ "signalling": "signaling",
1442
+ "smoulder": "smolder",
1443
+ "smouldered": "smoldered",
1444
+ "smouldering": "smoldering",
1445
+ "smoulders": "smolders",
1446
+ "snivelled": "sniveled",
1447
+ "snivelling": "sniveling",
1448
+ "snorkelled": "snorkeled",
1449
+ "snorkelling": "snorkeling",
1450
+ "snowplough": "snowplow",
1451
+ "snowploughs": "snowplow",
1452
+ "socialisation": "socialization",
1453
+ "socialise": "socialize",
1454
+ "socialised": "socialized",
1455
+ "socialises": "socializes",
1456
+ "socialising": "socializing",
1457
+ "sodomise": "sodomize",
1458
+ "sodomised": "sodomized",
1459
+ "sodomises": "sodomizes",
1460
+ "sodomising": "sodomizing",
1461
+ "solemnise": "solemnize",
1462
+ "solemnised": "solemnized",
1463
+ "solemnises": "solemnizes",
1464
+ "solemnising": "solemnizing",
1465
+ "sombre": "somber",
1466
+ "specialisation": "specialization",
1467
+ "specialisations": "specializations",
1468
+ "specialise": "specialize",
1469
+ "specialised": "specialized",
1470
+ "specialises": "specializes",
1471
+ "specialising": "specializing",
1472
+ "spectre": "specter",
1473
+ "spectres": "specters",
1474
+ "spiralled": "spiraled",
1475
+ "spiralling": "spiraling",
1476
+ "splendour": "splendor",
1477
+ "splendours": "splendors",
1478
+ "squirrelled": "squirreled",
1479
+ "squirrelling": "squirreling",
1480
+ "stabilisation": "stabilization",
1481
+ "stabilise": "stabilize",
1482
+ "stabilised": "stabilized",
1483
+ "stabiliser": "stabilizer",
1484
+ "stabilisers": "stabilizers",
1485
+ "stabilises": "stabilizes",
1486
+ "stabilising": "stabilizing",
1487
+ "standardisation": "standardization",
1488
+ "standardise": "standardize",
1489
+ "standardised": "standardized",
1490
+ "standardises": "standardizes",
1491
+ "standardising": "standardizing",
1492
+ "stencilled": "stenciled",
1493
+ "stencilling": "stenciling",
1494
+ "sterilisation": "sterilization",
1495
+ "sterilisations": "sterilizations",
1496
+ "sterilise": "sterilize",
1497
+ "sterilised": "sterilized",
1498
+ "steriliser": "sterilizer",
1499
+ "sterilisers": "sterilizers",
1500
+ "sterilises": "sterilizes",
1501
+ "sterilising": "sterilizing",
1502
+ "stigmatisation": "stigmatization",
1503
+ "stigmatise": "stigmatize",
1504
+ "stigmatised": "stigmatized",
1505
+ "stigmatises": "stigmatizes",
1506
+ "stigmatising": "stigmatizing",
1507
+ "storey": "story",
1508
+ "storeys": "stories",
1509
+ "subsidisation": "subsidization",
1510
+ "subsidise": "subsidize",
1511
+ "subsidised": "subsidized",
1512
+ "subsidiser": "subsidizer",
1513
+ "subsidisers": "subsidizers",
1514
+ "subsidises": "subsidizes",
1515
+ "subsidising": "subsidizing",
1516
+ "succour": "succor",
1517
+ "succoured": "succored",
1518
+ "succouring": "succoring",
1519
+ "succours": "succors",
1520
+ "sulphate": "sulfate",
1521
+ "sulphates": "sulfates",
1522
+ "sulphide": "sulfide",
1523
+ "sulphides": "sulfides",
1524
+ "sulphur": "sulfur",
1525
+ "sulphurous": "sulfurous",
1526
+ "summarise": "summarize",
1527
+ "summarised": "summarized",
1528
+ "summarises": "summarizes",
1529
+ "summarising": "summarizing",
1530
+ "swivelled": "swiveled",
1531
+ "swivelling": "swiveling",
1532
+ "symbolise": "symbolize",
1533
+ "symbolised": "symbolized",
1534
+ "symbolises": "symbolizes",
1535
+ "symbolising": "symbolizing",
1536
+ "sympathise": "sympathize",
1537
+ "sympathised": "sympathized",
1538
+ "sympathiser": "sympathizer",
1539
+ "sympathisers": "sympathizers",
1540
+ "sympathises": "sympathizes",
1541
+ "sympathising": "sympathizing",
1542
+ "synchronisation": "synchronization",
1543
+ "synchronise": "synchronize",
1544
+ "synchronised": "synchronized",
1545
+ "synchronises": "synchronizes",
1546
+ "synchronising": "synchronizing",
1547
+ "synthesise": "synthesize",
1548
+ "synthesised": "synthesized",
1549
+ "synthesiser": "synthesizer",
1550
+ "synthesisers": "synthesizers",
1551
+ "synthesises": "synthesizes",
1552
+ "synthesising": "synthesizing",
1553
+ "syphon": "siphon",
1554
+ "syphoned": "siphoned",
1555
+ "syphoning": "siphoning",
1556
+ "syphons": "siphons",
1557
+ "systematisation": "systematization",
1558
+ "systematise": "systematize",
1559
+ "systematised": "systematized",
1560
+ "systematises": "systematizes",
1561
+ "systematising": "systematizing",
1562
+ "tantalise": "tantalize",
1563
+ "tantalised": "tantalized",
1564
+ "tantalises": "tantalizes",
1565
+ "tantalising": "tantalizing",
1566
+ "tantalisingly": "tantalizingly",
1567
+ "tasselled": "tasseled",
1568
+ "technicolour": "technicolor",
1569
+ "temporise": "temporize",
1570
+ "temporised": "temporized",
1571
+ "temporises": "temporizes",
1572
+ "temporising": "temporizing",
1573
+ "tenderise": "tenderize",
1574
+ "tenderised": "tenderized",
1575
+ "tenderises": "tenderizes",
1576
+ "tenderising": "tenderizing",
1577
+ "terrorise": "terrorize",
1578
+ "terrorised": "terrorized",
1579
+ "terrorises": "terrorizes",
1580
+ "terrorising": "terrorizing",
1581
+ "theatre": "theater",
1582
+ "theatregoer": "theatergoer",
1583
+ "theatregoers": "theatergoers",
1584
+ "theatres": "theaters",
1585
+ "theorise": "theorize",
1586
+ "theorised": "theorized",
1587
+ "theorises": "theorizes",
1588
+ "theorising": "theorizing",
1589
+ "tonne": "ton",
1590
+ "tonnes": "tons",
1591
+ "towelled": "toweled",
1592
+ "towelling": "toweling",
1593
+ "toxaemia": "toxemia",
1594
+ "tranquillise": "tranquilize",
1595
+ "tranquillised": "tranquilized",
1596
+ "tranquilliser": "tranquilizer",
1597
+ "tranquillisers": "tranquilizers",
1598
+ "tranquillises": "tranquilizes",
1599
+ "tranquillising": "tranquilizing",
1600
+ "tranquillity": "tranquility",
1601
+ "tranquillize": "tranquilize",
1602
+ "tranquillized": "tranquilized",
1603
+ "tranquillizer": "tranquilizer",
1604
+ "tranquillizers": "tranquilizers",
1605
+ "tranquillizes": "tranquilizes",
1606
+ "tranquillizing": "tranquilizing",
1607
+ "tranquilly": "tranquility",
1608
+ "transistorised": "transistorized",
1609
+ "traumatise": "traumatize",
1610
+ "traumatised": "traumatized",
1611
+ "traumatises": "traumatizes",
1612
+ "traumatising": "traumatizing",
1613
+ "travelled": "traveled",
1614
+ "traveller": "traveler",
1615
+ "travellers": "travelers",
1616
+ "travelling": "traveling",
1617
+ "travelog": "travelogue",
1618
+ "travelogs": "travelogues",
1619
+ "trialled": "trialed",
1620
+ "trialling": "trialing",
1621
+ "tricolour": "tricolor",
1622
+ "tricolours": "tricolors",
1623
+ "trivialise": "trivialize",
1624
+ "trivialised": "trivialized",
1625
+ "trivialises": "trivializes",
1626
+ "trivialising": "trivializing",
1627
+ "tumour": "tumor",
1628
+ "tumours": "tumors",
1629
+ "tunnelled": "tunneled",
1630
+ "tunnelling": "tunneling",
1631
+ "tyrannise": "tyrannize",
1632
+ "tyrannised": "tyrannized",
1633
+ "tyrannises": "tyrannizes",
1634
+ "tyrannising": "tyrannizing",
1635
+ "tyre": "tire",
1636
+ "tyres": "tires",
1637
+ "unauthorised": "unauthorized",
1638
+ "uncivilised": "uncivilized",
1639
+ "underutilised": "underutilized",
1640
+ "unequalled": "unequaled",
1641
+ "unfavourable": "unfavorable",
1642
+ "unfavourably": "unfavorably",
1643
+ "unionisation": "unionization",
1644
+ "unionise": "unionize",
1645
+ "unionised": "unionized",
1646
+ "unionises": "unionizes",
1647
+ "unionising": "unionizing",
1648
+ "unorganised": "unorganized",
1649
+ "unravelled": "unraveled",
1650
+ "unravelling": "unraveling",
1651
+ "unrecognisable": "unrecognizable",
1652
+ "unrecognised": "unrecognized",
1653
+ "unrivalled": "unrivaled",
1654
+ "unsavoury": "unsavory",
1655
+ "untrammelled": "untrammeled",
1656
+ "urbanisation": "urbanization",
1657
+ "urbanise": "urbanize",
1658
+ "urbanised": "urbanized",
1659
+ "urbanises": "urbanizes",
1660
+ "urbanising": "urbanizing",
1661
+ "utilisable": "utilizable",
1662
+ "utilisation": "utilization",
1663
+ "utilise": "utilize",
1664
+ "utilised": "utilized",
1665
+ "utilises": "utilizes",
1666
+ "utilising": "utilizing",
1667
+ "valour": "valor",
1668
+ "vandalise": "vandalize",
1669
+ "vandalised": "vandalized",
1670
+ "vandalises": "vandalizes",
1671
+ "vandalising": "vandalizing",
1672
+ "vaporisation": "vaporization",
1673
+ "vaporise": "vaporize",
1674
+ "vaporised": "vaporized",
1675
+ "vaporises": "vaporizes",
1676
+ "vaporising": "vaporizing",
1677
+ "vapour": "vapor",
1678
+ "vapours": "vapors",
1679
+ "verbalise": "verbalize",
1680
+ "verbalised": "verbalized",
1681
+ "verbalises": "verbalizes",
1682
+ "verbalising": "verbalizing",
1683
+ "victimisation": "victimization",
1684
+ "victimise": "victimize",
1685
+ "victimised": "victimized",
1686
+ "victimises": "victimizes",
1687
+ "victimising": "victimizing",
1688
+ "videodisc": "videodisk",
1689
+ "videodiscs": "videodisks",
1690
+ "vigour": "vigor",
1691
+ "visualisation": "visualization",
1692
+ "visualisations": "visualizations",
1693
+ "visualise": "visualize",
1694
+ "visualised": "visualized",
1695
+ "visualises": "visualizes",
1696
+ "visualising": "visualizing",
1697
+ "vocalisation": "vocalization",
1698
+ "vocalisations": "vocalizations",
1699
+ "vocalise": "vocalize",
1700
+ "vocalised": "vocalized",
1701
+ "vocalises": "vocalizes",
1702
+ "vocalising": "vocalizing",
1703
+ "vulcanised": "vulcanized",
1704
+ "vulgarisation": "vulgarization",
1705
+ "vulgarise": "vulgarize",
1706
+ "vulgarised": "vulgarized",
1707
+ "vulgarises": "vulgarizes",
1708
+ "vulgarising": "vulgarizing",
1709
+ "waggon": "wagon",
1710
+ "waggons": "wagons",
1711
+ "watercolour": "watercolor",
1712
+ "watercolours": "watercolors",
1713
+ "weaselled": "weaseled",
1714
+ "weaselling": "weaseling",
1715
+ "westernisation": "westernization",
1716
+ "westernise": "westernize",
1717
+ "westernised": "westernized",
1718
+ "westernises": "westernizes",
1719
+ "westernising": "westernizing",
1720
+ "womanise": "womanize",
1721
+ "womanised": "womanized",
1722
+ "womaniser": "womanizer",
1723
+ "womanisers": "womanizers",
1724
+ "womanises": "womanizes",
1725
+ "womanising": "womanizing",
1726
+ "woollen": "woolen",
1727
+ "woollens": "woolens",
1728
+ "woollies": "woolies",
1729
+ "woolly": "wooly",
1730
+ "worshipped": "worshiped",
1731
+ "worshipping": "worshiping",
1732
+ "worshipper": "worshiper",
1733
+ "yodelled": "yodeled",
1734
+ "yodelling": "yodeling",
1735
+ "yoghourt": "yogurt",
1736
+ "yoghourts": "yogurts",
1737
+ "yoghurt": "yogurt",
1738
+ "yoghurts": "yogurts",
1739
+ "mhm": "hmm",
1740
+ "mmm": "hmm"
1741
+ }
whisper_stream/normalizers/english.py ADDED
@@ -0,0 +1,550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ from fractions import Fraction
5
+ from typing import Iterator, List, Match, Optional, Union
6
+
7
+ from more_itertools import windowed
8
+
9
+ from .basic import remove_symbols_and_diacritics
10
+
11
+
12
+ class EnglishNumberNormalizer:
13
+ """
14
+ Convert any spelled-out numbers into arabic numbers, while handling:
15
+
16
+ - remove any commas
17
+ - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
18
+ - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
19
+ - spell out `one` and `ones`
20
+ - interpret successive single-digit numbers as nominal: `one oh one` -> `101`
21
+ """
22
+
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ self.zeros = {"o", "oh", "zero"}
27
+ self.ones = {
28
+ name: i
29
+ for i, name in enumerate(
30
+ [
31
+ "one",
32
+ "two",
33
+ "three",
34
+ "four",
35
+ "five",
36
+ "six",
37
+ "seven",
38
+ "eight",
39
+ "nine",
40
+ "ten",
41
+ "eleven",
42
+ "twelve",
43
+ "thirteen",
44
+ "fourteen",
45
+ "fifteen",
46
+ "sixteen",
47
+ "seventeen",
48
+ "eighteen",
49
+ "nineteen",
50
+ ],
51
+ start=1,
52
+ )
53
+ }
54
+ self.ones_plural = {
55
+ "sixes" if name == "six" else name + "s": (value, "s")
56
+ for name, value in self.ones.items()
57
+ }
58
+ self.ones_ordinal = {
59
+ "zeroth": (0, "th"),
60
+ "first": (1, "st"),
61
+ "second": (2, "nd"),
62
+ "third": (3, "rd"),
63
+ "fifth": (5, "th"),
64
+ "twelfth": (12, "th"),
65
+ **{
66
+ name + ("h" if name.endswith("t") else "th"): (value, "th")
67
+ for name, value in self.ones.items()
68
+ if value > 3 and value != 5 and value != 12
69
+ },
70
+ }
71
+ self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
72
+
73
+ self.tens = {
74
+ "twenty": 20,
75
+ "thirty": 30,
76
+ "forty": 40,
77
+ "fifty": 50,
78
+ "sixty": 60,
79
+ "seventy": 70,
80
+ "eighty": 80,
81
+ "ninety": 90,
82
+ }
83
+ self.tens_plural = {
84
+ name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
85
+ }
86
+ self.tens_ordinal = {
87
+ name.replace("y", "ieth"): (value, "th")
88
+ for name, value in self.tens.items()
89
+ }
90
+ self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
91
+
92
+ self.multipliers = {
93
+ "hundred": 100,
94
+ "thousand": 1_000,
95
+ "million": 1_000_000,
96
+ "billion": 1_000_000_000,
97
+ "trillion": 1_000_000_000_000,
98
+ "quadrillion": 1_000_000_000_000_000,
99
+ "quintillion": 1_000_000_000_000_000_000,
100
+ "sextillion": 1_000_000_000_000_000_000_000,
101
+ "septillion": 1_000_000_000_000_000_000_000_000,
102
+ "octillion": 1_000_000_000_000_000_000_000_000_000,
103
+ "nonillion": 1_000_000_000_000_000_000_000_000_000_000,
104
+ "decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
105
+ }
106
+ self.multipliers_plural = {
107
+ name + "s": (value, "s") for name, value in self.multipliers.items()
108
+ }
109
+ self.multipliers_ordinal = {
110
+ name + "th": (value, "th") for name, value in self.multipliers.items()
111
+ }
112
+ self.multipliers_suffixed = {
113
+ **self.multipliers_plural,
114
+ **self.multipliers_ordinal,
115
+ }
116
+ self.decimals = {*self.ones, *self.tens, *self.zeros}
117
+
118
+ self.preceding_prefixers = {
119
+ "minus": "-",
120
+ "negative": "-",
121
+ "plus": "+",
122
+ "positive": "+",
123
+ }
124
+ self.following_prefixers = {
125
+ "pound": "£",
126
+ "pounds": "£",
127
+ "euro": "€",
128
+ "euros": "€",
129
+ "dollar": "$",
130
+ "dollars": "$",
131
+ "cent": "¢",
132
+ "cents": "¢",
133
+ }
134
+ self.prefixes = set(
135
+ list(self.preceding_prefixers.values())
136
+ + list(self.following_prefixers.values())
137
+ )
138
+ self.suffixers = {
139
+ "per": {"cent": "%"},
140
+ "percent": "%",
141
+ }
142
+ self.specials = {"and", "double", "triple", "point"}
143
+
144
+ self.words = set(
145
+ [
146
+ key
147
+ for mapping in [
148
+ self.zeros,
149
+ self.ones,
150
+ self.ones_suffixed,
151
+ self.tens,
152
+ self.tens_suffixed,
153
+ self.multipliers,
154
+ self.multipliers_suffixed,
155
+ self.preceding_prefixers,
156
+ self.following_prefixers,
157
+ self.suffixers,
158
+ self.specials,
159
+ ]
160
+ for key in mapping
161
+ ]
162
+ )
163
+ self.literal_words = {"one", "ones"}
164
+
165
+ def process_words(self, words: List[str]) -> Iterator[str]:
166
+ prefix: Optional[str] = None
167
+ value: Optional[Union[str, int]] = None
168
+ skip = False
169
+
170
+ def to_fraction(s: str):
171
+ try:
172
+ return Fraction(s)
173
+ except ValueError:
174
+ return None
175
+
176
+ def output(result: Union[str, int]):
177
+ nonlocal prefix, value
178
+ result = str(result)
179
+ if prefix is not None:
180
+ result = prefix + result
181
+ value = None
182
+ prefix = None
183
+ return result
184
+
185
+ if len(words) == 0:
186
+ return
187
+
188
+ for prev, current, next in windowed([None] + words + [None], 3):
189
+ if skip:
190
+ skip = False
191
+ continue
192
+
193
+ next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
194
+ has_prefix = current[0] in self.prefixes
195
+ current_without_prefix = current[1:] if has_prefix else current
196
+ if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
197
+ # arabic numbers (potentially with signs and fractions)
198
+ f = to_fraction(current_without_prefix)
199
+ assert f is not None
200
+ if value is not None:
201
+ if isinstance(value, str) and value.endswith("."):
202
+ # concatenate decimals / ip address components
203
+ value = str(value) + str(current)
204
+ continue
205
+ else:
206
+ yield output(value)
207
+
208
+ prefix = current[0] if has_prefix else prefix
209
+ if f.denominator == 1:
210
+ value = f.numerator # store integers as int
211
+ else:
212
+ value = current_without_prefix
213
+ elif current not in self.words:
214
+ # non-numeric words
215
+ if value is not None:
216
+ yield output(value)
217
+ yield output(current)
218
+ elif current in self.zeros:
219
+ value = str(value or "") + "0"
220
+ elif current in self.ones:
221
+ ones = self.ones[current]
222
+
223
+ if value is None:
224
+ value = ones
225
+ elif isinstance(value, str) or prev in self.ones:
226
+ if (
227
+ prev in self.tens and ones < 10
228
+ ): # replace the last zero with the digit
229
+ assert value[-1] == "0"
230
+ value = value[:-1] + str(ones)
231
+ else:
232
+ value = str(value) + str(ones)
233
+ elif ones < 10:
234
+ if value % 10 == 0:
235
+ value += ones
236
+ else:
237
+ value = str(value) + str(ones)
238
+ else: # eleven to nineteen
239
+ if value % 100 == 0:
240
+ value += ones
241
+ else:
242
+ value = str(value) + str(ones)
243
+ elif current in self.ones_suffixed:
244
+ # ordinal or cardinal; yield the number right away
245
+ ones, suffix = self.ones_suffixed[current]
246
+ if value is None:
247
+ yield output(str(ones) + suffix)
248
+ elif isinstance(value, str) or prev in self.ones:
249
+ if prev in self.tens and ones < 10:
250
+ assert value[-1] == "0"
251
+ yield output(value[:-1] + str(ones) + suffix)
252
+ else:
253
+ yield output(str(value) + str(ones) + suffix)
254
+ elif ones < 10:
255
+ if value % 10 == 0:
256
+ yield output(str(value + ones) + suffix)
257
+ else:
258
+ yield output(str(value) + str(ones) + suffix)
259
+ else: # eleven to nineteen
260
+ if value % 100 == 0:
261
+ yield output(str(value + ones) + suffix)
262
+ else:
263
+ yield output(str(value) + str(ones) + suffix)
264
+ value = None
265
+ elif current in self.tens:
266
+ tens = self.tens[current]
267
+ if value is None:
268
+ value = tens
269
+ elif isinstance(value, str):
270
+ value = str(value) + str(tens)
271
+ else:
272
+ if value % 100 == 0:
273
+ value += tens
274
+ else:
275
+ value = str(value) + str(tens)
276
+ elif current in self.tens_suffixed:
277
+ # ordinal or cardinal; yield the number right away
278
+ tens, suffix = self.tens_suffixed[current]
279
+ if value is None:
280
+ yield output(str(tens) + suffix)
281
+ elif isinstance(value, str):
282
+ yield output(str(value) + str(tens) + suffix)
283
+ else:
284
+ if value % 100 == 0:
285
+ yield output(str(value + tens) + suffix)
286
+ else:
287
+ yield output(str(value) + str(tens) + suffix)
288
+ elif current in self.multipliers:
289
+ multiplier = self.multipliers[current]
290
+ if value is None:
291
+ value = multiplier
292
+ elif isinstance(value, str) or value == 0:
293
+ f = to_fraction(value)
294
+ p = f * multiplier if f is not None else None
295
+ if f is not None and p.denominator == 1:
296
+ value = p.numerator
297
+ else:
298
+ yield output(value)
299
+ value = multiplier
300
+ else:
301
+ before = value // 1000 * 1000
302
+ residual = value % 1000
303
+ value = before + residual * multiplier
304
+ elif current in self.multipliers_suffixed:
305
+ multiplier, suffix = self.multipliers_suffixed[current]
306
+ if value is None:
307
+ yield output(str(multiplier) + suffix)
308
+ elif isinstance(value, str):
309
+ f = to_fraction(value)
310
+ p = f * multiplier if f is not None else None
311
+ if f is not None and p.denominator == 1:
312
+ yield output(str(p.numerator) + suffix)
313
+ else:
314
+ yield output(value)
315
+ yield output(str(multiplier) + suffix)
316
+ else: # int
317
+ before = value // 1000 * 1000
318
+ residual = value % 1000
319
+ value = before + residual * multiplier
320
+ yield output(str(value) + suffix)
321
+ value = None
322
+ elif current in self.preceding_prefixers:
323
+ # apply prefix (positive, minus, etc.) if it precedes a number
324
+ if value is not None:
325
+ yield output(value)
326
+
327
+ if next in self.words or next_is_numeric:
328
+ prefix = self.preceding_prefixers[current]
329
+ else:
330
+ yield output(current)
331
+ elif current in self.following_prefixers:
332
+ # apply prefix (dollars, cents, etc.) only after a number
333
+ if value is not None:
334
+ prefix = self.following_prefixers[current]
335
+ yield output(value)
336
+ else:
337
+ yield output(current)
338
+ elif current in self.suffixers:
339
+ # apply suffix symbols (percent -> '%')
340
+ if value is not None:
341
+ suffix = self.suffixers[current]
342
+ if isinstance(suffix, dict):
343
+ if next in suffix:
344
+ yield output(str(value) + suffix[next])
345
+ skip = True
346
+ else:
347
+ yield output(value)
348
+ yield output(current)
349
+ else:
350
+ yield output(str(value) + suffix)
351
+ else:
352
+ yield output(current)
353
+ elif current in self.specials:
354
+ if next not in self.words and not next_is_numeric:
355
+ # apply special handling only if the next word can be numeric
356
+ if value is not None:
357
+ yield output(value)
358
+ yield output(current)
359
+ elif current == "and":
360
+ # ignore "and" after hundreds, thousands, etc.
361
+ if prev not in self.multipliers:
362
+ if value is not None:
363
+ yield output(value)
364
+ yield output(current)
365
+ elif current == "double" or current == "triple":
366
+ if next in self.ones or next in self.zeros:
367
+ repeats = 2 if current == "double" else 3
368
+ ones = self.ones.get(next, 0)
369
+ value = str(value or "") + str(ones) * repeats
370
+ skip = True
371
+ else:
372
+ if value is not None:
373
+ yield output(value)
374
+ yield output(current)
375
+ elif current == "point":
376
+ if next in self.decimals or next_is_numeric:
377
+ value = str(value or "") + "."
378
+ else:
379
+ # should all have been covered at this point
380
+ raise ValueError(f"Unexpected token: {current}")
381
+ else:
382
+ # all should have been covered at this point
383
+ raise ValueError(f"Unexpected token: {current}")
384
+
385
+ if value is not None:
386
+ yield output(value)
387
+
388
+ def preprocess(self, s: str):
389
+ # replace "<number> and a half" with "<number> point five"
390
+ results = []
391
+
392
+ segments = re.split(r"\band\s+a\s+half\b", s)
393
+ for i, segment in enumerate(segments):
394
+ if len(segment.strip()) == 0:
395
+ continue
396
+ if i == len(segments) - 1:
397
+ results.append(segment)
398
+ else:
399
+ results.append(segment)
400
+ last_word = segment.rsplit(maxsplit=2)[-1]
401
+ if last_word in self.decimals or last_word in self.multipliers:
402
+ results.append("point five")
403
+ else:
404
+ results.append("and a half")
405
+
406
+ s = " ".join(results)
407
+
408
+ # put a space at number/letter boundary
409
+ s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
410
+ s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
411
+
412
+ # but remove spaces which could be a suffix
413
+ s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
414
+
415
+ return s
416
+
417
+ def postprocess(self, s: str):
418
+ def combine_cents(m: Match):
419
+ try:
420
+ currency = m.group(1)
421
+ integer = m.group(2)
422
+ cents = int(m.group(3))
423
+ return f"{currency}{integer}.{cents:02d}"
424
+ except ValueError:
425
+ return m.string
426
+
427
+ def extract_cents(m: Match):
428
+ try:
429
+ return f"¢{int(m.group(1))}"
430
+ except ValueError:
431
+ return m.string
432
+
433
+ # apply currency postprocessing; "$2 and ¢7" -> "$2.07"
434
+ s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
435
+ s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
436
+
437
+ # write "one(s)" instead of "1(s)", just for the readability
438
+ s = re.sub(r"\b1(s?)\b", r"one\1", s)
439
+
440
+ return s
441
+
442
+ def __call__(self, s: str):
443
+ s = self.preprocess(s)
444
+ s = " ".join(word for word in self.process_words(s.split()) if word is not None)
445
+ s = self.postprocess(s)
446
+
447
+ return s
448
+
449
+
450
+ class EnglishSpellingNormalizer:
451
+ """
452
+ Applies British-American spelling mappings as listed in [1].
453
+
454
+ [1] https://www.tysto.com/uk-us-spelling-list.html
455
+ """
456
+
457
+ def __init__(self):
458
+ mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
459
+ self.mapping = json.load(open(mapping_path))
460
+
461
+ def __call__(self, s: str):
462
+ return " ".join(self.mapping.get(word, word) for word in s.split())
463
+
464
+
465
+ class EnglishTextNormalizer:
466
+ def __init__(self):
467
+ self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
468
+ self.replacers = {
469
+ # common contractions
470
+ r"\bwon't\b": "will not",
471
+ r"\bcan't\b": "can not",
472
+ r"\blet's\b": "let us",
473
+ r"\bain't\b": "aint",
474
+ r"\by'all\b": "you all",
475
+ r"\bwanna\b": "want to",
476
+ r"\bgotta\b": "got to",
477
+ r"\bgonna\b": "going to",
478
+ r"\bi'ma\b": "i am going to",
479
+ r"\bimma\b": "i am going to",
480
+ r"\bwoulda\b": "would have",
481
+ r"\bcoulda\b": "could have",
482
+ r"\bshoulda\b": "should have",
483
+ r"\bma'am\b": "madam",
484
+ # contractions in titles/prefixes
485
+ r"\bmr\b": "mister ",
486
+ r"\bmrs\b": "missus ",
487
+ r"\bst\b": "saint ",
488
+ r"\bdr\b": "doctor ",
489
+ r"\bprof\b": "professor ",
490
+ r"\bcapt\b": "captain ",
491
+ r"\bgov\b": "governor ",
492
+ r"\bald\b": "alderman ",
493
+ r"\bgen\b": "general ",
494
+ r"\bsen\b": "senator ",
495
+ r"\brep\b": "representative ",
496
+ r"\bpres\b": "president ",
497
+ r"\brev\b": "reverend ",
498
+ r"\bhon\b": "honorable ",
499
+ r"\basst\b": "assistant ",
500
+ r"\bassoc\b": "associate ",
501
+ r"\blt\b": "lieutenant ",
502
+ r"\bcol\b": "colonel ",
503
+ r"\bjr\b": "junior ",
504
+ r"\bsr\b": "senior ",
505
+ r"\besq\b": "esquire ",
506
+ # prefect tenses, ideally it should be any past participles, but it's harder..
507
+ r"'d been\b": " had been",
508
+ r"'s been\b": " has been",
509
+ r"'d gone\b": " had gone",
510
+ r"'s gone\b": " has gone",
511
+ r"'d done\b": " had done", # "'s done" is ambiguous
512
+ r"'s got\b": " has got",
513
+ # general contractions
514
+ r"n't\b": " not",
515
+ r"'re\b": " are",
516
+ r"'s\b": " is",
517
+ r"'d\b": " would",
518
+ r"'ll\b": " will",
519
+ r"'t\b": " not",
520
+ r"'ve\b": " have",
521
+ r"'m\b": " am",
522
+ }
523
+ self.standardize_numbers = EnglishNumberNormalizer()
524
+ self.standardize_spellings = EnglishSpellingNormalizer()
525
+
526
+ def __call__(self, s: str):
527
+ s = s.lower()
528
+
529
+ s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
530
+ s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
531
+ s = re.sub(self.ignore_patterns, "", s)
532
+ s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
533
+
534
+ for pattern, replacement in self.replacers.items():
535
+ s = re.sub(pattern, replacement, s)
536
+
537
+ s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
538
+ s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
539
+ s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
540
+
541
+ s = self.standardize_numbers(s)
542
+ s = self.standardize_spellings(s)
543
+
544
+ # now remove prefix/suffix symbols that are not preceded/followed by numbers
545
+ s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
546
+ s = re.sub(r"([^0-9])%", r"\1 ", s)
547
+
548
+ s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
549
+
550
+ return s
whisper_stream/streaming_decoding.py ADDED
@@ -0,0 +1,1198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field, replace
2
+ from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import Tensor
8
+ from torch.distributions import Categorical
9
+
10
+ from .audio import CHUNK_LENGTH
11
+ from .tokenizer import Tokenizer, get_tokenizer
12
+ from .utils import compression_ratio
13
+
14
+ if TYPE_CHECKING:
15
+ from .streaming_model import StreamingWhisper
16
+ from .model import Whisper
17
+
18
+
19
+ @dataclass(frozen=False)
20
+ class DecodingOptions:
21
+ # whether to perform X->X "transcribe" or X->English "translate"
22
+ task: str = "transcribe"
23
+
24
+ # language that the audio is in; uses detected language if None
25
+ language: Optional[str] = None
26
+
27
+ # sampling-related options
28
+ temperature: float = 0.0
29
+ sample_len: Optional[int] = None # maximum number of tokens to sample
30
+ best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
31
+ beam_size: Optional[int] = None # number of beams in beam search, if t == 0
32
+ patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
33
+
34
+ # "alpha" in Google NMT, or None for length norm, when ranking generations
35
+ # to select which to return among the beams or best-of-N samples
36
+ length_penalty: Optional[float] = None
37
+
38
+ # text or tokens to feed as the prompt or the prefix; for more info:
39
+ # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
40
+ prompt: Optional[Union[str, List[int]]] = None # for the previous context
41
+ prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
42
+
43
+ # list of tokens ids (or comma-separated token ids) to suppress
44
+ # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
45
+ suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
46
+ suppress_blank: bool = True # this will suppress blank outputs
47
+
48
+ # timestamp sampling options
49
+ without_timestamps: bool = True # use <|notimestamps|> to sample text tokens only
50
+ max_initial_timestamp: Optional[float] = 1.0
51
+
52
+ # implementation details
53
+ fp16: bool = False # use fp16 for most of the calculation
54
+
55
+ # Advisor & Streaming params
56
+ advised: bool = False
57
+ attentive_advisor: bool = False
58
+ use_sa: bool = False
59
+ ctx: int = 3 # CA / SA ctx given to advisor.
60
+ gran: int = 15 # granularity of encoder embeddings (!) 15 = 300 msec (each frame equals 20msec)
61
+ pad_audio_features: bool = True
62
+ single_frame_mel: bool = True
63
+ pad_input: bool = False
64
+
65
+ look_ahead_blocks: int = 0
66
+ maximal_seconds_context: int = 30 # after 30 seconds, reset the mel.
67
+
68
+ use_kv_cache: bool = False
69
+ use_ca_kv_cache: bool = False
70
+
71
+ # streaming decoding args
72
+ stream_decode: bool = True
73
+ tokens_per_frame: int = 2
74
+ n_tokens_look_back: int = 2
75
+ streaming_timestamps: bool = True
76
+ wait_for_all: bool = False
77
+ force_first_tokens_timestamps: bool = False
78
+
79
+
80
+ @dataclass(frozen=False)
81
+ class DecodingResult:
82
+ audio_features: Tensor
83
+ language: str
84
+ language_probs: Optional[Dict[str, float]] = None
85
+ tokens: List[int] = field(default_factory=list)
86
+ text: str = ""
87
+ avg_logprob: float = np.nan
88
+ no_speech_prob: float = np.nan
89
+ temperature: float = np.nan
90
+ compression_ratio: float = np.nan
91
+ timestamps: Dict[Tuple, Tuple] = field(default_factory=tuple)
92
+ timed_tokens: List[int] = field(default_factory=list)
93
+ timed_text: str = ""
94
+
95
+
96
+ class Inference:
97
+ def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
98
+ """Perform a forward pass on the decoder and return per-token logits"""
99
+ raise NotImplementedError
100
+
101
+ def rearrange_kv_cache(self, source_indices) -> None:
102
+ """Update the key-value cache according to the updated beams"""
103
+ raise NotImplementedError
104
+
105
+ def cleanup_caching(self) -> None:
106
+ """Clean up any resources or hooks after decoding is finished"""
107
+ pass
108
+
109
+ def flush_tokens_from_cache(self) -> None:
110
+ """flush irrelevant tokens from cache during streaming"""
111
+ pass
112
+
113
+
114
+ class PyTorchInference(Inference):
115
+ def __init__(self, model: "StreamingWhisper", initial_token_length: int, use_kv_cache: bool = False, dump_type: str = None, n_tokens_look_back: int = 2, use_beam: bool = False):
116
+ self.model: "StreamingWhisper" = model
117
+ self.initial_token_length = initial_token_length
118
+ self.hooks = []
119
+ # custom
120
+ self.use_kv_cache = use_kv_cache
121
+ self.kv_cache = {} if use_kv_cache else None
122
+ self.dump_type = dump_type
123
+ self.n_tokens_look_back = n_tokens_look_back
124
+ self.cached_logits = None
125
+ self.use_beam = use_beam
126
+
127
+ key_modules = [block.attn.key for block in self.model.decoder.blocks]
128
+ value_modules = [block.attn.value for block in self.model.decoder.blocks]
129
+ self.kv_modules = key_modules + value_modules
130
+
131
+ def logits(self, tokens: Tensor, audio_features: Tensor, first_prediction: bool = False, beam_indices: list[list] = None) -> Tensor:
132
+
133
+ if not self.kv_cache and self.use_kv_cache:
134
+ self.kv_cache, self.hooks = self.model.install_decoder_kv_cache_hooks()
135
+
136
+ if tokens.shape[-1] > self.initial_token_length and self.use_kv_cache:
137
+ # only need to use the last token except in the first forward pass
138
+ if not self.use_beam:
139
+ tokens = tokens[:, -self.n_tokens_look_back:] if first_prediction else tokens[:, -1:]
140
+ else:
141
+ n_beams = tokens.shape[0]
142
+ self.kv_cache["beam_indices"] = beam_indices # an elegant way to send it to decoder ?
143
+ tokens = tokens[beam_indices[0], beam_indices[1]].view(n_beams, -1)
144
+
145
+
146
+ return self._concat_logits_if_needed(self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache))
147
+
148
+ def _concat_logits_if_needed(self, logits: Tensor):
149
+ if not self.use_kv_cache: return logits
150
+
151
+ if self.cached_logits is None:
152
+ self.cached_logits = logits
153
+
154
+ return logits
155
+
156
+ if not self.use_beam: # greedy
157
+ self.cached_logits = torch.cat([self.cached_logits, logits], dim=1)
158
+ return self.cached_logits
159
+
160
+ # beam kv_cache
161
+ n_beams, n_ctx, n_vocab = logits.shape
162
+
163
+ for i, (beam, index, output_index) in enumerate(zip(*self.kv_cache["beam_indices"])):
164
+ if index < self.cached_logits.shape[1]:
165
+ self.cached_logits[beam, index] = logits[beam, output_index]
166
+ else:
167
+ self.cached_logits = torch.cat([self.cached_logits, logits[:, (output_index):]], dim=1)
168
+
169
+ return self.cached_logits
170
+
171
+ def cleanup_caching(self):
172
+ if not self.use_kv_cache:
173
+ return
174
+
175
+ for hook in self.hooks:
176
+ hook.remove()
177
+
178
+ del self.kv_cache
179
+ del self.hooks
180
+ self.kv_cache = {}
181
+ self.hooks = []
182
+
183
+ def flush_tokens_from_cache(self):
184
+ for key in self.kv_cache.keys():
185
+ if key == "beam_indices": continue
186
+ self.kv_cache[key] = self.kv_cache[key][:, :-self.n_tokens_look_back].detach()
187
+
188
+ self.cached_logits = self.cached_logits[:, :-self.n_tokens_look_back].detach()
189
+
190
+ def rearrange_kv_cache(self, source_indices):
191
+ if not self.use_kv_cache: return
192
+
193
+ if source_indices != list(range(len(source_indices))):
194
+
195
+ for module in self.kv_modules:
196
+ # update the key/value cache to contain the selected sequences
197
+ self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
198
+
199
+ self.cached_logits = self.cached_logits[source_indices].detach()
200
+
201
+
202
+ class SequenceRanker:
203
+ def rank(
204
+ self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
205
+ ) -> List[int]:
206
+ """
207
+ Given a list of groups of samples and their cumulative log probabilities,
208
+ return the indices of the samples in each group to select as the final result
209
+ """
210
+ raise NotImplementedError
211
+
212
+
213
+ class MaximumLikelihoodRanker(SequenceRanker):
214
+ """
215
+ Select the sample with the highest log probabilities, penalized using either
216
+ a simple length normalization or Google NMT paper's length penalty
217
+ """
218
+
219
+ def __init__(self, length_penalty: Optional[float]):
220
+ self.length_penalty = length_penalty
221
+
222
+ def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
223
+ def scores(logprobs, lengths):
224
+ result = []
225
+ for logprob, length in zip(logprobs, lengths):
226
+ if self.length_penalty is None:
227
+ penalty = length
228
+ else:
229
+ # from the Google NMT paper
230
+ penalty = ((5 + length) / 6) ** self.length_penalty
231
+ result.append(logprob / penalty)
232
+ return result
233
+
234
+ # get the sequence with the highest score
235
+ lengths = [[len(t) for t in s] for s in tokens]
236
+
237
+ return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
238
+
239
+
240
+ class TokenDecoder:
241
+ def reset(self):
242
+ """Initialize any stateful variables for decoding a new sequence"""
243
+
244
+ def update(
245
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
246
+ ) -> Tuple[Tensor, bool]:
247
+ """Specify how to select the next token, based on the current trace and logits
248
+
249
+ Parameters
250
+ ----------
251
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
252
+ all tokens in the context so far, including the prefix and sot_sequence tokens
253
+
254
+ logits : Tensor, shape = (n_batch, vocab_size)
255
+ per-token logits of the probability distribution at the current step
256
+
257
+ sum_logprobs : Tensor, shape = (n_batch)
258
+ cumulative log probabilities for each sequence
259
+
260
+ Returns
261
+ -------
262
+ tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
263
+ the tokens, appended with the selected next token
264
+
265
+ completed : bool
266
+ True if all sequences has reached the end of text
267
+
268
+ """
269
+ raise NotImplementedError
270
+
271
+ def finalize(
272
+ self, tokens: Tensor, sum_logprobs: Tensor
273
+ ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
274
+ """Finalize search and return the final candidate sequences
275
+
276
+ Parameters
277
+ ----------
278
+ tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
279
+ all tokens in the context so far, including the prefix and sot_sequence
280
+
281
+ sum_logprobs : Tensor, shape = (n_audio, n_group)
282
+ cumulative log probabilities for each sequence
283
+
284
+ Returns
285
+ -------
286
+ tokens : Sequence[Sequence[Tensor]], length = n_audio
287
+ sequence of Tensors containing candidate token sequences, for each audio input
288
+
289
+ sum_logprobs : List[List[float]], length = n_audio
290
+ sequence of cumulative log probabilities corresponding to the above
291
+
292
+ """
293
+ raise NotImplementedError
294
+
295
+
296
+ class GreedyDecoder(TokenDecoder):
297
+ def __init__(self, temperature: float, eot: int):
298
+ self.temperature = temperature
299
+ self.eot = eot
300
+
301
+ def update(
302
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
303
+ ) -> Tuple[Tensor, bool]:
304
+ if self.temperature == 0:
305
+ next_tokens = logits.argmax(dim=-1)
306
+ else:
307
+ next_tokens = Categorical(logits=logits / self.temperature).sample()
308
+
309
+ logprobs = F.log_softmax(logits.float(), dim=-1)
310
+
311
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
312
+ sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
313
+
314
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
315
+ tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
316
+
317
+ completed = (tokens[:, -1] == self.eot).all()
318
+ return tokens, completed
319
+
320
+ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
321
+ # make sure each sequence has at least one EOT token at the end
322
+ tokens = F.pad(tokens, (0, 1), value=self.eot)
323
+ return tokens, sum_logprobs.tolist()
324
+
325
+
326
+ class BeamSearchDecoder(TokenDecoder):
327
+ def __init__(
328
+ self,
329
+ beam_size: int,
330
+ eot: int,
331
+ inference: Inference,
332
+ patience: Optional[float] = None,
333
+ ):
334
+ self.beam_size = beam_size
335
+ self.eot = eot
336
+ self.inference = inference
337
+ self.patience = patience or 1.0
338
+ self.max_candidates: int = round(beam_size * self.patience)
339
+ self.finished_sequences = None
340
+
341
+ assert (
342
+ self.max_candidates > 0
343
+ ), f"Invalid beam size ({beam_size}) or patience ({patience})"
344
+
345
+ def reset(self):
346
+ self.finished_sequences = None
347
+
348
+ def update(
349
+ self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
350
+ ) -> Tuple[Tensor, bool]:
351
+ if tokens.shape[0] % self.beam_size != 0:
352
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
353
+
354
+ n_audio = tokens.shape[0] // self.beam_size
355
+ if self.finished_sequences is None: # for the first update
356
+ self.finished_sequences = [{} for _ in range(n_audio)]
357
+
358
+ logprobs = F.log_softmax(logits.float(), dim=-1)
359
+ next_tokens, source_indices, finished_sequences = [], [], []
360
+ for i in range(n_audio): # in our case n_audio = 1
361
+ scores, sources, finished = {}, {}, {}
362
+
363
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
364
+ for j in range(self.beam_size): # go over each trajectory of beam
365
+ idx = i * self.beam_size + j
366
+ prefix = tokens[idx].tolist()
367
+ for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
368
+ new_logprob = (sum_logprobs[idx] + logprob).item()
369
+ sequence = tuple(prefix + [token.item()])
370
+ scores[sequence] = new_logprob
371
+ sources[sequence] = idx
372
+
373
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
374
+ saved = 0
375
+ for sequence in sorted(scores, key=scores.get, reverse=True):
376
+ if sequence[-1] == self.eot:
377
+ finished[sequence] = scores[sequence]
378
+ else:
379
+ sum_logprobs[len(next_tokens)] = scores[sequence]
380
+ next_tokens.append(sequence)
381
+ source_indices.append(sources[sequence])
382
+
383
+ saved += 1
384
+ if saved == self.beam_size:
385
+ break
386
+
387
+ finished_sequences.append(finished)
388
+
389
+ tokens = torch.tensor(next_tokens, device=tokens.device)
390
+ self.inference.rearrange_kv_cache(source_indices)
391
+
392
+ # add newly finished sequences to self.finished_sequences
393
+ assert len(self.finished_sequences) == len(finished_sequences)
394
+ for previously_finished, newly_finished in zip(
395
+ self.finished_sequences, finished_sequences
396
+ ):
397
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
398
+ if len(previously_finished) >= self.max_candidates:
399
+ break # the candidate list is full
400
+ previously_finished[seq] = newly_finished[seq]
401
+
402
+ # mark as completed if all audio has enough number of samples
403
+ completed = all(
404
+ len(sequences) >= self.max_candidates
405
+ for sequences in self.finished_sequences
406
+ )
407
+ return tokens, completed
408
+
409
+ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
410
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
411
+ sum_logprobs = sum_logprobs.cpu()
412
+ for i, sequences in enumerate(self.finished_sequences):
413
+ if (
414
+ len(sequences) < self.beam_size
415
+ ): # when not enough sequences are finished
416
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
417
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
418
+ sequences[tuple(sequence)] = sum_logprobs[i][j].item()
419
+ if len(sequences) >= self.beam_size:
420
+ break
421
+
422
+ tokens: List[List[Tensor]] = [
423
+ [torch.tensor(seq) for seq in sequences.keys()]
424
+ for sequences in self.finished_sequences
425
+ ]
426
+ sum_logprobs: List[List[float]] = [
427
+ list(sequences.values()) for sequences in self.finished_sequences
428
+ ]
429
+ return tokens, sum_logprobs
430
+
431
+
432
+ class StreamingDecoder(TokenDecoder):
433
+ def __init__(self,
434
+ temperature: float,
435
+ tokens_per_frame: int,
436
+ eot: int,
437
+ inference: Inference,
438
+ n_tokens_look_back: int = 2,
439
+ streaming_timestamps: bool = False
440
+ ):
441
+ self.tokens_per_frame = tokens_per_frame
442
+ self.eot = eot
443
+ self.inference = inference
444
+ self.last_logits: Tensor = None
445
+ self.temperature = temperature
446
+ self.tokens_look_back = n_tokens_look_back
447
+ self.check_token_index = -n_tokens_look_back - 1
448
+ self.streaming_timestamps = streaming_timestamps
449
+ self.timestamps_map = {-1: 0}
450
+ self._reset_timestamps()
451
+
452
+ def _insert_timestamps(self, audio_features: Tensor, tokens: Tensor, enc_emb_gran: int):
453
+ """
454
+ Given the tokens and stable lists.
455
+ If we send complete, it means that we should put a timestamp on the tokens.
456
+ """
457
+ # Given T, and a gran g, forward pass T/g times.
458
+ # Observe each vector, and try to find swap points for each token.
459
+ # Overall, extra T/g forward passes.
460
+ examined_token_index = 3 # this index is the first non initial token index to be predcited
461
+ for i in range(enc_emb_gran, audio_features.shape[-1] + 1, enc_emb_gran):
462
+ i_logits = self.inference.logits(tokens, audio_features[:, :i])
463
+
464
+ if tokens[0, examined_token_index] in i_logits.argmax(dim=-1): # stamp of this token.
465
+ self.timestamps_map[examined_token_index - 3] = i * 20
466
+ examined_token_index += 1
467
+
468
+ if examined_token_index >= tokens.shape[-1]: break
469
+
470
+ def _check_last_tokens(self, logits: Tensor, tokens: Tensor, next_tokens: Tensor, check_tokens: bool):
471
+ stable = []
472
+
473
+ if not check_tokens:
474
+ return stable, tokens
475
+
476
+ examine_tokens_indices = range(tokens.shape[1] + self.check_token_index, tokens.shape[1] - 1)
477
+ for i in examine_tokens_indices:
478
+
479
+ token_index = i + 1
480
+ token_prob_index = i
481
+
482
+ examined_token = tokens[:, token_index] # This is the predicted token for this index.
483
+
484
+ # It means that the prediction is stabilizing. We can move on, check next token.
485
+ # or - Prob is down, but token has still highest prob. Continue.
486
+ if self.last_logits[:, token_prob_index, examined_token] <= logits[:, token_prob_index, examined_token] \
487
+ or \
488
+ (self.last_logits[:, token_prob_index, examined_token] > logits[:, token_prob_index, examined_token] and next_tokens[:, token_prob_index] == examined_token):
489
+ stable.append(True)
490
+ continue
491
+
492
+ else:
493
+ tokens = tokens[:, :token_index] # Crop next tokens - Irrelevant after flush.
494
+ stable.append(False)
495
+ break
496
+
497
+ return stable, tokens
498
+
499
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, check_tokens: bool = False, index: int = 0, first_frame: bool = False) -> Tuple[Tensor, bool]:
500
+ # calc next_tokens, logprob, completed
501
+ if self.temperature == 0:
502
+ next_tokens = logits.argmax(dim=-1)
503
+ else:
504
+ next_tokens = Categorical(logits=logits / self.temperature).sample()
505
+
506
+ logprobs = F.log_softmax(logits.float(), dim=-1)
507
+ completed = next_tokens[0, -1] == self.eot
508
+
509
+ # means we are at the beginning, append tokens greedily.
510
+ if first_frame:
511
+ if next_tokens[0, -1] != self.eot:
512
+ tokens = torch.cat([tokens, next_tokens[None, :, -1]], dim=-1)
513
+ self.last_logits = logits
514
+ return tokens, completed
515
+
516
+ # Otherwise, check if last tokens are stable.
517
+ stable, tokens = self._check_last_tokens(logits, tokens, next_tokens, check_tokens)
518
+
519
+ if all(stable) and next_tokens[0, -1] != self.eot:
520
+ tokens = torch.cat([tokens, next_tokens[None, :, -1]], dim=-1)
521
+
522
+ if tokens.shape[-1] - 1 - 4 not in self.timestamps_map.keys(): # mark start of timestamp
523
+ self.timestamps_map[tokens.shape[-1] - 1 - 4] = index - 1
524
+
525
+ sum_logprobs += logprobs[:, -1, next_tokens[0, -1]]
526
+
527
+ # take last tokens logits to compare on next decode step
528
+ self.last_logits = logits
529
+
530
+ return tokens, completed
531
+
532
+ def _reset_timestamps(self):
533
+ self.timestamp_tokens_indices = [None, None]
534
+ self.timestamps_indices = [None, None]
535
+
536
+ def reset(self):
537
+ self.check_token_index = -self.tokens_look_back - 1
538
+
539
+ def finalize(self, tokens, sum_logprobs):
540
+ tokens = F.pad(tokens, (0, 1), value=self.eot)
541
+ return tokens, sum_logprobs.tolist()
542
+
543
+
544
+ class BeamStreamingDecoder(TokenDecoder):
545
+ def __init__(self,
546
+ temperature: float,
547
+ tokens_per_frame: int,
548
+ eot: int,
549
+ inference: Inference,
550
+ n_tokens_look_back: int = 2,
551
+ n_beams: int = 1,
552
+ pad_token: int = None,
553
+ wait_for_all: bool = False,
554
+ ):
555
+ self.tokens_per_frame = tokens_per_frame
556
+ self.eot = eot
557
+ self.pad_token = pad_token # sotlm token
558
+ self.inference = inference
559
+ self.last_logits: Tensor = None
560
+ self.temperature = temperature
561
+ self.tokens_look_back = n_tokens_look_back
562
+ self.check_token_index = [4 + 1 for _ in range(n_beams)] # Here I'll save the last index that is relevant for checking.
563
+ self.n_beams = n_beams
564
+ self.sum_logprobs = torch.zeros(n_beams)
565
+ self.timestamps_map = {}
566
+ self.wait_for_all = wait_for_all
567
+ self.finished_sequences = {}
568
+
569
+ def _get_last_valid_token_index(self, prefix: Tensor):
570
+ indices = torch.where((prefix == self.eot) | (prefix == self.pad_token))[0]
571
+ last_valid_token_index = (prefix.shape[-1] - 1) if indices.shape == torch.Size([0]) else (indices.min().item() - 1)
572
+ last_valid_token_index = max(last_valid_token_index, 3)
573
+
574
+ return last_valid_token_index
575
+
576
+ def _check_last_tokens(self, prefix: Tensor, logits: Tensor, check_tokens: bool):
577
+ """
578
+ given the last k logits, and the tokens, determine where we should proceed from.
579
+ prefix - a list with the tokens, including padding tokens and eot. Start from eot / pad token index. of shape (l)
580
+ logits - of shape (l, n_vocab). The logits of the specific beam.
581
+ """
582
+ last_valid_token_index = self._get_last_valid_token_index(prefix)
583
+
584
+ if not check_tokens:
585
+ return last_valid_token_index, True
586
+
587
+ examined_prob_indices = range(max(last_valid_token_index - self.tokens_look_back, 3), last_valid_token_index)
588
+ for examined_prob_index in examined_prob_indices:
589
+
590
+ examined_token_index = examined_prob_index + 1
591
+ examined_token = prefix[examined_token_index]
592
+ examined_candidates = logits[examined_prob_index].topk(self.n_beams).indices
593
+
594
+ if examined_token not in examined_candidates: # Means the last predicted token is out of out topk.
595
+ # pad with irrelevant tokens
596
+ prefix[examined_token_index:] = self.pad_token
597
+ return examined_prob_index, False
598
+
599
+ # If nothing was returned during the check, it means that all tokens were stable.
600
+ # Continue decoding from the last valid token index.
601
+ return last_valid_token_index, True
602
+
603
+ def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, first_frame: bool = False, check_tokens: bool = False) -> Tuple[Tensor, bool]:
604
+ if tokens.shape[0] % self.n_beams != 0:
605
+ raise ValueError(f"{tokens.shape}[0] % {self.n_beams} != 0")
606
+
607
+ scores, sources, finished = {}, {}, {}
608
+ next_tokens, source_indices, finished_sequences = [], [], []
609
+ logprobs = F.log_softmax(logits.float(), dim=-1)
610
+
611
+ for beam in range(self.n_beams):
612
+
613
+ prefix = tokens[beam] # tokens in this beam.
614
+ sampling_index, stable = self._check_last_tokens(prefix, logits[beam], check_tokens) if not first_frame else ((tokens.shape[-1] - 1), False)
615
+ prefix = prefix.tolist() # using list, easier to append to.
616
+ # Calculate candidates from the last token index we should check.
617
+ for logprob, token in zip(*logprobs[beam, sampling_index].topk(self.n_beams + 1)):
618
+ new_logprob = (logprobs[beam, range(3, sampling_index - 1), tokens[beam, 4:sampling_index]].sum() + logprob).item()
619
+
620
+ token_index = sampling_index + 1
621
+ if token_index == len(prefix):
622
+ prefix.append(token.item())
623
+ else:
624
+ prefix[token_index] = token.item()
625
+
626
+ sequence = prefix[:prefix.index(self.pad_token)] if self.pad_token in prefix else prefix
627
+ sequence = tuple(sequence)
628
+ scores[sequence] = new_logprob
629
+ sources[sequence] = beam
630
+
631
+ # After all beams were checked, and tokens were calculated
632
+ # Get top n_beams sequences.
633
+ saved = 0
634
+ for sequence in sorted(scores, key=scores.get, reverse=True):
635
+ if self.wait_for_all and sequence[-1] == self.eot:
636
+ finished[sequence] = scores[sequence]
637
+ else:
638
+ sum_logprobs[len(next_tokens)] = scores[sequence]
639
+ next_tokens.append(Tensor(sequence).long())
640
+ source_indices.append(sources[sequence])
641
+
642
+ saved += 1
643
+ if saved == self.n_beams:
644
+ break
645
+
646
+ tokens = torch.nn.utils.rnn.pad_sequence(next_tokens, batch_first=True, padding_value=self.pad_token).to(tokens.device)
647
+ self.inference.rearrange_kv_cache(source_indices)
648
+
649
+ if not self.wait_for_all: # greedy stop mode.
650
+ completed = any([self.eot in s for s in next_tokens]) # Greedy stop - Believe any beam that says enough.
651
+ return tokens, completed
652
+
653
+ for sequence in sorted(finished, key=finished.get, reverse=True):
654
+ if len(self.finished_sequences) >= self.n_beams: break
655
+ self.finished_sequences[sequence] = finished[sequence]
656
+
657
+ # we have enough trajectories that reached EOT, in this specific frame.
658
+ completed = len(self.finished_sequences) >= self.n_beams
659
+ return tokens, completed
660
+
661
+ def reset(self):
662
+ self.check_token_index = -self.tokens_look_back
663
+
664
+ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
665
+ if not self.wait_for_all:
666
+ preceding_tokens = F.pad(preceding_tokens, (0, 1), value=self.eot)
667
+ return preceding_tokens, sum_logprobs.tolist()
668
+
669
+ sum_logprobs = sum_logprobs.cpu()
670
+ if len(self.finished_sequences) < self.n_beams:
671
+ for j in list(np.argsort(sum_logprobs[0]))[::-1]:
672
+ sequence = preceding_tokens[0][j].tolist() + [self.eot]
673
+ self.finished_sequences[tuple(sequence)] = sum_logprobs[0][j].item()
674
+ if len(self.finished_sequences) >= self.n_beams: break
675
+
676
+ tokens: List[List[Tensor]] = [
677
+ [torch.tensor(seq) for seq in self.finished_sequences.keys()]
678
+ ]
679
+
680
+ sum_logprobs: List[List[float]] = [
681
+ list(self.finished_sequences.values())
682
+ ]
683
+
684
+ return tokens, sum_logprobs
685
+
686
+
687
+ class LogitFilter:
688
+ def apply(self, logits: Tensor, tokens: Tensor) -> None:
689
+ """Apply any filtering or masking to logits in-place
690
+
691
+ Parameters
692
+ ----------
693
+ logits : Tensor, shape = (n_batch, vocab_size)
694
+ per-token logits of the probability distribution at the current step
695
+
696
+ tokens : Tensor, shape = (n_batch, current_sequence_length)
697
+ all tokens in the context so far, including the prefix and sot_sequence tokens
698
+
699
+ """
700
+ raise NotImplementedError
701
+
702
+
703
+ class SuppressBlank(LogitFilter):
704
+ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
705
+ self.tokenizer = tokenizer
706
+ self.sample_begin = sample_begin
707
+
708
+ def apply(self, logits: Tensor, tokens: Tensor):
709
+ if tokens.shape[1] == self.sample_begin:
710
+ logits[..., self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
711
+
712
+
713
+ class SuppressTokens(LogitFilter):
714
+ def __init__(self, suppress_tokens: Sequence[int]):
715
+ self.suppress_tokens = list(suppress_tokens)
716
+
717
+ def apply(self, logits: Tensor, tokens: Tensor):
718
+ logits[..., self.suppress_tokens] = -np.inf
719
+
720
+
721
+ class DecodingTask:
722
+ inference: Inference
723
+ sequence_ranker: SequenceRanker
724
+ decoder: TokenDecoder
725
+ logit_filters: List[LogitFilter]
726
+
727
+ def __init__(self, model: "StreamingWhisper", options: DecodingOptions):
728
+ self.model = model
729
+
730
+ language = options.language or "en"
731
+ tokenizer = get_tokenizer(
732
+ model.is_multilingual,
733
+ num_languages=model.num_languages,
734
+ language=language,
735
+ task=options.task,
736
+ )
737
+ self.tokenizer: Tokenizer = tokenizer
738
+ self.options: DecodingOptions = self._verify_options(options)
739
+
740
+ self.n_group: int = options.beam_size or options.best_of or 1
741
+ self.n_ctx: int = model.dims.n_text_ctx
742
+ self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
743
+
744
+ self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
745
+ if self.options.without_timestamps:
746
+ self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
747
+
748
+ self.initial_tokens: Tuple[int] = self._get_initial_tokens()
749
+ self.sample_begin: int = len(self.initial_tokens)
750
+ self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
751
+
752
+ # inference: implements the forward pass through the decoder, including kv caching
753
+ dump_type = "ATT" if self.options.attentive_advisor else None
754
+ self.inference = PyTorchInference(model, len(self.initial_tokens), self.options.use_kv_cache, dump_type=dump_type, n_tokens_look_back=options.n_tokens_look_back, use_beam=options.beam_size > 0)
755
+
756
+ # sequence ranker: implements how to rank a group of sampled sequences
757
+ self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
758
+
759
+ # decoder: implements how to select the next tokens, given the autoregressive distribution
760
+ if options.stream_decode and options.beam_size in [None, 0]:
761
+ self.decoder = StreamingDecoder(
762
+ options.temperature, options.tokens_per_frame, tokenizer.eot, self.inference, options.n_tokens_look_back, options.streaming_timestamps
763
+ )
764
+ elif options.stream_decode and options.beam_size > 0:
765
+ self.decoder = BeamStreamingDecoder(
766
+ options.temperature, options.tokens_per_frame, tokenizer.eot, self.inference, options.n_tokens_look_back, options.beam_size, tokenizer.sot_lm, options.wait_for_all
767
+ )
768
+ elif options.beam_size is not None:
769
+ self.decoder = BeamSearchDecoder(
770
+ options.beam_size, tokenizer.eot, self.inference, options.patience
771
+ )
772
+ else:
773
+ self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
774
+
775
+ # logit filters: applies various rules to suppress or penalize certain tokens
776
+ self.logit_filters = []
777
+ if self.options.suppress_blank:
778
+ self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
779
+ if self.options.suppress_tokens:
780
+ self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
781
+
782
+ self.mel = None
783
+ self.index = 0
784
+
785
+ # repeat text tensors by the group size, for beam search or best-of-n sampling
786
+ self.tokens: Tensor = torch.tensor([self.initial_tokens]) # no need to use repeat, batch is meaningless in stream.
787
+ self.tokens = self.tokens.repeat_interleave(self.n_group, dim=0).to(model.device)
788
+
789
+ n_batch = self.n_group # streaming, batch_size larger than 1 is not allowed. only beam size is relevant
790
+ self.sum_logprobs: Tensor = torch.zeros(n_batch, device=model.device)
791
+ self.no_speech_probs = [np.nan] * n_batch
792
+
793
+ # Causal encoder inference
794
+ self._init_enc_kv_caching()
795
+
796
+ if options.use_ca_kv_cache:
797
+ self._init_ca_kv_caching()
798
+
799
+ self.audio_features = torch.zeros((1, model.dims.n_audio_ctx, model.dims.n_audio_state)).to(model.device)
800
+ self.frame_counter = 0
801
+
802
+ def _init_enc_kv_caching(self):
803
+ self.enc_kv_cache, self.enc_hooks = self.model.install_encoder_kv_cache_hooks()
804
+
805
+ def _init_ca_kv_caching(self):
806
+ self.ca_kv_cache, self.dec_ca_hooks = self.model.install_cross_attn_kv_cache_hooks()
807
+
808
+ def _cleanup_encoder_caching(self):
809
+ for hook in self.enc_hooks:
810
+ hook.remove()
811
+
812
+ del self.enc_kv_cache
813
+ del self.enc_hooks
814
+ self.enc_kv_cache = {}
815
+ self.enc_hooks = []
816
+
817
+ def _cleanup_ca_caching(self):
818
+ if not self.options.use_ca_kv_cache:
819
+ return
820
+
821
+ for hook in self.dec_ca_hooks:
822
+ hook.remove()
823
+
824
+ del self.ca_kv_cache
825
+ del self.dec_ca_hooks
826
+ self.ca_kv_cache = {}
827
+ self.dec_ca_hooks = []
828
+
829
+ def __del__(self):
830
+ self._cleanup_encoder_caching()
831
+ self._cleanup_ca_caching()
832
+ self.inference.cleanup_caching()
833
+
834
+ def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
835
+ if options.beam_size is not None and options.best_of is not None:
836
+ raise ValueError("beam_size and best_of can't be given together")
837
+ if options.temperature == 0:
838
+ if options.best_of is not None:
839
+ raise ValueError("best_of with greedy sampling (T=0) is not compatible")
840
+ if options.patience is not None and options.beam_size is None:
841
+ raise ValueError("patience requires beam_size to be given")
842
+ if options.length_penalty is not None and not (
843
+ 0 <= options.length_penalty <= 1
844
+ ):
845
+ raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
846
+
847
+ return options
848
+
849
+ def _get_initial_tokens(self) -> Tuple[int]:
850
+ tokens = list(self.sot_sequence)
851
+
852
+ if prefix := self.options.prefix:
853
+ prefix_tokens = (
854
+ self.tokenizer.encode(" " + prefix.strip())
855
+ if isinstance(prefix, str)
856
+ else prefix
857
+ )
858
+ if self.sample_len is not None:
859
+ max_prefix_len = self.n_ctx // 2 - self.sample_len
860
+ prefix_tokens = prefix_tokens[-max_prefix_len:]
861
+ tokens = tokens + prefix_tokens
862
+
863
+ if prompt := self.options.prompt:
864
+ prompt_tokens = (
865
+ self.tokenizer.encode(" " + prompt.strip())
866
+ if isinstance(prompt, str)
867
+ else prompt
868
+ )
869
+ tokens = (
870
+ [self.tokenizer.sot_prev]
871
+ + prompt_tokens[-(self.n_ctx // 2 - 1) :]
872
+ + tokens
873
+ )
874
+
875
+ return tuple(tokens)
876
+
877
+ def _get_suppress_tokens(self) -> Tuple[int]:
878
+ suppress_tokens = self.options.suppress_tokens
879
+
880
+ if isinstance(suppress_tokens, str):
881
+ suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
882
+
883
+ if -1 in suppress_tokens:
884
+ suppress_tokens = [t for t in suppress_tokens if t >= 0]
885
+ suppress_tokens.extend(self.tokenizer.non_speech_tokens)
886
+ elif suppress_tokens is None or len(suppress_tokens) == 0:
887
+ suppress_tokens = [] # interpret empty string as an empty list
888
+ else:
889
+ assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
890
+
891
+ suppress_tokens.extend(
892
+ [
893
+ self.tokenizer.transcribe,
894
+ self.tokenizer.translate,
895
+ self.tokenizer.sot,
896
+ self.tokenizer.sot_prev,
897
+ self.tokenizer.sot_lm,
898
+ ]
899
+ )
900
+ if self.tokenizer.no_speech is not None:
901
+ # no-speech probability is collected separately
902
+ suppress_tokens.append(self.tokenizer.no_speech)
903
+
904
+ return tuple(sorted(set(suppress_tokens)))
905
+
906
+ def _get_audio_features(self, mel: Tensor, index: list = None):
907
+
908
+ if self.options.fp16:
909
+ mel = mel.half()
910
+
911
+ audio_features: Tensor = self.model.encoder(mel, kv_cache=self.enc_kv_cache, mask=None)
912
+
913
+ if audio_features.dtype != (
914
+ torch.float16 if self.options.fp16 else torch.float32
915
+ ):
916
+ return TypeError(
917
+ f"audio_features has an incorrect dtype: {audio_features.dtype}"
918
+ )
919
+
920
+ # Usually will run with this config
921
+ if self.options.use_ca_kv_cache:
922
+ return audio_features
923
+
924
+ # update audio_features
925
+ end_index = (self.mel.shape[-1] // 2) - 1
926
+
927
+ if end_index == (self.model.encoder.gran * (1 + self.model.encoder.extra_gran_blocks)):
928
+ start_index = 0
929
+ else:
930
+ start_index = end_index - self.model.encoder.gran
931
+
932
+ if start_index % self.model.encoder.gran != 0:
933
+ modolu_res = start_index % self.model.encoder.gran
934
+ steps = self.model.encoder.gran - modolu_res
935
+ start_index += steps
936
+ end_index = start_index + self.model.encoder.gran
937
+
938
+ self.audio_features[:, start_index:end_index] = audio_features
939
+
940
+ def _detect_language(self, audio_features: Tensor, tokens: Tensor):
941
+ languages = [self.options.language] * audio_features.shape[0]
942
+ lang_probs = None
943
+
944
+ if self.options.language is None or self.options.task == "lang_id":
945
+ lang_tokens, lang_probs = self.model.detect_language(
946
+ audio_features, self.tokenizer
947
+ )
948
+ languages = [max(probs, key=probs.get) for probs in lang_probs]
949
+ if self.options.language is None:
950
+ tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
951
+
952
+ return languages, lang_probs
953
+
954
+ def _set_ca_kv_cache(self, value: bool):
955
+ self.model.use_ca_cache_hook = value
956
+
957
+ def _main_loop(self, audio_features: Tensor):
958
+ """
959
+ in streaming we use self.tokens, since we need to keep in context the last tokens we got.
960
+ """
961
+ is_first_frame = self.index == (self.options.gran * (1 + self.options.look_ahead_blocks))
962
+ self._set_ca_kv_cache(True)
963
+ beam_indices = None
964
+
965
+ try:
966
+ for i in range(self.sample_len // 8):
967
+
968
+ if self.tokens.shape[-1] > self.n_ctx: break
969
+
970
+ if self.options.stream_decode and self.options.beam_size > 0 and self.options.use_kv_cache: # We are in beam search. kv cache
971
+ last_valid_token_indices = [self.decoder._get_last_valid_token_index(self.tokens[i]) for i in range(self.tokens.shape[0])]
972
+ beam_indices_cols = last_valid_token_indices if i > 0 else sum([list(range(index - self.options.n_tokens_look_back, index + 1)) for index in last_valid_token_indices], [])
973
+ beam_indices_rows = [i // (self.options.n_tokens_look_back + 1) for i in range(self.options.beam_size * (self.options.n_tokens_look_back + 1))] if i==0 else list(range(self.options.beam_size))
974
+ beam_indices_cols_input = [item - beam_indices_cols[(j // (len(beam_indices_cols) // self.options.beam_size)) * (len(beam_indices_cols) // self.options.beam_size)] for j, item in enumerate(beam_indices_cols)]
975
+ beam_indices = [beam_indices_rows, beam_indices_cols, beam_indices_cols_input]
976
+
977
+ logits = self.inference.logits(self.tokens, audio_features, i==0, beam_indices) # run inference through decoder
978
+
979
+ # after the first decoder forward pass, no need to cache
980
+ if i == 0: self._set_ca_kv_cache(False)
981
+
982
+ if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
983
+ probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
984
+ self.no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
985
+
986
+ if isinstance(self.decoder, StreamingDecoder) or isinstance(self.decoder, BeamStreamingDecoder): # We need tokens for context
987
+ logits = logits[:, :]
988
+ else: # we need to consider the logits at the last token only
989
+ logits = logits[:, -1]
990
+
991
+ # apply the logit filters, e.g. for suppressing or applying penalty to
992
+ for logit_filter in self.logit_filters:
993
+ logit_filter.apply(logits, self.tokens)
994
+
995
+ # expand the tokens tensor with the selected next tokens
996
+ if isinstance(self.decoder, BeamStreamingDecoder):
997
+ self.tokens, completed = self.decoder.update(self.tokens, logits, self.sum_logprobs, first_frame=is_first_frame, check_tokens=(i == 0))
998
+ elif isinstance(self.decoder, StreamingDecoder):
999
+ self.tokens, completed = self.decoder.update(self.tokens, logits, self.sum_logprobs, first_frame=is_first_frame, check_tokens=(i == 0), index=(self.index * 20))
1000
+ else:
1001
+ self.tokens, completed = self.decoder.update(self.tokens, logits, self.sum_logprobs)
1002
+
1003
+ if completed: # ctx is unlimited, we are limited only by i
1004
+ break
1005
+ finally:
1006
+ if self.options.use_kv_cache:
1007
+ self.inference.flush_tokens_from_cache()
1008
+
1009
+ if isinstance(self.decoder, StreamingDecoder) or isinstance(self.decoder, BeamStreamingDecoder):
1010
+
1011
+ if is_first_frame and self.options.streaming_timestamps and self.options.force_first_tokens_timestamps:
1012
+ self.decoder._insert_timestamps(audio_features, self.tokens, self.options.gran)
1013
+
1014
+ if self.tokens.shape[1] > logits.shape[1]:
1015
+ self.tokens = self.tokens[:, :logits.shape[1]]
1016
+
1017
+ self.decoder.reset()
1018
+
1019
+ return self.sum_logprobs, self.no_speech_probs
1020
+
1021
+ def _pad_audio_features(self, audio_features: Tensor):
1022
+ to_pad = 1500 - audio_features.shape[1]
1023
+ padding = torch.zeros((1, to_pad, audio_features.shape[2])).to(audio_features.device)
1024
+ audio_features = torch.cat([audio_features, padding], dim=1).to(audio_features.device)
1025
+ return audio_features
1026
+
1027
+ def _empty_results(self):
1028
+ return DecodingResult(
1029
+ audio_features=None,
1030
+ language="en",
1031
+ tokens=[],
1032
+ text="",
1033
+ avg_logprob=None,
1034
+ no_speech_prob=None,
1035
+ temperature=self.options.temperature,
1036
+ compression_ratio=0,
1037
+ )
1038
+
1039
+ def _caching_inner_reset(self):
1040
+ # Encoder kv cache clean
1041
+ self._cleanup_encoder_caching()
1042
+ self._init_enc_kv_caching()
1043
+
1044
+ # Decoder CA kv-cache
1045
+ if self.options.use_ca_kv_cache:
1046
+ self._cleanup_ca_caching()
1047
+ self._init_ca_kv_caching()
1048
+
1049
+ # Decoder SA kv-cache
1050
+ if self.options.use_kv_cache:
1051
+ self.inference.cleanup_caching()
1052
+ # Caching will be triggered on the logits function
1053
+
1054
+ def _reset_after_maximal_context(self, new_mel_frame: Tensor):
1055
+ print("Reset context...")
1056
+ num_old_mels = (self.options.gran * (self.options.look_ahead_blocks) * 2) + 2 if self.options.look_ahead_blocks > 0 else (self.options.gran * 2) + 2
1057
+ self.mel = torch.cat([self.mel[..., -num_old_mels:], new_mel_frame], dim=-1)
1058
+ self.options.prefix = self.tokens[:, len(self.sot_sequence):].tolist()[0][-self.options.n_tokens_look_back-1:]
1059
+ print(f"Modifying tokens! {self.options.prefix=}")
1060
+ self.initial_tokens = self._get_initial_tokens()
1061
+ print("Modified tokens!")
1062
+ self.tokens = torch.tensor([list(self.initial_tokens)]) # no need to use repeat, batch is meaningless in stream.
1063
+ self.tokens = self.tokens.repeat_interleave(self.n_group, dim=0).to(self.model.device)
1064
+ self._caching_inner_reset()
1065
+ self.audio_features = torch.zeros((1, self.model.dims.n_audio_ctx, self.model.dims.n_audio_state)).to(self.model.device)
1066
+ print("Finished reset...")
1067
+
1068
+ @torch.no_grad()
1069
+ def run(self, mel_frame: Tensor) -> List[DecodingResult]:
1070
+ """
1071
+ mel_frame - a tensor containing the last frame we got from the stream.
1072
+
1073
+ The function needs to:
1074
+ 1. Take the mel frame.
1075
+ 2. Extract features from the given frame using kv caching in the encoder.
1076
+ 3. Note that the caching in the encoder will do it automatically.
1077
+ """
1078
+ # concat mel
1079
+ if self.options.single_frame_mel: # each time we get a single frame of mel
1080
+ self.mel = torch.cat([self.mel, mel_frame], dim=-1) if self.mel is not None else mel_frame
1081
+ else: # we are getting a whole frame [0, t_curr]
1082
+ self.mel = mel_frame
1083
+ tokenizer: Tokenizer = self.tokenizer
1084
+ n_audio: int = mel_frame.shape[0]
1085
+
1086
+ self.index += self.options.gran # on each decoding call we add more context
1087
+ self.frame_counter += 1
1088
+ print(f"Collected {self.frame_counter} frames...")
1089
+
1090
+ if self.mel.shape[-1] >= self.options.maximal_seconds_context * 100:
1091
+ self._reset_after_maximal_context(mel_frame)
1092
+
1093
+ if self.mel.shape[-1] < (self.options.gran * 2 * (self.options.look_ahead_blocks + 1)):
1094
+ print("Decoding Task: skipping first frames...")
1095
+ return self._empty_results()
1096
+
1097
+ # call the main sampling loop
1098
+ if not self.options.use_ca_kv_cache:
1099
+ self._get_audio_features(self.mel) # encoder forward pass, updates self.audio_features
1100
+ audio_features = self.audio_features
1101
+ sum_logprobs, no_speech_probs = self._main_loop(audio_features[:, :self.index])
1102
+ else:
1103
+ audio_features = self._get_audio_features(self.mel)
1104
+ sum_logprobs, no_speech_probs = self._main_loop(audio_features)
1105
+
1106
+
1107
+ # reshape the tensors to have (n_audio, n_group) as the first two dimensions
1108
+ audio_features = audio_features[:: self.n_group]
1109
+ no_speech_probs = no_speech_probs[:: self.n_group]
1110
+
1111
+ self.tokens = self.tokens.reshape(n_audio, self.n_group, -1)
1112
+ sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
1113
+
1114
+ # get the final candidates for each group, and slice between the first sampled token and EOT
1115
+ tokens, sum_logprobs = self.decoder.finalize(self.tokens, sum_logprobs)
1116
+ tokens: List[List[Tensor]] = [[t[self.sample_begin: ((t == tokenizer.eot) | (t == tokenizer.sot_lm)).nonzero()[0, 0]] for t in s] for s in tokens]
1117
+
1118
+
1119
+ # if any of the suggested beams is empty (no token), it means that the predicted token was EOT. Add it, so ML Ranker won't fail.
1120
+ for i in range(len(tokens[0])):
1121
+ if tokens[0][i].shape[0] > 0: continue
1122
+ else: tokens[0][i] = torch.tensor([self.tokenizer.eot]).to(tokens[0][i].device)
1123
+
1124
+ # select the top-ranked sample in each group
1125
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
1126
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
1127
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
1128
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
1129
+ avg_logprobs: List[float] = [
1130
+ lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
1131
+ ]
1132
+
1133
+ fields = (texts, tokens, audio_features, avg_logprobs, no_speech_probs)
1134
+
1135
+ self.tokens = self.tokens.squeeze(0)
1136
+
1137
+ if len(set(map(len, fields))) != 1:
1138
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
1139
+
1140
+ # apply timestamps
1141
+ if self.options.beam_size == 0 or self.options.beam_size is None:
1142
+ timed_tokens = tokens.copy()[0]
1143
+ for i, index in enumerate(sorted(self.decoder.timestamps_map.keys())):
1144
+ timed_tokens.insert(index + i + 1, self.tokenizer.timestamp_begin + (self.decoder.timestamps_map[index] // 20))
1145
+ else:
1146
+ timed_tokens = [50257]
1147
+
1148
+ return DecodingResult(
1149
+ audio_features=audio_features,
1150
+ language="en",
1151
+ tokens=tokens,
1152
+ text=texts[0],
1153
+ avg_logprob=avg_logprobs,
1154
+ no_speech_prob=no_speech_probs,
1155
+ temperature=self.options.temperature,
1156
+ compression_ratio=compression_ratio(texts[0]),
1157
+ timestamps=self.decoder.timestamps_map,
1158
+ timed_tokens=timed_tokens,
1159
+ timed_text=self.tokenizer.decode_with_timestamps(timed_tokens)
1160
+ )
1161
+
1162
+
1163
+ @torch.no_grad()
1164
+ def decode(
1165
+ model: "Whisper",
1166
+ mel: Tensor,
1167
+ task: DecodingTask,
1168
+ options: DecodingOptions = DecodingOptions(),
1169
+ **kwargs,
1170
+ ) -> Union[DecodingResult, List[DecodingResult]]:
1171
+ """
1172
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
1173
+
1174
+ Parameters
1175
+ ----------
1176
+ model: Whisper
1177
+ the Whisper model instance
1178
+
1179
+ mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
1180
+ A tensor containing the Mel spectrogram(s)
1181
+
1182
+ options: DecodingOptions
1183
+ A dataclass that contains all necessary options for decoding 30-second segments
1184
+
1185
+ Returns
1186
+ -------
1187
+ result: Union[DecodingResult, List[DecodingResult]]
1188
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
1189
+ """
1190
+ if mel.ndim == 2:
1191
+ mel = mel.unsqueeze(0)
1192
+
1193
+ if kwargs:
1194
+ options = replace(options, **kwargs)
1195
+
1196
+ result = DecodingTask(model, options).run(mel)
1197
+ return result
1198
+
whisper_stream/streaming_model.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from torch import Tensor
7
+ from typing import Optional
8
+ from .model import AudioEncoder, TextDecoder, Whisper
9
+
10
+ from .streaming_decoding import DecodingTask, DecodingOptions, DecodingResult
11
+ from .streaming_transcribe import transcribe as transcribe_function
12
+ from .decoding import decode as non_causal_decode_function
13
+ from .audio import SpectrogramStream
14
+
15
+ from dataclasses import replace
16
+
17
+ from pytorch_lightning import LightningModule
18
+
19
+
20
+ class LoraLayer(LightningModule):
21
+ def __init__(self, input_dim, output_dim, rank=8, alpha=None, *args, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+
24
+ self.lora_A = nn.Parameter(torch.zeros(input_dim, rank))
25
+ self.lora_B = nn.Parameter(torch.zeros(rank, output_dim))
26
+
27
+ self.alpha = rank if alpha is None else alpha
28
+ self.rank = rank
29
+ self.scale = self.alpha / self.rank
30
+
31
+ self._init_weights()
32
+
33
+ def _init_weights(self):
34
+ nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
35
+ nn.init.zeros_(self.lora_B)
36
+
37
+ def forward(self, x):
38
+ return x @ (self.lora_A @ self.lora_B) * self.scale
39
+
40
+
41
+ class LoraLinearLayer(LightningModule):
42
+ def __init__(self, base_layer: nn.Linear, rank: int = 8, bias: int = True, *args, **kwargs):
43
+ super().__init__(*args, **kwargs)
44
+
45
+ self.base_layer = base_layer
46
+
47
+ self.lora_layer = LoraLayer(base_layer.in_features, base_layer.out_features, rank=rank)
48
+ self.aggregate_lora = True
49
+
50
+ def turn_on_lora(self):
51
+ self.aggregate_lora = True
52
+
53
+ def turn_off_lora(self):
54
+ self.aggregate_lora = False
55
+
56
+ def forward(self, x: Tensor):
57
+ if not self.aggregate_lora:
58
+ return self.base_layer(x)
59
+
60
+ return self.base_layer(x) + self.lora_layer(x)
61
+
62
+
63
+ class LoRAMultiHeadAttention(LightningModule):
64
+ def __init__(self, n_head, query, key, value, out, rank, *args, **kwargs):
65
+ super().__init__(*args, **kwargs)
66
+
67
+ self.n_head = n_head
68
+ self.query = LoraLinearLayer(query, rank)
69
+ self.key = LoraLinearLayer(key, rank, bias=False)
70
+ self.value = LoraLinearLayer(value, rank)
71
+ self.out = LoraLinearLayer(out, rank)
72
+
73
+ def forward(
74
+ self,
75
+ x: Tensor,
76
+ xa: Tensor = None,
77
+ mask: Tensor = None,
78
+ kv_cache: dict = None,
79
+ *args, **kwargs
80
+ ):
81
+ q = self.query(x)
82
+
83
+ if kv_cache is None or xa is None or self.key not in kv_cache:
84
+ k = self.key(x if xa is None else xa)
85
+ v = self.value(x if xa is None else xa)
86
+
87
+ else:
88
+ k = kv_cache[self.key]
89
+ v = kv_cache[self.value]
90
+
91
+ wv, qk = self.qkv_attention(q, k, v, mask, kv_cache)
92
+
93
+ return self.out(wv), qk
94
+
95
+ def qkv_attention(
96
+ self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None, kv_cache: dict = None
97
+ ):
98
+ n_batch, n_ctx, n_state = q.shape
99
+ _, k_ctx, _ = k.shape
100
+ scale = (n_state // self.n_head) ** -0.25
101
+ q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
102
+ k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
103
+ v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
104
+ qk = q @ k
105
+
106
+ # apply causal mask
107
+ if mask is not None:
108
+
109
+ # kv_cache for beam search decoding case
110
+ if kv_cache is not None and "beam_indices" in kv_cache.keys():
111
+ for i in range(n_batch):
112
+ qk[i] = qk[i] + mask[kv_cache["beam_indices"][1][i * n_ctx]:kv_cache["beam_indices"][1][i * n_ctx] + n_ctx, :k_ctx]
113
+
114
+ # For training, encoder/decoder causal masks
115
+ elif k_ctx == n_ctx:
116
+ qk = qk + mask[:n_ctx, :n_ctx]
117
+
118
+ # kv_cache in the greedy decoding case
119
+ else:
120
+ qk = qk + mask[k_ctx - n_ctx:k_ctx, :k_ctx]
121
+
122
+ qk = qk.float()
123
+
124
+ w = F.softmax(qk, dim=-1).to(q.dtype)
125
+
126
+ return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
127
+
128
+
129
+ class StreamingAudioEncoder(AudioEncoder):
130
+
131
+ def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer, cache_gran, gran, rank, extra_gran_blocks):
132
+ super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
133
+
134
+ self.gran = gran
135
+ self.extra_gran_blocks = extra_gran_blocks
136
+
137
+ for block in self.blocks:
138
+ block.attn = LoRAMultiHeadAttention(self.n_head,
139
+ block.attn.query,
140
+ block.attn.key,
141
+ block.attn.value,
142
+ block.attn.out,
143
+ rank)
144
+
145
+ self.use_stream = False
146
+
147
+ # mask for training
148
+ matrix_size = n_ctx
149
+ block_size = gran
150
+ extra_blocks = extra_gran_blocks
151
+ mask = torch.full((matrix_size, matrix_size), float('-inf'))
152
+
153
+ for i in range(0, matrix_size, block_size):
154
+ if (i // block_size) <= extra_blocks:
155
+ zero_cols = (block_size * (extra_blocks + 1))
156
+ else:
157
+ zero_cols = (block_size * (extra_blocks + 1)) + ((i // block_size) - extra_blocks) * block_size
158
+
159
+ mask[i:i + block_size, :zero_cols] = 1
160
+
161
+ self.register_buffer("mask", mask, persistent=False)
162
+
163
+ def _use_stream(self, use_stream: bool):
164
+ self.use_stream = use_stream
165
+
166
+ def forward(self, x: Tensor, index: list = [0, 1500], kv_cache = None, mask = True):
167
+ """
168
+ simulate streaming forward using qk cache self attn.
169
+ """
170
+ x = F.gelu(self.conv1(x))
171
+ x = F.gelu(self.conv2(x))
172
+ x = x.permute(0, 2, 1)
173
+
174
+ if self.use_stream:
175
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
176
+ x = x[:, offset:offset + self.gran + (int(offset == 0) * (self.extra_gran_blocks * self.gran))] # offset
177
+ x = (x + self.positional_embedding[offset:offset + self.gran + (int(offset == 0) * (self.extra_gran_blocks * self.gran))]).to(x.dtype)
178
+ else: # use during training
179
+ x = x[:, index[0]:index[1]] # offset
180
+ x = (x + self.positional_embedding[index[0]:index[1]]).to(x.dtype)
181
+
182
+ for block in self.blocks:
183
+ chosen_mask = mask[..., :index[1], :index[1]] if isinstance(mask, Tensor) else self.mask if (mask is not None) and (self.use_stream) else None
184
+ x = block(x, mask=chosen_mask, kv_cache=kv_cache)
185
+
186
+ x = self.ln_post(x)
187
+
188
+ return x
189
+
190
+
191
+ class StreamingTextDecoder(TextDecoder):
192
+ def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer, rank):
193
+ super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
194
+
195
+ self.n_ctx = n_ctx
196
+ self.n_state = n_state
197
+
198
+ for block in self.blocks:
199
+ block.attn = LoRAMultiHeadAttention(n_head,
200
+ block.attn.query,
201
+ block.attn.key,
202
+ block.attn.value,
203
+ block.attn.out,
204
+ rank)
205
+
206
+ block.cross_attn = LoRAMultiHeadAttention(n_head,
207
+ block.cross_attn.query,
208
+ block.cross_attn.key,
209
+ block.cross_attn.value,
210
+ block.cross_attn.out,
211
+ rank)
212
+
213
+ def forward(self, x: Tensor, xa: Tensor, kv_cache: dict = None, dump_type: str = None, **kwargs):
214
+ """
215
+ x : torch.LongTensor, shape = (batch_size, <= n_ctx)
216
+ the text tokens
217
+ xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
218
+ the encoded audio features to be attended on
219
+ """
220
+ if kv_cache is not None and "beam_indices" in kv_cache.keys():
221
+ x = self.token_embedding(x) + self.positional_embedding.unsqueeze(0).expand(x.shape[0], self.positional_embedding.shape[0], self.positional_embedding.shape[1])[kv_cache["beam_indices"][0], kv_cache["beam_indices"][1]].view(x.shape[0], -1, self.n_state)
222
+ else:
223
+ offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
224
+ x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
225
+
226
+ x = x.to(xa.dtype)
227
+
228
+ for block in self.blocks:
229
+ x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
230
+
231
+ x = self.ln(x)
232
+ logits = (
233
+ x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
234
+ ).float()
235
+
236
+ return logits
237
+
238
+
239
+ class StreamingWhisper(Whisper):
240
+ def __init__(self, dims, cache_gran: bool = True, gran: int = 16, rank: int = 0, extra_gran_blocks: int = 0):
241
+ super().__init__(dims)
242
+
243
+ self.cache_gran = cache_gran
244
+ self.gran = gran
245
+ self.rank = rank
246
+ self.extra_gran_blocks = extra_gran_blocks
247
+
248
+ print(f"Running a streaming whisper model, using chunk size: {gran * 20}[msec] and {extra_gran_blocks} extra chunks for initialization.")
249
+
250
+ # The only difference is a streaming encoder
251
+ self.encoder = StreamingAudioEncoder(
252
+ self.dims.n_mels,
253
+ self.dims.n_audio_ctx,
254
+ self.dims.n_audio_state,
255
+ self.dims.n_audio_head,
256
+ self.dims.n_audio_layer,
257
+ cache_gran=cache_gran,
258
+ gran=gran,
259
+ rank=rank,
260
+ extra_gran_blocks=extra_gran_blocks
261
+ )
262
+
263
+ self.decoder = StreamingTextDecoder(
264
+ self.dims.n_vocab,
265
+ self.dims.n_text_ctx,
266
+ self.dims.n_text_state,
267
+ self.dims.n_text_head,
268
+ self.dims.n_text_layer,
269
+ rank=rank
270
+ )
271
+
272
+ # Advisor params - Dropped.
273
+ self.advisor_type = None
274
+ self.n_advisor_class = 0
275
+ self.decoder_advisor = None
276
+
277
+ self.decoding_task = None
278
+ self.spec_streamer = SpectrogramStream()
279
+
280
+ self.use_ca_cache_hook = True # relevant only when ca_kv_cache is installed
281
+
282
+ def reset(self, use_stream: bool = True, clean_task: bool = True):
283
+ self.encoder._use_stream(use_stream)
284
+ del self.decoding_task # trigger clean encoder kv caching
285
+ self.decoding_task = None
286
+ self.spec_streamer.reset()
287
+
288
+ @torch.no_grad()
289
+ def decode(self, mel: Tensor, options: DecodingOptions = DecodingOptions(), use_frames: bool = False, **kwargs) -> DecodingResult:
290
+ if kwargs: options = replace(options, **kwargs)
291
+
292
+ if use_frames: # mel is frames of audio, need to calc mel
293
+ mel = self.spec_streamer.calc_mel_with_new_frame(mel).squeeze(0)
294
+
295
+ if self.encoder.gran != options.gran:
296
+ print(f"Encoder gran & options gran differ. forcing options to be on encoder's gran: {self.encoder.gran}")
297
+ options.gran = self.encoder.gran
298
+
299
+ if not self.decoding_task:
300
+ self.decoding_task = DecodingTask(self, options)
301
+
302
+ return self.decoding_task.run(mel.unsqueeze(0))
303
+
304
+ def _turn_off_lora(self):
305
+ for _, layer in self.encoder.named_modules():
306
+ if isinstance(layer, LoraLinearLayer):
307
+ layer.turn_off_lora()
308
+
309
+ def _turn_on_lora(self):
310
+ for _, layer in self.encoder.named_modules():
311
+ if isinstance(layer, LoraLinearLayer):
312
+ layer.turn_on_lora()
313
+
314
+ def _cancel_streaming_mode(self):
315
+ self._turn_off_lora()
316
+ self.encoder._use_stream(False)
317
+
318
+ def _revert_streaming_mode(self):
319
+ self._turn_on_lora()
320
+ self.encoder._use_stream(True)
321
+
322
+ @torch.no_grad()
323
+ def non_causal_decode(self, mel: Tensor, options: DecodingOptions = DecodingOptions(), **kwargs) -> DecodingResult:
324
+ self._cancel_streaming_mode()
325
+ results = non_causal_decode_function(self, mel, options, **kwargs)
326
+ self._revert_streaming_mode()
327
+ return results
328
+
329
+ def remove_encoder_kv_cache_hooks(self):
330
+ for hook in self.encoder._forward_hooks.values():
331
+ hook.remove()
332
+
333
+ def install_encoder_kv_cache_hooks(self, cache = None):
334
+ cache = {**cache} if cache is not None else {}
335
+ hooks = []
336
+
337
+ def save_to_cache(module, _, output):
338
+ if module not in cache or output.shape[1] > self.dims.n_audio_ctx:
339
+ # save as-is, for the first token or cross attention
340
+ cache[module] = output
341
+ else:
342
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
343
+ return cache[module]
344
+
345
+ def install_hooks(layer: nn.Module):
346
+ if isinstance(layer, LoRAMultiHeadAttention):
347
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
348
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
349
+
350
+ self.encoder.apply(install_hooks)
351
+ return cache, hooks
352
+
353
+ def install_decoder_kv_cache_hooks(self, cache = None):
354
+ cache = {**cache} if cache is not None else {}
355
+ hooks = []
356
+
357
+ def save_to_cache(module, _, output):
358
+ if module not in cache or output.shape[1] > self.dims.n_text_ctx:
359
+ cache[module] = output
360
+ else:
361
+ if "beam_indices" not in cache.keys() or all([index == (cache[module].shape[1]) for index in cache["beam_indices"][1]]):
362
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
363
+ else:
364
+ for _, (beam, index, output_index) in enumerate(zip(*cache["beam_indices"])):
365
+ if index < cache[module].shape[1]:
366
+ cache[module][beam, index] = output[beam, output_index]
367
+ else:
368
+ cache[module] = torch.cat([cache[module], output[:, (output_index):]], dim=1).detach()
369
+
370
+ return cache[module]
371
+
372
+ for name, module in self.decoder.named_modules():
373
+ if isinstance(module, LoRAMultiHeadAttention) and "attn" in name and "cross" not in name:
374
+ hooks.append(module.key.register_forward_hook(save_to_cache))
375
+ hooks.append(module.value.register_forward_hook(save_to_cache))
376
+
377
+ return cache, hooks
378
+
379
+ def install_cross_attn_kv_cache_hooks(self, cache=None):
380
+ cache = {**cache} if cache is not None else {}
381
+ hooks = []
382
+
383
+ def save_to_cache(module, _, output):
384
+ if not self.use_ca_cache_hook:
385
+ return cache[module]
386
+
387
+ if module not in cache or output.shape[1] > self.dims.n_audio_ctx:
388
+ # save as-is, for the first token or cross attention
389
+ cache[module] = output
390
+ else:
391
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
392
+
393
+ return cache[module]
394
+
395
+ def check_if_calculation_is_needed(module, _):
396
+ if not self.use_ca_cache_hook:
397
+ return cache[module]
398
+
399
+ for name, module in self.decoder.named_modules():
400
+ if isinstance(module, LoRAMultiHeadAttention) and "cross_attn" in name:
401
+ hooks.append(module.key.register_forward_hook(save_to_cache))
402
+ hooks.append(module.key.register_forward_pre_hook(check_if_calculation_is_needed))
403
+ hooks.append(module.value.register_forward_hook(save_to_cache))
404
+ hooks.append(module.value.register_forward_pre_hook(check_if_calculation_is_needed))
405
+
406
+ return cache, hooks
407
+
408
+ # For non-causal decoding compatibility
409
+ def install_kv_cache_hooks(self, cache: Optional[dict] = None):
410
+ """
411
+ The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
412
+ tensors calculated for the previous positions. This method returns a dictionary that stores
413
+ all caches, and the necessary hooks for the key and value projection modules that save the
414
+ intermediate tensors to be reused during later calculations.
415
+
416
+ Returns
417
+ -------
418
+ cache : Dict[nn.Module, torch.Tensor]
419
+ A dictionary object mapping the key/value projection modules to its cache
420
+ hooks : List[RemovableHandle]
421
+ List of PyTorch RemovableHandle objects to stop the hooks to be called
422
+ """
423
+ cache = {**cache} if cache is not None else {}
424
+ hooks = []
425
+
426
+ def save_to_cache(module, _, output):
427
+ if module not in cache or output.shape[1] > self.dims.n_text_ctx:
428
+ # save as-is, for the first token or cross attention
429
+ cache[module] = output
430
+ else:
431
+ cache[module] = torch.cat([cache[module], output], dim=1).detach()
432
+ return cache[module]
433
+
434
+ def install_hooks(layer: nn.Module):
435
+ if isinstance(layer, LoRAMultiHeadAttention):
436
+ hooks.append(layer.key.register_forward_hook(save_to_cache))
437
+ hooks.append(layer.value.register_forward_hook(save_to_cache))
438
+
439
+ self.decoder.apply(install_hooks)
440
+ return cache, hooks
441
+
442
+ # refers to function from streaming_decoding, streaming_transcribe library
443
+ transcribe = transcribe_function
444
+
whisper_stream/streaming_transcribe.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import time
4
+ import traceback
5
+ import warnings
6
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ import librosa
12
+
13
+ from .audio import (
14
+ FRAMES_PER_SECOND,
15
+ HOP_LENGTH,
16
+ N_FRAMES,
17
+ N_SAMPLES,
18
+ SAMPLE_RATE,
19
+ log_mel_spectrogram,
20
+ pad_or_trim,
21
+ load_audio,
22
+ SpectrogramStream,
23
+ MyStream
24
+ )
25
+ from .streaming_decoding import DecodingOptions, DecodingResult, DecodingTask
26
+ from .timing import add_word_timestamps
27
+ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
28
+ from .utils import (
29
+ exact_div,
30
+ format_timestamp,
31
+ get_end,
32
+ get_writer,
33
+ make_safe,
34
+ optional_float,
35
+ optional_int,
36
+ str2bool,
37
+ )
38
+
39
+ if TYPE_CHECKING:
40
+ from .streaming_model import StreamingWhisper
41
+
42
+
43
+ def transcribe(
44
+ model: "StreamingWhisper" = None,
45
+ filename: str = None,
46
+ channels: int = 2,
47
+ simulate_stream: bool = False,
48
+ wav_file: str = None,
49
+ single_frame_mel: bool = True,
50
+ advised: bool = False,
51
+ attentive_advisor: bool = False,
52
+ get_mel: bool = False,
53
+ temperature: float = 0,
54
+ beam_size: int = 5,
55
+ stream_decode: bool = True,
56
+ use_ca_kv_cache: bool = False,
57
+ use_sa_kv_cache: bool = False,
58
+ language: str = "en",
59
+ wait_for_all: bool = False,
60
+ use_latency: bool = False,
61
+ get_times: bool = False,
62
+ pad_trim: bool = True,
63
+ max_sec_context: int = 30,
64
+ streaming_timestamps: bool = False,
65
+ force_first_tokens_timestamps: bool = False,
66
+ use_remote_stream: bool = False,
67
+ ) -> List[str]:
68
+ """
69
+ Open a stream and transcribe it using streaming whisper model
70
+
71
+ Parameters
72
+ ----------
73
+ model: Whisper
74
+ The Whisper model instance
75
+
76
+ Returns -
77
+ -------
78
+ A dict with a text, tokens field with all of the text that was transcribed till the stream stopped.
79
+ """
80
+ model.reset(use_stream=True) # we first reset the model before starting a stream, cleaning any cache.
81
+ model.eval()
82
+
83
+ # Instantiate streaming instance and open a stream
84
+ ms_gran = model.encoder.gran * 20
85
+ stream_instance = MyStream(ms_gran,
86
+ channels=channels,
87
+ filename=filename,
88
+ simulate_stream=simulate_stream,
89
+ wav_file=wav_file,
90
+ use_latency=use_latency,
91
+ pad_trim=pad_trim,
92
+ use_remote_machine=use_remote_stream)
93
+
94
+ stream_instance.open_stream()
95
+
96
+ # frames - used only when filename is given, in order to save a long wav at the end of the conversation.
97
+ frames = []
98
+
99
+ # first we'll use
100
+ decoding_options = DecodingOptions(
101
+ language=language,
102
+ gran=(ms_gran // 20),
103
+ single_frame_mel=single_frame_mel,
104
+ without_timestamps=True,
105
+ beam_size=beam_size if temperature == 0 else None,
106
+ temperature=temperature,
107
+ length_penalty=None,
108
+ advised=advised,
109
+ attentive_advisor=attentive_advisor,
110
+ look_ahead_blocks=model.encoder.extra_gran_blocks,
111
+ patience=None,
112
+ stream_decode=stream_decode,
113
+ use_kv_cache=use_sa_kv_cache,
114
+ use_ca_kv_cache=use_ca_kv_cache,
115
+ wait_for_all=wait_for_all,
116
+ streaming_timestamps=streaming_timestamps,
117
+ force_first_tokens_timestamps=force_first_tokens_timestamps
118
+ )
119
+
120
+ streamed_spectrogram = SpectrogramStream(n_mels=model.dims.n_mels) # default values are whisper default values
121
+
122
+ texts = []
123
+ times = []
124
+ mel = None
125
+ reset_len = (max_sec_context * SAMPLE_RATE) + 360 # 360 is for the mel padding
126
+ try:
127
+ for frame in stream_instance.read():
128
+ # save frames for optional save
129
+ frames.extend(frame)
130
+
131
+ if len(frames) > reset_len: # When we surpass the max_sec_context - reset model (positional embeddings constrain us)
132
+ frame = np.concatenate((frames[-360:], frame))
133
+ frames = []
134
+ frames.extend(frame.tolist())
135
+ model.reset(use_stream=True)
136
+ streamed_spectrogram.reset()
137
+
138
+ if get_times:
139
+ torch.cuda.synchronize()
140
+ start = time.time()
141
+
142
+ frame_tensor = torch.from_numpy(frame).pin_memory()
143
+ mel_frame = streamed_spectrogram.calc_mel_with_new_frame(frame_tensor.to(model.device, non_blocking=True))
144
+
145
+ # decode given the new mel frame and print results
146
+ result = model.decode(mel_frame.squeeze(0), decoding_options)
147
+
148
+ if get_times:
149
+ torch.cuda.synchronize()
150
+ end = time.time()
151
+
152
+ print(result.text)
153
+
154
+ # append metrics
155
+ if get_times:
156
+ times.append(end - start)
157
+ texts.append(result)
158
+
159
+ except KeyboardInterrupt:
160
+ stream_instance.close_stream(frames)
161
+
162
+ print("Finished capturing audio.")
163
+
164
+ if get_mel: return texts, mel
165
+ if get_times: return texts, times
166
+
167
+ return texts, []
168
+
169
+
170
+ def cli():
171
+ parser = argparse.ArgumentParser(description="Transcribe streaming audio with customizable options")
172
+
173
+ # Model choices
174
+ parser.add_argument("--model", type=str, default="small", help="Model size to transcribe with")
175
+ parser.add_argument("--chunk_size", type=int, default=300, help="Chunk size for streaming")
176
+ parser.add_argument("--multilingual", action="store_true", help="Use a multilingual checkpoint if exists.")
177
+
178
+ # Required/optional file inputs
179
+ parser.add_argument("--filename", type=str, help="Path to the input audio file")
180
+ parser.add_argument("--wav_file", type=str, help="Optional WAV file path to stream")
181
+
182
+ # Audio configuration
183
+ parser.add_argument("--channels", type=int, default=2, help="Number of audio channels")
184
+
185
+ # Streaming behavior
186
+ parser.add_argument("--simulate_stream", action="store_true", help="Simulate a stream from a file")
187
+ parser.add_argument("--single_frame_mel", action="store_true", default=True, help="Use single frame MELs")
188
+ parser.add_argument("--stream_decode", action="store_true", default=True, help="Use streaming decode")
189
+ parser.add_argument("--use_ca_kv_cache", action="store_true", help="Use cross-attention key-value cache")
190
+ parser.add_argument("--use_sa_kv_cache", action="store_true", help="Use self-attention key-value cache")
191
+ parser.add_argument("--wait_for_all", action="store_true", help="Wait for all results before outputting")
192
+ parser.add_argument("--use_latency", action="store_true", help="Track latency for metrics")
193
+ parser.add_argument("--pad_trim", action="store_true", default=True, help="Enable padding and trimming")
194
+ parser.add_argument("--streaming_timestamps", action="store_true", help="Use timestamps in streaming")
195
+ parser.add_argument("--force_first_tokens_timestamps", action="store_true", help="Force timestamps on first tokens")
196
+ parser.add_argument("--use_remote_stream", action="store_true", help="Using remote stream")
197
+
198
+ # Model behavior
199
+ parser.add_argument("--advised", action="store_true", help="Enable advised decoding")
200
+ parser.add_argument("--attentive_advisor", action="store_true", help="Use attentive advisor logic")
201
+ parser.add_argument("--get_mel", action="store_true", help="Return MEL spectrogram")
202
+ parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
203
+ parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search decoding")
204
+ parser.add_argument("--language", type=str, default="en", help="Language of transcription")
205
+ parser.add_argument("--get_times", action="store_true", help="Return word-level timing information")
206
+ parser.add_argument("--max_sec_context", type=int, default=30, help="Max context window size in seconds")
207
+
208
+ return parser.parse_args()
whisper_stream/timing.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import subprocess
3
+ import warnings
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING, List
6
+
7
+ import numba
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
13
+ from .tokenizer import Tokenizer
14
+
15
+ if TYPE_CHECKING:
16
+ from .model import Whisper
17
+
18
+
19
+ def median_filter(x: torch.Tensor, filter_width: int):
20
+ """Apply a median filter of width `filter_width` along the last dimension of `x`"""
21
+ pad_width = filter_width // 2
22
+ if x.shape[-1] <= pad_width:
23
+ # F.pad requires the padding width to be smaller than the input dimension
24
+ return x
25
+
26
+ if (ndim := x.ndim) <= 2:
27
+ # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
28
+ x = x[None, None, :]
29
+
30
+ assert (
31
+ filter_width > 0 and filter_width % 2 == 1
32
+ ), "`filter_width` should be an odd number"
33
+
34
+ result = None
35
+ x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
36
+ if x.is_cuda:
37
+ try:
38
+ from .triton_ops import median_filter_cuda
39
+
40
+ result = median_filter_cuda(x, filter_width)
41
+ except (RuntimeError, subprocess.CalledProcessError):
42
+ warnings.warn(
43
+ "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
44
+ "falling back to a slower median kernel implementation..."
45
+ )
46
+
47
+ if result is None:
48
+ # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
49
+ result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
50
+
51
+ if ndim <= 2:
52
+ result = result[0, 0]
53
+
54
+ return result
55
+
56
+
57
+ @numba.jit(nopython=True)
58
+ def backtrace(trace: np.ndarray):
59
+ i = trace.shape[0] - 1
60
+ j = trace.shape[1] - 1
61
+ trace[0, :] = 2
62
+ trace[:, 0] = 1
63
+
64
+ result = []
65
+ while i > 0 or j > 0:
66
+ result.append((i - 1, j - 1))
67
+
68
+ if trace[i, j] == 0:
69
+ i -= 1
70
+ j -= 1
71
+ elif trace[i, j] == 1:
72
+ i -= 1
73
+ elif trace[i, j] == 2:
74
+ j -= 1
75
+ else:
76
+ raise ValueError("Unexpected trace[i, j]")
77
+
78
+ result = np.array(result)
79
+ return result[::-1, :].T
80
+
81
+
82
+ @numba.jit(nopython=True, parallel=True)
83
+ def dtw_cpu(x: np.ndarray):
84
+ N, M = x.shape
85
+ cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
86
+ trace = -np.ones((N + 1, M + 1), dtype=np.float32)
87
+
88
+ cost[0, 0] = 0
89
+ for j in range(1, M + 1):
90
+ for i in range(1, N + 1):
91
+ c0 = cost[i - 1, j - 1]
92
+ c1 = cost[i - 1, j]
93
+ c2 = cost[i, j - 1]
94
+
95
+ if c0 < c1 and c0 < c2:
96
+ c, t = c0, 0
97
+ elif c1 < c0 and c1 < c2:
98
+ c, t = c1, 1
99
+ else:
100
+ c, t = c2, 2
101
+
102
+ cost[i, j] = x[i - 1, j - 1] + c
103
+ trace[i, j] = t
104
+
105
+ return backtrace(trace)
106
+
107
+
108
+ def dtw_cuda(x, BLOCK_SIZE=1024):
109
+ from .triton_ops import dtw_kernel
110
+
111
+ M, N = x.shape
112
+ assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
113
+
114
+ x_skew = (
115
+ F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
116
+ )
117
+ x_skew = x_skew.T.contiguous()
118
+ cost = torch.ones(N + M + 2, M + 2) * np.inf
119
+ cost[0, 0] = 0
120
+ cost = cost.cuda()
121
+ trace = torch.zeros_like(cost, dtype=torch.int32)
122
+
123
+ dtw_kernel[(1,)](
124
+ cost,
125
+ trace,
126
+ x_skew,
127
+ x_skew.stride(0),
128
+ cost.stride(0),
129
+ trace.stride(0),
130
+ N,
131
+ M,
132
+ BLOCK_SIZE=BLOCK_SIZE,
133
+ )
134
+
135
+ trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
136
+ :, : N + 1
137
+ ]
138
+ return backtrace(trace.cpu().numpy())
139
+
140
+
141
+ def dtw(x: torch.Tensor) -> np.ndarray:
142
+ if x.is_cuda:
143
+ try:
144
+ return dtw_cuda(x)
145
+ except (RuntimeError, subprocess.CalledProcessError):
146
+ warnings.warn(
147
+ "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
148
+ "falling back to a slower DTW implementation..."
149
+ )
150
+
151
+ return dtw_cpu(x.double().cpu().numpy())
152
+
153
+
154
+ @dataclass
155
+ class WordTiming:
156
+ word: str
157
+ tokens: List[int]
158
+ start: float
159
+ end: float
160
+ probability: float
161
+
162
+
163
+ def find_alignment(
164
+ model: "Whisper",
165
+ tokenizer: Tokenizer,
166
+ text_tokens: List[int],
167
+ mel: torch.Tensor,
168
+ num_frames: int,
169
+ *,
170
+ medfilt_width: int = 7,
171
+ qk_scale: float = 1.0,
172
+ ) -> List[WordTiming]:
173
+ if len(text_tokens) == 0:
174
+ return []
175
+
176
+ tokens = torch.tensor(
177
+ [
178
+ *tokenizer.sot_sequence,
179
+ tokenizer.no_timestamps,
180
+ *text_tokens,
181
+ tokenizer.eot,
182
+ ]
183
+ ).to(model.device)
184
+
185
+ # install hooks on the cross attention layers to retrieve the attention weights
186
+ QKs = [None] * model.dims.n_text_layer
187
+ hooks = [
188
+ block.cross_attn.register_forward_hook(
189
+ lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
190
+ )
191
+ for i, block in enumerate(model.decoder.blocks)
192
+ ]
193
+
194
+ with torch.no_grad():
195
+ logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
196
+ sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
197
+ token_probs = sampled_logits.softmax(dim=-1)
198
+ text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
199
+ text_token_probs = text_token_probs.tolist()
200
+
201
+ for hook in hooks:
202
+ hook.remove()
203
+
204
+ # heads * tokens * frames
205
+ weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
206
+ weights = weights[:, :, : num_frames // 2]
207
+ weights = (weights * qk_scale).softmax(dim=-1)
208
+ std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
209
+ weights = (weights - mean) / std
210
+ weights = median_filter(weights, medfilt_width)
211
+
212
+ matrix = weights.mean(axis=0)
213
+ matrix = matrix[len(tokenizer.sot_sequence) : -1]
214
+ text_indices, time_indices = dtw(-matrix)
215
+
216
+ words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
217
+ if len(word_tokens) <= 1:
218
+ # return on eot only
219
+ # >>> np.pad([], (1, 0))
220
+ # array([0.])
221
+ # This results in crashes when we lookup jump_times with float, like
222
+ # IndexError: arrays used as indices must be of integer (or boolean) type
223
+ return []
224
+ word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
225
+
226
+ jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
227
+ jump_times = time_indices[jumps] / TOKENS_PER_SECOND
228
+ start_times = jump_times[word_boundaries[:-1]]
229
+ end_times = jump_times[word_boundaries[1:]]
230
+ word_probabilities = [
231
+ np.mean(text_token_probs[i:j])
232
+ for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
233
+ ]
234
+
235
+ return [
236
+ WordTiming(word, tokens, start, end, probability)
237
+ for word, tokens, start, end, probability in zip(
238
+ words, word_tokens, start_times, end_times, word_probabilities
239
+ )
240
+ ]
241
+
242
+
243
+ def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
244
+ # merge prepended punctuations
245
+ i = len(alignment) - 2
246
+ j = len(alignment) - 1
247
+ while i >= 0:
248
+ previous = alignment[i]
249
+ following = alignment[j]
250
+ if previous.word.startswith(" ") and previous.word.strip() in prepended:
251
+ # prepend it to the following word
252
+ following.word = previous.word + following.word
253
+ following.tokens = previous.tokens + following.tokens
254
+ previous.word = ""
255
+ previous.tokens = []
256
+ else:
257
+ j = i
258
+ i -= 1
259
+
260
+ # merge appended punctuations
261
+ i = 0
262
+ j = 1
263
+ while j < len(alignment):
264
+ previous = alignment[i]
265
+ following = alignment[j]
266
+ if not previous.word.endswith(" ") and following.word in appended:
267
+ # append it to the previous word
268
+ previous.word = previous.word + following.word
269
+ previous.tokens = previous.tokens + following.tokens
270
+ following.word = ""
271
+ following.tokens = []
272
+ else:
273
+ i = j
274
+ j += 1
275
+
276
+
277
+ def add_word_timestamps(
278
+ *,
279
+ segments: List[dict],
280
+ model: "Whisper",
281
+ tokenizer: Tokenizer,
282
+ mel: torch.Tensor,
283
+ num_frames: int,
284
+ prepend_punctuations: str = "\"'“¿([{-",
285
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
286
+ last_speech_timestamp: float,
287
+ **kwargs,
288
+ ):
289
+ if len(segments) == 0:
290
+ return
291
+
292
+ text_tokens_per_segment = [
293
+ [token for token in segment["tokens"] if token < tokenizer.eot]
294
+ for segment in segments
295
+ ]
296
+
297
+ text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
298
+ alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
299
+ word_durations = np.array([t.end - t.start for t in alignment])
300
+ word_durations = word_durations[word_durations.nonzero()]
301
+ median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
302
+ median_duration = min(0.7, float(median_duration))
303
+ max_duration = median_duration * 2
304
+
305
+ # hack: truncate long words at sentence boundaries.
306
+ # a better segmentation algorithm based on VAD should be able to replace this.
307
+ if len(word_durations) > 0:
308
+ sentence_end_marks = ".。!!??"
309
+ # ensure words at sentence boundaries are not longer than twice the median word duration.
310
+ for i in range(1, len(alignment)):
311
+ if alignment[i].end - alignment[i].start > max_duration:
312
+ if alignment[i].word in sentence_end_marks:
313
+ alignment[i].end = alignment[i].start + max_duration
314
+ elif alignment[i - 1].word in sentence_end_marks:
315
+ alignment[i].start = alignment[i].end - max_duration
316
+
317
+ merge_punctuations(alignment, prepend_punctuations, append_punctuations)
318
+
319
+ time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
320
+ word_index = 0
321
+
322
+ for segment, text_tokens in zip(segments, text_tokens_per_segment):
323
+ saved_tokens = 0
324
+ words = []
325
+
326
+ while word_index < len(alignment) and saved_tokens < len(text_tokens):
327
+ timing = alignment[word_index]
328
+
329
+ if timing.word:
330
+ words.append(
331
+ dict(
332
+ word=timing.word,
333
+ start=round(time_offset + timing.start, 2),
334
+ end=round(time_offset + timing.end, 2),
335
+ probability=timing.probability,
336
+ )
337
+ )
338
+
339
+ saved_tokens += len(timing.tokens)
340
+ word_index += 1
341
+
342
+ # hack: truncate long words at segment boundaries.
343
+ # a better segmentation algorithm based on VAD should be able to replace this.
344
+ if len(words) > 0:
345
+ # ensure the first and second word after a pause is not longer than
346
+ # twice the median word duration.
347
+ if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
348
+ words[0]["end"] - words[0]["start"] > max_duration
349
+ or (
350
+ len(words) > 1
351
+ and words[1]["end"] - words[0]["start"] > max_duration * 2
352
+ )
353
+ ):
354
+ if (
355
+ len(words) > 1
356
+ and words[1]["end"] - words[1]["start"] > max_duration
357
+ ):
358
+ boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
359
+ words[0]["end"] = words[1]["start"] = boundary
360
+ words[0]["start"] = max(0, words[0]["end"] - max_duration)
361
+
362
+ # prefer the segment-level start timestamp if the first word is too long.
363
+ if (
364
+ segment["start"] < words[0]["end"]
365
+ and segment["start"] - 0.5 > words[0]["start"]
366
+ ):
367
+ words[0]["start"] = max(
368
+ 0, min(words[0]["end"] - median_duration, segment["start"])
369
+ )
370
+ else:
371
+ segment["start"] = words[0]["start"]
372
+
373
+ # prefer the segment-level end timestamp if the last word is too long.
374
+ if (
375
+ segment["end"] > words[-1]["start"]
376
+ and segment["end"] + 0.5 < words[-1]["end"]
377
+ ):
378
+ words[-1]["end"] = max(
379
+ words[-1]["start"] + median_duration, segment["end"]
380
+ )
381
+ else:
382
+ segment["end"] = words[-1]["end"]
383
+
384
+ last_speech_timestamp = segment["end"]
385
+
386
+ segment["words"] = words
whisper_stream/tokenizer.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import string
4
+ from dataclasses import dataclass, field
5
+ from functools import cached_property, lru_cache
6
+ from typing import Dict, List, Optional, Tuple
7
+
8
+ import tiktoken
9
+
10
+ LANGUAGES = {
11
+ "en": "english",
12
+ "zh": "chinese",
13
+ "de": "german",
14
+ "es": "spanish",
15
+ "ru": "russian",
16
+ "ko": "korean",
17
+ "fr": "french",
18
+ "ja": "japanese",
19
+ "pt": "portuguese",
20
+ "tr": "turkish",
21
+ "pl": "polish",
22
+ "ca": "catalan",
23
+ "nl": "dutch",
24
+ "ar": "arabic",
25
+ "sv": "swedish",
26
+ "it": "italian",
27
+ "id": "indonesian",
28
+ "hi": "hindi",
29
+ "fi": "finnish",
30
+ "vi": "vietnamese",
31
+ "he": "hebrew",
32
+ "uk": "ukrainian",
33
+ "el": "greek",
34
+ "ms": "malay",
35
+ "cs": "czech",
36
+ "ro": "romanian",
37
+ "da": "danish",
38
+ "hu": "hungarian",
39
+ "ta": "tamil",
40
+ "no": "norwegian",
41
+ "th": "thai",
42
+ "ur": "urdu",
43
+ "hr": "croatian",
44
+ "bg": "bulgarian",
45
+ "lt": "lithuanian",
46
+ "la": "latin",
47
+ "mi": "maori",
48
+ "ml": "malayalam",
49
+ "cy": "welsh",
50
+ "sk": "slovak",
51
+ "te": "telugu",
52
+ "fa": "persian",
53
+ "lv": "latvian",
54
+ "bn": "bengali",
55
+ "sr": "serbian",
56
+ "az": "azerbaijani",
57
+ "sl": "slovenian",
58
+ "kn": "kannada",
59
+ "et": "estonian",
60
+ "mk": "macedonian",
61
+ "br": "breton",
62
+ "eu": "basque",
63
+ "is": "icelandic",
64
+ "hy": "armenian",
65
+ "ne": "nepali",
66
+ "mn": "mongolian",
67
+ "bs": "bosnian",
68
+ "kk": "kazakh",
69
+ "sq": "albanian",
70
+ "sw": "swahili",
71
+ "gl": "galician",
72
+ "mr": "marathi",
73
+ "pa": "punjabi",
74
+ "si": "sinhala",
75
+ "km": "khmer",
76
+ "sn": "shona",
77
+ "yo": "yoruba",
78
+ "so": "somali",
79
+ "af": "afrikaans",
80
+ "oc": "occitan",
81
+ "ka": "georgian",
82
+ "be": "belarusian",
83
+ "tg": "tajik",
84
+ "sd": "sindhi",
85
+ "gu": "gujarati",
86
+ "am": "amharic",
87
+ "yi": "yiddish",
88
+ "lo": "lao",
89
+ "uz": "uzbek",
90
+ "fo": "faroese",
91
+ "ht": "haitian creole",
92
+ "ps": "pashto",
93
+ "tk": "turkmen",
94
+ "nn": "nynorsk",
95
+ "mt": "maltese",
96
+ "sa": "sanskrit",
97
+ "lb": "luxembourgish",
98
+ "my": "myanmar",
99
+ "bo": "tibetan",
100
+ "tl": "tagalog",
101
+ "mg": "malagasy",
102
+ "as": "assamese",
103
+ "tt": "tatar",
104
+ "haw": "hawaiian",
105
+ "ln": "lingala",
106
+ "ha": "hausa",
107
+ "ba": "bashkir",
108
+ "jw": "javanese",
109
+ "su": "sundanese",
110
+ "yue": "cantonese",
111
+ }
112
+
113
+ # language code lookup by name, with a few language aliases
114
+ TO_LANGUAGE_CODE = {
115
+ **{language: code for code, language in LANGUAGES.items()},
116
+ "burmese": "my",
117
+ "valencian": "ca",
118
+ "flemish": "nl",
119
+ "haitian": "ht",
120
+ "letzeburgesch": "lb",
121
+ "pushto": "ps",
122
+ "panjabi": "pa",
123
+ "moldavian": "ro",
124
+ "moldovan": "ro",
125
+ "sinhalese": "si",
126
+ "castilian": "es",
127
+ "mandarin": "zh",
128
+ }
129
+
130
+
131
+ @dataclass
132
+ class Tokenizer:
133
+ """A thin wrapper around `tiktoken` providing quick access to special tokens"""
134
+
135
+ encoding: tiktoken.Encoding
136
+ num_languages: int
137
+ language: Optional[str] = None
138
+ task: Optional[str] = None
139
+ sot_sequence: Tuple[int] = ()
140
+ special_tokens: Dict[str, int] = field(default_factory=dict)
141
+
142
+ def __post_init__(self):
143
+ for special in self.encoding.special_tokens_set:
144
+ special_token = self.encoding.encode_single_token(special)
145
+ self.special_tokens[special] = special_token
146
+
147
+ sot: int = self.special_tokens["<|startoftranscript|>"]
148
+ translate: int = self.special_tokens["<|translate|>"]
149
+ transcribe: int = self.special_tokens["<|transcribe|>"]
150
+
151
+ langs = tuple(LANGUAGES.keys())[: self.num_languages]
152
+ sot_sequence = [sot]
153
+ if self.language is not None:
154
+ sot_sequence.append(sot + 1 + langs.index(self.language))
155
+ if self.task is not None:
156
+ task_token: int = transcribe if self.task == "transcribe" else translate
157
+ sot_sequence.append(task_token)
158
+
159
+ self.sot_sequence = tuple(sot_sequence)
160
+
161
+ def encode(self, text, **kwargs):
162
+ return self.encoding.encode(text, **kwargs)
163
+
164
+ def decode(self, token_ids: List[int], **kwargs) -> str:
165
+ token_ids = [t for t in token_ids if t < self.timestamp_begin]
166
+ return self.encoding.decode(token_ids, **kwargs)
167
+
168
+ def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
169
+ """
170
+ Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
171
+ This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
172
+ """
173
+ return self.encoding.decode(token_ids, **kwargs)
174
+
175
+ @cached_property
176
+ def eot(self) -> int:
177
+ return self.encoding.eot_token
178
+
179
+ @cached_property
180
+ def transcribe(self) -> int:
181
+ return self.special_tokens["<|transcribe|>"]
182
+
183
+ @cached_property
184
+ def translate(self) -> int:
185
+ return self.special_tokens["<|translate|>"]
186
+
187
+ @cached_property
188
+ def sot(self) -> int:
189
+ return self.special_tokens["<|startoftranscript|>"]
190
+
191
+ @cached_property
192
+ def sot_lm(self) -> int:
193
+ return self.special_tokens["<|startoflm|>"]
194
+
195
+ @cached_property
196
+ def sot_prev(self) -> int:
197
+ return self.special_tokens["<|startofprev|>"]
198
+
199
+ @cached_property
200
+ def no_speech(self) -> int:
201
+ return self.special_tokens["<|nospeech|>"]
202
+
203
+ @cached_property
204
+ def no_timestamps(self) -> int:
205
+ return self.special_tokens["<|notimestamps|>"]
206
+
207
+ @cached_property
208
+ def timestamp_begin(self) -> int:
209
+ return self.special_tokens["<|0.00|>"]
210
+
211
+ @cached_property
212
+ def language_token(self) -> int:
213
+ """Returns the token id corresponding to the value of the `language` field"""
214
+ if self.language is None:
215
+ raise ValueError("This tokenizer does not have language token configured")
216
+
217
+ return self.to_language_token(self.language)
218
+
219
+ def to_language_token(self, language):
220
+ if token := self.special_tokens.get(f"<|{language}|>", None):
221
+ return token
222
+
223
+ raise KeyError(f"Language {language} not found in tokenizer.")
224
+
225
+ @cached_property
226
+ def all_language_tokens(self) -> Tuple[int]:
227
+ result = []
228
+ for token, token_id in self.special_tokens.items():
229
+ if token.strip("<|>") in LANGUAGES:
230
+ result.append(token_id)
231
+ return tuple(result)[: self.num_languages]
232
+
233
+ @cached_property
234
+ def all_language_codes(self) -> Tuple[str]:
235
+ return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
236
+
237
+ @cached_property
238
+ def sot_sequence_including_notimestamps(self) -> Tuple[int]:
239
+ return tuple(list(self.sot_sequence) + [self.no_timestamps])
240
+
241
+ @cached_property
242
+ def non_speech_tokens(self) -> Tuple[int]:
243
+ """
244
+ Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
245
+ annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
246
+
247
+ - ♪♪♪
248
+ - ( SPEAKING FOREIGN LANGUAGE )
249
+ - [DAVID] Hey there,
250
+
251
+ keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
252
+ """
253
+ symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
254
+ symbols += (
255
+ "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
256
+ )
257
+
258
+ # symbols that may be a single token or multiple tokens depending on the tokenizer.
259
+ # In case they're multiple tokens, suppress the first token, which is safe because:
260
+ # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
261
+ # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
262
+ miscellaneous = set("♩♪♫♬♭♮♯")
263
+ assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
264
+
265
+ # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
266
+ result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
267
+ for symbol in symbols + list(miscellaneous):
268
+ for tokens in [
269
+ self.encoding.encode(symbol),
270
+ self.encoding.encode(" " + symbol),
271
+ ]:
272
+ if len(tokens) == 1 or symbol in miscellaneous:
273
+ result.add(tokens[0])
274
+
275
+ return tuple(sorted(result))
276
+
277
+ def split_to_word_tokens(self, tokens: List[int]):
278
+ if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
279
+ # These languages don't typically use spaces, so it is difficult to split words
280
+ # without morpheme analysis. Here, we instead split words at any
281
+ # position where the tokens are decoded as valid unicode points
282
+ return self.split_tokens_on_unicode(tokens)
283
+
284
+ return self.split_tokens_on_spaces(tokens)
285
+
286
+ def split_tokens_on_unicode(self, tokens: List[int]):
287
+ decoded_full = self.decode_with_timestamps(tokens)
288
+ replacement_char = "\ufffd"
289
+
290
+ words = []
291
+ word_tokens = []
292
+ current_tokens = []
293
+ unicode_offset = 0
294
+
295
+ for token in tokens:
296
+ current_tokens.append(token)
297
+ decoded = self.decode_with_timestamps(current_tokens)
298
+
299
+ if (
300
+ replacement_char not in decoded
301
+ or decoded_full[unicode_offset + decoded.index(replacement_char)]
302
+ == replacement_char
303
+ ):
304
+ words.append(decoded)
305
+ word_tokens.append(current_tokens)
306
+ current_tokens = []
307
+ unicode_offset += len(decoded)
308
+
309
+ return words, word_tokens
310
+
311
+ def split_tokens_on_spaces(self, tokens: List[int]):
312
+ subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
313
+ words = []
314
+ word_tokens = []
315
+
316
+ for subword, subword_tokens in zip(subwords, subword_tokens_list):
317
+ special = subword_tokens[0] >= self.eot
318
+ with_space = subword.startswith(" ")
319
+ punctuation = subword.strip() in string.punctuation
320
+ if special or with_space or punctuation or len(words) == 0:
321
+ words.append(subword)
322
+ word_tokens.append(subword_tokens)
323
+ else:
324
+ words[-1] = words[-1] + subword
325
+ word_tokens[-1].extend(subword_tokens)
326
+
327
+ return words, word_tokens
328
+
329
+
330
+ @lru_cache(maxsize=None)
331
+ def get_encoding(name: str = "gpt2", num_languages: int = 99):
332
+ vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
333
+ ranks = {
334
+ base64.b64decode(token): int(rank)
335
+ for token, rank in (line.split() for line in open(vocab_path) if line)
336
+ }
337
+ n_vocab = len(ranks)
338
+ special_tokens = {}
339
+
340
+ specials = [
341
+ "<|endoftext|>",
342
+ "<|startoftranscript|>",
343
+ *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
344
+ "<|translate|>",
345
+ "<|transcribe|>",
346
+ "<|startoflm|>",
347
+ "<|startofprev|>",
348
+ "<|nospeech|>",
349
+ "<|notimestamps|>",
350
+ *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
351
+ ]
352
+
353
+ for token in specials:
354
+ special_tokens[token] = n_vocab
355
+ n_vocab += 1
356
+
357
+ return tiktoken.Encoding(
358
+ name=os.path.basename(vocab_path),
359
+ explicit_n_vocab=n_vocab,
360
+ pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
361
+ mergeable_ranks=ranks,
362
+ special_tokens=special_tokens,
363
+ )
364
+
365
+
366
+ @lru_cache(maxsize=None)
367
+ def get_tokenizer(
368
+ multilingual: bool,
369
+ *,
370
+ num_languages: int = 99,
371
+ language: Optional[str] = None,
372
+ task: Optional[str] = None, # Literal["transcribe", "translate", None]
373
+ ) -> Tokenizer:
374
+ if language is not None:
375
+ language = language.lower()
376
+ if language not in LANGUAGES:
377
+ if language in TO_LANGUAGE_CODE:
378
+ language = TO_LANGUAGE_CODE[language]
379
+ else:
380
+ raise ValueError(f"Unsupported language: {language}")
381
+
382
+ if multilingual:
383
+ encoding_name = "multilingual"
384
+ language = language or "en"
385
+ task = task or "transcribe"
386
+ else:
387
+ encoding_name = "gpt2"
388
+ language = None
389
+ task = None
390
+
391
+ encoding = get_encoding(name=encoding_name, num_languages=num_languages)
392
+
393
+ return Tokenizer(
394
+ encoding=encoding, num_languages=num_languages, language=language, task=task
395
+ )
whisper_stream/transcribe.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import traceback
4
+ import warnings
5
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ import tqdm
10
+
11
+ from .audio import (
12
+ FRAMES_PER_SECOND,
13
+ HOP_LENGTH,
14
+ N_FRAMES,
15
+ N_SAMPLES,
16
+ SAMPLE_RATE,
17
+ log_mel_spectrogram,
18
+ pad_or_trim,
19
+ )
20
+ from .decoding import DecodingOptions, DecodingResult
21
+ from .timing import add_word_timestamps
22
+ from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
23
+ from .utils import (
24
+ exact_div,
25
+ format_timestamp,
26
+ get_end,
27
+ get_writer,
28
+ make_safe,
29
+ optional_float,
30
+ optional_int,
31
+ str2bool,
32
+ )
33
+
34
+ if TYPE_CHECKING:
35
+ from .model import Whisper
36
+
37
+
38
+ def transcribe(
39
+ model: "Whisper",
40
+ audio: Union[str, np.ndarray, torch.Tensor],
41
+ *,
42
+ verbose: Optional[bool] = None,
43
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
44
+ compression_ratio_threshold: Optional[float] = 2.4,
45
+ logprob_threshold: Optional[float] = -1.0,
46
+ no_speech_threshold: Optional[float] = 0.6,
47
+ condition_on_previous_text: bool = True,
48
+ initial_prompt: Optional[str] = None,
49
+ word_timestamps: bool = False,
50
+ prepend_punctuations: str = "\"'“¿([{-",
51
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
52
+ clip_timestamps: Union[str, List[float]] = "0",
53
+ hallucination_silence_threshold: Optional[float] = None,
54
+ **decode_options,
55
+ ):
56
+ """
57
+ Transcribe an audio file using Whisper
58
+
59
+ Parameters
60
+ ----------
61
+ model: Whisper
62
+ The Whisper model instance
63
+
64
+ audio: Union[str, np.ndarray, torch.Tensor]
65
+ The path to the audio file to open, or the audio waveform
66
+
67
+ verbose: bool
68
+ Whether to display the text being decoded to the console. If True, displays all the details,
69
+ If False, displays minimal details. If None, does not display anything
70
+
71
+ temperature: Union[float, Tuple[float, ...]]
72
+ Temperature for sampling. It can be a tuple of temperatures, which will be successively used
73
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
74
+
75
+ compression_ratio_threshold: float
76
+ If the gzip compression ratio is above this value, treat as failed
77
+
78
+ logprob_threshold: float
79
+ If the average log probability over sampled tokens is below this value, treat as failed
80
+
81
+ no_speech_threshold: float
82
+ If the no_speech probability is higher than this value AND the average log probability
83
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
84
+
85
+ condition_on_previous_text: bool
86
+ if True, the previous output of the model is provided as a prompt for the next window;
87
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
88
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
89
+
90
+ word_timestamps: bool
91
+ Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
92
+ and include the timestamps for each word in each segment.
93
+
94
+ prepend_punctuations: str
95
+ If word_timestamps is True, merge these punctuation symbols with the next word
96
+
97
+ append_punctuations: str
98
+ If word_timestamps is True, merge these punctuation symbols with the previous word
99
+
100
+ initial_prompt: Optional[str]
101
+ Optional text to provide as a prompt for the first window. This can be used to provide, or
102
+ "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
103
+ to make it more likely to predict those word correctly.
104
+
105
+ decode_options: dict
106
+ Keyword arguments to construct `DecodingOptions` instances
107
+
108
+ clip_timestamps: Union[str, List[float]]
109
+ Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
110
+ The last end timestamp defaults to the end of the file.
111
+
112
+ hallucination_silence_threshold: Optional[float]
113
+ When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
114
+ when a possible hallucination is detected
115
+
116
+ Returns
117
+ -------
118
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
119
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
120
+ """
121
+ dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
122
+ if model.device == torch.device("cpu"):
123
+ if torch.cuda.is_available():
124
+ warnings.warn("Performing inference on CPU when CUDA is available")
125
+ if dtype == torch.float16:
126
+ warnings.warn("FP16 is not supported on CPU; using FP32 instead")
127
+ dtype = torch.float32
128
+
129
+ if dtype == torch.float32:
130
+ decode_options["fp16"] = False
131
+
132
+ # Pad 30-seconds of silence to the input audio, for slicing
133
+ mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
134
+ content_frames = mel.shape[-1] - N_FRAMES
135
+ content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
136
+
137
+ if decode_options.get("language", None) is None:
138
+ if not model.is_multilingual:
139
+ decode_options["language"] = "en"
140
+ else:
141
+ if verbose:
142
+ print(
143
+ "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
144
+ )
145
+ mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
146
+ _, probs = model.detect_language(mel_segment)
147
+ decode_options["language"] = max(probs, key=probs.get)
148
+ if verbose is not None:
149
+ print(
150
+ f"Detected language: {LANGUAGES[decode_options['language']].title()}"
151
+ )
152
+
153
+ language: str = decode_options["language"]
154
+ task: str = decode_options.get("task", "transcribe")
155
+ tokenizer = get_tokenizer(
156
+ model.is_multilingual,
157
+ num_languages=model.num_languages,
158
+ language=language,
159
+ task=task,
160
+ )
161
+
162
+ if isinstance(clip_timestamps, str):
163
+ clip_timestamps = [
164
+ float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
165
+ ]
166
+ seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
167
+ if len(seek_points) == 0:
168
+ seek_points.append(0)
169
+ if len(seek_points) % 2 == 1:
170
+ seek_points.append(content_frames)
171
+ seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
172
+
173
+ punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
174
+
175
+ if word_timestamps and task == "translate":
176
+ warnings.warn("Word-level timestamps on translations may not be reliable.")
177
+
178
+ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
179
+ temperatures = (
180
+ [temperature] if isinstance(temperature, (int, float)) else temperature
181
+ )
182
+ decode_result = None
183
+
184
+ for t in temperatures:
185
+ kwargs = {**decode_options}
186
+ if t > 0:
187
+ # disable beam_size and patience when t > 0
188
+ kwargs.pop("beam_size", None)
189
+ kwargs.pop("patience", None)
190
+ else:
191
+ # disable best_of when t == 0
192
+ kwargs.pop("best_of", None)
193
+
194
+ options = DecodingOptions(**kwargs, temperature=t)
195
+ decode_result = model.decode(segment, options)
196
+
197
+ needs_fallback = False
198
+ if (
199
+ compression_ratio_threshold is not None
200
+ and decode_result.compression_ratio > compression_ratio_threshold
201
+ ):
202
+ needs_fallback = True # too repetitive
203
+ if (
204
+ logprob_threshold is not None
205
+ and decode_result.avg_logprob < logprob_threshold
206
+ ):
207
+ needs_fallback = True # average log probability is too low
208
+ if (
209
+ no_speech_threshold is not None
210
+ and decode_result.no_speech_prob > no_speech_threshold
211
+ ):
212
+ needs_fallback = False # silence
213
+ if not needs_fallback:
214
+ break
215
+
216
+ return decode_result
217
+
218
+ clip_idx = 0
219
+ seek = seek_clips[clip_idx][0]
220
+ input_stride = exact_div(
221
+ N_FRAMES, model.dims.n_audio_ctx
222
+ ) # mel frames per output token: 2
223
+ time_precision = (
224
+ input_stride * HOP_LENGTH / SAMPLE_RATE
225
+ ) # time per output token: 0.02 (seconds)
226
+ all_tokens = []
227
+ all_segments = []
228
+ prompt_reset_since = 0
229
+
230
+ if initial_prompt is not None:
231
+ initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
232
+ all_tokens.extend(initial_prompt_tokens)
233
+ else:
234
+ initial_prompt_tokens = []
235
+
236
+ def new_segment(
237
+ *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
238
+ ):
239
+ tokens = tokens.tolist()
240
+ text_tokens = [token for token in tokens if token < tokenizer.eot]
241
+ return {
242
+ "seek": seek,
243
+ "start": start,
244
+ "end": end,
245
+ "text": tokenizer.decode(text_tokens),
246
+ "tokens": tokens,
247
+ "temperature": result.temperature,
248
+ "avg_logprob": result.avg_logprob,
249
+ "compression_ratio": result.compression_ratio,
250
+ "no_speech_prob": result.no_speech_prob,
251
+ }
252
+
253
+ # show the progress bar when verbose is False (if True, transcribed text will be printed)
254
+ with tqdm.tqdm(
255
+ total=content_frames, unit="frames", disable=verbose is not False
256
+ ) as pbar:
257
+ last_speech_timestamp = 0.0
258
+ # NOTE: This loop is obscurely flattened to make the diff readable.
259
+ # A later commit should turn this into a simpler nested loop.
260
+ # for seek_clip_start, seek_clip_end in seek_clips:
261
+ # while seek < seek_clip_end
262
+ while clip_idx < len(seek_clips):
263
+ seek_clip_start, seek_clip_end = seek_clips[clip_idx]
264
+ if seek < seek_clip_start:
265
+ seek = seek_clip_start
266
+ if seek >= seek_clip_end:
267
+ clip_idx += 1
268
+ if clip_idx < len(seek_clips):
269
+ seek = seek_clips[clip_idx][0]
270
+ continue
271
+ time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
272
+ window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
273
+ segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
274
+ mel_segment = mel[:, seek : seek + segment_size]
275
+ segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
276
+ mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
277
+
278
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
279
+ result: DecodingResult = decode_with_fallback(mel_segment)
280
+ tokens = torch.tensor(result.tokens)
281
+
282
+ if no_speech_threshold is not None:
283
+ # no voice activity check
284
+ should_skip = result.no_speech_prob > no_speech_threshold
285
+ if (
286
+ logprob_threshold is not None
287
+ and result.avg_logprob > logprob_threshold
288
+ ):
289
+ # don't skip if the logprob is high enough, despite the no_speech_prob
290
+ should_skip = False
291
+
292
+ if should_skip:
293
+ seek += segment_size # fast-forward to the next segment boundary
294
+ continue
295
+
296
+ previous_seek = seek
297
+ current_segments = []
298
+
299
+ # anomalous words are very long/short/improbable
300
+ def word_anomaly_score(word: dict) -> float:
301
+ probability = word.get("probability", 0.0)
302
+ duration = word["end"] - word["start"]
303
+ score = 0.0
304
+ if probability < 0.15:
305
+ score += 1.0
306
+ if duration < 0.133:
307
+ score += (0.133 - duration) * 15
308
+ if duration > 2.0:
309
+ score += duration - 2.0
310
+ return score
311
+
312
+ def is_segment_anomaly(segment: Optional[dict]) -> bool:
313
+ if segment is None or not segment["words"]:
314
+ return False
315
+ words = [w for w in segment["words"] if w["word"] not in punctuation]
316
+ words = words[:8]
317
+ score = sum(word_anomaly_score(w) for w in words)
318
+ return score >= 3 or score + 0.01 >= len(words)
319
+
320
+ def next_words_segment(segments: List[dict]) -> Optional[dict]:
321
+ return next((s for s in segments if s["words"]), None)
322
+
323
+ timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
324
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
325
+
326
+ consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
327
+ consecutive.add_(1)
328
+ if len(consecutive) > 0:
329
+ # if the output contains two consecutive timestamp tokens
330
+ slices = consecutive.tolist()
331
+ if single_timestamp_ending:
332
+ slices.append(len(tokens))
333
+
334
+ last_slice = 0
335
+ for current_slice in slices:
336
+ sliced_tokens = tokens[last_slice:current_slice]
337
+ start_timestamp_pos = (
338
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
339
+ )
340
+ end_timestamp_pos = (
341
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
342
+ )
343
+ current_segments.append(
344
+ new_segment(
345
+ start=time_offset + start_timestamp_pos * time_precision,
346
+ end=time_offset + end_timestamp_pos * time_precision,
347
+ tokens=sliced_tokens,
348
+ result=result,
349
+ )
350
+ )
351
+ last_slice = current_slice
352
+
353
+ if single_timestamp_ending:
354
+ # single timestamp at the end means no speech after the last timestamp.
355
+ seek += segment_size
356
+ else:
357
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
358
+ last_timestamp_pos = (
359
+ tokens[last_slice - 1].item() - tokenizer.timestamp_begin
360
+ )
361
+ seek += last_timestamp_pos * input_stride
362
+ else:
363
+ duration = segment_duration
364
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
365
+ if (
366
+ len(timestamps) > 0
367
+ and timestamps[-1].item() != tokenizer.timestamp_begin
368
+ ):
369
+ # no consecutive timestamps but it has a timestamp; use the last one.
370
+ last_timestamp_pos = (
371
+ timestamps[-1].item() - tokenizer.timestamp_begin
372
+ )
373
+ duration = last_timestamp_pos * time_precision
374
+
375
+ current_segments.append(
376
+ new_segment(
377
+ start=time_offset,
378
+ end=time_offset + duration,
379
+ tokens=tokens,
380
+ result=result,
381
+ )
382
+ )
383
+ seek += segment_size
384
+
385
+ if word_timestamps:
386
+ add_word_timestamps(
387
+ segments=current_segments,
388
+ model=model,
389
+ tokenizer=tokenizer,
390
+ mel=mel_segment,
391
+ num_frames=segment_size,
392
+ prepend_punctuations=prepend_punctuations,
393
+ append_punctuations=append_punctuations,
394
+ last_speech_timestamp=last_speech_timestamp,
395
+ )
396
+
397
+ if not single_timestamp_ending:
398
+ last_word_end = get_end(current_segments)
399
+ if last_word_end is not None and last_word_end > time_offset:
400
+ seek = round(last_word_end * FRAMES_PER_SECOND)
401
+
402
+ # skip silence before possible hallucinations
403
+ if hallucination_silence_threshold is not None:
404
+ threshold = hallucination_silence_threshold
405
+ if not single_timestamp_ending:
406
+ last_word_end = get_end(current_segments)
407
+ if last_word_end is not None and last_word_end > time_offset:
408
+ remaining_duration = window_end_time - last_word_end
409
+ if remaining_duration > threshold:
410
+ seek = round(last_word_end * FRAMES_PER_SECOND)
411
+ else:
412
+ seek = previous_seek + segment_size
413
+
414
+ # if first segment might be a hallucination, skip leading silence
415
+ first_segment = next_words_segment(current_segments)
416
+ if first_segment is not None and is_segment_anomaly(first_segment):
417
+ gap = first_segment["start"] - time_offset
418
+ if gap > threshold:
419
+ seek = previous_seek + round(gap * FRAMES_PER_SECOND)
420
+ continue
421
+
422
+ # skip silence before any possible hallucination that is surrounded
423
+ # by silence or more hallucinations
424
+ hal_last_end = last_speech_timestamp
425
+ for si in range(len(current_segments)):
426
+ segment = current_segments[si]
427
+ if not segment["words"]:
428
+ continue
429
+ if is_segment_anomaly(segment):
430
+ next_segment = next_words_segment(
431
+ current_segments[si + 1 :]
432
+ )
433
+ if next_segment is not None:
434
+ hal_next_start = next_segment["words"][0]["start"]
435
+ else:
436
+ hal_next_start = time_offset + segment_duration
437
+ silence_before = (
438
+ segment["start"] - hal_last_end > threshold
439
+ or segment["start"] < threshold
440
+ or segment["start"] - time_offset < 2.0
441
+ )
442
+ silence_after = (
443
+ hal_next_start - segment["end"] > threshold
444
+ or is_segment_anomaly(next_segment)
445
+ or window_end_time - segment["end"] < 2.0
446
+ )
447
+ if silence_before and silence_after:
448
+ seek = round(
449
+ max(time_offset + 1, segment["start"])
450
+ * FRAMES_PER_SECOND
451
+ )
452
+ if content_duration - segment["end"] < threshold:
453
+ seek = content_frames
454
+ current_segments[si:] = []
455
+ break
456
+ hal_last_end = segment["end"]
457
+
458
+ last_word_end = get_end(current_segments)
459
+ if last_word_end is not None:
460
+ last_speech_timestamp = last_word_end
461
+
462
+ if verbose:
463
+ for segment in current_segments:
464
+ start, end, text = segment["start"], segment["end"], segment["text"]
465
+ line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
466
+ print(make_safe(line))
467
+
468
+ # if a segment is instantaneous or does not contain text, clear it
469
+ for i, segment in enumerate(current_segments):
470
+ if segment["start"] == segment["end"] or segment["text"].strip() == "":
471
+ segment["text"] = ""
472
+ segment["tokens"] = []
473
+ segment["words"] = []
474
+
475
+ all_segments.extend(
476
+ [
477
+ {"id": i, **segment}
478
+ for i, segment in enumerate(
479
+ current_segments, start=len(all_segments)
480
+ )
481
+ ]
482
+ )
483
+ all_tokens.extend(
484
+ [token for segment in current_segments for token in segment["tokens"]]
485
+ )
486
+
487
+ if not condition_on_previous_text or result.temperature > 0.5:
488
+ # do not feed the prompt tokens if a high temperature was used
489
+ prompt_reset_since = len(all_tokens)
490
+
491
+ # update progress bar
492
+ pbar.update(min(content_frames, seek) - previous_seek)
493
+
494
+ return dict(
495
+ text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
496
+ segments=all_segments,
497
+ language=language,
498
+ )
499
+
500
+
501
+ def cli():
502
+ from . import available_models
503
+
504
+ def valid_model_name(name):
505
+ if name in available_models() or os.path.exists(name):
506
+ return name
507
+ raise ValueError(
508
+ f"model should be one of {available_models()} or path to a model checkpoint"
509
+ )
510
+
511
+ # fmt: off
512
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
513
+ parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
514
+ parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
515
+ parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
516
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
517
+ parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
518
+ parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
519
+ parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
520
+
521
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
522
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
523
+
524
+ parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
525
+ parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
526
+ parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
527
+ parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
528
+ parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
529
+
530
+ parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
531
+ parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
532
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
533
+ parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
534
+
535
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
536
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
537
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
538
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
539
+ parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
540
+ parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
541
+ parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
542
+ parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
543
+ parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
544
+ parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
545
+ parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
546
+ parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
547
+ parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
548
+ parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
549
+ # fmt: on
550
+
551
+ args = parser.parse_args().__dict__
552
+ model_name: str = args.pop("model")
553
+ model_dir: str = args.pop("model_dir")
554
+ output_dir: str = args.pop("output_dir")
555
+ output_format: str = args.pop("output_format")
556
+ device: str = args.pop("device")
557
+ os.makedirs(output_dir, exist_ok=True)
558
+
559
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
560
+ if args["language"] is not None:
561
+ warnings.warn(
562
+ f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
563
+ )
564
+ args["language"] = "en"
565
+
566
+ temperature = args.pop("temperature")
567
+ if (increment := args.pop("temperature_increment_on_fallback")) is not None:
568
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
569
+ else:
570
+ temperature = [temperature]
571
+
572
+ if (threads := args.pop("threads")) > 0:
573
+ torch.set_num_threads(threads)
574
+
575
+ from . import load_model
576
+
577
+ model = load_model(model_name, device=device, download_root=model_dir)
578
+
579
+ writer = get_writer(output_format, output_dir)
580
+ word_options = [
581
+ "highlight_words",
582
+ "max_line_count",
583
+ "max_line_width",
584
+ "max_words_per_line",
585
+ ]
586
+ if not args["word_timestamps"]:
587
+ for option in word_options:
588
+ if args[option]:
589
+ parser.error(f"--{option} requires --word_timestamps True")
590
+ if args["max_line_count"] and not args["max_line_width"]:
591
+ warnings.warn("--max_line_count has no effect without --max_line_width")
592
+ if args["max_words_per_line"] and args["max_line_width"]:
593
+ warnings.warn("--max_words_per_line has no effect with --max_line_width")
594
+ writer_args = {arg: args.pop(arg) for arg in word_options}
595
+ for audio_path in args.pop("audio"):
596
+ try:
597
+ result = transcribe(model, audio_path, temperature=temperature, **args)
598
+ writer(result, audio_path, **writer_args)
599
+ except Exception as e:
600
+ traceback.print_exc()
601
+ print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
602
+
603
+
604
+ if __name__ == "__main__":
605
+ cli()
whisper_stream/triton_ops.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ try:
7
+ import triton
8
+ import triton.language as tl
9
+ except ImportError:
10
+ raise RuntimeError("triton import failed; try `pip install --pre triton`")
11
+
12
+
13
+ @triton.jit
14
+ def dtw_kernel(
15
+ cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
16
+ ):
17
+ offsets = tl.arange(0, BLOCK_SIZE)
18
+ mask = offsets < M
19
+
20
+ for k in range(1, N + M + 1): # k = i + j
21
+ tl.debug_barrier()
22
+
23
+ p0 = cost + (k - 1) * cost_stride
24
+ p1 = cost + k * cost_stride
25
+ p2 = cost + k * cost_stride + 1
26
+
27
+ c0 = tl.load(p0 + offsets, mask=mask)
28
+ c1 = tl.load(p1 + offsets, mask=mask)
29
+ c2 = tl.load(p2 + offsets, mask=mask)
30
+
31
+ x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
32
+ cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
33
+
34
+ cost_ptr = cost + (k + 1) * cost_stride + 1
35
+ tl.store(cost_ptr + offsets, cost_row, mask=mask)
36
+
37
+ trace_ptr = trace + (k + 1) * trace_stride + 1
38
+ tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
39
+ tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
40
+ tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
41
+
42
+
43
+ @lru_cache(maxsize=None)
44
+ def median_kernel(filter_width: int):
45
+ @triton.jit
46
+ def kernel(
47
+ y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
48
+ ): # x.shape[-1] == filter_width
49
+ row_idx = tl.program_id(0)
50
+ offsets = tl.arange(0, BLOCK_SIZE)
51
+ mask = offsets < y_stride
52
+
53
+ x_ptr = x + row_idx * x_stride # noqa: F841
54
+ y_ptr = y + row_idx * y_stride
55
+
56
+ LOAD_ALL_ROWS_HERE # noqa: F821
57
+
58
+ BUBBLESORT_HERE # noqa: F821
59
+
60
+ tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
61
+
62
+ kernel = triton.JITFunction(kernel.fn)
63
+ kernel.src = kernel.src.replace(
64
+ " LOAD_ALL_ROWS_HERE",
65
+ "\n".join(
66
+ [
67
+ f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
68
+ for i in range(filter_width)
69
+ ]
70
+ ),
71
+ )
72
+ kernel.src = kernel.src.replace(
73
+ " BUBBLESORT_HERE",
74
+ "\n\n".join(
75
+ [
76
+ "\n\n".join(
77
+ [
78
+ "\n".join(
79
+ [
80
+ f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
81
+ f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
82
+ f" row{j} = smaller",
83
+ f" row{j + 1} = larger",
84
+ ]
85
+ )
86
+ for j in range(filter_width - i - 1)
87
+ ]
88
+ )
89
+ for i in range(filter_width // 2 + 1)
90
+ ]
91
+ ),
92
+ )
93
+ kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
94
+
95
+ return kernel
96
+
97
+
98
+ def median_filter_cuda(x: torch.Tensor, filter_width: int):
99
+ """Apply a median filter of given width along the last dimension of x"""
100
+ slices = x.contiguous().unfold(-1, filter_width, 1)
101
+ grid = np.prod(slices.shape[:-2])
102
+
103
+ kernel = median_kernel(filter_width)
104
+ y = torch.empty_like(slices[..., 0])
105
+
106
+ BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
107
+ kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
108
+
109
+ return y
whisper_stream/utils.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import re
4
+ import sys
5
+ import zlib
6
+ from typing import Callable, List, Optional, TextIO
7
+
8
+ system_encoding = sys.getdefaultencoding()
9
+
10
+ if system_encoding != "utf-8":
11
+
12
+ def make_safe(string):
13
+ # replaces any character not representable using the system default encoding with an '?',
14
+ # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
15
+ return string.encode(system_encoding, errors="replace").decode(system_encoding)
16
+
17
+ else:
18
+
19
+ def make_safe(string):
20
+ # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
21
+ return string
22
+
23
+
24
+ def exact_div(x, y):
25
+ assert x % y == 0
26
+ return x // y
27
+
28
+
29
+ def str2bool(string):
30
+ str2val = {"True": True, "False": False}
31
+ if string in str2val:
32
+ return str2val[string]
33
+ else:
34
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
35
+
36
+
37
+ def optional_int(string):
38
+ return None if string == "None" else int(string)
39
+
40
+
41
+ def optional_float(string):
42
+ return None if string == "None" else float(string)
43
+
44
+
45
+ def compression_ratio(text) -> float:
46
+ text_bytes = text.encode("utf-8")
47
+ return len(text_bytes) / len(zlib.compress(text_bytes))
48
+
49
+
50
+ def format_timestamp(
51
+ seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
52
+ ):
53
+ assert seconds >= 0, "non-negative timestamp expected"
54
+ milliseconds = round(seconds * 1000.0)
55
+
56
+ hours = milliseconds // 3_600_000
57
+ milliseconds -= hours * 3_600_000
58
+
59
+ minutes = milliseconds // 60_000
60
+ milliseconds -= minutes * 60_000
61
+
62
+ seconds = milliseconds // 1_000
63
+ milliseconds -= seconds * 1_000
64
+
65
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
66
+ return (
67
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
68
+ )
69
+
70
+
71
+ def get_start(segments: List[dict]) -> Optional[float]:
72
+ return next(
73
+ (w["start"] for s in segments for w in s["words"]),
74
+ segments[0]["start"] if segments else None,
75
+ )
76
+
77
+
78
+ def get_end(segments: List[dict]) -> Optional[float]:
79
+ return next(
80
+ (w["end"] for s in reversed(segments) for w in reversed(s["words"])),
81
+ segments[-1]["end"] if segments else None,
82
+ )
83
+
84
+
85
+ class ResultWriter:
86
+ extension: str
87
+
88
+ def __init__(self, output_dir: str):
89
+ self.output_dir = output_dir
90
+
91
+ def __call__(
92
+ self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
93
+ ):
94
+ audio_basename = os.path.basename(audio_path)
95
+ audio_basename = os.path.splitext(audio_basename)[0]
96
+ output_path = os.path.join(
97
+ self.output_dir, audio_basename + "." + self.extension
98
+ )
99
+
100
+ with open(output_path, "w", encoding="utf-8") as f:
101
+ self.write_result(result, file=f, options=options, **kwargs)
102
+
103
+ def write_result(
104
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
105
+ ):
106
+ raise NotImplementedError
107
+
108
+
109
+ class WriteTXT(ResultWriter):
110
+ extension: str = "txt"
111
+
112
+ def write_result(
113
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
114
+ ):
115
+ for segment in result["segments"]:
116
+ print(segment["text"].strip(), file=file, flush=True)
117
+
118
+
119
+ class SubtitlesWriter(ResultWriter):
120
+ always_include_hours: bool
121
+ decimal_marker: str
122
+
123
+ def iterate_result(
124
+ self,
125
+ result: dict,
126
+ options: Optional[dict] = None,
127
+ *,
128
+ max_line_width: Optional[int] = None,
129
+ max_line_count: Optional[int] = None,
130
+ highlight_words: bool = False,
131
+ max_words_per_line: Optional[int] = None,
132
+ ):
133
+ options = options or {}
134
+ max_line_width = max_line_width or options.get("max_line_width")
135
+ max_line_count = max_line_count or options.get("max_line_count")
136
+ highlight_words = highlight_words or options.get("highlight_words", False)
137
+ max_words_per_line = max_words_per_line or options.get("max_words_per_line")
138
+ preserve_segments = max_line_count is None or max_line_width is None
139
+ max_line_width = max_line_width or 1000
140
+ max_words_per_line = max_words_per_line or 1000
141
+
142
+ def iterate_subtitles():
143
+ line_len = 0
144
+ line_count = 1
145
+ # the next subtitle to yield (a list of word timings with whitespace)
146
+ subtitle: List[dict] = []
147
+ last: float = get_start(result["segments"]) or 0.0
148
+ for segment in result["segments"]:
149
+ chunk_index = 0
150
+ words_count = max_words_per_line
151
+ while chunk_index < len(segment["words"]):
152
+ remaining_words = len(segment["words"]) - chunk_index
153
+ if max_words_per_line > len(segment["words"]) - chunk_index:
154
+ words_count = remaining_words
155
+ for i, original_timing in enumerate(
156
+ segment["words"][chunk_index : chunk_index + words_count]
157
+ ):
158
+ timing = original_timing.copy()
159
+ long_pause = (
160
+ not preserve_segments and timing["start"] - last > 3.0
161
+ )
162
+ has_room = line_len + len(timing["word"]) <= max_line_width
163
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
164
+ if (
165
+ line_len > 0
166
+ and has_room
167
+ and not long_pause
168
+ and not seg_break
169
+ ):
170
+ # line continuation
171
+ line_len += len(timing["word"])
172
+ else:
173
+ # new line
174
+ timing["word"] = timing["word"].strip()
175
+ if (
176
+ len(subtitle) > 0
177
+ and max_line_count is not None
178
+ and (long_pause or line_count >= max_line_count)
179
+ or seg_break
180
+ ):
181
+ # subtitle break
182
+ yield subtitle
183
+ subtitle = []
184
+ line_count = 1
185
+ elif line_len > 0:
186
+ # line break
187
+ line_count += 1
188
+ timing["word"] = "\n" + timing["word"]
189
+ line_len = len(timing["word"].strip())
190
+ subtitle.append(timing)
191
+ last = timing["start"]
192
+ chunk_index += max_words_per_line
193
+ if len(subtitle) > 0:
194
+ yield subtitle
195
+
196
+ if len(result["segments"]) > 0 and "words" in result["segments"][0]:
197
+ for subtitle in iterate_subtitles():
198
+ subtitle_start = self.format_timestamp(subtitle[0]["start"])
199
+ subtitle_end = self.format_timestamp(subtitle[-1]["end"])
200
+ subtitle_text = "".join([word["word"] for word in subtitle])
201
+ if highlight_words:
202
+ last = subtitle_start
203
+ all_words = [timing["word"] for timing in subtitle]
204
+ for i, this_word in enumerate(subtitle):
205
+ start = self.format_timestamp(this_word["start"])
206
+ end = self.format_timestamp(this_word["end"])
207
+ if last != start:
208
+ yield last, start, subtitle_text
209
+
210
+ yield start, end, "".join(
211
+ [
212
+ re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
213
+ if j == i
214
+ else word
215
+ for j, word in enumerate(all_words)
216
+ ]
217
+ )
218
+ last = end
219
+ else:
220
+ yield subtitle_start, subtitle_end, subtitle_text
221
+ else:
222
+ for segment in result["segments"]:
223
+ segment_start = self.format_timestamp(segment["start"])
224
+ segment_end = self.format_timestamp(segment["end"])
225
+ segment_text = segment["text"].strip().replace("-->", "->")
226
+ yield segment_start, segment_end, segment_text
227
+
228
+ def format_timestamp(self, seconds: float):
229
+ return format_timestamp(
230
+ seconds=seconds,
231
+ always_include_hours=self.always_include_hours,
232
+ decimal_marker=self.decimal_marker,
233
+ )
234
+
235
+
236
+ class WriteVTT(SubtitlesWriter):
237
+ extension: str = "vtt"
238
+ always_include_hours: bool = False
239
+ decimal_marker: str = "."
240
+
241
+ def write_result(
242
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
243
+ ):
244
+ print("WEBVTT\n", file=file)
245
+ for start, end, text in self.iterate_result(result, options, **kwargs):
246
+ print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
247
+
248
+
249
+ class WriteSRT(SubtitlesWriter):
250
+ extension: str = "srt"
251
+ always_include_hours: bool = True
252
+ decimal_marker: str = ","
253
+
254
+ def write_result(
255
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
256
+ ):
257
+ for i, (start, end, text) in enumerate(
258
+ self.iterate_result(result, options, **kwargs), start=1
259
+ ):
260
+ print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
261
+
262
+
263
+ class WriteTSV(ResultWriter):
264
+ """
265
+ Write a transcript to a file in TSV (tab-separated values) format containing lines like:
266
+ <start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
267
+
268
+ Using integer milliseconds as start and end times means there's no chance of interference from
269
+ an environment setting a language encoding that causes the decimal in a floating point number
270
+ to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
271
+ """
272
+
273
+ extension: str = "tsv"
274
+
275
+ def write_result(
276
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
277
+ ):
278
+ print("start", "end", "text", sep="\t", file=file)
279
+ for segment in result["segments"]:
280
+ print(round(1000 * segment["start"]), file=file, end="\t")
281
+ print(round(1000 * segment["end"]), file=file, end="\t")
282
+ print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
283
+
284
+
285
+ class WriteJSON(ResultWriter):
286
+ extension: str = "json"
287
+
288
+ def write_result(
289
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
290
+ ):
291
+ json.dump(result, file)
292
+
293
+
294
+ def get_writer(
295
+ output_format: str, output_dir: str
296
+ ) -> Callable[[dict, TextIO, dict], None]:
297
+ writers = {
298
+ "txt": WriteTXT,
299
+ "vtt": WriteVTT,
300
+ "srt": WriteSRT,
301
+ "tsv": WriteTSV,
302
+ "json": WriteJSON,
303
+ }
304
+
305
+ if output_format == "all":
306
+ all_writers = [writer(output_dir) for writer in writers.values()]
307
+
308
+ def write_all(
309
+ result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
310
+ ):
311
+ for writer in all_writers:
312
+ writer(result, file, options, **kwargs)
313
+
314
+ return write_all
315
+
316
+ return writers[output_format](output_dir)
whisper_stream/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = "20231117"