first commit
Browse files- .gitignore +4 -0
- dockerfile +54 -0
- requirements.txt +16 -0
- static/client.html +567 -0
- static/whisper_client.html +595 -0
- unified_socket_server.py +325 -0
- whisper_stream/__init__.py +331 -0
- whisper_stream/__main__.py +3 -0
- whisper_stream/assets/gpt2.tiktoken +0 -0
- whisper_stream/assets/mel_filters.npz +0 -0
- whisper_stream/assets/multilingual.tiktoken +0 -0
- whisper_stream/audio.py +357 -0
- whisper_stream/decoding.py +826 -0
- whisper_stream/model.py +329 -0
- whisper_stream/normalizers/__init__.py +2 -0
- whisper_stream/normalizers/basic.py +76 -0
- whisper_stream/normalizers/english.json +1741 -0
- whisper_stream/normalizers/english.py +550 -0
- whisper_stream/streaming_decoding.py +1198 -0
- whisper_stream/streaming_model.py +444 -0
- whisper_stream/streaming_transcribe.py +208 -0
- whisper_stream/timing.py +386 -0
- whisper_stream/tokenizer.py +395 -0
- whisper_stream/transcribe.py +605 -0
- whisper_stream/triton_ops.py +109 -0
- whisper_stream/utils.py +316 -0
- whisper_stream/version.py +1 -0
.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.ipynb
|
2 |
+
old_demo_code/
|
3 |
+
htokenf.txt
|
4 |
+
__pycache__/
|
dockerfile
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Use NVIDIA CUDA base image
|
2 |
+
FROM nvidia/cuda:11.8-devel-ubuntu20.04
|
3 |
+
|
4 |
+
# Set environment variables
|
5 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
6 |
+
ENV PYTHONUNBUFFERED=1
|
7 |
+
|
8 |
+
# Install system dependencies
|
9 |
+
RUN apt-get update && apt-get install -y \
|
10 |
+
python3.9 \
|
11 |
+
python3.9-dev \
|
12 |
+
python3-pip \
|
13 |
+
ffmpeg \
|
14 |
+
git \
|
15 |
+
wget \
|
16 |
+
curl \
|
17 |
+
&& rm -rf /var/lib/apt/lists/*
|
18 |
+
|
19 |
+
# Create symlinks for python
|
20 |
+
RUN ln -s /usr/bin/python3.9 /usr/bin/python
|
21 |
+
|
22 |
+
# Set up a new user named "user" with user ID 1000
|
23 |
+
RUN useradd -m -u 1000 user
|
24 |
+
|
25 |
+
# Switch to the "user" user
|
26 |
+
USER user
|
27 |
+
|
28 |
+
# Set home to the user's home directory
|
29 |
+
ENV HOME=/home/user \
|
30 |
+
PATH=/home/user/.local/bin:$PATH
|
31 |
+
|
32 |
+
# Set the working directory to the user's home directory
|
33 |
+
WORKDIR $HOME/app
|
34 |
+
|
35 |
+
# Upgrade pip
|
36 |
+
RUN python -m pip install --no-cache-dir --upgrade pip
|
37 |
+
|
38 |
+
# Install PyTorch with CUDA support first
|
39 |
+
RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
40 |
+
|
41 |
+
# Copy requirements file
|
42 |
+
COPY --chown=user requirements.txt .
|
43 |
+
|
44 |
+
# Install other Python dependencies
|
45 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
46 |
+
|
47 |
+
# Copy the application code
|
48 |
+
COPY --chown=user . .
|
49 |
+
|
50 |
+
# Expose the port
|
51 |
+
EXPOSE 7860
|
52 |
+
|
53 |
+
# Run the application
|
54 |
+
CMD ["python", "unified_socket_server.py", "--host", "0.0.0.0", "--port", "7860"]
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
evaluate==0.4.3
|
2 |
+
librosa==0.10.2.post1
|
3 |
+
more_itertools==10.7.0
|
4 |
+
numba==0.60.0
|
5 |
+
numpy==1.26.0
|
6 |
+
openai_whisper==20240930
|
7 |
+
pandas==2.3.1
|
8 |
+
praatio==6.2.0
|
9 |
+
pyaudio==0.2.11
|
10 |
+
pytorch_lightning==2.5.0.post0
|
11 |
+
regex==2024.11.6
|
12 |
+
soundfile==0.12.1
|
13 |
+
tiktoken==0.8.0
|
14 |
+
tqdm==4.66.5
|
15 |
+
triton==3.2.0
|
16 |
+
websockets==15.0.1
|
static/client.html
ADDED
@@ -0,0 +1,567 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Real-Time Whisper Transcription</title>
|
7 |
+
<style>
|
8 |
+
body {
|
9 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
10 |
+
max-width: 1200px;
|
11 |
+
margin: 0 auto;
|
12 |
+
padding: 20px;
|
13 |
+
background-color: #f5f5f5;
|
14 |
+
}
|
15 |
+
|
16 |
+
.container {
|
17 |
+
background: white;
|
18 |
+
border-radius: 10px;
|
19 |
+
padding: 30px;
|
20 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
21 |
+
}
|
22 |
+
|
23 |
+
h1 {
|
24 |
+
color: #333;
|
25 |
+
text-align: center;
|
26 |
+
margin-bottom: 30px;
|
27 |
+
}
|
28 |
+
|
29 |
+
.config-section {
|
30 |
+
display: grid;
|
31 |
+
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
32 |
+
gap: 15px;
|
33 |
+
margin-bottom: 30px;
|
34 |
+
padding: 20px;
|
35 |
+
background: #f8f9fa;
|
36 |
+
border-radius: 8px;
|
37 |
+
}
|
38 |
+
|
39 |
+
.config-group {
|
40 |
+
display: flex;
|
41 |
+
flex-direction: column;
|
42 |
+
}
|
43 |
+
|
44 |
+
label {
|
45 |
+
font-weight: 600;
|
46 |
+
margin-bottom: 5px;
|
47 |
+
color: #555;
|
48 |
+
}
|
49 |
+
|
50 |
+
input, select {
|
51 |
+
padding: 10px;
|
52 |
+
border: 2px solid #ddd;
|
53 |
+
border-radius: 5px;
|
54 |
+
font-size: 14px;
|
55 |
+
transition: border-color 0.3s;
|
56 |
+
}
|
57 |
+
|
58 |
+
input:focus, select:focus {
|
59 |
+
outline: none;
|
60 |
+
border-color: #007bff;
|
61 |
+
}
|
62 |
+
|
63 |
+
.controls {
|
64 |
+
display: flex;
|
65 |
+
gap: 10px;
|
66 |
+
justify-content: center;
|
67 |
+
margin-bottom: 30px;
|
68 |
+
}
|
69 |
+
|
70 |
+
button {
|
71 |
+
padding: 12px 24px;
|
72 |
+
border: none;
|
73 |
+
border-radius: 5px;
|
74 |
+
font-size: 16px;
|
75 |
+
font-weight: 600;
|
76 |
+
cursor: pointer;
|
77 |
+
transition: all 0.3s;
|
78 |
+
}
|
79 |
+
|
80 |
+
.start-btn {
|
81 |
+
background: #28a745;
|
82 |
+
color: white;
|
83 |
+
}
|
84 |
+
|
85 |
+
.start-btn:hover:not(:disabled) {
|
86 |
+
background: #218838;
|
87 |
+
}
|
88 |
+
|
89 |
+
.stop-btn {
|
90 |
+
background: #dc3545;
|
91 |
+
color: white;
|
92 |
+
}
|
93 |
+
|
94 |
+
.stop-btn:hover:not(:disabled) {
|
95 |
+
background: #c82333;
|
96 |
+
}
|
97 |
+
|
98 |
+
.clear-btn {
|
99 |
+
background: #6c757d;
|
100 |
+
color: white;
|
101 |
+
}
|
102 |
+
|
103 |
+
.clear-btn:hover:not(:disabled) {
|
104 |
+
background: #5a6268;
|
105 |
+
}
|
106 |
+
|
107 |
+
button:disabled {
|
108 |
+
opacity: 0.6;
|
109 |
+
cursor: not-allowed;
|
110 |
+
}
|
111 |
+
|
112 |
+
.status {
|
113 |
+
display: flex;
|
114 |
+
align-items: center;
|
115 |
+
justify-content: center;
|
116 |
+
gap: 10px;
|
117 |
+
margin-bottom: 20px;
|
118 |
+
font-weight: 600;
|
119 |
+
}
|
120 |
+
|
121 |
+
.status-indicator {
|
122 |
+
width: 12px;
|
123 |
+
height: 12px;
|
124 |
+
border-radius: 50%;
|
125 |
+
background: #dc3545;
|
126 |
+
animation: pulse 2s infinite;
|
127 |
+
}
|
128 |
+
|
129 |
+
.status-indicator.connected {
|
130 |
+
background: #28a745;
|
131 |
+
}
|
132 |
+
|
133 |
+
.status-indicator.streaming {
|
134 |
+
background: #ffc107;
|
135 |
+
}
|
136 |
+
|
137 |
+
@keyframes pulse {
|
138 |
+
0% { opacity: 1; }
|
139 |
+
50% { opacity: 0.5; }
|
140 |
+
100% { opacity: 1; }
|
141 |
+
}
|
142 |
+
|
143 |
+
.transcription-section {
|
144 |
+
display: grid;
|
145 |
+
grid-template-columns: 1fr 1fr;
|
146 |
+
gap: 20px;
|
147 |
+
}
|
148 |
+
|
149 |
+
.transcription-panel {
|
150 |
+
background: #f8f9fa;
|
151 |
+
border-radius: 8px;
|
152 |
+
padding: 20px;
|
153 |
+
}
|
154 |
+
|
155 |
+
.transcription-panel h3 {
|
156 |
+
margin-top: 0;
|
157 |
+
color: #333;
|
158 |
+
}
|
159 |
+
|
160 |
+
.log-area, .transcription-area {
|
161 |
+
background: #fff;
|
162 |
+
border: 1px solid #ddd;
|
163 |
+
border-radius: 5px;
|
164 |
+
padding: 15px;
|
165 |
+
height: 300px;
|
166 |
+
overflow-y: auto;
|
167 |
+
font-family: 'Courier New', monospace;
|
168 |
+
font-size: 14px;
|
169 |
+
line-height: 1.4;
|
170 |
+
white-space: pre-wrap;
|
171 |
+
}
|
172 |
+
|
173 |
+
.transcription-area {
|
174 |
+
font-family: inherit;
|
175 |
+
font-size: 16px;
|
176 |
+
line-height: 1.6;
|
177 |
+
}
|
178 |
+
|
179 |
+
.stats {
|
180 |
+
display: grid;
|
181 |
+
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
|
182 |
+
gap: 15px;
|
183 |
+
margin-top: 20px;
|
184 |
+
}
|
185 |
+
|
186 |
+
.stat-item {
|
187 |
+
text-align: center;
|
188 |
+
padding: 10px;
|
189 |
+
background: #e9ecef;
|
190 |
+
border-radius: 5px;
|
191 |
+
}
|
192 |
+
|
193 |
+
.stat-value {
|
194 |
+
font-size: 18px;
|
195 |
+
font-weight: bold;
|
196 |
+
color: #007bff;
|
197 |
+
}
|
198 |
+
|
199 |
+
.stat-label {
|
200 |
+
font-size: 12px;
|
201 |
+
color: #666;
|
202 |
+
margin-top: 5px;
|
203 |
+
}
|
204 |
+
|
205 |
+
@media (max-width: 768px) {
|
206 |
+
.config-section {
|
207 |
+
grid-template-columns: 1fr;
|
208 |
+
}
|
209 |
+
|
210 |
+
.transcription-section {
|
211 |
+
grid-template-columns: 1fr;
|
212 |
+
}
|
213 |
+
|
214 |
+
.controls {
|
215 |
+
flex-direction: column;
|
216 |
+
align-items: center;
|
217 |
+
}
|
218 |
+
|
219 |
+
button {
|
220 |
+
width: 200px;
|
221 |
+
}
|
222 |
+
}
|
223 |
+
</style>
|
224 |
+
</head>
|
225 |
+
<body>
|
226 |
+
<div class="container">
|
227 |
+
<h1>🎤 Real-Time Whisper Transcription</h1>
|
228 |
+
|
229 |
+
<div class="config-section">
|
230 |
+
<div class="config-group">
|
231 |
+
<label for="modelSize">Model Size:</label>
|
232 |
+
<select id="modelSize">
|
233 |
+
<option value="base">Base (74 MB)</option>
|
234 |
+
<option value="small" selected>Small (244 MB)</option>
|
235 |
+
<option value="large-v2">Large-v2 (1550 MB)</option>
|
236 |
+
</select>
|
237 |
+
</div>
|
238 |
+
|
239 |
+
<div class="config-group">
|
240 |
+
<label for="chunkSize">Chunk Size (ms):</label>
|
241 |
+
<input type="number" id="chunkSize" value="300" min="100" max="2000" step="100">
|
242 |
+
</div>
|
243 |
+
|
244 |
+
<div class="config-group">
|
245 |
+
<label for="beamSize">Beam Size:</label>
|
246 |
+
<input type="number" id="beamSize" value="0" min="0" max="10">
|
247 |
+
</div>
|
248 |
+
|
249 |
+
<div class="config-group">
|
250 |
+
<label for="temperature">Temperature:</label>
|
251 |
+
<input type="number" id="temperature" value="0.0" min="0.0" max="1.0" step="0.1">
|
252 |
+
</div>
|
253 |
+
</div>
|
254 |
+
|
255 |
+
<div class="controls">
|
256 |
+
<button id="startBtn" class="start-btn" onclick="startStreaming()">🎤 Start Recording</button>
|
257 |
+
<button id="stopBtn" class="stop-btn" onclick="stopStreaming()" disabled>⏹️ Stop Recording</button>
|
258 |
+
<button id="clearBtn" class="clear-btn" onclick="clearAll()">🗑️ Clear All</button>
|
259 |
+
</div>
|
260 |
+
|
261 |
+
<div class="status">
|
262 |
+
<div class="status-indicator" id="statusIndicator"></div>
|
263 |
+
<span id="statusText">Disconnected</span>
|
264 |
+
</div>
|
265 |
+
|
266 |
+
<div class="transcription-section">
|
267 |
+
<div class="transcription-panel">
|
268 |
+
<h3>📝 Transcription</h3>
|
269 |
+
<div id="transcriptionArea" class="transcription-area"></div>
|
270 |
+
</div>
|
271 |
+
|
272 |
+
<div class="transcription-panel">
|
273 |
+
<h3>📋 System Log</h3>
|
274 |
+
<div id="logArea" class="log-area"></div>
|
275 |
+
</div>
|
276 |
+
</div>
|
277 |
+
|
278 |
+
<div class="stats">
|
279 |
+
<div class="stat-item">
|
280 |
+
<div class="stat-value" id="durationStat">0.0s</div>
|
281 |
+
<div class="stat-label">Duration</div>
|
282 |
+
</div>
|
283 |
+
<div class="stat-item">
|
284 |
+
<div class="stat-value" id="chunksStat">0</div>
|
285 |
+
<div class="stat-label">Chunks Sent</div>
|
286 |
+
</div>
|
287 |
+
<div class="stat-item">
|
288 |
+
<div class="stat-value" id="transcriptionsStat">0</div>
|
289 |
+
<div class="stat-label">Transcriptions</div>
|
290 |
+
</div>
|
291 |
+
<div class="stat-item">
|
292 |
+
<div class="stat-value" id="errorsStat">0</div>
|
293 |
+
<div class="stat-label">Errors</div>
|
294 |
+
</div>
|
295 |
+
</div>
|
296 |
+
</div>
|
297 |
+
|
298 |
+
<script>
|
299 |
+
let socket;
|
300 |
+
let audioContext, processor, micStream;
|
301 |
+
let isStreaming = false;
|
302 |
+
let startTime = 0;
|
303 |
+
let stats = {
|
304 |
+
chunks: 0,
|
305 |
+
transcriptions: 0,
|
306 |
+
errors: 0
|
307 |
+
};
|
308 |
+
|
309 |
+
// UI Elements
|
310 |
+
const startBtn = document.getElementById('startBtn');
|
311 |
+
const stopBtn = document.getElementById('stopBtn');
|
312 |
+
const statusIndicator = document.getElementById('statusIndicator');
|
313 |
+
const statusText = document.getElementById('statusText');
|
314 |
+
const logArea = document.getElementById('logArea');
|
315 |
+
const transcriptionArea = document.getElementById('transcriptionArea');
|
316 |
+
|
317 |
+
function log(message, isError = false) {
|
318 |
+
const timestamp = new Date().toLocaleTimeString();
|
319 |
+
const logMessage = `[${timestamp}] ${message}`;
|
320 |
+
logArea.textContent += logMessage + '\n';
|
321 |
+
logArea.scrollTop = logArea.scrollHeight;
|
322 |
+
|
323 |
+
if (isError) {
|
324 |
+
console.error(logMessage);
|
325 |
+
stats.errors++;
|
326 |
+
updateStats();
|
327 |
+
} else {
|
328 |
+
console.log(logMessage);
|
329 |
+
}
|
330 |
+
}
|
331 |
+
|
332 |
+
function updateStatus(status, color) {
|
333 |
+
statusText.textContent = status;
|
334 |
+
statusIndicator.className = `status-indicator ${color}`;
|
335 |
+
}
|
336 |
+
|
337 |
+
function updateStats() {
|
338 |
+
document.getElementById('durationStat').textContent =
|
339 |
+
isStreaming ? ((Date.now() - startTime) / 1000).toFixed(1) + 's' : '0.0s';
|
340 |
+
document.getElementById('chunksStat').textContent = stats.chunks;
|
341 |
+
document.getElementById('transcriptionsStat').textContent = stats.transcriptions;
|
342 |
+
document.getElementById('errorsStat').textContent = stats.errors;
|
343 |
+
}
|
344 |
+
|
345 |
+
function addTranscription(text, timestamp) {
|
346 |
+
const transcriptionText = `[${timestamp.toFixed(1)}s] ${text}\n`;
|
347 |
+
transcriptionArea.textContent += transcriptionText;
|
348 |
+
transcriptionArea.scrollTop = transcriptionArea.scrollHeight;
|
349 |
+
stats.transcriptions++;
|
350 |
+
updateStats();
|
351 |
+
}
|
352 |
+
|
353 |
+
async function startStreaming() {
|
354 |
+
try {
|
355 |
+
// Get configuration
|
356 |
+
const config = {
|
357 |
+
model_size: document.getElementById('modelSize').value,
|
358 |
+
chunk_size: parseInt(document.getElementById('chunkSize').value),
|
359 |
+
beam_size: parseInt(document.getElementById('beamSize').value),
|
360 |
+
temperature: parseFloat(document.getElementById('temperature').value)
|
361 |
+
};
|
362 |
+
|
363 |
+
log('Starting transcription session...');
|
364 |
+
log(`Config: ${JSON.stringify(config, null, 2)}`);
|
365 |
+
|
366 |
+
// Create WebSocket URL (relative to current page)
|
367 |
+
const wsProtocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
368 |
+
const wsUrl = `${wsProtocol}//${window.location.host}/ws`;
|
369 |
+
|
370 |
+
// Initialize WebSocket
|
371 |
+
socket = new WebSocket(wsUrl);
|
372 |
+
socket.binaryType = 'arraybuffer';
|
373 |
+
|
374 |
+
socket.onopen = async () => {
|
375 |
+
log('WebSocket connected');
|
376 |
+
updateStatus('Connected', 'connected');
|
377 |
+
|
378 |
+
// Send configuration
|
379 |
+
log('Sending configuration...');
|
380 |
+
socket.send(JSON.stringify(config));
|
381 |
+
};
|
382 |
+
|
383 |
+
socket.onmessage = (event) => {
|
384 |
+
try {
|
385 |
+
const data = JSON.parse(event.data);
|
386 |
+
|
387 |
+
if (data.error) {
|
388 |
+
log(`Server error: ${data.error}`, true);
|
389 |
+
return;
|
390 |
+
}
|
391 |
+
|
392 |
+
if (data.status === 'CONFIG_RECEIVED') {
|
393 |
+
log(`Configuration received. GPU: ${data.gpu ? 'Yes' : 'No'}`);
|
394 |
+
if (data.fallback) {
|
395 |
+
log('Using fallback Whisper model');
|
396 |
+
}
|
397 |
+
// Wait a bit for WebSocket to be fully ready
|
398 |
+
setTimeout(() => {
|
399 |
+
startAudioCapture();
|
400 |
+
}, 100);
|
401 |
+
} else if (data.text) {
|
402 |
+
addTranscription(data.text, data.timestamp || 0);
|
403 |
+
}
|
404 |
+
} catch (e) {
|
405 |
+
log(`Server message: ${event.data}`);
|
406 |
+
}
|
407 |
+
};
|
408 |
+
|
409 |
+
socket.onerror = (error) => {
|
410 |
+
log(`WebSocket error: ${error}`, true);
|
411 |
+
updateStatus('Error', 'error');
|
412 |
+
};
|
413 |
+
|
414 |
+
socket.onclose = () => {
|
415 |
+
log('WebSocket disconnected');
|
416 |
+
updateStatus('Disconnected', '');
|
417 |
+
stopStreaming();
|
418 |
+
};
|
419 |
+
|
420 |
+
} catch (error) {
|
421 |
+
log(`Failed to start: ${error.message}`, true);
|
422 |
+
updateStatus('Error', 'error');
|
423 |
+
}
|
424 |
+
}
|
425 |
+
|
426 |
+
async function startAudioCapture() {
|
427 |
+
try {
|
428 |
+
log('Requesting microphone access...');
|
429 |
+
|
430 |
+
const stream = await navigator.mediaDevices.getUserMedia({
|
431 |
+
audio: {
|
432 |
+
sampleRate: 16000,
|
433 |
+
channelCount: 1,
|
434 |
+
echoCancellation: true,
|
435 |
+
noiseSuppression: true,
|
436 |
+
autoGainControl: true
|
437 |
+
}
|
438 |
+
});
|
439 |
+
|
440 |
+
log('Microphone access granted');
|
441 |
+
|
442 |
+
// Create audio context
|
443 |
+
audioContext = new (window.AudioContext || window.webkitAudioContext)({
|
444 |
+
sampleRate: 16000
|
445 |
+
});
|
446 |
+
|
447 |
+
// Resume audio context if suspended
|
448 |
+
if (audioContext.state === 'suspended') {
|
449 |
+
await audioContext.resume();
|
450 |
+
log('Audio context resumed');
|
451 |
+
}
|
452 |
+
|
453 |
+
// Create media stream source
|
454 |
+
micStream = audioContext.createMediaStreamSource(stream);
|
455 |
+
|
456 |
+
// Create script processor
|
457 |
+
const bufferSize = 4096;
|
458 |
+
processor = audioContext.createScriptProcessor(bufferSize, 1, 1);
|
459 |
+
|
460 |
+
processor.onaudioprocess = (event) => {
|
461 |
+
if (!isStreaming || !socket || socket.readyState !== WebSocket.OPEN) {
|
462 |
+
return;
|
463 |
+
}
|
464 |
+
|
465 |
+
const inputData = event.inputBuffer.getChannelData(0);
|
466 |
+
|
467 |
+
// Convert to 16-bit PCM
|
468 |
+
const int16Array = new Int16Array(inputData.length);
|
469 |
+
for (let i = 0; i < inputData.length; i++) {
|
470 |
+
const sample = Math.max(-1, Math.min(1, inputData[i]));
|
471 |
+
int16Array[i] = sample * 32767;
|
472 |
+
}
|
473 |
+
|
474 |
+
// Send to server
|
475 |
+
socket.send(int16Array.buffer);
|
476 |
+
stats.chunks++;
|
477 |
+
};
|
478 |
+
|
479 |
+
// Connect audio nodes
|
480 |
+
micStream.connect(processor);
|
481 |
+
processor.connect(audioContext.destination);
|
482 |
+
|
483 |
+
// Update UI
|
484 |
+
isStreaming = true;
|
485 |
+
startTime = Date.now();
|
486 |
+
startBtn.disabled = true;
|
487 |
+
stopBtn.disabled = false;
|
488 |
+
updateStatus('Streaming', 'streaming');
|
489 |
+
|
490 |
+
log('Audio streaming started');
|
491 |
+
|
492 |
+
// Start stats update timer
|
493 |
+
const statsTimer = setInterval(() => {
|
494 |
+
if (isStreaming) {
|
495 |
+
updateStats();
|
496 |
+
} else {
|
497 |
+
clearInterval(statsTimer);
|
498 |
+
}
|
499 |
+
}, 100);
|
500 |
+
|
501 |
+
} catch (error) {
|
502 |
+
log(`Audio capture failed: ${error.message}`, true);
|
503 |
+
updateStatus('Error', 'error');
|
504 |
+
startBtn.disabled = false;
|
505 |
+
stopBtn.disabled = true;
|
506 |
+
}
|
507 |
+
}
|
508 |
+
|
509 |
+
function stopStreaming() {
|
510 |
+
isStreaming = false;
|
511 |
+
|
512 |
+
// Stop all tracks
|
513 |
+
if (micStream && micStream.mediaStream) {
|
514 |
+
micStream.mediaStream.getTracks().forEach(track => track.stop());
|
515 |
+
}
|
516 |
+
|
517 |
+
// Disconnect audio nodes
|
518 |
+
if (processor) {
|
519 |
+
processor.disconnect();
|
520 |
+
processor = null;
|
521 |
+
}
|
522 |
+
|
523 |
+
if (micStream) {
|
524 |
+
micStream.disconnect();
|
525 |
+
micStream = null;
|
526 |
+
}
|
527 |
+
|
528 |
+
if (audioContext) {
|
529 |
+
audioContext.close();
|
530 |
+
audioContext = null;
|
531 |
+
}
|
532 |
+
|
533 |
+
// Close WebSocket
|
534 |
+
if (socket && socket.readyState === WebSocket.OPEN) {
|
535 |
+
socket.close();
|
536 |
+
}
|
537 |
+
|
538 |
+
// Update UI
|
539 |
+
startBtn.disabled = false;
|
540 |
+
stopBtn.disabled = true;
|
541 |
+
updateStatus('Disconnected', '');
|
542 |
+
|
543 |
+
log('Streaming stopped');
|
544 |
+
updateStats();
|
545 |
+
}
|
546 |
+
|
547 |
+
function clearAll() {
|
548 |
+
logArea.textContent = '';
|
549 |
+
transcriptionArea.textContent = '';
|
550 |
+
stats = { chunks: 0, transcriptions: 0, errors: 0 };
|
551 |
+
updateStats();
|
552 |
+
log('All content cleared');
|
553 |
+
}
|
554 |
+
|
555 |
+
// Initialize
|
556 |
+
updateStats();
|
557 |
+
log('Real-time transcription client ready');
|
558 |
+
|
559 |
+
// Handle page unload
|
560 |
+
window.addEventListener('beforeunload', () => {
|
561 |
+
if (isStreaming) {
|
562 |
+
stopStreaming();
|
563 |
+
}
|
564 |
+
});
|
565 |
+
</script>
|
566 |
+
</body>
|
567 |
+
</html>
|
static/whisper_client.html
ADDED
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Real-Time Whisper Transcription</title>
|
7 |
+
<style>
|
8 |
+
body {
|
9 |
+
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
10 |
+
max-width: 1200px;
|
11 |
+
margin: 0 auto;
|
12 |
+
padding: 20px;
|
13 |
+
background-color: #f5f5f5;
|
14 |
+
}
|
15 |
+
|
16 |
+
.container {
|
17 |
+
background: white;
|
18 |
+
border-radius: 10px;
|
19 |
+
padding: 30px;
|
20 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
21 |
+
}
|
22 |
+
|
23 |
+
h1 {
|
24 |
+
color: #333;
|
25 |
+
text-align: center;
|
26 |
+
margin-bottom: 30px;
|
27 |
+
}
|
28 |
+
|
29 |
+
.config-section {
|
30 |
+
display: grid;
|
31 |
+
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
32 |
+
gap: 15px;
|
33 |
+
margin-bottom: 30px;
|
34 |
+
padding: 20px;
|
35 |
+
background: #f8f9fa;
|
36 |
+
border-radius: 8px;
|
37 |
+
}
|
38 |
+
|
39 |
+
.config-group {
|
40 |
+
display: flex;
|
41 |
+
flex-direction: column;
|
42 |
+
}
|
43 |
+
|
44 |
+
label {
|
45 |
+
font-weight: 600;
|
46 |
+
margin-bottom: 5px;
|
47 |
+
color: #555;
|
48 |
+
}
|
49 |
+
|
50 |
+
input, select {
|
51 |
+
padding: 10px;
|
52 |
+
border: 2px solid #ddd;
|
53 |
+
border-radius: 5px;
|
54 |
+
font-size: 14px;
|
55 |
+
transition: border-color 0.3s;
|
56 |
+
}
|
57 |
+
|
58 |
+
input:focus, select:focus {
|
59 |
+
outline: none;
|
60 |
+
border-color: #007bff;
|
61 |
+
}
|
62 |
+
|
63 |
+
.controls {
|
64 |
+
display: flex;
|
65 |
+
gap: 10px;
|
66 |
+
justify-content: center;
|
67 |
+
margin-bottom: 30px;
|
68 |
+
}
|
69 |
+
|
70 |
+
button {
|
71 |
+
padding: 12px 24px;
|
72 |
+
border: none;
|
73 |
+
border-radius: 5px;
|
74 |
+
font-size: 16px;
|
75 |
+
font-weight: 600;
|
76 |
+
cursor: pointer;
|
77 |
+
transition: all 0.3s;
|
78 |
+
}
|
79 |
+
|
80 |
+
.start-btn {
|
81 |
+
background: #28a745;
|
82 |
+
color: white;
|
83 |
+
}
|
84 |
+
|
85 |
+
.start-btn:hover:not(:disabled) {
|
86 |
+
background: #218838;
|
87 |
+
}
|
88 |
+
|
89 |
+
.stop-btn {
|
90 |
+
background: #dc3545;
|
91 |
+
color: white;
|
92 |
+
}
|
93 |
+
|
94 |
+
.stop-btn:hover:not(:disabled) {
|
95 |
+
background: #c82333;
|
96 |
+
}
|
97 |
+
|
98 |
+
.clear-btn {
|
99 |
+
background: #6c757d;
|
100 |
+
color: white;
|
101 |
+
}
|
102 |
+
|
103 |
+
.clear-btn:hover:not(:disabled) {
|
104 |
+
background: #5a6268;
|
105 |
+
}
|
106 |
+
|
107 |
+
button:disabled {
|
108 |
+
opacity: 0.6;
|
109 |
+
cursor: not-allowed;
|
110 |
+
}
|
111 |
+
|
112 |
+
.status {
|
113 |
+
display: flex;
|
114 |
+
align-items: center;
|
115 |
+
justify-content: center;
|
116 |
+
gap: 10px;
|
117 |
+
margin-bottom: 20px;
|
118 |
+
font-weight: 600;
|
119 |
+
}
|
120 |
+
|
121 |
+
.status-indicator {
|
122 |
+
width: 12px;
|
123 |
+
height: 12px;
|
124 |
+
border-radius: 50%;
|
125 |
+
background: #dc3545;
|
126 |
+
animation: pulse 2s infinite;
|
127 |
+
}
|
128 |
+
|
129 |
+
.status-indicator.connected {
|
130 |
+
background: #28a745;
|
131 |
+
}
|
132 |
+
|
133 |
+
.status-indicator.streaming {
|
134 |
+
background: #ffc107;
|
135 |
+
}
|
136 |
+
|
137 |
+
@keyframes pulse {
|
138 |
+
0% { opacity: 1; }
|
139 |
+
50% { opacity: 0.5; }
|
140 |
+
100% { opacity: 1; }
|
141 |
+
}
|
142 |
+
|
143 |
+
.transcription-section {
|
144 |
+
display: grid;
|
145 |
+
grid-template-columns: 1fr 1fr;
|
146 |
+
gap: 20px;
|
147 |
+
}
|
148 |
+
|
149 |
+
.transcription-panel {
|
150 |
+
background: #f8f9fa;
|
151 |
+
border-radius: 8px;
|
152 |
+
padding: 20px;
|
153 |
+
}
|
154 |
+
|
155 |
+
.transcription-panel h3 {
|
156 |
+
margin-top: 0;
|
157 |
+
color: #333;
|
158 |
+
}
|
159 |
+
|
160 |
+
.log-area, .transcription-area {
|
161 |
+
background: #fff;
|
162 |
+
border: 1px solid #ddd;
|
163 |
+
border-radius: 5px;
|
164 |
+
padding: 15px;
|
165 |
+
height: 300px;
|
166 |
+
overflow-y: auto;
|
167 |
+
font-family: 'Courier New', monospace;
|
168 |
+
font-size: 14px;
|
169 |
+
line-height: 1.4;
|
170 |
+
white-space: pre-wrap;
|
171 |
+
}
|
172 |
+
|
173 |
+
.transcription-area {
|
174 |
+
font-family: inherit;
|
175 |
+
font-size: 16px;
|
176 |
+
line-height: 1.6;
|
177 |
+
}
|
178 |
+
|
179 |
+
.stats {
|
180 |
+
display: grid;
|
181 |
+
grid-template-columns: repeat(auto-fit, minmax(150px, 1fr));
|
182 |
+
gap: 15px;
|
183 |
+
margin-top: 20px;
|
184 |
+
}
|
185 |
+
|
186 |
+
.stat-item {
|
187 |
+
text-align: center;
|
188 |
+
padding: 10px;
|
189 |
+
background: #e9ecef;
|
190 |
+
border-radius: 5px;
|
191 |
+
}
|
192 |
+
|
193 |
+
.stat-value {
|
194 |
+
font-size: 18px;
|
195 |
+
font-weight: bold;
|
196 |
+
color: #007bff;
|
197 |
+
}
|
198 |
+
|
199 |
+
.stat-label {
|
200 |
+
font-size: 12px;
|
201 |
+
color: #666;
|
202 |
+
margin-top: 5px;
|
203 |
+
}
|
204 |
+
|
205 |
+
@media (max-width: 768px) {
|
206 |
+
.config-section {
|
207 |
+
grid-template-columns: 1fr;
|
208 |
+
}
|
209 |
+
|
210 |
+
.transcription-section {
|
211 |
+
grid-template-columns: 1fr;
|
212 |
+
}
|
213 |
+
|
214 |
+
.controls {
|
215 |
+
flex-direction: column;
|
216 |
+
align-items: center;
|
217 |
+
}
|
218 |
+
|
219 |
+
button {
|
220 |
+
width: 200px;
|
221 |
+
}
|
222 |
+
}
|
223 |
+
</style>
|
224 |
+
</head>
|
225 |
+
<body>
|
226 |
+
<div class="container">
|
227 |
+
<h1>🎤 Real-Time Whisper Transcription</h1>
|
228 |
+
|
229 |
+
<div class="config-section">
|
230 |
+
<div class="config-group">
|
231 |
+
<label for="serverUrl">Server URL:</label>
|
232 |
+
<input type="text" id="serverUrl" value="ws://localhost:8000">
|
233 |
+
</div>
|
234 |
+
|
235 |
+
<div class="config-group">
|
236 |
+
<label for="modelSize">Model Size:</label>
|
237 |
+
<select id="modelSize">
|
238 |
+
<option value="base">Base (74 MB)</option>
|
239 |
+
<option value="small" selected>Small (244 MB)</option>
|
240 |
+
<option value="large-v2">Large-v2 (1550 MB)</option>
|
241 |
+
</select>
|
242 |
+
</div>
|
243 |
+
|
244 |
+
<div class="config-group">
|
245 |
+
<label for="chunkSize">Chunk Size (ms):</label>
|
246 |
+
<input type="number" id="chunkSize" value="300" min="100" max="2000" step="100">
|
247 |
+
</div>
|
248 |
+
|
249 |
+
<div class="config-group">
|
250 |
+
<label for="beamSize">Beam Size:</label>
|
251 |
+
<input type="number" id="beamSize" value="0" min="0" max="10">
|
252 |
+
</div>
|
253 |
+
|
254 |
+
<div class="config-group">
|
255 |
+
<label for="temperature">Temperature:</label>
|
256 |
+
<input type="number" id="temperature" value="0.0" min="0.0" max="1.0" step="0.1">
|
257 |
+
</div>
|
258 |
+
</div>
|
259 |
+
|
260 |
+
<div class="controls">
|
261 |
+
<button id="startBtn" class="start-btn" onclick="startStreaming()">🎤 Start Recording</button>
|
262 |
+
<button id="stopBtn" class="stop-btn" onclick="stopStreaming()" disabled>⏹️ Stop Recording</button>
|
263 |
+
<button id="clearBtn" class="clear-btn" onclick="clearAll()">🗑️ Clear All</button>
|
264 |
+
</div>
|
265 |
+
|
266 |
+
<div class="status">
|
267 |
+
<div class="status-indicator" id="statusIndicator"></div>
|
268 |
+
<span id="statusText">Disconnected</span>
|
269 |
+
</div>
|
270 |
+
|
271 |
+
<div class="transcription-section">
|
272 |
+
<div class="transcription-panel">
|
273 |
+
<h3>📝 Transcription</h3>
|
274 |
+
<div id="transcriptionArea" class="transcription-area"></div>
|
275 |
+
</div>
|
276 |
+
|
277 |
+
<div class="transcription-panel">
|
278 |
+
<h3>📋 System Log</h3>
|
279 |
+
<div id="logArea" class="log-area"></div>
|
280 |
+
</div>
|
281 |
+
</div>
|
282 |
+
|
283 |
+
<div class="stats">
|
284 |
+
<div class="stat-item">
|
285 |
+
<div class="stat-value" id="durationStat">0.0s</div>
|
286 |
+
<div class="stat-label">Duration</div>
|
287 |
+
</div>
|
288 |
+
<div class="stat-item">
|
289 |
+
<div class="stat-value" id="chunksStat">0</div>
|
290 |
+
<div class="stat-label">Chunks Sent</div>
|
291 |
+
</div>
|
292 |
+
<div class="stat-item">
|
293 |
+
<div class="stat-value" id="transcriptionsStat">0</div>
|
294 |
+
<div class="stat-label">Transcriptions</div>
|
295 |
+
</div>
|
296 |
+
<div class="stat-item">
|
297 |
+
<div class="stat-value" id="errorsStat">0</div>
|
298 |
+
<div class="stat-label">Errors</div>
|
299 |
+
</div>
|
300 |
+
</div>
|
301 |
+
</div>
|
302 |
+
|
303 |
+
<script>
|
304 |
+
let socket, audioContext, processor, micStream;
|
305 |
+
let isStreaming = false;
|
306 |
+
let startTime = 0;
|
307 |
+
let stats = {
|
308 |
+
chunks: 0,
|
309 |
+
transcriptions: 0,
|
310 |
+
errors: 0
|
311 |
+
};
|
312 |
+
|
313 |
+
// UI Elements
|
314 |
+
const startBtn = document.getElementById('startBtn');
|
315 |
+
const stopBtn = document.getElementById('stopBtn');
|
316 |
+
const statusIndicator = document.getElementById('statusIndicator');
|
317 |
+
const statusText = document.getElementById('statusText');
|
318 |
+
const logArea = document.getElementById('logArea');
|
319 |
+
const transcriptionArea = document.getElementById('transcriptionArea');
|
320 |
+
|
321 |
+
function log(message, isError = false) {
|
322 |
+
const timestamp = new Date().toLocaleTimeString();
|
323 |
+
const logMessage = `[${timestamp}] ${message}`;
|
324 |
+
logArea.textContent += logMessage + '\n';
|
325 |
+
logArea.scrollTop = logArea.scrollHeight;
|
326 |
+
|
327 |
+
if (isError) {
|
328 |
+
console.error(logMessage);
|
329 |
+
stats.errors++;
|
330 |
+
updateStats();
|
331 |
+
} else {
|
332 |
+
console.log(logMessage);
|
333 |
+
}
|
334 |
+
}
|
335 |
+
|
336 |
+
function updateStatus(status, color) {
|
337 |
+
statusText.textContent = status;
|
338 |
+
statusIndicator.className = `status-indicator ${color}`;
|
339 |
+
}
|
340 |
+
|
341 |
+
function updateStats() {
|
342 |
+
document.getElementById('durationStat').textContent =
|
343 |
+
isStreaming ? ((Date.now() - startTime) / 1000).toFixed(1) + 's' : '0.0s';
|
344 |
+
document.getElementById('chunksStat').textContent = stats.chunks;
|
345 |
+
document.getElementById('transcriptionsStat').textContent = stats.transcriptions;
|
346 |
+
document.getElementById('errorsStat').textContent = stats.errors;
|
347 |
+
}
|
348 |
+
|
349 |
+
function addTranscription(text, timestamp) {
|
350 |
+
const transcriptionText = `[${timestamp.toFixed(1)}s] ${text}\n`;
|
351 |
+
transcriptionArea.textContent += transcriptionText;
|
352 |
+
transcriptionArea.scrollTop = transcriptionArea.scrollHeight;
|
353 |
+
stats.transcriptions++;
|
354 |
+
updateStats();
|
355 |
+
}
|
356 |
+
|
357 |
+
async function startStreaming() {
|
358 |
+
try {
|
359 |
+
// Get configuration
|
360 |
+
const config = {
|
361 |
+
serverUrl: document.getElementById('serverUrl').value,
|
362 |
+
model_size: document.getElementById('modelSize').value,
|
363 |
+
chunk_size: parseInt(document.getElementById('chunkSize').value),
|
364 |
+
beam_size: parseInt(document.getElementById('beamSize').value),
|
365 |
+
temperature: parseFloat(document.getElementById('temperature').value)
|
366 |
+
};
|
367 |
+
|
368 |
+
log('Starting transcription session...');
|
369 |
+
log(`Config: ${JSON.stringify(config, null, 2)}`);
|
370 |
+
|
371 |
+
// Initialize WebSocket
|
372 |
+
socket = new WebSocket(config.serverUrl);
|
373 |
+
socket.binaryType = 'arraybuffer';
|
374 |
+
|
375 |
+
socket.onopen = async () => {
|
376 |
+
log('WebSocket connected');
|
377 |
+
updateStatus('Connected', 'connected');
|
378 |
+
|
379 |
+
// Send configuration
|
380 |
+
log('Sending configuration...');
|
381 |
+
socket.send(JSON.stringify(config));
|
382 |
+
};
|
383 |
+
|
384 |
+
socket.onmessage = (event) => {
|
385 |
+
try {
|
386 |
+
const data = JSON.parse(event.data);
|
387 |
+
|
388 |
+
if (data.error) {
|
389 |
+
log(`Server error: ${data.error}`, true);
|
390 |
+
return;
|
391 |
+
}
|
392 |
+
|
393 |
+
if (data.status === 'CONFIG_RECEIVED') {
|
394 |
+
log(`Configuration received. GPU: ${data.gpu ? 'Yes' : 'No'}`);
|
395 |
+
if (data.fallback) {
|
396 |
+
log('Using fallback Whisper model');
|
397 |
+
}
|
398 |
+
// Wait a bit for WebSocket to be fully ready
|
399 |
+
setTimeout(() => {
|
400 |
+
startAudioCapture();
|
401 |
+
}, 100);
|
402 |
+
} else if (data.text) {
|
403 |
+
addTranscription(data.text, data.timestamp || 0);
|
404 |
+
}
|
405 |
+
} catch (e) {
|
406 |
+
log(`Server message: ${event.data}`);
|
407 |
+
}
|
408 |
+
};
|
409 |
+
|
410 |
+
socket.onerror = (error) => {
|
411 |
+
log(`WebSocket error: ${error}`, true);
|
412 |
+
updateStatus('Error', 'error');
|
413 |
+
};
|
414 |
+
|
415 |
+
socket.onclose = () => {
|
416 |
+
log('WebSocket disconnected');
|
417 |
+
updateStatus('Disconnected', '');
|
418 |
+
stopStreaming();
|
419 |
+
};
|
420 |
+
|
421 |
+
} catch (error) {
|
422 |
+
log(`Failed to start: ${error.message}`, true);
|
423 |
+
updateStatus('Error', 'error');
|
424 |
+
}
|
425 |
+
}
|
426 |
+
|
427 |
+
async function startAudioCapture() {
|
428 |
+
try {
|
429 |
+
// Request microphone access with explicit user gesture
|
430 |
+
log('Requesting microphone access...');
|
431 |
+
|
432 |
+
// Check if we have permission first
|
433 |
+
if (navigator.permissions) {
|
434 |
+
const permission = await navigator.permissions.query({name: 'microphone'});
|
435 |
+
log(`Microphone permission state: ${permission.state}`);
|
436 |
+
}
|
437 |
+
|
438 |
+
const stream = await navigator.mediaDevices.getUserMedia({
|
439 |
+
audio: {
|
440 |
+
sampleRate: 16000,
|
441 |
+
channelCount: 1,
|
442 |
+
echoCancellation: true,
|
443 |
+
noiseSuppression: true,
|
444 |
+
autoGainControl: true
|
445 |
+
}
|
446 |
+
});
|
447 |
+
|
448 |
+
log('Microphone access granted');
|
449 |
+
|
450 |
+
// Create audio context
|
451 |
+
audioContext = new (window.AudioContext || window.webkitAudioContext)({
|
452 |
+
sampleRate: 16000
|
453 |
+
});
|
454 |
+
|
455 |
+
// Resume audio context if suspended (required by some browsers)
|
456 |
+
if (audioContext.state === 'suspended') {
|
457 |
+
await audioContext.resume();
|
458 |
+
log('Audio context resumed');
|
459 |
+
}
|
460 |
+
|
461 |
+
// Create media stream source
|
462 |
+
micStream = audioContext.createMediaStreamSource(stream);
|
463 |
+
|
464 |
+
// Create script processor
|
465 |
+
const bufferSize = 4096;
|
466 |
+
processor = audioContext.createScriptProcessor(bufferSize, 1, 1);
|
467 |
+
|
468 |
+
processor.onaudioprocess = (event) => {
|
469 |
+
if (!isStreaming || !socket || socket.readyState !== WebSocket.OPEN) {
|
470 |
+
return;
|
471 |
+
}
|
472 |
+
|
473 |
+
const inputData = event.inputBuffer.getChannelData(0);
|
474 |
+
|
475 |
+
// Convert to 16-bit PCM
|
476 |
+
const int16Array = new Int16Array(inputData.length);
|
477 |
+
for (let i = 0; i < inputData.length; i++) {
|
478 |
+
const sample = Math.max(-1, Math.min(1, inputData[i]));
|
479 |
+
int16Array[i] = sample * 32767;
|
480 |
+
}
|
481 |
+
|
482 |
+
// Send to server
|
483 |
+
socket.send(int16Array.buffer);
|
484 |
+
stats.chunks++;
|
485 |
+
};
|
486 |
+
|
487 |
+
// Connect audio nodes
|
488 |
+
micStream.connect(processor);
|
489 |
+
processor.connect(audioContext.destination);
|
490 |
+
|
491 |
+
// Update UI
|
492 |
+
isStreaming = true;
|
493 |
+
startTime = Date.now();
|
494 |
+
startBtn.disabled = true;
|
495 |
+
stopBtn.disabled = false;
|
496 |
+
updateStatus('Streaming', 'streaming');
|
497 |
+
|
498 |
+
log('Audio streaming started');
|
499 |
+
|
500 |
+
// Start stats update timer
|
501 |
+
const statsTimer = setInterval(() => {
|
502 |
+
if (isStreaming) {
|
503 |
+
updateStats();
|
504 |
+
} else {
|
505 |
+
clearInterval(statsTimer);
|
506 |
+
}
|
507 |
+
}, 100);
|
508 |
+
|
509 |
+
} catch (error) {
|
510 |
+
log(`Audio capture failed: ${error.message}`, true);
|
511 |
+
log(`Error details: ${error.name} - ${error.message}`);
|
512 |
+
|
513 |
+
// Handle specific permission errors
|
514 |
+
if (error.name === 'NotAllowedError') {
|
515 |
+
log('Microphone access denied by user. Please allow microphone access and try again.', true);
|
516 |
+
} else if (error.name === 'NotFoundError') {
|
517 |
+
log('No microphone found. Please check your audio devices.', true);
|
518 |
+
} else if (error.name === 'NotReadableError') {
|
519 |
+
log('Microphone is already in use by another application.', true);
|
520 |
+
}
|
521 |
+
|
522 |
+
updateStatus('Error', 'error');
|
523 |
+
|
524 |
+
// Reset UI state
|
525 |
+
startBtn.disabled = false;
|
526 |
+
stopBtn.disabled = true;
|
527 |
+
}
|
528 |
+
}
|
529 |
+
|
530 |
+
function stopStreaming() {
|
531 |
+
isStreaming = false;
|
532 |
+
|
533 |
+
// Stop all tracks to properly release the microphone
|
534 |
+
if (micStream && micStream.mediaStream) {
|
535 |
+
micStream.mediaStream.getTracks().forEach(track => track.stop());
|
536 |
+
}
|
537 |
+
|
538 |
+
// Disconnect audio nodes
|
539 |
+
if (processor) {
|
540 |
+
processor.disconnect();
|
541 |
+
processor = null;
|
542 |
+
}
|
543 |
+
|
544 |
+
if (micStream) {
|
545 |
+
micStream.disconnect();
|
546 |
+
micStream = null;
|
547 |
+
}
|
548 |
+
|
549 |
+
if (audioContext) {
|
550 |
+
audioContext.close();
|
551 |
+
audioContext = null;
|
552 |
+
}
|
553 |
+
|
554 |
+
// Close WebSocket
|
555 |
+
if (socket && socket.readyState === WebSocket.OPEN) {
|
556 |
+
socket.close();
|
557 |
+
}
|
558 |
+
|
559 |
+
// Update UI
|
560 |
+
startBtn.disabled = false;
|
561 |
+
stopBtn.disabled = true;
|
562 |
+
updateStatus('Disconnected', '');
|
563 |
+
|
564 |
+
log('Streaming stopped');
|
565 |
+
updateStats();
|
566 |
+
}
|
567 |
+
|
568 |
+
function clearAll() {
|
569 |
+
logArea.textContent = '';
|
570 |
+
transcriptionArea.textContent = '';
|
571 |
+
stats = { chunks: 0, transcriptions: 0, errors: 0 };
|
572 |
+
updateStats();
|
573 |
+
log('All content cleared');
|
574 |
+
}
|
575 |
+
|
576 |
+
// Initialize stats display
|
577 |
+
updateStats();
|
578 |
+
|
579 |
+
// Handle page unload
|
580 |
+
window.addEventListener('beforeunload', () => {
|
581 |
+
if (isStreaming) {
|
582 |
+
stopStreaming();
|
583 |
+
}
|
584 |
+
});
|
585 |
+
|
586 |
+
// Handle errors
|
587 |
+
window.addEventListener('error', (event) => {
|
588 |
+
log(`JavaScript error: ${event.message}`, true);
|
589 |
+
});
|
590 |
+
|
591 |
+
// Show initial log message
|
592 |
+
log('Real-time transcription client ready');
|
593 |
+
</script>
|
594 |
+
</body>
|
595 |
+
</html>
|
unified_socket_server.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Unified WebSocket/HTTP Whisper Transcription Server
|
4 |
+
Handles real-time audio streaming, transcription using Whisper, and HTTP serving
|
5 |
+
"""
|
6 |
+
|
7 |
+
import asyncio
|
8 |
+
import websockets
|
9 |
+
import json
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import logging
|
13 |
+
import traceback
|
14 |
+
import os
|
15 |
+
from typing import Dict, Any
|
16 |
+
from aiohttp import web, WSMsgType
|
17 |
+
from aiohttp.web_ws import WebSocketResponse
|
18 |
+
|
19 |
+
|
20 |
+
# Configure logging
|
21 |
+
logging.basicConfig(level=logging.INFO)
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
try:
|
25 |
+
from whisper_stream import load_streaming_model_correct
|
26 |
+
from whisper_stream.streaming_decoding import DecodingOptions
|
27 |
+
except ImportError:
|
28 |
+
logger.error("whisper_stream not found. Please install it or use regular whisper")
|
29 |
+
# Fallback to regular whisper if whisper_stream is not available
|
30 |
+
import whisper
|
31 |
+
|
32 |
+
class UnifiedTranscriptionServer:
|
33 |
+
def __init__(self, host: str = "0.0.0.0", port: int = 8000):
|
34 |
+
self.host = host
|
35 |
+
self.port = port
|
36 |
+
self.clients: Dict[str, Dict[str, Any]] = {}
|
37 |
+
self.app = web.Application()
|
38 |
+
self.setup_routes()
|
39 |
+
|
40 |
+
def setup_routes(self):
|
41 |
+
"""Setup HTTP routes and WebSocket endpoint"""
|
42 |
+
# HTTP routes
|
43 |
+
self.app.router.add_get('/', self.serve_index)
|
44 |
+
self.app.router.add_get('/health', self.health_check)
|
45 |
+
|
46 |
+
# WebSocket endpoint
|
47 |
+
self.app.router.add_get('/ws', self.websocket_handler)
|
48 |
+
|
49 |
+
# Static file serving (if needed)
|
50 |
+
if os.path.exists('static'):
|
51 |
+
self.app.router.add_static('/static/', 'static')
|
52 |
+
|
53 |
+
async def serve_index(self, request):
|
54 |
+
"""Serve the main HTML page"""
|
55 |
+
try:
|
56 |
+
with open("./static/client.html", "r", encoding='utf-8') as f:
|
57 |
+
html_content = f.read()
|
58 |
+
return web.Response(text=html_content, content_type='text/html')
|
59 |
+
except FileNotFoundError:
|
60 |
+
return web.Response(text="client.html not found!", status=404)
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"Error serving client.html! {e}")
|
63 |
+
return web.Response(text="Error loading page...", status=500)
|
64 |
+
|
65 |
+
async def health_check(self, request):
|
66 |
+
"""Health check endpoint"""
|
67 |
+
return web.json_response({"status": "healthy", "cuda": torch.cuda.is_available()})
|
68 |
+
|
69 |
+
async def websocket_handler(self, request):
|
70 |
+
"""Handle WebSocket connections"""
|
71 |
+
ws = WebSocketResponse()
|
72 |
+
await ws.prepare(request)
|
73 |
+
|
74 |
+
# Generate client ID
|
75 |
+
client_id = f"{request.remote}:{id(ws)}"
|
76 |
+
logger.info(f"New WebSocket client connected: {client_id}")
|
77 |
+
|
78 |
+
# Initialize client state
|
79 |
+
self.clients[client_id] = {
|
80 |
+
'websocket': ws,
|
81 |
+
'model': None,
|
82 |
+
'config': None,
|
83 |
+
'buffer': bytearray(),
|
84 |
+
'total_samples': 0,
|
85 |
+
'is_first_chunk': True
|
86 |
+
}
|
87 |
+
|
88 |
+
try:
|
89 |
+
await self.process_websocket_messages(client_id)
|
90 |
+
except Exception as e:
|
91 |
+
logger.error(f"Error handling WebSocket client {client_id}: {e}")
|
92 |
+
logger.error(traceback.format_exc())
|
93 |
+
finally:
|
94 |
+
# Cleanup
|
95 |
+
if client_id in self.clients:
|
96 |
+
del self.clients[client_id]
|
97 |
+
|
98 |
+
if not ws.closed:
|
99 |
+
await ws.close()
|
100 |
+
|
101 |
+
return ws
|
102 |
+
|
103 |
+
async def process_websocket_messages(self, client_id: str):
|
104 |
+
"""Process messages from a WebSocket client"""
|
105 |
+
client = self.clients[client_id]
|
106 |
+
ws = client['websocket']
|
107 |
+
|
108 |
+
async for msg in ws:
|
109 |
+
if msg.type == WSMsgType.TEXT:
|
110 |
+
# Handle configuration message
|
111 |
+
await self.handle_config_message(client_id, msg.data)
|
112 |
+
elif msg.type == WSMsgType.BINARY:
|
113 |
+
# Handle audio data
|
114 |
+
await self.handle_audio_data(client_id, msg.data)
|
115 |
+
elif msg.type == WSMsgType.ERROR:
|
116 |
+
logger.error(f'WebSocket error for client {client_id}: {ws.exception()}')
|
117 |
+
break
|
118 |
+
|
119 |
+
async def handle_config_message(self, client_id: str, message: str):
|
120 |
+
"""Handle configuration message from client"""
|
121 |
+
client = self.clients[client_id]
|
122 |
+
ws = client['websocket']
|
123 |
+
|
124 |
+
try:
|
125 |
+
config = json.loads(message)
|
126 |
+
logger.info(f"Received config from {client_id}: {config}")
|
127 |
+
|
128 |
+
# Validate config
|
129 |
+
required_fields = ['model_size', 'chunk_size', 'beam_size', 'temperature']
|
130 |
+
for field in required_fields:
|
131 |
+
if field not in config:
|
132 |
+
await ws.send_str(json.dumps({"error": f"Missing required field: {field}"}))
|
133 |
+
return
|
134 |
+
|
135 |
+
# Load model
|
136 |
+
model_size = config['model_size']
|
137 |
+
chunk_size = config['chunk_size']
|
138 |
+
|
139 |
+
logger.info(f"Loading model {model_size} for client {client_id}")
|
140 |
+
|
141 |
+
# Try to use whisper_stream, fallback to regular whisper
|
142 |
+
try:
|
143 |
+
model = load_streaming_model_correct(model_size, chunk_size)
|
144 |
+
client['first_chunk'] = True
|
145 |
+
if torch.cuda.is_available():
|
146 |
+
model = model.to("cuda")
|
147 |
+
logger.info(f"Model loaded on GPU for client {client_id}")
|
148 |
+
else:
|
149 |
+
logger.info(f"Model loaded on CPU for client {client_id}")
|
150 |
+
|
151 |
+
model.reset(use_stream=True)
|
152 |
+
model.eval()
|
153 |
+
|
154 |
+
client['model'] = model
|
155 |
+
client['config'] = config
|
156 |
+
|
157 |
+
await ws.send_str(json.dumps({"status": "CONFIG_RECEIVED", "gpu": torch.cuda.is_available()}))
|
158 |
+
|
159 |
+
except Exception as e:
|
160 |
+
logger.error(f"Error loading streaming model: {e}")
|
161 |
+
# Fallback to regular whisper
|
162 |
+
try:
|
163 |
+
model = whisper.load_model(model_size)
|
164 |
+
if torch.cuda.is_available():
|
165 |
+
model = model.to("cuda")
|
166 |
+
|
167 |
+
client['model'] = model
|
168 |
+
client['config'] = config
|
169 |
+
client['use_streaming'] = False
|
170 |
+
|
171 |
+
await ws.send_str(json.dumps({"status": "CONFIG_RECEIVED", "gpu": torch.cuda.is_available(), "fallback": True}))
|
172 |
+
except Exception as e2:
|
173 |
+
logger.error(f"Error loading fallback model: {e2}")
|
174 |
+
await ws.send_str(json.dumps({"error": f"Failed to load model: {e2}"}))
|
175 |
+
|
176 |
+
except json.JSONDecodeError as e:
|
177 |
+
await ws.send_str(json.dumps({"error": f"Invalid JSON: {e}"}))
|
178 |
+
except Exception as e:
|
179 |
+
logger.error(f"Error handling config for client {client_id}: {e}")
|
180 |
+
await ws.send_str(json.dumps({"error": str(e)}))
|
181 |
+
|
182 |
+
async def handle_audio_data(self, client_id: str, audio_data: bytes):
|
183 |
+
"""Handle audio data from client"""
|
184 |
+
client = self.clients[client_id]
|
185 |
+
ws = client['websocket']
|
186 |
+
|
187 |
+
if client['config'] is None:
|
188 |
+
await ws.send_str(json.dumps({"error": "Config not set"}))
|
189 |
+
return
|
190 |
+
|
191 |
+
if client['model'] is None:
|
192 |
+
await ws.send_str(json.dumps({"error": "Model not loaded"}))
|
193 |
+
return
|
194 |
+
|
195 |
+
# Add audio data to buffer
|
196 |
+
client['buffer'].extend(audio_data)
|
197 |
+
|
198 |
+
# Calculate chunk size in bytes
|
199 |
+
chunk_size_ms = client['config']['chunk_size']
|
200 |
+
sample_rate = 16000
|
201 |
+
chunk_samples = int(sample_rate * (chunk_size_ms / 1000))
|
202 |
+
chunk_bytes = chunk_samples * 2 # 16-bit audio = 2 bytes per sample
|
203 |
+
if client.get('first_chunk', True):
|
204 |
+
chunk_bytes += 720
|
205 |
+
|
206 |
+
# Process complete chunks
|
207 |
+
while len(client['buffer']) >= chunk_bytes:
|
208 |
+
chunk = client['buffer'][:chunk_bytes]
|
209 |
+
client['buffer'] = client['buffer'][chunk_bytes:]
|
210 |
+
|
211 |
+
try:
|
212 |
+
if client.get('first_chunk', True):
|
213 |
+
client['first_chunk'] = False
|
214 |
+
await self.transcribe_chunk(client_id, chunk)
|
215 |
+
except Exception as e:
|
216 |
+
logger.error(f"Error transcribing chunk for client {client_id}: {e}")
|
217 |
+
await ws.send_str(json.dumps({"error": f"Transcription error: {str(e)}"}))
|
218 |
+
|
219 |
+
async def transcribe_chunk(self, client_id: str, chunk: bytes):
|
220 |
+
"""Transcribe audio chunk"""
|
221 |
+
client = self.clients[client_id]
|
222 |
+
ws = client['websocket']
|
223 |
+
model = client['model']
|
224 |
+
config = client['config']
|
225 |
+
|
226 |
+
try:
|
227 |
+
# Convert bytes to numpy array
|
228 |
+
pcm = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32768.0
|
229 |
+
|
230 |
+
# Convert to torch tensor
|
231 |
+
audio = torch.tensor(pcm)
|
232 |
+
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
|
233 |
+
audio = audio.to("cuda")
|
234 |
+
|
235 |
+
# Transcribe based on model type
|
236 |
+
if hasattr(model, 'decode') and 'use_streaming' not in client:
|
237 |
+
# Using whisper_stream
|
238 |
+
decoding_options = DecodingOptions(
|
239 |
+
language="en",
|
240 |
+
gran=(config['chunk_size'] // 20),
|
241 |
+
single_frame_mel=True,
|
242 |
+
without_timestamps=True,
|
243 |
+
beam_size=config['beam_size'],
|
244 |
+
temperature=config['temperature'],
|
245 |
+
stream_decode=True,
|
246 |
+
use_ca_kv_cache=True,
|
247 |
+
look_ahead_blocks=model.extra_gran_blocks
|
248 |
+
)
|
249 |
+
result = model.decode(audio, decoding_options, use_frames=True)
|
250 |
+
text = result.text
|
251 |
+
else:
|
252 |
+
# Using regular whisper
|
253 |
+
# Pad audio to minimum length if needed
|
254 |
+
min_length = 16000 # 1 second at 16kHz
|
255 |
+
if len(audio) < min_length:
|
256 |
+
audio = torch.nn.functional.pad(audio, (0, min_length - len(audio)))
|
257 |
+
|
258 |
+
result = model.transcribe(audio.cpu().numpy(),
|
259 |
+
language="en",
|
260 |
+
beam_size=config['beam_size'],
|
261 |
+
temperature=config['temperature'])
|
262 |
+
text = result['text']
|
263 |
+
|
264 |
+
# Send transcription result
|
265 |
+
if text.strip():
|
266 |
+
client['total_samples'] += len(pcm)
|
267 |
+
duration = client['total_samples'] / 16000 # seconds
|
268 |
+
|
269 |
+
await ws.send_str(json.dumps({
|
270 |
+
"text": text.strip(),
|
271 |
+
"timestamp": duration,
|
272 |
+
"chunk_duration": len(pcm) / 16000
|
273 |
+
}))
|
274 |
+
|
275 |
+
except Exception as e:
|
276 |
+
logger.error(f"Error in transcription for client {client_id}: {e}")
|
277 |
+
logger.exception("Exception occurred")
|
278 |
+
raise
|
279 |
+
|
280 |
+
async def start_server(self):
|
281 |
+
"""Start the unified server"""
|
282 |
+
logger.info(f"Starting unified server on {self.host}:{self.port}")
|
283 |
+
logger.info(f"CUDA available: {torch.cuda.is_available()}")
|
284 |
+
|
285 |
+
runner = web.AppRunner(self.app)
|
286 |
+
await runner.setup()
|
287 |
+
site = web.TCPSite(runner, self.host, self.port)
|
288 |
+
await site.start()
|
289 |
+
|
290 |
+
logger.info(f"Server running on http://{self.host}:{self.port}")
|
291 |
+
logger.info(f"WebSocket endpoint: ws://{self.host}:{self.port}/ws")
|
292 |
+
|
293 |
+
# Keep the server running
|
294 |
+
try:
|
295 |
+
await asyncio.Future() # Run forever
|
296 |
+
except KeyboardInterrupt:
|
297 |
+
logger.info("Server stopped by user")
|
298 |
+
finally:
|
299 |
+
await runner.cleanup()
|
300 |
+
|
301 |
+
def main():
|
302 |
+
import argparse
|
303 |
+
|
304 |
+
parser = argparse.ArgumentParser(description='Unified WebSocket/HTTP Whisper Transcription Server')
|
305 |
+
parser.add_argument('--host', default='0.0.0.0', help='Host to bind to')
|
306 |
+
parser.add_argument('--port', type=int, default=8000, help='Port to bind to')
|
307 |
+
parser.add_argument('--log-level', default='INFO', help='Log level')
|
308 |
+
|
309 |
+
args = parser.parse_args()
|
310 |
+
|
311 |
+
# Set log level
|
312 |
+
logging.getLogger().setLevel(getattr(logging, args.log_level.upper()))
|
313 |
+
|
314 |
+
server = UnifiedTranscriptionServer(args.host, args.port)
|
315 |
+
|
316 |
+
try:
|
317 |
+
asyncio.run(server.start_server())
|
318 |
+
except KeyboardInterrupt:
|
319 |
+
logger.info("Server stopped by user")
|
320 |
+
except Exception as e:
|
321 |
+
logger.error(f"Server error: {e}")
|
322 |
+
logger.error(traceback.format_exc())
|
323 |
+
|
324 |
+
if __name__ == '__main__':
|
325 |
+
main()
|
whisper_stream/__init__.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
import urllib
|
5 |
+
import warnings
|
6 |
+
from typing import List, Optional, Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from .audio import load_audio, pad_or_trim, log_mel_spectrogram
|
12 |
+
from .model import ModelDimensions, Whisper
|
13 |
+
from .streaming_model import StreamingWhisper
|
14 |
+
from .version import __version__
|
15 |
+
|
16 |
+
_MODELS = {
|
17 |
+
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
18 |
+
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
19 |
+
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
20 |
+
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
21 |
+
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
22 |
+
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
23 |
+
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
24 |
+
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
25 |
+
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
26 |
+
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
27 |
+
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
28 |
+
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
29 |
+
"large-v3-turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
30 |
+
"turbo": "https://openaipublic.azureedge.net/main/whisper/models/aff26ae408abcba5fbf8813c21e62b0941638c5f6eebfb145be0c9839262a19a/large-v3-turbo.pt",
|
31 |
+
}
|
32 |
+
|
33 |
+
_STREAMING_MODELS = {
|
34 |
+
"base": {
|
35 |
+
"300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.25/checkpoint/checkpoint-epoch=0009.pt",
|
36 |
+
"200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0009.pt",
|
37 |
+
"100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0009.pt",
|
38 |
+
"40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_base_LIBRI-960-ALIGNED_32_full_streaming_eot_fixed_timings_LR-1e-05_r32_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.02/checkpoint/checkpoint-epoch=0006.pt",
|
39 |
+
},
|
40 |
+
"small": {
|
41 |
+
"1000": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g50_eg0_top5_full-streamTrue_random-orderFalse_fraction0.4/checkpoint/checkpoint-epoch=0009.pt",
|
42 |
+
"300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.25/checkpoint/checkpoint-epoch=0009.pt",
|
43 |
+
"200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0009.pt",
|
44 |
+
"100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0009.pt",
|
45 |
+
"40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_small_LIBRI-960-ALIGNED_16_full_streaming_eot_fixed_timings_LR-1e-05_r32_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.02/checkpoint/checkpoint-epoch=0009.pt",
|
46 |
+
},
|
47 |
+
"large-v2": {
|
48 |
+
"1000": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g50_eg0_top5_full-streamTrue_random-orderFalse_fraction0.3/checkpoint/checkpoint-epoch=0002.pt",
|
49 |
+
"300": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.1/checkpoint/checkpoint-epoch=0002.pt",
|
50 |
+
"200": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g10_eg2_top5_full-streamTrue_random-orderFalse_fraction0.07/checkpoint/checkpoint-epoch=0002.pt",
|
51 |
+
"100": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g5_eg5_top5_full-streamTrue_random-orderFalse_fraction0.03/checkpoint/checkpoint-epoch=0002.pt",
|
52 |
+
"40": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-960-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g2_eg14_top5_full-streamTrue_random-orderFalse_fraction0.01/checkpoint/checkpoint-epoch=0002.pt",
|
53 |
+
"300-multi": "/mlspeech/data/tomer/streaming_whisper/models/ckpts/LoRA_streamed_whisper_large-v2_LIBRI-BLEND-ALIGNED_4_full_streaming_eot_fixed_timings_LR-1e-05_r4_g15_eg1_top5_full-streamTrue_random-orderFalse_fraction0.05/checkpoint/checkpoint-epoch=0001.pt",
|
54 |
+
}
|
55 |
+
}
|
56 |
+
|
57 |
+
_STREAMING_MODELS_HF = {
|
58 |
+
"base": {
|
59 |
+
"300": "base_300.pt",
|
60 |
+
"200": "base_200.pt",
|
61 |
+
"100": "base_100.pt",
|
62 |
+
"40": "base_40.pt",
|
63 |
+
},
|
64 |
+
"small": {
|
65 |
+
"1000": "small_1000.pt",
|
66 |
+
"300": "small_300.pt",
|
67 |
+
"200": "small_200.pt",
|
68 |
+
"100": "small_100.pt",
|
69 |
+
"40": "small_40.pt",
|
70 |
+
},
|
71 |
+
"large-v2": {
|
72 |
+
"1000": "large-v2_1000.pt",
|
73 |
+
"300": "large-v2_300.pt",
|
74 |
+
"200": "large-v2_200.pt",
|
75 |
+
"100": "large-v2_100.pt",
|
76 |
+
"40": "large-v2_40.pt",
|
77 |
+
"300-multi": "large-v2_300_multi.pt",
|
78 |
+
}
|
79 |
+
}
|
80 |
+
|
81 |
+
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
82 |
+
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
83 |
+
_ALIGNMENT_HEADS = {
|
84 |
+
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
85 |
+
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
86 |
+
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
87 |
+
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
88 |
+
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
89 |
+
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
90 |
+
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
91 |
+
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
92 |
+
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
93 |
+
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
94 |
+
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
95 |
+
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
96 |
+
"large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
97 |
+
"turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`",
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
102 |
+
os.makedirs(root, exist_ok=True)
|
103 |
+
|
104 |
+
expected_sha256 = url.split("/")[-2]
|
105 |
+
download_target = os.path.join(root, os.path.basename(url))
|
106 |
+
|
107 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
108 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
109 |
+
|
110 |
+
if os.path.isfile(download_target):
|
111 |
+
with open(download_target, "rb") as f:
|
112 |
+
model_bytes = f.read()
|
113 |
+
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
114 |
+
return model_bytes if in_memory else download_target
|
115 |
+
else:
|
116 |
+
warnings.warn(
|
117 |
+
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
118 |
+
)
|
119 |
+
|
120 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
121 |
+
with tqdm(
|
122 |
+
total=int(source.info().get("Content-Length")),
|
123 |
+
ncols=80,
|
124 |
+
unit="iB",
|
125 |
+
unit_scale=True,
|
126 |
+
unit_divisor=1024,
|
127 |
+
) as loop:
|
128 |
+
while True:
|
129 |
+
buffer = source.read(8192)
|
130 |
+
if not buffer:
|
131 |
+
break
|
132 |
+
|
133 |
+
output.write(buffer)
|
134 |
+
loop.update(len(buffer))
|
135 |
+
|
136 |
+
model_bytes = open(download_target, "rb").read()
|
137 |
+
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
138 |
+
raise RuntimeError(
|
139 |
+
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
140 |
+
)
|
141 |
+
|
142 |
+
return model_bytes if in_memory else download_target
|
143 |
+
|
144 |
+
|
145 |
+
def available_models() -> List[str]:
|
146 |
+
"""Returns the names of available models"""
|
147 |
+
return list(_MODELS.keys())
|
148 |
+
|
149 |
+
|
150 |
+
def load_model(
|
151 |
+
name: str,
|
152 |
+
device: Optional[Union[str, torch.device]] = None,
|
153 |
+
download_root: str = None,
|
154 |
+
in_memory: bool = False,
|
155 |
+
) -> Whisper:
|
156 |
+
"""
|
157 |
+
Load a Whisper ASR model
|
158 |
+
|
159 |
+
Parameters
|
160 |
+
----------
|
161 |
+
name : str
|
162 |
+
one of the official model names listed by `whisper.available_models()`, or
|
163 |
+
path to a model checkpoint containing the model dimensions and the model state_dict.
|
164 |
+
device : Union[str, torch.device]
|
165 |
+
the PyTorch device to put the model into
|
166 |
+
download_root: str
|
167 |
+
path to download the model files; by default, it uses "~/.cache/whisper"
|
168 |
+
in_memory: bool
|
169 |
+
whether to preload the model weights into host memory
|
170 |
+
|
171 |
+
Returns
|
172 |
+
-------
|
173 |
+
model : Whisper
|
174 |
+
The Whisper ASR model instance
|
175 |
+
"""
|
176 |
+
|
177 |
+
if device is None:
|
178 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
179 |
+
if download_root is None:
|
180 |
+
default = os.path.join(os.path.expanduser("~"), ".cache")
|
181 |
+
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
182 |
+
|
183 |
+
if name in _MODELS:
|
184 |
+
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
185 |
+
alignment_heads = _ALIGNMENT_HEADS[name]
|
186 |
+
elif os.path.isfile(name):
|
187 |
+
checkpoint_file = open(name, "rb").read() if in_memory else name
|
188 |
+
alignment_heads = None
|
189 |
+
else:
|
190 |
+
raise RuntimeError(
|
191 |
+
f"Model {name} not found; available models = {available_models()}"
|
192 |
+
)
|
193 |
+
|
194 |
+
with (
|
195 |
+
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
196 |
+
) as fp:
|
197 |
+
checkpoint = torch.load(fp, map_location=device)
|
198 |
+
del checkpoint_file
|
199 |
+
|
200 |
+
dims = ModelDimensions(**checkpoint["dims"])
|
201 |
+
model = Whisper(dims)
|
202 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
203 |
+
|
204 |
+
if alignment_heads is not None:
|
205 |
+
model.set_alignment_heads(alignment_heads)
|
206 |
+
|
207 |
+
return model.to(device)
|
208 |
+
|
209 |
+
|
210 |
+
def load_streaming_model(
|
211 |
+
name: str,
|
212 |
+
advisor_ckpt_path: str = None,
|
213 |
+
ft_model_ckpt_path: str = None,
|
214 |
+
device: Optional[Union[str, torch.device]] = None,
|
215 |
+
download_root: str = None,
|
216 |
+
in_memory: bool = False,
|
217 |
+
cache_gran: bool = True,
|
218 |
+
gran: int = 15,
|
219 |
+
rank: int = 8,
|
220 |
+
extra_gran_blocks: int = 0,
|
221 |
+
n_advisor_class: int = 4,
|
222 |
+
**kwargs: any
|
223 |
+
) -> StreamingWhisper:
|
224 |
+
"""
|
225 |
+
Load a StreamingWhisper ASR model
|
226 |
+
|
227 |
+
Parameters
|
228 |
+
----------
|
229 |
+
name : str
|
230 |
+
one of the official model names listed by `whisper.available_models()`, or
|
231 |
+
path to a model checkpoint containing the model dimensions and the model state_dict.
|
232 |
+
device : Union[str, torch.device]
|
233 |
+
the PyTorch device to put the model into
|
234 |
+
download_root: str
|
235 |
+
path to download the model files; by default, it uses "~/.cache/whisper"
|
236 |
+
in_memory: bool
|
237 |
+
whether to preload the model weights into host memory
|
238 |
+
|
239 |
+
Returns
|
240 |
+
-------
|
241 |
+
model : Whisper
|
242 |
+
The Whisper ASR model instance
|
243 |
+
"""
|
244 |
+
if ft_model_ckpt_path is None:
|
245 |
+
if device is None:
|
246 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
247 |
+
if download_root is None:
|
248 |
+
default = os.path.join(os.path.expanduser("~"), ".cache")
|
249 |
+
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
250 |
+
|
251 |
+
if name in _MODELS:
|
252 |
+
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
253 |
+
alignment_heads = _ALIGNMENT_HEADS[name]
|
254 |
+
elif os.path.isfile(name):
|
255 |
+
checkpoint_file = open(name, "rb").read() if in_memory else name
|
256 |
+
alignment_heads = None
|
257 |
+
else:
|
258 |
+
raise RuntimeError(
|
259 |
+
f"Model {name} not found; available models = {available_models()}"
|
260 |
+
)
|
261 |
+
|
262 |
+
with (
|
263 |
+
io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
264 |
+
) as fp:
|
265 |
+
checkpoint = torch.load(fp, map_location=device)
|
266 |
+
del checkpoint_file
|
267 |
+
else:
|
268 |
+
checkpoint = torch.load(ft_model_ckpt_path, weights_only=False)
|
269 |
+
|
270 |
+
decoder_advisor_chkpt = torch.load(advisor_ckpt_path, weights_only=False) if advisor_ckpt_path is not None else {"state_dict": {}}
|
271 |
+
advisor_state_dict = {k: v for k, v in decoder_advisor_chkpt["state_dict"].items() if "decoder_advisor" in k}
|
272 |
+
|
273 |
+
whisper_dict = checkpoint["model_state_dict"] if "model_state_dict" in checkpoint.keys() else checkpoint["state_dict"]
|
274 |
+
|
275 |
+
whisper_dict = {k.replace("weight", "base_layer.weight") if "attn." in k and "weight" in k
|
276 |
+
else k.replace("bias", "base_layer.bias") if "attn." in k and "bias" in k
|
277 |
+
else k: v for k, v in whisper_dict.items()}
|
278 |
+
|
279 |
+
streaming_whisper_state_dict = {**advisor_state_dict, **whisper_dict}
|
280 |
+
|
281 |
+
dims = ModelDimensions(**checkpoint["dims"])
|
282 |
+
|
283 |
+
model = StreamingWhisper(dims,
|
284 |
+
cache_gran=cache_gran,
|
285 |
+
gran=gran,
|
286 |
+
rank=rank,
|
287 |
+
extra_gran_blocks=extra_gran_blocks,
|
288 |
+
n_advisor_class=n_advisor_class,
|
289 |
+
**kwargs)
|
290 |
+
|
291 |
+
model.load_state_dict(streaming_whisper_state_dict, strict=False)
|
292 |
+
|
293 |
+
# for n, p in model.named_parameters():
|
294 |
+
# print(n, p)
|
295 |
+
|
296 |
+
if ft_model_ckpt_path is None and alignment_heads is not None:
|
297 |
+
model.set_alignment_heads(alignment_heads)
|
298 |
+
|
299 |
+
return model.to(device)
|
300 |
+
|
301 |
+
|
302 |
+
def load_streaming_model_correct(
|
303 |
+
name: str,
|
304 |
+
gran: int = 300,
|
305 |
+
multilingual: bool = False,
|
306 |
+
device: Optional[Union[str, torch.device]] = None,
|
307 |
+
download_root: str = None,
|
308 |
+
in_memory: bool = False,
|
309 |
+
) -> StreamingWhisper:
|
310 |
+
|
311 |
+
subname = (str(gran) + '-multi') if multilingual else str(gran)
|
312 |
+
|
313 |
+
from huggingface_hub import hf_hub_download
|
314 |
+
|
315 |
+
try:
|
316 |
+
ckpt_path = hf_hub_download(repo_id="MLSpeech/causal-whisper", filename=_STREAMING_MODELS_HF[name][subname], repo_type="model", token=True)
|
317 |
+
except KeyError as e:
|
318 |
+
print(f"Streaming model with the next configs: size {name}, multilingual: {multilingual} and chunk size: {gran} is not available.")
|
319 |
+
|
320 |
+
checkpoint = torch.load(ckpt_path, weights_only=False)
|
321 |
+
|
322 |
+
dims = ModelDimensions(**checkpoint["dims"])
|
323 |
+
|
324 |
+
model = StreamingWhisper(dims,
|
325 |
+
gran=checkpoint['cfg']['gran'],
|
326 |
+
rank=checkpoint['cfg']['rank'],
|
327 |
+
extra_gran_blocks=checkpoint['cfg']['extra_gran_blocks'])
|
328 |
+
|
329 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
330 |
+
|
331 |
+
return model.to(device)
|
whisper_stream/__main__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .transcribe import cli
|
2 |
+
|
3 |
+
cli()
|
whisper_stream/assets/gpt2.tiktoken
ADDED
The diff for this file is too large to render.
See raw diff
|
|
whisper_stream/assets/mel_filters.npz
ADDED
Binary file (4.27 kB). View file
|
|
whisper_stream/assets/multilingual.tiktoken
ADDED
The diff for this file is too large to render.
See raw diff
|
|
whisper_stream/audio.py
ADDED
@@ -0,0 +1,357 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from functools import lru_cache
|
4 |
+
from subprocess import CalledProcessError, run
|
5 |
+
from typing import Optional, Union
|
6 |
+
|
7 |
+
import wave
|
8 |
+
import torch
|
9 |
+
import pyaudio
|
10 |
+
import numpy as np
|
11 |
+
import soundfile as sf
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
from .utils import exact_div
|
15 |
+
|
16 |
+
# hard-coded audio hyperparameters
|
17 |
+
SAMPLE_RATE = 16000
|
18 |
+
N_FFT = 400
|
19 |
+
HOP_LENGTH = 160
|
20 |
+
CHUNK_LENGTH = 30
|
21 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
22 |
+
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
23 |
+
|
24 |
+
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
25 |
+
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
26 |
+
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
27 |
+
|
28 |
+
|
29 |
+
class MyStream:
|
30 |
+
def __init__(self,
|
31 |
+
ms_gran: int = 200,
|
32 |
+
sample_rate: int = 16000,
|
33 |
+
channels: int = 2,
|
34 |
+
filename: str = None,
|
35 |
+
inp_dtype: any = pyaudio.paInt16,
|
36 |
+
simulate_stream: bool = False,
|
37 |
+
wav_file: str = None,
|
38 |
+
relay: bool = False,
|
39 |
+
use_latency: bool = False,
|
40 |
+
pad_trim: bool = True,
|
41 |
+
use_remote_machine: bool = False):
|
42 |
+
|
43 |
+
assert ms_gran % 20 == 0, "ms_gran must be a multiple of 20"
|
44 |
+
|
45 |
+
self.ms_gran = ms_gran
|
46 |
+
self.sample_rate = sample_rate
|
47 |
+
self.channels = channels
|
48 |
+
self.inp_dtype = inp_dtype
|
49 |
+
self.relay = relay
|
50 |
+
self.use_latency = use_latency
|
51 |
+
self.use_remote_machine = use_remote_machine
|
52 |
+
|
53 |
+
rate_fraction = ms_gran / 1000
|
54 |
+
self.chunk_size = int(rate_fraction * sample_rate)
|
55 |
+
self.filename = filename
|
56 |
+
self.streamed_wav_file = wav_file
|
57 |
+
|
58 |
+
self.simulate_stream = simulate_stream
|
59 |
+
if self.simulate_stream:
|
60 |
+
assert wav_file is not None, "when simulating stream a wav file must be provided."
|
61 |
+
if pad_trim:
|
62 |
+
self.wav_array = pad_or_trim(load_audio(wav_file, sample_rate), length=N_SAMPLES+180) # wav array
|
63 |
+
else:
|
64 |
+
audio = load_audio(wav_file, sample_rate)
|
65 |
+
self.wav_array = pad_or_trim(audio, length=audio.shape[-1]+180)
|
66 |
+
print(f"{self.wav_array.shape=}")
|
67 |
+
|
68 |
+
def _simulate_stream_using_wav(self):
|
69 |
+
print("Streaming simulation of a wav started...")
|
70 |
+
|
71 |
+
for i in range(self.wav_array.shape[-1] // self.chunk_size):
|
72 |
+
if i == 0:
|
73 |
+
yield self.wav_array[..., :(((i + 1) * self.chunk_size) + 40 + 320)] # 320 is extra 20 msec buffer we need!
|
74 |
+
else:
|
75 |
+
yield self.wav_array[..., ((i * self.chunk_size) + 40 + 320):(((i + 1) * self.chunk_size) + 40 + 320)]
|
76 |
+
|
77 |
+
if self.use_latency: time.sleep(self.ms_gran / 1000) # simulating the latency between audio chunks
|
78 |
+
|
79 |
+
def open_stream(self):
|
80 |
+
if self.simulate_stream or self.relay or self.use_remote_machine: return
|
81 |
+
|
82 |
+
self.audio = pyaudio.PyAudio()
|
83 |
+
self.stream = self.audio.open(input=True, format=self.inp_dtype, channels=self.channels, rate=self.sample_rate, frames_per_buffer=self.chunk_size)
|
84 |
+
|
85 |
+
def _read_from_stream(self):
|
86 |
+
print("Streaming instance recording started...")
|
87 |
+
|
88 |
+
while True:
|
89 |
+
yield self.stream.read(self.chunk_size)
|
90 |
+
|
91 |
+
def _follow_growing_wav(self):
|
92 |
+
while not os.path.exists(self.streamed_wav_file):
|
93 |
+
time.sleep(0.1)
|
94 |
+
|
95 |
+
with sf.SoundFile(self.streamed_wav_file, mode='r') as f:
|
96 |
+
while True:
|
97 |
+
block = f.read(self.chunk_size)
|
98 |
+
if len(block) == 0:
|
99 |
+
time.sleep(self.ms_gran / 1000) # Wait for more data
|
100 |
+
continue
|
101 |
+
yield block
|
102 |
+
|
103 |
+
def _read_raw_pcm(self):
|
104 |
+
samples_per_chunk = int(self.sample_rate * (self.ms_gran / 1000))
|
105 |
+
bytes_per_sample = 2 # s16le = 16 bits = 2 bytes
|
106 |
+
chunk_size = samples_per_chunk * bytes_per_sample
|
107 |
+
|
108 |
+
while not os.path.exists(self.streamed_wav_file):
|
109 |
+
time.sleep(0.1)
|
110 |
+
|
111 |
+
with open(self.streamed_wav_file, 'rb') as f:
|
112 |
+
while True:
|
113 |
+
chunk = f.read(chunk_size)
|
114 |
+
if not chunk:
|
115 |
+
time.sleep((self.ms_gran / 1000))
|
116 |
+
continue
|
117 |
+
yield np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32768.0
|
118 |
+
|
119 |
+
def read(self):
|
120 |
+
if self.simulate_stream:
|
121 |
+
return self._simulate_stream_using_wav()
|
122 |
+
|
123 |
+
if self.use_remote_machine:
|
124 |
+
return self._read_raw_pcm()
|
125 |
+
|
126 |
+
return self._read_from_stream()
|
127 |
+
|
128 |
+
def _save_recording_file(self, frames: list):
|
129 |
+
print(f"Saving recorded audio file on path {self.filename}")
|
130 |
+
|
131 |
+
waveFile = wave.open(self.filename, 'wb')
|
132 |
+
waveFile.setnchannels(self.channels)
|
133 |
+
waveFile.setsampwidth(self.audio.get_sample_size(self.inp_dtype))
|
134 |
+
waveFile.setframerate(self.sample_rate)
|
135 |
+
waveFile.writeframes(b''.join(frames))
|
136 |
+
waveFile.close()
|
137 |
+
|
138 |
+
def close_stream(self, frames: list):
|
139 |
+
if self.simulate_stream: return
|
140 |
+
|
141 |
+
# Stop Recording
|
142 |
+
self.stream.stop_stream()
|
143 |
+
self.stream.close()
|
144 |
+
self.audio.terminate()
|
145 |
+
|
146 |
+
print("Finished recording, stream and audio terminated.")
|
147 |
+
|
148 |
+
if self.filename: self._save_recording_file(frames)
|
149 |
+
|
150 |
+
|
151 |
+
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
152 |
+
"""
|
153 |
+
Open an audio file and read as mono waveform, resampling as necessary
|
154 |
+
|
155 |
+
Parameters
|
156 |
+
----------
|
157 |
+
file: str
|
158 |
+
The audio file to open
|
159 |
+
|
160 |
+
sr: int
|
161 |
+
The sample rate to resample the audio if necessary
|
162 |
+
|
163 |
+
Returns
|
164 |
+
-------
|
165 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
166 |
+
"""
|
167 |
+
|
168 |
+
# This launches a subprocess to decode audio while down-mixing
|
169 |
+
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
170 |
+
# fmt: off
|
171 |
+
cmd = [
|
172 |
+
"ffmpeg",
|
173 |
+
"-nostdin",
|
174 |
+
"-threads", "0",
|
175 |
+
"-i", file,
|
176 |
+
"-f", "s16le",
|
177 |
+
"-ac", "1",
|
178 |
+
"-acodec", "pcm_s16le",
|
179 |
+
"-ar", str(sr),
|
180 |
+
"-"
|
181 |
+
]
|
182 |
+
# fmt: on
|
183 |
+
try:
|
184 |
+
out = run(cmd, capture_output=True, check=True).stdout
|
185 |
+
except CalledProcessError as e:
|
186 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
187 |
+
|
188 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
189 |
+
|
190 |
+
|
191 |
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
192 |
+
"""
|
193 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
194 |
+
"""
|
195 |
+
if torch.is_tensor(array):
|
196 |
+
if array.shape[axis] > length:
|
197 |
+
array = array.index_select(
|
198 |
+
dim=axis, index=torch.arange(length, device=array.device)
|
199 |
+
)
|
200 |
+
|
201 |
+
if array.shape[axis] < length:
|
202 |
+
pad_widths = [(0, 0)] * array.ndim
|
203 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
204 |
+
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
205 |
+
else:
|
206 |
+
if array.shape[axis] > length:
|
207 |
+
array = array.take(indices=range(length), axis=axis)
|
208 |
+
|
209 |
+
if array.shape[axis] < length:
|
210 |
+
pad_widths = [(0, 0)] * array.ndim
|
211 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
212 |
+
array = np.pad(array, pad_widths)
|
213 |
+
|
214 |
+
return array
|
215 |
+
|
216 |
+
|
217 |
+
@lru_cache(maxsize=None)
|
218 |
+
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
219 |
+
"""
|
220 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
221 |
+
Allows decoupling librosa dependency; saved using:
|
222 |
+
|
223 |
+
np.savez_compressed(
|
224 |
+
"mel_filters.npz",
|
225 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
226 |
+
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
227 |
+
)
|
228 |
+
"""
|
229 |
+
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
230 |
+
|
231 |
+
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
232 |
+
with np.load(filters_path, allow_pickle=False) as f:
|
233 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
234 |
+
|
235 |
+
|
236 |
+
def log_mel_spectrogram(
|
237 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
238 |
+
n_mels: int = 80,
|
239 |
+
padding: int = 0,
|
240 |
+
device: Optional[Union[str, torch.device]] = None,
|
241 |
+
):
|
242 |
+
"""
|
243 |
+
Compute the log-Mel spectrogram of
|
244 |
+
|
245 |
+
Parameters
|
246 |
+
----------
|
247 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
248 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
249 |
+
|
250 |
+
n_mels: int
|
251 |
+
The number of Mel-frequency filters, only 80 is supported
|
252 |
+
|
253 |
+
padding: int
|
254 |
+
Number of zero samples to pad to the right
|
255 |
+
|
256 |
+
device: Optional[Union[str, torch.device]]
|
257 |
+
If given, the audio tensor is moved to this device before STFT
|
258 |
+
|
259 |
+
Returns
|
260 |
+
-------
|
261 |
+
torch.Tensor, shape = (80, n_frames)
|
262 |
+
A Tensor that contains the Mel spectrogram
|
263 |
+
"""
|
264 |
+
if not torch.is_tensor(audio):
|
265 |
+
if isinstance(audio, str):
|
266 |
+
audio = load_audio(audio)
|
267 |
+
audio = torch.from_numpy(audio)
|
268 |
+
|
269 |
+
if device is not None:
|
270 |
+
audio = audio.to(device)
|
271 |
+
if padding > 0:
|
272 |
+
audio = F.pad(audio, (0, padding))
|
273 |
+
|
274 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
275 |
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
276 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
277 |
+
|
278 |
+
filters = mel_filters(audio.device, n_mels)
|
279 |
+
mel_spec = filters @ magnitudes
|
280 |
+
|
281 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
282 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
283 |
+
log_spec = (log_spec + 4.0) / 4.0
|
284 |
+
return log_spec
|
285 |
+
|
286 |
+
|
287 |
+
class SpectrogramStream:
|
288 |
+
def __init__(self, n_fft: int = N_FFT, hop_length: int = HOP_LENGTH, n_mels: int = 80, window: Optional[str] = "hann", pad_mode: str = "reflect"):
|
289 |
+
|
290 |
+
self.n_fft = n_fft
|
291 |
+
self.hop_length = hop_length
|
292 |
+
self.pad_mode = pad_mode
|
293 |
+
self.n_mels = n_mels
|
294 |
+
|
295 |
+
self.window = torch.hann_window(n_fft)
|
296 |
+
self.window_type = window
|
297 |
+
|
298 |
+
self.ctx_samples = self.n_fft - self.hop_length
|
299 |
+
|
300 |
+
self.reset()
|
301 |
+
|
302 |
+
def reset(self):
|
303 |
+
self.is_first = True
|
304 |
+
self.audio_ctx = torch.tensor([])
|
305 |
+
self.log_spec_max = -torch.inf
|
306 |
+
|
307 |
+
def calc_mel_with_new_frame(self, audio_frame: torch.Tensor, is_last: bool = False):
|
308 |
+
|
309 |
+
self.window = self.window.to(audio_frame.device)
|
310 |
+
|
311 |
+
if len(audio_frame.shape) == 1:
|
312 |
+
audio_frame = audio_frame.unsqueeze(0)
|
313 |
+
|
314 |
+
n_batch = audio_frame.shape[0]
|
315 |
+
|
316 |
+
if isinstance(self.log_spec_max, float):
|
317 |
+
self.log_spec_max = torch.ones((n_batch)).to(audio_frame.device) * -torch.inf
|
318 |
+
|
319 |
+
# check if we are on first frame, if so, pad using reflection
|
320 |
+
if self.is_first:
|
321 |
+
pad = int(self.n_fft // 2) + 1
|
322 |
+
audio_input = F.pad(audio_frame, [pad, 0], self.pad_mode)
|
323 |
+
self.is_first = False
|
324 |
+
else: # pad with previous context
|
325 |
+
audio_input = torch.cat([self.audio_ctx[..., -self.ctx_samples:], audio_frame], dim=-1)
|
326 |
+
|
327 |
+
if is_last: # pad reflect last frame
|
328 |
+
pad = int(self.n_fft // 4) + 1
|
329 |
+
audio_input = F.pad(audio_input, [pad, 0], self.pad_mode)
|
330 |
+
|
331 |
+
self.audio_ctx = audio_frame # now audio ctx is the last frame
|
332 |
+
|
333 |
+
stft = torch.stft(audio_input, self.n_fft, self.hop_length, window=self.window, return_complex=True, center=False)
|
334 |
+
magnitudes = stft.abs() ** 2
|
335 |
+
filters = mel_filters(audio_frame.device, self.n_mels)
|
336 |
+
mel_spec = filters @ magnitudes
|
337 |
+
|
338 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10() # from shape (b, n_mels, audio_frames)
|
339 |
+
self.log_spec_max = torch.maximum(log_spec.view(n_batch, -1).max(dim=-1).values, self.log_spec_max).to(log_spec.device)
|
340 |
+
|
341 |
+
log_spec = torch.maximum(log_spec.view(n_batch, -1).permute(1, 0), self.log_spec_max - 8.0).permute(1, 0).view(n_batch, self.n_mels, -1)
|
342 |
+
log_spec = (log_spec + 4.0) / 4.0
|
343 |
+
return log_spec
|
344 |
+
|
345 |
+
def _simulate_streaming_log_spec(self, audio: torch.Tensor, ms_gran: int = 300, total_frames: int = 3000, get_gt: bool = False):
|
346 |
+
self.reset()
|
347 |
+
|
348 |
+
samples_gran = HOP_LENGTH * (ms_gran // 10)
|
349 |
+
sub_mel_frames = int(total_frames / ms_gran) * 10
|
350 |
+
# print(samples_gran, sub_mel_frames)
|
351 |
+
pred_mel = torch.cat([self.calc_mel_with_new_frame(audio[..., (i * samples_gran) + (40 * int(i != 0)): ((i + 1) * samples_gran) + 40], is_last=(i == sub_mel_frames - 1)) for i in range(sub_mel_frames)], dim=-1)
|
352 |
+
|
353 |
+
if get_gt:
|
354 |
+
gt_mel = log_mel_spectrogram(audio)
|
355 |
+
return pred_mel, gt_mel
|
356 |
+
|
357 |
+
return pred_mel
|
whisper_stream/decoding.py
ADDED
@@ -0,0 +1,826 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field, replace
|
2 |
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import Tensor
|
8 |
+
from torch.distributions import Categorical
|
9 |
+
|
10 |
+
from .audio import CHUNK_LENGTH
|
11 |
+
from .tokenizer import Tokenizer, get_tokenizer
|
12 |
+
from .utils import compression_ratio
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from .model import Whisper
|
16 |
+
|
17 |
+
|
18 |
+
@torch.no_grad()
|
19 |
+
def detect_language(
|
20 |
+
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
21 |
+
) -> Tuple[Tensor, List[dict]]:
|
22 |
+
"""
|
23 |
+
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
24 |
+
of the most probable language tokens and the probability distribution over all language tokens.
|
25 |
+
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
26 |
+
|
27 |
+
Returns
|
28 |
+
-------
|
29 |
+
language_tokens : Tensor, shape = (n_audio,)
|
30 |
+
ids of the most probable language tokens, which appears after the startoftranscript token.
|
31 |
+
language_probs : List[Dict[str, float]], length = n_audio
|
32 |
+
list of dictionaries containing the probability distribution over all languages.
|
33 |
+
"""
|
34 |
+
if tokenizer is None:
|
35 |
+
tokenizer = get_tokenizer(
|
36 |
+
model.is_multilingual, num_languages=model.num_languages
|
37 |
+
)
|
38 |
+
if (
|
39 |
+
tokenizer.language is None
|
40 |
+
or tokenizer.language_token not in tokenizer.sot_sequence
|
41 |
+
):
|
42 |
+
raise ValueError(
|
43 |
+
"This model doesn't have language tokens so it can't perform lang id"
|
44 |
+
)
|
45 |
+
|
46 |
+
single = mel.ndim == 2
|
47 |
+
if single:
|
48 |
+
mel = mel.unsqueeze(0)
|
49 |
+
|
50 |
+
# skip encoder forward pass if already-encoded audio features were given
|
51 |
+
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
52 |
+
mel = model.encoder(mel)
|
53 |
+
|
54 |
+
# forward pass using a single token, startoftranscript
|
55 |
+
n_audio = mel.shape[0]
|
56 |
+
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
57 |
+
logits = model.logits(x, mel)[:, 0]
|
58 |
+
|
59 |
+
# collect detected languages; suppress all non-language tokens
|
60 |
+
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
61 |
+
mask[list(tokenizer.all_language_tokens)] = False
|
62 |
+
logits[:, mask] = -np.inf
|
63 |
+
language_tokens = logits.argmax(dim=-1)
|
64 |
+
language_token_probs = logits.softmax(dim=-1).cpu()
|
65 |
+
language_probs = [
|
66 |
+
{
|
67 |
+
c: language_token_probs[i, j].item()
|
68 |
+
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
69 |
+
}
|
70 |
+
for i in range(n_audio)
|
71 |
+
]
|
72 |
+
|
73 |
+
if single:
|
74 |
+
language_tokens = language_tokens[0]
|
75 |
+
language_probs = language_probs[0]
|
76 |
+
|
77 |
+
return language_tokens, language_probs
|
78 |
+
|
79 |
+
|
80 |
+
@dataclass(frozen=True)
|
81 |
+
class DecodingOptions:
|
82 |
+
# whether to perform X->X "transcribe" or X->English "translate"
|
83 |
+
task: str = "transcribe"
|
84 |
+
|
85 |
+
# language that the audio is in; uses detected language if None
|
86 |
+
language: Optional[str] = None
|
87 |
+
|
88 |
+
# sampling-related options
|
89 |
+
temperature: float = 0.0
|
90 |
+
sample_len: Optional[int] = None # maximum number of tokens to sample
|
91 |
+
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
92 |
+
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
93 |
+
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
94 |
+
|
95 |
+
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
96 |
+
# to select which to return among the beams or best-of-N samples
|
97 |
+
length_penalty: Optional[float] = None
|
98 |
+
|
99 |
+
# text or tokens to feed as the prompt or the prefix; for more info:
|
100 |
+
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
101 |
+
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
102 |
+
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
103 |
+
|
104 |
+
# list of tokens ids (or comma-separated token ids) to suppress
|
105 |
+
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
106 |
+
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
107 |
+
suppress_blank: bool = True # this will suppress blank outputs
|
108 |
+
|
109 |
+
# timestamp sampling options
|
110 |
+
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
111 |
+
max_initial_timestamp: Optional[float] = 1.0
|
112 |
+
|
113 |
+
# implementation details
|
114 |
+
fp16: bool = True # use fp16 for most of the calculation
|
115 |
+
|
116 |
+
|
117 |
+
@dataclass(frozen=True)
|
118 |
+
class DecodingResult:
|
119 |
+
audio_features: Tensor
|
120 |
+
language: str
|
121 |
+
language_probs: Optional[Dict[str, float]] = None
|
122 |
+
tokens: List[int] = field(default_factory=list)
|
123 |
+
text: str = ""
|
124 |
+
avg_logprob: float = np.nan
|
125 |
+
no_speech_prob: float = np.nan
|
126 |
+
temperature: float = np.nan
|
127 |
+
compression_ratio: float = np.nan
|
128 |
+
|
129 |
+
|
130 |
+
class Inference:
|
131 |
+
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
132 |
+
"""Perform a forward pass on the decoder and return per-token logits"""
|
133 |
+
raise NotImplementedError
|
134 |
+
|
135 |
+
def rearrange_kv_cache(self, source_indices) -> None:
|
136 |
+
"""Update the key-value cache according to the updated beams"""
|
137 |
+
raise NotImplementedError
|
138 |
+
|
139 |
+
def cleanup_caching(self) -> None:
|
140 |
+
"""Clean up any resources or hooks after decoding is finished"""
|
141 |
+
pass
|
142 |
+
|
143 |
+
|
144 |
+
class PyTorchInference(Inference):
|
145 |
+
def __init__(self, model: "Whisper", initial_token_length: int):
|
146 |
+
self.model: "Whisper" = model
|
147 |
+
self.initial_token_length = initial_token_length
|
148 |
+
self.kv_cache = {}
|
149 |
+
self.hooks = []
|
150 |
+
|
151 |
+
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
152 |
+
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
153 |
+
self.kv_modules = key_modules + value_modules
|
154 |
+
|
155 |
+
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
156 |
+
if not self.kv_cache:
|
157 |
+
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
158 |
+
|
159 |
+
if tokens.shape[-1] > self.initial_token_length:
|
160 |
+
# only need to use the last token except in the first forward pass
|
161 |
+
tokens = tokens[:, -1:]
|
162 |
+
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
163 |
+
|
164 |
+
def cleanup_caching(self):
|
165 |
+
for hook in self.hooks:
|
166 |
+
hook.remove()
|
167 |
+
|
168 |
+
self.kv_cache = {}
|
169 |
+
self.hooks = []
|
170 |
+
|
171 |
+
def rearrange_kv_cache(self, source_indices):
|
172 |
+
if source_indices != list(range(len(source_indices))):
|
173 |
+
for module in self.kv_modules:
|
174 |
+
# update the key/value cache to contain the selected sequences
|
175 |
+
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
176 |
+
|
177 |
+
|
178 |
+
class SequenceRanker:
|
179 |
+
def rank(
|
180 |
+
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
|
181 |
+
) -> List[int]:
|
182 |
+
"""
|
183 |
+
Given a list of groups of samples and their cumulative log probabilities,
|
184 |
+
return the indices of the samples in each group to select as the final result
|
185 |
+
"""
|
186 |
+
raise NotImplementedError
|
187 |
+
|
188 |
+
|
189 |
+
class MaximumLikelihoodRanker(SequenceRanker):
|
190 |
+
"""
|
191 |
+
Select the sample with the highest log probabilities, penalized using either
|
192 |
+
a simple length normalization or Google NMT paper's length penalty
|
193 |
+
"""
|
194 |
+
|
195 |
+
def __init__(self, length_penalty: Optional[float]):
|
196 |
+
self.length_penalty = length_penalty
|
197 |
+
|
198 |
+
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
199 |
+
def scores(logprobs, lengths):
|
200 |
+
result = []
|
201 |
+
for logprob, length in zip(logprobs, lengths):
|
202 |
+
if self.length_penalty is None:
|
203 |
+
penalty = length
|
204 |
+
else:
|
205 |
+
# from the Google NMT paper
|
206 |
+
penalty = ((5 + length) / 6) ** self.length_penalty
|
207 |
+
result.append(logprob / penalty)
|
208 |
+
return result
|
209 |
+
|
210 |
+
# get the sequence with the highest score
|
211 |
+
lengths = [[len(t) for t in s] for s in tokens]
|
212 |
+
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
213 |
+
|
214 |
+
|
215 |
+
class TokenDecoder:
|
216 |
+
def reset(self):
|
217 |
+
"""Initialize any stateful variables for decoding a new sequence"""
|
218 |
+
|
219 |
+
def update(
|
220 |
+
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
221 |
+
) -> Tuple[Tensor, bool]:
|
222 |
+
"""Specify how to select the next token, based on the current trace and logits
|
223 |
+
|
224 |
+
Parameters
|
225 |
+
----------
|
226 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
227 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
228 |
+
|
229 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
230 |
+
per-token logits of the probability distribution at the current step
|
231 |
+
|
232 |
+
sum_logprobs : Tensor, shape = (n_batch)
|
233 |
+
cumulative log probabilities for each sequence
|
234 |
+
|
235 |
+
Returns
|
236 |
+
-------
|
237 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
238 |
+
the tokens, appended with the selected next token
|
239 |
+
|
240 |
+
completed : bool
|
241 |
+
True if all sequences has reached the end of text
|
242 |
+
|
243 |
+
"""
|
244 |
+
raise NotImplementedError
|
245 |
+
|
246 |
+
def finalize(
|
247 |
+
self, tokens: Tensor, sum_logprobs: Tensor
|
248 |
+
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
249 |
+
"""Finalize search and return the final candidate sequences
|
250 |
+
|
251 |
+
Parameters
|
252 |
+
----------
|
253 |
+
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
254 |
+
all tokens in the context so far, including the prefix and sot_sequence
|
255 |
+
|
256 |
+
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
257 |
+
cumulative log probabilities for each sequence
|
258 |
+
|
259 |
+
Returns
|
260 |
+
-------
|
261 |
+
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
262 |
+
sequence of Tensors containing candidate token sequences, for each audio input
|
263 |
+
|
264 |
+
sum_logprobs : List[List[float]], length = n_audio
|
265 |
+
sequence of cumulative log probabilities corresponding to the above
|
266 |
+
|
267 |
+
"""
|
268 |
+
raise NotImplementedError
|
269 |
+
|
270 |
+
|
271 |
+
class GreedyDecoder(TokenDecoder):
|
272 |
+
def __init__(self, temperature: float, eot: int):
|
273 |
+
self.temperature = temperature
|
274 |
+
self.eot = eot
|
275 |
+
|
276 |
+
def update(
|
277 |
+
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
278 |
+
) -> Tuple[Tensor, bool]:
|
279 |
+
if self.temperature == 0:
|
280 |
+
next_tokens = logits.argmax(dim=-1)
|
281 |
+
else:
|
282 |
+
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
283 |
+
|
284 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
285 |
+
|
286 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
287 |
+
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
288 |
+
|
289 |
+
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
290 |
+
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
291 |
+
|
292 |
+
completed = (tokens[:, -1] == self.eot).all()
|
293 |
+
return tokens, completed
|
294 |
+
|
295 |
+
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
296 |
+
# make sure each sequence has at least one EOT token at the end
|
297 |
+
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
298 |
+
return tokens, sum_logprobs.tolist()
|
299 |
+
|
300 |
+
|
301 |
+
class BeamSearchDecoder(TokenDecoder):
|
302 |
+
def __init__(
|
303 |
+
self,
|
304 |
+
beam_size: int,
|
305 |
+
eot: int,
|
306 |
+
inference: Inference,
|
307 |
+
patience: Optional[float] = None,
|
308 |
+
):
|
309 |
+
self.beam_size = beam_size
|
310 |
+
self.eot = eot
|
311 |
+
self.inference = inference
|
312 |
+
self.patience = patience or 1.0
|
313 |
+
self.max_candidates: int = round(beam_size * self.patience)
|
314 |
+
self.finished_sequences = None
|
315 |
+
|
316 |
+
assert (
|
317 |
+
self.max_candidates > 0
|
318 |
+
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
319 |
+
|
320 |
+
def reset(self):
|
321 |
+
self.finished_sequences = None
|
322 |
+
|
323 |
+
def update(
|
324 |
+
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
325 |
+
) -> Tuple[Tensor, bool]:
|
326 |
+
if tokens.shape[0] % self.beam_size != 0:
|
327 |
+
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
328 |
+
|
329 |
+
n_audio = tokens.shape[0] // self.beam_size
|
330 |
+
if self.finished_sequences is None: # for the first update
|
331 |
+
self.finished_sequences = [{} for _ in range(n_audio)]
|
332 |
+
|
333 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
334 |
+
next_tokens, source_indices, finished_sequences = [], [], []
|
335 |
+
for i in range(n_audio):
|
336 |
+
scores, sources, finished = {}, {}, {}
|
337 |
+
|
338 |
+
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
339 |
+
for j in range(self.beam_size):
|
340 |
+
idx = i * self.beam_size + j
|
341 |
+
prefix = tokens[idx].tolist()
|
342 |
+
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
343 |
+
new_logprob = (sum_logprobs[idx] + logprob).item()
|
344 |
+
sequence = tuple(prefix + [token.item()])
|
345 |
+
scores[sequence] = new_logprob
|
346 |
+
sources[sequence] = idx
|
347 |
+
|
348 |
+
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
349 |
+
saved = 0
|
350 |
+
for sequence in sorted(scores, key=scores.get, reverse=True):
|
351 |
+
if sequence[-1] == self.eot:
|
352 |
+
finished[sequence] = scores[sequence]
|
353 |
+
else:
|
354 |
+
sum_logprobs[len(next_tokens)] = scores[sequence]
|
355 |
+
next_tokens.append(sequence)
|
356 |
+
source_indices.append(sources[sequence])
|
357 |
+
|
358 |
+
saved += 1
|
359 |
+
if saved == self.beam_size:
|
360 |
+
break
|
361 |
+
|
362 |
+
finished_sequences.append(finished)
|
363 |
+
|
364 |
+
tokens = torch.tensor(next_tokens, device=tokens.device)
|
365 |
+
self.inference.rearrange_kv_cache(source_indices)
|
366 |
+
|
367 |
+
# add newly finished sequences to self.finished_sequences
|
368 |
+
assert len(self.finished_sequences) == len(finished_sequences)
|
369 |
+
for previously_finished, newly_finished in zip(
|
370 |
+
self.finished_sequences, finished_sequences
|
371 |
+
):
|
372 |
+
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
373 |
+
if len(previously_finished) >= self.max_candidates:
|
374 |
+
break # the candidate list is full
|
375 |
+
previously_finished[seq] = newly_finished[seq]
|
376 |
+
|
377 |
+
# mark as completed if all audio has enough number of samples
|
378 |
+
completed = all(
|
379 |
+
len(sequences) >= self.max_candidates
|
380 |
+
for sequences in self.finished_sequences
|
381 |
+
)
|
382 |
+
return tokens, completed
|
383 |
+
|
384 |
+
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
385 |
+
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
386 |
+
sum_logprobs = sum_logprobs.cpu()
|
387 |
+
for i, sequences in enumerate(self.finished_sequences):
|
388 |
+
if (
|
389 |
+
len(sequences) < self.beam_size
|
390 |
+
): # when not enough sequences are finished
|
391 |
+
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
392 |
+
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
393 |
+
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
394 |
+
if len(sequences) >= self.beam_size:
|
395 |
+
break
|
396 |
+
|
397 |
+
tokens: List[List[Tensor]] = [
|
398 |
+
[torch.tensor(seq) for seq in sequences.keys()]
|
399 |
+
for sequences in self.finished_sequences
|
400 |
+
]
|
401 |
+
sum_logprobs: List[List[float]] = [
|
402 |
+
list(sequences.values()) for sequences in self.finished_sequences
|
403 |
+
]
|
404 |
+
return tokens, sum_logprobs
|
405 |
+
|
406 |
+
|
407 |
+
class LogitFilter:
|
408 |
+
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
409 |
+
"""Apply any filtering or masking to logits in-place
|
410 |
+
|
411 |
+
Parameters
|
412 |
+
----------
|
413 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
414 |
+
per-token logits of the probability distribution at the current step
|
415 |
+
|
416 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
417 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
418 |
+
|
419 |
+
"""
|
420 |
+
raise NotImplementedError
|
421 |
+
|
422 |
+
|
423 |
+
class SuppressBlank(LogitFilter):
|
424 |
+
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
425 |
+
self.tokenizer = tokenizer
|
426 |
+
self.sample_begin = sample_begin
|
427 |
+
|
428 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
429 |
+
if tokens.shape[1] == self.sample_begin:
|
430 |
+
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
431 |
+
|
432 |
+
|
433 |
+
class SuppressTokens(LogitFilter):
|
434 |
+
def __init__(self, suppress_tokens: Sequence[int]):
|
435 |
+
self.suppress_tokens = list(suppress_tokens)
|
436 |
+
|
437 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
438 |
+
logits[:, self.suppress_tokens] = -np.inf
|
439 |
+
|
440 |
+
|
441 |
+
class ApplyTimestampRules(LogitFilter):
|
442 |
+
def __init__(
|
443 |
+
self,
|
444 |
+
tokenizer: Tokenizer,
|
445 |
+
sample_begin: int,
|
446 |
+
max_initial_timestamp_index: Optional[int],
|
447 |
+
):
|
448 |
+
self.tokenizer = tokenizer
|
449 |
+
self.sample_begin = sample_begin
|
450 |
+
self.max_initial_timestamp_index = max_initial_timestamp_index
|
451 |
+
|
452 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
453 |
+
# suppress <|notimestamps|> which is handled by without_timestamps
|
454 |
+
if self.tokenizer.no_timestamps is not None:
|
455 |
+
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
456 |
+
|
457 |
+
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
458 |
+
for k in range(tokens.shape[0]):
|
459 |
+
sampled_tokens = tokens[k, self.sample_begin :]
|
460 |
+
seq = [t for t in sampled_tokens.tolist()]
|
461 |
+
last_was_timestamp = (
|
462 |
+
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
463 |
+
)
|
464 |
+
penultimate_was_timestamp = (
|
465 |
+
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
466 |
+
)
|
467 |
+
|
468 |
+
if last_was_timestamp:
|
469 |
+
if penultimate_was_timestamp: # has to be non-timestamp
|
470 |
+
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
471 |
+
else: # cannot be normal text tokens
|
472 |
+
logits[k, : self.tokenizer.eot] = -np.inf
|
473 |
+
|
474 |
+
timestamps = sampled_tokens[
|
475 |
+
sampled_tokens.ge(self.tokenizer.timestamp_begin)
|
476 |
+
]
|
477 |
+
if timestamps.numel() > 0:
|
478 |
+
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
479 |
+
# also force each segment to have a nonzero length, to prevent infinite looping
|
480 |
+
if last_was_timestamp and not penultimate_was_timestamp:
|
481 |
+
timestamp_last = timestamps[-1]
|
482 |
+
else:
|
483 |
+
timestamp_last = timestamps[-1] + 1
|
484 |
+
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
|
485 |
+
|
486 |
+
if tokens.shape[1] == self.sample_begin:
|
487 |
+
# suppress generating non-timestamp tokens at the beginning
|
488 |
+
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
489 |
+
|
490 |
+
# apply the `max_initial_timestamp` option
|
491 |
+
if self.max_initial_timestamp_index is not None:
|
492 |
+
last_allowed = (
|
493 |
+
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
494 |
+
)
|
495 |
+
logits[:, last_allowed + 1 :] = -np.inf
|
496 |
+
|
497 |
+
# if sum of probability over timestamps is above any other token, sample timestamp
|
498 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
499 |
+
for k in range(tokens.shape[0]):
|
500 |
+
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
501 |
+
dim=-1
|
502 |
+
)
|
503 |
+
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
504 |
+
if timestamp_logprob > max_text_token_logprob:
|
505 |
+
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
506 |
+
|
507 |
+
|
508 |
+
class DecodingTask:
|
509 |
+
inference: Inference
|
510 |
+
sequence_ranker: SequenceRanker
|
511 |
+
decoder: TokenDecoder
|
512 |
+
logit_filters: List[LogitFilter]
|
513 |
+
|
514 |
+
def __init__(self, model: "Whisper", options: DecodingOptions):
|
515 |
+
self.model = model
|
516 |
+
|
517 |
+
language = options.language or "en"
|
518 |
+
tokenizer = get_tokenizer(
|
519 |
+
model.is_multilingual,
|
520 |
+
num_languages=model.num_languages,
|
521 |
+
language=language,
|
522 |
+
task=options.task,
|
523 |
+
)
|
524 |
+
self.tokenizer: Tokenizer = tokenizer
|
525 |
+
self.options: DecodingOptions = self._verify_options(options)
|
526 |
+
|
527 |
+
self.n_group: int = options.beam_size or options.best_of or 1
|
528 |
+
self.n_ctx: int = model.dims.n_text_ctx
|
529 |
+
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
530 |
+
|
531 |
+
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
532 |
+
if self.options.without_timestamps:
|
533 |
+
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
534 |
+
|
535 |
+
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
536 |
+
self.sample_begin: int = len(self.initial_tokens)
|
537 |
+
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
538 |
+
|
539 |
+
# inference: implements the forward pass through the decoder, including kv caching
|
540 |
+
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
541 |
+
|
542 |
+
# sequence ranker: implements how to rank a group of sampled sequences
|
543 |
+
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
544 |
+
|
545 |
+
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
546 |
+
if options.beam_size is not None:
|
547 |
+
self.decoder = BeamSearchDecoder(
|
548 |
+
options.beam_size, tokenizer.eot, self.inference, options.patience
|
549 |
+
)
|
550 |
+
else:
|
551 |
+
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
552 |
+
|
553 |
+
# logit filters: applies various rules to suppress or penalize certain tokens
|
554 |
+
self.logit_filters = []
|
555 |
+
if self.options.suppress_blank:
|
556 |
+
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
557 |
+
if self.options.suppress_tokens:
|
558 |
+
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
559 |
+
if not options.without_timestamps:
|
560 |
+
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
561 |
+
max_initial_timestamp_index = None
|
562 |
+
if options.max_initial_timestamp:
|
563 |
+
max_initial_timestamp_index = round(
|
564 |
+
self.options.max_initial_timestamp / precision
|
565 |
+
)
|
566 |
+
self.logit_filters.append(
|
567 |
+
ApplyTimestampRules(
|
568 |
+
tokenizer, self.sample_begin, max_initial_timestamp_index
|
569 |
+
)
|
570 |
+
)
|
571 |
+
|
572 |
+
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
573 |
+
if options.beam_size is not None and options.best_of is not None:
|
574 |
+
raise ValueError("beam_size and best_of can't be given together")
|
575 |
+
if options.temperature == 0:
|
576 |
+
if options.best_of is not None:
|
577 |
+
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
578 |
+
if options.patience is not None and options.beam_size is None:
|
579 |
+
raise ValueError("patience requires beam_size to be given")
|
580 |
+
if options.length_penalty is not None and not (
|
581 |
+
0 <= options.length_penalty <= 1
|
582 |
+
):
|
583 |
+
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
584 |
+
|
585 |
+
return options
|
586 |
+
|
587 |
+
def _get_initial_tokens(self) -> Tuple[int]:
|
588 |
+
tokens = list(self.sot_sequence)
|
589 |
+
|
590 |
+
if prefix := self.options.prefix:
|
591 |
+
prefix_tokens = (
|
592 |
+
self.tokenizer.encode(" " + prefix.strip())
|
593 |
+
if isinstance(prefix, str)
|
594 |
+
else prefix
|
595 |
+
)
|
596 |
+
if self.sample_len is not None:
|
597 |
+
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
598 |
+
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
599 |
+
tokens = tokens + prefix_tokens
|
600 |
+
|
601 |
+
if prompt := self.options.prompt:
|
602 |
+
prompt_tokens = (
|
603 |
+
self.tokenizer.encode(" " + prompt.strip())
|
604 |
+
if isinstance(prompt, str)
|
605 |
+
else prompt
|
606 |
+
)
|
607 |
+
tokens = (
|
608 |
+
[self.tokenizer.sot_prev]
|
609 |
+
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
610 |
+
+ tokens
|
611 |
+
)
|
612 |
+
|
613 |
+
return tuple(tokens)
|
614 |
+
|
615 |
+
def _get_suppress_tokens(self) -> Tuple[int]:
|
616 |
+
suppress_tokens = self.options.suppress_tokens
|
617 |
+
|
618 |
+
if isinstance(suppress_tokens, str):
|
619 |
+
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
620 |
+
|
621 |
+
if -1 in suppress_tokens:
|
622 |
+
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
623 |
+
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
624 |
+
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
625 |
+
suppress_tokens = [] # interpret empty string as an empty list
|
626 |
+
else:
|
627 |
+
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
628 |
+
|
629 |
+
suppress_tokens.extend(
|
630 |
+
[
|
631 |
+
self.tokenizer.transcribe,
|
632 |
+
self.tokenizer.translate,
|
633 |
+
self.tokenizer.sot,
|
634 |
+
self.tokenizer.sot_prev,
|
635 |
+
self.tokenizer.sot_lm,
|
636 |
+
]
|
637 |
+
)
|
638 |
+
if self.tokenizer.no_speech is not None:
|
639 |
+
# no-speech probability is collected separately
|
640 |
+
suppress_tokens.append(self.tokenizer.no_speech)
|
641 |
+
|
642 |
+
return tuple(sorted(set(suppress_tokens)))
|
643 |
+
|
644 |
+
def _get_audio_features(self, mel: Tensor):
|
645 |
+
if self.options.fp16:
|
646 |
+
mel = mel.half()
|
647 |
+
|
648 |
+
if mel.shape[-2:] == (
|
649 |
+
self.model.dims.n_audio_ctx,
|
650 |
+
self.model.dims.n_audio_state,
|
651 |
+
):
|
652 |
+
# encoded audio features are given; skip audio encoding
|
653 |
+
audio_features = mel
|
654 |
+
else:
|
655 |
+
audio_features = self.model.encoder(mel)
|
656 |
+
|
657 |
+
if audio_features.dtype != (
|
658 |
+
torch.float16 if self.options.fp16 else torch.float32
|
659 |
+
):
|
660 |
+
return TypeError(
|
661 |
+
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
662 |
+
)
|
663 |
+
|
664 |
+
return audio_features
|
665 |
+
|
666 |
+
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
667 |
+
languages = [self.options.language] * audio_features.shape[0]
|
668 |
+
lang_probs = None
|
669 |
+
|
670 |
+
if self.options.language is None or self.options.task == "lang_id":
|
671 |
+
lang_tokens, lang_probs = self.model.detect_language(
|
672 |
+
audio_features, self.tokenizer
|
673 |
+
)
|
674 |
+
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
675 |
+
if self.options.language is None:
|
676 |
+
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
677 |
+
|
678 |
+
return languages, lang_probs
|
679 |
+
|
680 |
+
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
681 |
+
n_batch = tokens.shape[0]
|
682 |
+
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
683 |
+
no_speech_probs = [np.nan] * n_batch
|
684 |
+
|
685 |
+
try:
|
686 |
+
for i in range(self.sample_len):
|
687 |
+
|
688 |
+
logits = self.inference.logits(tokens, audio_features)
|
689 |
+
|
690 |
+
if (
|
691 |
+
i == 0 and self.tokenizer.no_speech is not None
|
692 |
+
): # save no_speech_probs
|
693 |
+
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
694 |
+
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
695 |
+
|
696 |
+
# now we need to consider the logits at the last token only
|
697 |
+
logits = logits[:, -1]
|
698 |
+
|
699 |
+
# apply the logit filters, e.g. for suppressing or applying penalty to
|
700 |
+
for logit_filter in self.logit_filters:
|
701 |
+
logit_filter.apply(logits, tokens)
|
702 |
+
|
703 |
+
# expand the tokens tensor with the selected next tokens
|
704 |
+
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
705 |
+
|
706 |
+
if completed or tokens.shape[-1] > self.n_ctx:
|
707 |
+
break
|
708 |
+
finally:
|
709 |
+
self.inference.cleanup_caching()
|
710 |
+
|
711 |
+
return tokens, sum_logprobs, no_speech_probs
|
712 |
+
|
713 |
+
@torch.no_grad()
|
714 |
+
def run(self, mel: Tensor) -> List[DecodingResult]:
|
715 |
+
self.decoder.reset()
|
716 |
+
tokenizer: Tokenizer = self.tokenizer
|
717 |
+
n_audio: int = mel.shape[0]
|
718 |
+
|
719 |
+
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
720 |
+
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
721 |
+
|
722 |
+
# detect language if requested, overwriting the language token
|
723 |
+
languages, language_probs = self._detect_language(audio_features, tokens)
|
724 |
+
if self.options.task == "lang_id":
|
725 |
+
return [
|
726 |
+
DecodingResult(
|
727 |
+
audio_features=features, language=language, language_probs=probs
|
728 |
+
)
|
729 |
+
for features, language, probs in zip(
|
730 |
+
audio_features, languages, language_probs
|
731 |
+
)
|
732 |
+
]
|
733 |
+
|
734 |
+
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
735 |
+
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
736 |
+
|
737 |
+
# call the main sampling loop
|
738 |
+
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
739 |
+
|
740 |
+
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
741 |
+
audio_features = audio_features[:: self.n_group]
|
742 |
+
no_speech_probs = no_speech_probs[:: self.n_group]
|
743 |
+
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
744 |
+
|
745 |
+
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
746 |
+
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
747 |
+
|
748 |
+
# get the final candidates for each group, and slice between the first sampled token and EOT
|
749 |
+
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
750 |
+
tokens: List[List[Tensor]] = [
|
751 |
+
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
752 |
+
for s in tokens
|
753 |
+
]
|
754 |
+
|
755 |
+
# select the top-ranked sample in each group
|
756 |
+
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
757 |
+
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
758 |
+
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
759 |
+
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
760 |
+
avg_logprobs: List[float] = [
|
761 |
+
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
762 |
+
]
|
763 |
+
|
764 |
+
fields = (
|
765 |
+
texts,
|
766 |
+
languages,
|
767 |
+
tokens,
|
768 |
+
audio_features,
|
769 |
+
avg_logprobs,
|
770 |
+
no_speech_probs,
|
771 |
+
)
|
772 |
+
if len(set(map(len, fields))) != 1:
|
773 |
+
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
774 |
+
|
775 |
+
return [
|
776 |
+
DecodingResult(
|
777 |
+
audio_features=features,
|
778 |
+
language=language,
|
779 |
+
tokens=tokens,
|
780 |
+
text=text,
|
781 |
+
avg_logprob=avg_logprob,
|
782 |
+
no_speech_prob=no_speech_prob,
|
783 |
+
temperature=self.options.temperature,
|
784 |
+
compression_ratio=compression_ratio(text),
|
785 |
+
)
|
786 |
+
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
787 |
+
*fields
|
788 |
+
)
|
789 |
+
]
|
790 |
+
|
791 |
+
|
792 |
+
@torch.no_grad()
|
793 |
+
def decode(
|
794 |
+
model: "Whisper",
|
795 |
+
mel: Tensor,
|
796 |
+
options: DecodingOptions = DecodingOptions(),
|
797 |
+
**kwargs,
|
798 |
+
) -> Union[DecodingResult, List[DecodingResult]]:
|
799 |
+
"""
|
800 |
+
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
801 |
+
|
802 |
+
Parameters
|
803 |
+
----------
|
804 |
+
model: Whisper
|
805 |
+
the Whisper model instance
|
806 |
+
|
807 |
+
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
808 |
+
A tensor containing the Mel spectrogram(s)
|
809 |
+
|
810 |
+
options: DecodingOptions
|
811 |
+
A dataclass that contains all necessary options for decoding 30-second segments
|
812 |
+
|
813 |
+
Returns
|
814 |
+
-------
|
815 |
+
result: Union[DecodingResult, List[DecodingResult]]
|
816 |
+
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
817 |
+
"""
|
818 |
+
if single := mel.ndim == 2:
|
819 |
+
mel = mel.unsqueeze(0)
|
820 |
+
|
821 |
+
if kwargs:
|
822 |
+
options = replace(options, **kwargs)
|
823 |
+
|
824 |
+
result = DecodingTask(model, options).run(mel)
|
825 |
+
|
826 |
+
return result[0] if single else result
|
whisper_stream/model.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import gzip
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Dict, Iterable, Optional, Tuple
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from torch import Tensor, nn
|
10 |
+
|
11 |
+
from .decoding import decode as decode_function
|
12 |
+
from .decoding import detect_language as detect_language_function
|
13 |
+
from .transcribe import transcribe as transcribe_function
|
14 |
+
|
15 |
+
|
16 |
+
@dataclass
|
17 |
+
class ModelDimensions:
|
18 |
+
n_mels: int
|
19 |
+
n_audio_ctx: int
|
20 |
+
n_audio_state: int
|
21 |
+
n_audio_head: int
|
22 |
+
n_audio_layer: int
|
23 |
+
n_vocab: int
|
24 |
+
n_text_ctx: int
|
25 |
+
n_text_state: int
|
26 |
+
n_text_head: int
|
27 |
+
n_text_layer: int
|
28 |
+
|
29 |
+
|
30 |
+
class LayerNorm(nn.LayerNorm):
|
31 |
+
def forward(self, x: Tensor) -> Tensor:
|
32 |
+
return super().forward(x.float()).type(x.dtype)
|
33 |
+
|
34 |
+
|
35 |
+
class Linear(nn.Linear):
|
36 |
+
def forward(self, x: Tensor) -> Tensor:
|
37 |
+
return F.linear(
|
38 |
+
x,
|
39 |
+
self.weight.to(x.dtype),
|
40 |
+
None if self.bias is None else self.bias.to(x.dtype),
|
41 |
+
)
|
42 |
+
|
43 |
+
|
44 |
+
class Conv1d(nn.Conv1d):
|
45 |
+
def _conv_forward(
|
46 |
+
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
47 |
+
) -> Tensor:
|
48 |
+
return super()._conv_forward(
|
49 |
+
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
def sinusoids(length, channels, max_timescale=10000):
|
54 |
+
"""Returns sinusoids for positional embedding"""
|
55 |
+
assert channels % 2 == 0
|
56 |
+
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
57 |
+
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
58 |
+
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
59 |
+
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
60 |
+
|
61 |
+
|
62 |
+
class MultiHeadAttention(nn.Module):
|
63 |
+
def __init__(self, n_state: int, n_head: int):
|
64 |
+
super().__init__()
|
65 |
+
self.n_head = n_head
|
66 |
+
self.query = Linear(n_state, n_state)
|
67 |
+
self.key = Linear(n_state, n_state, bias=False)
|
68 |
+
self.value = Linear(n_state, n_state)
|
69 |
+
self.out = Linear(n_state, n_state)
|
70 |
+
|
71 |
+
def forward(
|
72 |
+
self,
|
73 |
+
x: Tensor,
|
74 |
+
xa: Optional[Tensor] = None,
|
75 |
+
mask: Optional[Tensor] = None,
|
76 |
+
kv_cache: Optional[dict[any, Tensor]] = None,
|
77 |
+
):
|
78 |
+
q = self.query(x)
|
79 |
+
|
80 |
+
if kv_cache is None or xa is None or self.key not in kv_cache:
|
81 |
+
# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
|
82 |
+
# otherwise, perform key/value projections for self- or cross-attention as usual.
|
83 |
+
k = self.key(x if xa is None else xa)
|
84 |
+
v = self.value(x if xa is None else xa)
|
85 |
+
else:
|
86 |
+
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
87 |
+
k = kv_cache[self.key]
|
88 |
+
v = kv_cache[self.value]
|
89 |
+
|
90 |
+
wv, qk = self.qkv_attention(q, k, v, mask)
|
91 |
+
|
92 |
+
return self.out(wv), qk
|
93 |
+
|
94 |
+
def qkv_attention(
|
95 |
+
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
96 |
+
):
|
97 |
+
# print(f"q shape: {q.shape}")
|
98 |
+
n_batch, n_ctx, n_state = q.shape
|
99 |
+
scale = (n_state // self.n_head) ** -0.25
|
100 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
101 |
+
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
102 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
103 |
+
|
104 |
+
qk = q @ k
|
105 |
+
if mask is not None:
|
106 |
+
qk = qk + mask[:n_ctx, :n_ctx]
|
107 |
+
qk = qk.float()
|
108 |
+
|
109 |
+
w = F.softmax(qk, dim=-1).to(q.dtype)
|
110 |
+
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
111 |
+
|
112 |
+
|
113 |
+
class ResidualAttentionBlock(nn.Module):
|
114 |
+
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
115 |
+
super().__init__()
|
116 |
+
|
117 |
+
self.attn = MultiHeadAttention(n_state, n_head)
|
118 |
+
self.attn_ln = LayerNorm(n_state)
|
119 |
+
|
120 |
+
self.cross_attn = (
|
121 |
+
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
122 |
+
)
|
123 |
+
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
124 |
+
|
125 |
+
n_mlp = n_state * 4
|
126 |
+
self.mlp = nn.Sequential(
|
127 |
+
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
128 |
+
)
|
129 |
+
self.mlp_ln = LayerNorm(n_state)
|
130 |
+
|
131 |
+
def forward(
|
132 |
+
self,
|
133 |
+
x: Tensor,
|
134 |
+
xa: Optional[Tensor] = None,
|
135 |
+
mask: Optional[Tensor] = None,
|
136 |
+
kv_cache: Optional[dict] = None,
|
137 |
+
):
|
138 |
+
# SA
|
139 |
+
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
|
140 |
+
|
141 |
+
# CA
|
142 |
+
if self.cross_attn:
|
143 |
+
x = x + self.cross_attn(self.cross_attn_ln(x), xa)[0]
|
144 |
+
|
145 |
+
# MLP
|
146 |
+
x = x + self.mlp(self.mlp_ln(x))
|
147 |
+
|
148 |
+
return x
|
149 |
+
|
150 |
+
|
151 |
+
class AudioEncoder(nn.Module):
|
152 |
+
def __init__(
|
153 |
+
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
157 |
+
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
158 |
+
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
159 |
+
|
160 |
+
self.n_head = n_head
|
161 |
+
self.n_layer = n_layer
|
162 |
+
self.n_state = n_state
|
163 |
+
|
164 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
165 |
+
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
166 |
+
)
|
167 |
+
self.ln_post = LayerNorm(n_state)
|
168 |
+
|
169 |
+
def forward(self, x: Tensor):
|
170 |
+
"""
|
171 |
+
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
172 |
+
the mel spectrogram of the audio
|
173 |
+
"""
|
174 |
+
x = F.gelu(self.conv1(x))
|
175 |
+
x = F.gelu(self.conv2(x))
|
176 |
+
x = x.permute(0, 2, 1)
|
177 |
+
|
178 |
+
assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
179 |
+
x = (x + self.positional_embedding).to(x.dtype)
|
180 |
+
|
181 |
+
for block in self.blocks:
|
182 |
+
x = block(x)
|
183 |
+
|
184 |
+
x = self.ln_post(x)
|
185 |
+
return x
|
186 |
+
|
187 |
+
|
188 |
+
class TextDecoder(nn.Module):
|
189 |
+
def __init__(
|
190 |
+
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int
|
191 |
+
):
|
192 |
+
super().__init__()
|
193 |
+
|
194 |
+
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
195 |
+
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
196 |
+
|
197 |
+
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
198 |
+
[
|
199 |
+
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
200 |
+
for _ in range(n_layer)
|
201 |
+
]
|
202 |
+
)
|
203 |
+
self.ln = LayerNorm(n_state)
|
204 |
+
|
205 |
+
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
206 |
+
self.register_buffer("mask", mask, persistent=False)
|
207 |
+
|
208 |
+
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
209 |
+
"""
|
210 |
+
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
211 |
+
the text tokens
|
212 |
+
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
213 |
+
the encoded audio features to be attended on
|
214 |
+
dump_type: str - specifies which dump to return (MLP, pre_MLP, ATT)
|
215 |
+
"""
|
216 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
217 |
+
x = (
|
218 |
+
self.token_embedding(x)
|
219 |
+
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
220 |
+
)
|
221 |
+
x = x.to(xa.dtype)
|
222 |
+
|
223 |
+
for block in self.blocks:
|
224 |
+
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
225 |
+
|
226 |
+
x = self.ln(x)
|
227 |
+
logits = (
|
228 |
+
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
229 |
+
).float()
|
230 |
+
|
231 |
+
return logits
|
232 |
+
|
233 |
+
|
234 |
+
class Whisper(nn.Module):
|
235 |
+
def __init__(self, dims: ModelDimensions):
|
236 |
+
super().__init__()
|
237 |
+
self.dims = dims
|
238 |
+
self.encoder = AudioEncoder(
|
239 |
+
self.dims.n_mels,
|
240 |
+
self.dims.n_audio_ctx,
|
241 |
+
self.dims.n_audio_state,
|
242 |
+
self.dims.n_audio_head,
|
243 |
+
self.dims.n_audio_layer,
|
244 |
+
)
|
245 |
+
self.decoder = TextDecoder(
|
246 |
+
self.dims.n_vocab,
|
247 |
+
self.dims.n_text_ctx,
|
248 |
+
self.dims.n_text_state,
|
249 |
+
self.dims.n_text_head,
|
250 |
+
self.dims.n_text_layer,
|
251 |
+
)
|
252 |
+
# use the last half among the decoder layers for time alignment by default;
|
253 |
+
# to use a specific set of heads, see `set_alignment_heads()` below.
|
254 |
+
all_heads = torch.zeros(
|
255 |
+
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
256 |
+
)
|
257 |
+
all_heads[self.dims.n_text_layer // 2 :] = True
|
258 |
+
# self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
|
259 |
+
self.register_buffer("alignment_heads", all_heads, persistent=False) # To use lightning can't use sparse weights
|
260 |
+
|
261 |
+
def set_alignment_heads(self, dump: bytes):
|
262 |
+
array = np.frombuffer(
|
263 |
+
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
264 |
+
).copy()
|
265 |
+
mask = torch.from_numpy(array).reshape(
|
266 |
+
self.dims.n_text_layer, self.dims.n_text_head
|
267 |
+
)
|
268 |
+
# self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
|
269 |
+
self.register_buffer("alignment_heads", mask, persistent=False) # To use lightning can't use sparse weights
|
270 |
+
|
271 |
+
def embed_audio(self, mel: torch.Tensor):
|
272 |
+
return self.encoder(mel)
|
273 |
+
|
274 |
+
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
275 |
+
return self.decoder(tokens, audio_features)
|
276 |
+
|
277 |
+
def forward(
|
278 |
+
self, mel: torch.Tensor, tokens: torch.Tensor
|
279 |
+
) -> Dict[str, torch.Tensor]:
|
280 |
+
return self.decoder(tokens, self.encoder(mel))
|
281 |
+
|
282 |
+
@property
|
283 |
+
def device(self):
|
284 |
+
return next(self.parameters()).device
|
285 |
+
|
286 |
+
@property
|
287 |
+
def is_multilingual(self):
|
288 |
+
return self.dims.n_vocab >= 51865
|
289 |
+
|
290 |
+
@property
|
291 |
+
def num_languages(self):
|
292 |
+
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
293 |
+
|
294 |
+
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
295 |
+
"""
|
296 |
+
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
297 |
+
tensors calculated for the previous positions. This method returns a dictionary that stores
|
298 |
+
all caches, and the necessary hooks for the key and value projection modules that save the
|
299 |
+
intermediate tensors to be reused during later calculations.
|
300 |
+
|
301 |
+
Returns
|
302 |
+
-------
|
303 |
+
cache : Dict[nn.Module, torch.Tensor]
|
304 |
+
A dictionary object mapping the key/value projection modules to its cache
|
305 |
+
hooks : List[RemovableHandle]
|
306 |
+
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
307 |
+
"""
|
308 |
+
cache = {**cache} if cache is not None else {}
|
309 |
+
hooks = []
|
310 |
+
|
311 |
+
def save_to_cache(module, _, output):
|
312 |
+
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
313 |
+
# save as-is, for the first token or cross attention
|
314 |
+
cache[module] = output
|
315 |
+
else:
|
316 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
317 |
+
return cache[module]
|
318 |
+
|
319 |
+
def install_hooks(layer: nn.Module):
|
320 |
+
if isinstance(layer, MultiHeadAttention):
|
321 |
+
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
322 |
+
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
323 |
+
|
324 |
+
self.decoder.apply(install_hooks)
|
325 |
+
return cache, hooks
|
326 |
+
|
327 |
+
detect_language = detect_language_function
|
328 |
+
transcribe = transcribe_function
|
329 |
+
decode = decode_function
|
whisper_stream/normalizers/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .basic import BasicTextNormalizer as BasicTextNormalizer
|
2 |
+
from .english import EnglishTextNormalizer as EnglishTextNormalizer
|
whisper_stream/normalizers/basic.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import unicodedata
|
3 |
+
|
4 |
+
import regex
|
5 |
+
|
6 |
+
# non-ASCII letters that are not separated by "NFKD" normalization
|
7 |
+
ADDITIONAL_DIACRITICS = {
|
8 |
+
"œ": "oe",
|
9 |
+
"Œ": "OE",
|
10 |
+
"ø": "o",
|
11 |
+
"Ø": "O",
|
12 |
+
"æ": "ae",
|
13 |
+
"Æ": "AE",
|
14 |
+
"ß": "ss",
|
15 |
+
"ẞ": "SS",
|
16 |
+
"đ": "d",
|
17 |
+
"Đ": "D",
|
18 |
+
"ð": "d",
|
19 |
+
"Ð": "D",
|
20 |
+
"þ": "th",
|
21 |
+
"Þ": "th",
|
22 |
+
"ł": "l",
|
23 |
+
"Ł": "L",
|
24 |
+
}
|
25 |
+
|
26 |
+
|
27 |
+
def remove_symbols_and_diacritics(s: str, keep=""):
|
28 |
+
"""
|
29 |
+
Replace any other markers, symbols, and punctuations with a space,
|
30 |
+
and drop any diacritics (category 'Mn' and some manual mappings)
|
31 |
+
"""
|
32 |
+
return "".join(
|
33 |
+
c
|
34 |
+
if c in keep
|
35 |
+
else ADDITIONAL_DIACRITICS[c]
|
36 |
+
if c in ADDITIONAL_DIACRITICS
|
37 |
+
else ""
|
38 |
+
if unicodedata.category(c) == "Mn"
|
39 |
+
else " "
|
40 |
+
if unicodedata.category(c)[0] in "MSP"
|
41 |
+
else c
|
42 |
+
for c in unicodedata.normalize("NFKD", s)
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
def remove_symbols(s: str):
|
47 |
+
"""
|
48 |
+
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
49 |
+
"""
|
50 |
+
return "".join(
|
51 |
+
" " if unicodedata.category(c)[0] in "MSP" else c
|
52 |
+
for c in unicodedata.normalize("NFKC", s)
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
class BasicTextNormalizer:
|
57 |
+
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
58 |
+
self.clean = (
|
59 |
+
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
60 |
+
)
|
61 |
+
self.split_letters = split_letters
|
62 |
+
|
63 |
+
def __call__(self, s: str):
|
64 |
+
s = s.lower()
|
65 |
+
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
66 |
+
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
67 |
+
s = self.clean(s).lower()
|
68 |
+
|
69 |
+
if self.split_letters:
|
70 |
+
s = " ".join(regex.findall(r"\X", s, regex.U))
|
71 |
+
|
72 |
+
s = re.sub(
|
73 |
+
r"\s+", " ", s
|
74 |
+
) # replace any successive whitespace characters with a space
|
75 |
+
|
76 |
+
return s
|
whisper_stream/normalizers/english.json
ADDED
@@ -0,0 +1,1741 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"accessorise": "accessorize",
|
3 |
+
"accessorised": "accessorized",
|
4 |
+
"accessorises": "accessorizes",
|
5 |
+
"accessorising": "accessorizing",
|
6 |
+
"acclimatisation": "acclimatization",
|
7 |
+
"acclimatise": "acclimatize",
|
8 |
+
"acclimatised": "acclimatized",
|
9 |
+
"acclimatises": "acclimatizes",
|
10 |
+
"acclimatising": "acclimatizing",
|
11 |
+
"accoutrements": "accouterments",
|
12 |
+
"aeon": "eon",
|
13 |
+
"aeons": "eons",
|
14 |
+
"aerogramme": "aerogram",
|
15 |
+
"aerogrammes": "aerograms",
|
16 |
+
"aeroplane": "airplane",
|
17 |
+
"aeroplanes": "airplanes",
|
18 |
+
"aesthete": "esthete",
|
19 |
+
"aesthetes": "esthetes",
|
20 |
+
"aesthetic": "esthetic",
|
21 |
+
"aesthetically": "esthetically",
|
22 |
+
"aesthetics": "esthetics",
|
23 |
+
"aetiology": "etiology",
|
24 |
+
"ageing": "aging",
|
25 |
+
"aggrandisement": "aggrandizement",
|
26 |
+
"agonise": "agonize",
|
27 |
+
"agonised": "agonized",
|
28 |
+
"agonises": "agonizes",
|
29 |
+
"agonising": "agonizing",
|
30 |
+
"agonisingly": "agonizingly",
|
31 |
+
"almanack": "almanac",
|
32 |
+
"almanacks": "almanacs",
|
33 |
+
"aluminium": "aluminum",
|
34 |
+
"amortisable": "amortizable",
|
35 |
+
"amortisation": "amortization",
|
36 |
+
"amortisations": "amortizations",
|
37 |
+
"amortise": "amortize",
|
38 |
+
"amortised": "amortized",
|
39 |
+
"amortises": "amortizes",
|
40 |
+
"amortising": "amortizing",
|
41 |
+
"amphitheatre": "amphitheater",
|
42 |
+
"amphitheatres": "amphitheaters",
|
43 |
+
"anaemia": "anemia",
|
44 |
+
"anaemic": "anemic",
|
45 |
+
"anaesthesia": "anesthesia",
|
46 |
+
"anaesthetic": "anesthetic",
|
47 |
+
"anaesthetics": "anesthetics",
|
48 |
+
"anaesthetise": "anesthetize",
|
49 |
+
"anaesthetised": "anesthetized",
|
50 |
+
"anaesthetises": "anesthetizes",
|
51 |
+
"anaesthetising": "anesthetizing",
|
52 |
+
"anaesthetist": "anesthetist",
|
53 |
+
"anaesthetists": "anesthetists",
|
54 |
+
"anaesthetize": "anesthetize",
|
55 |
+
"anaesthetized": "anesthetized",
|
56 |
+
"anaesthetizes": "anesthetizes",
|
57 |
+
"anaesthetizing": "anesthetizing",
|
58 |
+
"analogue": "analog",
|
59 |
+
"analogues": "analogs",
|
60 |
+
"analyse": "analyze",
|
61 |
+
"analysed": "analyzed",
|
62 |
+
"analyses": "analyzes",
|
63 |
+
"analysing": "analyzing",
|
64 |
+
"anglicise": "anglicize",
|
65 |
+
"anglicised": "anglicized",
|
66 |
+
"anglicises": "anglicizes",
|
67 |
+
"anglicising": "anglicizing",
|
68 |
+
"annualised": "annualized",
|
69 |
+
"antagonise": "antagonize",
|
70 |
+
"antagonised": "antagonized",
|
71 |
+
"antagonises": "antagonizes",
|
72 |
+
"antagonising": "antagonizing",
|
73 |
+
"apologise": "apologize",
|
74 |
+
"apologised": "apologized",
|
75 |
+
"apologises": "apologizes",
|
76 |
+
"apologising": "apologizing",
|
77 |
+
"appal": "appall",
|
78 |
+
"appals": "appalls",
|
79 |
+
"appetiser": "appetizer",
|
80 |
+
"appetisers": "appetizers",
|
81 |
+
"appetising": "appetizing",
|
82 |
+
"appetisingly": "appetizingly",
|
83 |
+
"arbour": "arbor",
|
84 |
+
"arbours": "arbors",
|
85 |
+
"archeological": "archaeological",
|
86 |
+
"archaeologically": "archeologically",
|
87 |
+
"archaeologist": "archeologist",
|
88 |
+
"archaeologists": "archeologists",
|
89 |
+
"archaeology": "archeology</span>",
|
90 |
+
"ardour": "ardor",
|
91 |
+
"armour": "armor",
|
92 |
+
"armoured": "armored",
|
93 |
+
"armourer": "armorer",
|
94 |
+
"armourers": "armorers",
|
95 |
+
"armouries": "armories",
|
96 |
+
"armoury": "armory",
|
97 |
+
"artefact": "artifact",
|
98 |
+
"artefacts": "artifacts",
|
99 |
+
"authorise": "authorize",
|
100 |
+
"authorised": "authorized",
|
101 |
+
"authorises": "authorizes",
|
102 |
+
"authorising": "authorizing",
|
103 |
+
"axe": "ax",
|
104 |
+
"backpedalled": "backpedaled",
|
105 |
+
"backpedalling": "backpedaling",
|
106 |
+
"bannister": "banister",
|
107 |
+
"bannisters": "banisters",
|
108 |
+
"baptise": "baptize",
|
109 |
+
"baptised": "baptized",
|
110 |
+
"baptises": "baptizes",
|
111 |
+
"baptising": "baptizing",
|
112 |
+
"bastardise": "bastardize",
|
113 |
+
"bastardised": "bastardized",
|
114 |
+
"bastardises": "bastardizes",
|
115 |
+
"bastardising": "bastardizing",
|
116 |
+
"battleax": "battleaxe",
|
117 |
+
"baulk": "balk",
|
118 |
+
"baulked": "balked",
|
119 |
+
"baulking": "balking",
|
120 |
+
"baulks": "balks",
|
121 |
+
"bedevilled": "bedeviled",
|
122 |
+
"bedevilling": "bedeviling",
|
123 |
+
"behaviour": "behavior",
|
124 |
+
"behavioural": "behavioral",
|
125 |
+
"behaviourism": "behaviorism",
|
126 |
+
"behaviourist": "behaviorist",
|
127 |
+
"behaviourists": "behaviorists",
|
128 |
+
"behaviours": "behaviors",
|
129 |
+
"behove": "behoove",
|
130 |
+
"behoved": "behooved",
|
131 |
+
"behoves": "behooves",
|
132 |
+
"bejewelled": "bejeweled",
|
133 |
+
"belabour": "belabor",
|
134 |
+
"belaboured": "belabored",
|
135 |
+
"belabouring": "belaboring",
|
136 |
+
"belabours": "belabors",
|
137 |
+
"bevelled": "beveled",
|
138 |
+
"bevvies": "bevies",
|
139 |
+
"bevvy": "bevy",
|
140 |
+
"biassed": "biased",
|
141 |
+
"biassing": "biasing",
|
142 |
+
"bingeing": "binging",
|
143 |
+
"bougainvillaea": "bougainvillea",
|
144 |
+
"bougainvillaeas": "bougainvilleas",
|
145 |
+
"bowdlerise": "bowdlerize",
|
146 |
+
"bowdlerised": "bowdlerized",
|
147 |
+
"bowdlerises": "bowdlerizes",
|
148 |
+
"bowdlerising": "bowdlerizing",
|
149 |
+
"breathalyse": "breathalyze",
|
150 |
+
"breathalysed": "breathalyzed",
|
151 |
+
"breathalyser": "breathalyzer",
|
152 |
+
"breathalysers": "breathalyzers",
|
153 |
+
"breathalyses": "breathalyzes",
|
154 |
+
"breathalysing": "breathalyzing",
|
155 |
+
"brutalise": "brutalize",
|
156 |
+
"brutalised": "brutalized",
|
157 |
+
"brutalises": "brutalizes",
|
158 |
+
"brutalising": "brutalizing",
|
159 |
+
"busses": "buses",
|
160 |
+
"bussing": "busing",
|
161 |
+
"caesarean": "cesarean",
|
162 |
+
"caesareans": "cesareans",
|
163 |
+
"calibre": "caliber",
|
164 |
+
"calibres": "calibers",
|
165 |
+
"calliper": "caliper",
|
166 |
+
"callipers": "calipers",
|
167 |
+
"callisthenics": "calisthenics",
|
168 |
+
"canalise": "canalize",
|
169 |
+
"canalised": "canalized",
|
170 |
+
"canalises": "canalizes",
|
171 |
+
"canalising": "canalizing",
|
172 |
+
"cancelation": "cancellation",
|
173 |
+
"cancelations": "cancellations",
|
174 |
+
"cancelled": "canceled",
|
175 |
+
"cancelling": "canceling",
|
176 |
+
"candour": "candor",
|
177 |
+
"cannibalise": "cannibalize",
|
178 |
+
"cannibalised": "cannibalized",
|
179 |
+
"cannibalises": "cannibalizes",
|
180 |
+
"cannibalising": "cannibalizing",
|
181 |
+
"canonise": "canonize",
|
182 |
+
"canonised": "canonized",
|
183 |
+
"canonises": "canonizes",
|
184 |
+
"canonising": "canonizing",
|
185 |
+
"capitalise": "capitalize",
|
186 |
+
"capitalised": "capitalized",
|
187 |
+
"capitalises": "capitalizes",
|
188 |
+
"capitalising": "capitalizing",
|
189 |
+
"caramelise": "caramelize",
|
190 |
+
"caramelised": "caramelized",
|
191 |
+
"caramelises": "caramelizes",
|
192 |
+
"caramelising": "caramelizing",
|
193 |
+
"carbonise": "carbonize",
|
194 |
+
"carbonised": "carbonized",
|
195 |
+
"carbonises": "carbonizes",
|
196 |
+
"carbonising": "carbonizing",
|
197 |
+
"carolled": "caroled",
|
198 |
+
"carolling": "caroling",
|
199 |
+
"catalogue": "catalog",
|
200 |
+
"catalogued": "cataloged",
|
201 |
+
"catalogues": "catalogs",
|
202 |
+
"cataloguing": "cataloging",
|
203 |
+
"catalyse": "catalyze",
|
204 |
+
"catalysed": "catalyzed",
|
205 |
+
"catalyses": "catalyzes",
|
206 |
+
"catalysing": "catalyzing",
|
207 |
+
"categorise": "categorize",
|
208 |
+
"categorised": "categorized",
|
209 |
+
"categorises": "categorizes",
|
210 |
+
"categorising": "categorizing",
|
211 |
+
"cauterise": "cauterize",
|
212 |
+
"cauterised": "cauterized",
|
213 |
+
"cauterises": "cauterizes",
|
214 |
+
"cauterising": "cauterizing",
|
215 |
+
"cavilled": "caviled",
|
216 |
+
"cavilling": "caviling",
|
217 |
+
"centigramme": "centigram",
|
218 |
+
"centigrammes": "centigrams",
|
219 |
+
"centilitre": "centiliter",
|
220 |
+
"centilitres": "centiliters",
|
221 |
+
"centimetre": "centimeter",
|
222 |
+
"centimetres": "centimeters",
|
223 |
+
"centralise": "centralize",
|
224 |
+
"centralised": "centralized",
|
225 |
+
"centralises": "centralizes",
|
226 |
+
"centralising": "centralizing",
|
227 |
+
"centre": "center",
|
228 |
+
"centred": "centered",
|
229 |
+
"centrefold": "centerfold",
|
230 |
+
"centrefolds": "centerfolds",
|
231 |
+
"centrepiece": "centerpiece",
|
232 |
+
"centrepieces": "centerpieces",
|
233 |
+
"centres": "centers",
|
234 |
+
"channelled": "channeled",
|
235 |
+
"channelling": "channeling",
|
236 |
+
"characterise": "characterize",
|
237 |
+
"characterised": "characterized",
|
238 |
+
"characterises": "characterizes",
|
239 |
+
"characterising": "characterizing",
|
240 |
+
"cheque": "check",
|
241 |
+
"chequebook": "checkbook",
|
242 |
+
"chequebooks": "checkbooks",
|
243 |
+
"chequered": "checkered",
|
244 |
+
"cheques": "checks",
|
245 |
+
"chilli": "chili",
|
246 |
+
"chimaera": "chimera",
|
247 |
+
"chimaeras": "chimeras",
|
248 |
+
"chiselled": "chiseled",
|
249 |
+
"chiselling": "chiseling",
|
250 |
+
"circularise": "circularize",
|
251 |
+
"circularised": "circularized",
|
252 |
+
"circularises": "circularizes",
|
253 |
+
"circularising": "circularizing",
|
254 |
+
"civilise": "civilize",
|
255 |
+
"civilised": "civilized",
|
256 |
+
"civilises": "civilizes",
|
257 |
+
"civilising": "civilizing",
|
258 |
+
"clamour": "clamor",
|
259 |
+
"clamoured": "clamored",
|
260 |
+
"clamouring": "clamoring",
|
261 |
+
"clamours": "clamors",
|
262 |
+
"clangour": "clangor",
|
263 |
+
"clarinettist": "clarinetist",
|
264 |
+
"clarinettists": "clarinetists",
|
265 |
+
"collectivise": "collectivize",
|
266 |
+
"collectivised": "collectivized",
|
267 |
+
"collectivises": "collectivizes",
|
268 |
+
"collectivising": "collectivizing",
|
269 |
+
"colonisation": "colonization",
|
270 |
+
"colonise": "colonize",
|
271 |
+
"colonised": "colonized",
|
272 |
+
"coloniser": "colonizer",
|
273 |
+
"colonisers": "colonizers",
|
274 |
+
"colonises": "colonizes",
|
275 |
+
"colonising": "colonizing",
|
276 |
+
"colour": "color",
|
277 |
+
"colourant": "colorant",
|
278 |
+
"colourants": "colorants",
|
279 |
+
"coloured": "colored",
|
280 |
+
"coloureds": "coloreds",
|
281 |
+
"colourful": "colorful",
|
282 |
+
"colourfully": "colorfully",
|
283 |
+
"colouring": "coloring",
|
284 |
+
"colourize": "colorize",
|
285 |
+
"colourized": "colorized",
|
286 |
+
"colourizes": "colorizes",
|
287 |
+
"colourizing": "colorizing",
|
288 |
+
"colourless": "colorless",
|
289 |
+
"colours": "colors",
|
290 |
+
"commercialise": "commercialize",
|
291 |
+
"commercialised": "commercialized",
|
292 |
+
"commercialises": "commercializes",
|
293 |
+
"commercialising": "commercializing",
|
294 |
+
"compartmentalise": "compartmentalize",
|
295 |
+
"compartmentalised": "compartmentalized",
|
296 |
+
"compartmentalises": "compartmentalizes",
|
297 |
+
"compartmentalising": "compartmentalizing",
|
298 |
+
"computerise": "computerize",
|
299 |
+
"computerised": "computerized",
|
300 |
+
"computerises": "computerizes",
|
301 |
+
"computerising": "computerizing",
|
302 |
+
"conceptualise": "conceptualize",
|
303 |
+
"conceptualised": "conceptualized",
|
304 |
+
"conceptualises": "conceptualizes",
|
305 |
+
"conceptualising": "conceptualizing",
|
306 |
+
"connexion": "connection",
|
307 |
+
"connexions": "connections",
|
308 |
+
"contextualise": "contextualize",
|
309 |
+
"contextualised": "contextualized",
|
310 |
+
"contextualises": "contextualizes",
|
311 |
+
"contextualising": "contextualizing",
|
312 |
+
"cosier": "cozier",
|
313 |
+
"cosies": "cozies",
|
314 |
+
"cosiest": "coziest",
|
315 |
+
"cosily": "cozily",
|
316 |
+
"cosiness": "coziness",
|
317 |
+
"cosy": "cozy",
|
318 |
+
"councillor": "councilor",
|
319 |
+
"councillors": "councilors",
|
320 |
+
"counselled": "counseled",
|
321 |
+
"counselling": "counseling",
|
322 |
+
"counsellor": "counselor",
|
323 |
+
"counsellors": "counselors",
|
324 |
+
"crenelated": "crenellated",
|
325 |
+
"criminalise": "criminalize",
|
326 |
+
"criminalised": "criminalized",
|
327 |
+
"criminalises": "criminalizes",
|
328 |
+
"criminalising": "criminalizing",
|
329 |
+
"criticise": "criticize",
|
330 |
+
"criticised": "criticized",
|
331 |
+
"criticises": "criticizes",
|
332 |
+
"criticising": "criticizing",
|
333 |
+
"crueller": "crueler",
|
334 |
+
"cruellest": "cruelest",
|
335 |
+
"crystallisation": "crystallization",
|
336 |
+
"crystallise": "crystallize",
|
337 |
+
"crystallised": "crystallized",
|
338 |
+
"crystallises": "crystallizes",
|
339 |
+
"crystallising": "crystallizing",
|
340 |
+
"cudgelled": "cudgeled",
|
341 |
+
"cudgelling": "cudgeling",
|
342 |
+
"customise": "customize",
|
343 |
+
"customised": "customized",
|
344 |
+
"customises": "customizes",
|
345 |
+
"customising": "customizing",
|
346 |
+
"cypher": "cipher",
|
347 |
+
"cyphers": "ciphers",
|
348 |
+
"decentralisation": "decentralization",
|
349 |
+
"decentralise": "decentralize",
|
350 |
+
"decentralised": "decentralized",
|
351 |
+
"decentralises": "decentralizes",
|
352 |
+
"decentralising": "decentralizing",
|
353 |
+
"decriminalisation": "decriminalization",
|
354 |
+
"decriminalise": "decriminalize",
|
355 |
+
"decriminalised": "decriminalized",
|
356 |
+
"decriminalises": "decriminalizes",
|
357 |
+
"decriminalising": "decriminalizing",
|
358 |
+
"defence": "defense",
|
359 |
+
"defenceless": "defenseless",
|
360 |
+
"defences": "defenses",
|
361 |
+
"dehumanisation": "dehumanization",
|
362 |
+
"dehumanise": "dehumanize",
|
363 |
+
"dehumanised": "dehumanized",
|
364 |
+
"dehumanises": "dehumanizes",
|
365 |
+
"dehumanising": "dehumanizing",
|
366 |
+
"demeanour": "demeanor",
|
367 |
+
"demilitarisation": "demilitarization",
|
368 |
+
"demilitarise": "demilitarize",
|
369 |
+
"demilitarised": "demilitarized",
|
370 |
+
"demilitarises": "demilitarizes",
|
371 |
+
"demilitarising": "demilitarizing",
|
372 |
+
"demobilisation": "demobilization",
|
373 |
+
"demobilise": "demobilize",
|
374 |
+
"demobilised": "demobilized",
|
375 |
+
"demobilises": "demobilizes",
|
376 |
+
"demobilising": "demobilizing",
|
377 |
+
"democratisation": "democratization",
|
378 |
+
"democratise": "democratize",
|
379 |
+
"democratised": "democratized",
|
380 |
+
"democratises": "democratizes",
|
381 |
+
"democratising": "democratizing",
|
382 |
+
"demonise": "demonize",
|
383 |
+
"demonised": "demonized",
|
384 |
+
"demonises": "demonizes",
|
385 |
+
"demonising": "demonizing",
|
386 |
+
"demoralisation": "demoralization",
|
387 |
+
"demoralise": "demoralize",
|
388 |
+
"demoralised": "demoralized",
|
389 |
+
"demoralises": "demoralizes",
|
390 |
+
"demoralising": "demoralizing",
|
391 |
+
"denationalisation": "denationalization",
|
392 |
+
"denationalise": "denationalize",
|
393 |
+
"denationalised": "denationalized",
|
394 |
+
"denationalises": "denationalizes",
|
395 |
+
"denationalising": "denationalizing",
|
396 |
+
"deodorise": "deodorize",
|
397 |
+
"deodorised": "deodorized",
|
398 |
+
"deodorises": "deodorizes",
|
399 |
+
"deodorising": "deodorizing",
|
400 |
+
"depersonalise": "depersonalize",
|
401 |
+
"depersonalised": "depersonalized",
|
402 |
+
"depersonalises": "depersonalizes",
|
403 |
+
"depersonalising": "depersonalizing",
|
404 |
+
"deputise": "deputize",
|
405 |
+
"deputised": "deputized",
|
406 |
+
"deputises": "deputizes",
|
407 |
+
"deputising": "deputizing",
|
408 |
+
"desensitisation": "desensitization",
|
409 |
+
"desensitise": "desensitize",
|
410 |
+
"desensitised": "desensitized",
|
411 |
+
"desensitises": "desensitizes",
|
412 |
+
"desensitising": "desensitizing",
|
413 |
+
"destabilisation": "destabilization",
|
414 |
+
"destabilise": "destabilize",
|
415 |
+
"destabilised": "destabilized",
|
416 |
+
"destabilises": "destabilizes",
|
417 |
+
"destabilising": "destabilizing",
|
418 |
+
"dialled": "dialed",
|
419 |
+
"dialling": "dialing",
|
420 |
+
"dialogue": "dialog",
|
421 |
+
"dialogues": "dialogs",
|
422 |
+
"diarrhoea": "diarrhea",
|
423 |
+
"digitise": "digitize",
|
424 |
+
"digitised": "digitized",
|
425 |
+
"digitises": "digitizes",
|
426 |
+
"digitising": "digitizing",
|
427 |
+
"disc": "disk",
|
428 |
+
"discolour": "discolor",
|
429 |
+
"discoloured": "discolored",
|
430 |
+
"discolouring": "discoloring",
|
431 |
+
"discolours": "discolors",
|
432 |
+
"discs": "disks",
|
433 |
+
"disembowelled": "disemboweled",
|
434 |
+
"disembowelling": "disemboweling",
|
435 |
+
"disfavour": "disfavor",
|
436 |
+
"dishevelled": "disheveled",
|
437 |
+
"dishonour": "dishonor",
|
438 |
+
"dishonourable": "dishonorable",
|
439 |
+
"dishonourably": "dishonorably",
|
440 |
+
"dishonoured": "dishonored",
|
441 |
+
"dishonouring": "dishonoring",
|
442 |
+
"dishonours": "dishonors",
|
443 |
+
"disorganisation": "disorganization",
|
444 |
+
"disorganised": "disorganized",
|
445 |
+
"distil": "distill",
|
446 |
+
"distils": "distills",
|
447 |
+
"dramatisation": "dramatization",
|
448 |
+
"dramatisations": "dramatizations",
|
449 |
+
"dramatise": "dramatize",
|
450 |
+
"dramatised": "dramatized",
|
451 |
+
"dramatises": "dramatizes",
|
452 |
+
"dramatising": "dramatizing",
|
453 |
+
"draught": "draft",
|
454 |
+
"draughtboard": "draftboard",
|
455 |
+
"draughtboards": "draftboards",
|
456 |
+
"draughtier": "draftier",
|
457 |
+
"draughtiest": "draftiest",
|
458 |
+
"draughts": "drafts",
|
459 |
+
"draughtsman": "draftsman",
|
460 |
+
"draughtsmanship": "draftsmanship",
|
461 |
+
"draughtsmen": "draftsmen",
|
462 |
+
"draughtswoman": "draftswoman",
|
463 |
+
"draughtswomen": "draftswomen",
|
464 |
+
"draughty": "drafty",
|
465 |
+
"drivelled": "driveled",
|
466 |
+
"drivelling": "driveling",
|
467 |
+
"duelled": "dueled",
|
468 |
+
"duelling": "dueling",
|
469 |
+
"economise": "economize",
|
470 |
+
"economised": "economized",
|
471 |
+
"economises": "economizes",
|
472 |
+
"economising": "economizing",
|
473 |
+
"edoema": "edema",
|
474 |
+
"editorialise": "editorialize",
|
475 |
+
"editorialised": "editorialized",
|
476 |
+
"editorialises": "editorializes",
|
477 |
+
"editorialising": "editorializing",
|
478 |
+
"empathise": "empathize",
|
479 |
+
"empathised": "empathized",
|
480 |
+
"empathises": "empathizes",
|
481 |
+
"empathising": "empathizing",
|
482 |
+
"emphasise": "emphasize",
|
483 |
+
"emphasised": "emphasized",
|
484 |
+
"emphasises": "emphasizes",
|
485 |
+
"emphasising": "emphasizing",
|
486 |
+
"enamelled": "enameled",
|
487 |
+
"enamelling": "enameling",
|
488 |
+
"enamoured": "enamored",
|
489 |
+
"encyclopaedia": "encyclopedia",
|
490 |
+
"encyclopaedias": "encyclopedias",
|
491 |
+
"encyclopaedic": "encyclopedic",
|
492 |
+
"endeavour": "endeavor",
|
493 |
+
"endeavoured": "endeavored",
|
494 |
+
"endeavouring": "endeavoring",
|
495 |
+
"endeavours": "endeavors",
|
496 |
+
"energise": "energize",
|
497 |
+
"energised": "energized",
|
498 |
+
"energises": "energizes",
|
499 |
+
"energising": "energizing",
|
500 |
+
"enrol": "enroll",
|
501 |
+
"enrols": "enrolls",
|
502 |
+
"enthral": "enthrall",
|
503 |
+
"enthrals": "enthralls",
|
504 |
+
"epaulette": "epaulet",
|
505 |
+
"epaulettes": "epaulets",
|
506 |
+
"epicentre": "epicenter",
|
507 |
+
"epicentres": "epicenters",
|
508 |
+
"epilogue": "epilog",
|
509 |
+
"epilogues": "epilogs",
|
510 |
+
"epitomise": "epitomize",
|
511 |
+
"epitomised": "epitomized",
|
512 |
+
"epitomises": "epitomizes",
|
513 |
+
"epitomising": "epitomizing",
|
514 |
+
"equalisation": "equalization",
|
515 |
+
"equalise": "equalize",
|
516 |
+
"equalised": "equalized",
|
517 |
+
"equaliser": "equalizer",
|
518 |
+
"equalisers": "equalizers",
|
519 |
+
"equalises": "equalizes",
|
520 |
+
"equalising": "equalizing",
|
521 |
+
"eulogise": "eulogize",
|
522 |
+
"eulogised": "eulogized",
|
523 |
+
"eulogises": "eulogizes",
|
524 |
+
"eulogising": "eulogizing",
|
525 |
+
"evangelise": "evangelize",
|
526 |
+
"evangelised": "evangelized",
|
527 |
+
"evangelises": "evangelizes",
|
528 |
+
"evangelising": "evangelizing",
|
529 |
+
"exorcise": "exorcize",
|
530 |
+
"exorcised": "exorcized",
|
531 |
+
"exorcises": "exorcizes",
|
532 |
+
"exorcising": "exorcizing",
|
533 |
+
"extemporisation": "extemporization",
|
534 |
+
"extemporise": "extemporize",
|
535 |
+
"extemporised": "extemporized",
|
536 |
+
"extemporises": "extemporizes",
|
537 |
+
"extemporising": "extemporizing",
|
538 |
+
"externalisation": "externalization",
|
539 |
+
"externalisations": "externalizations",
|
540 |
+
"externalise": "externalize",
|
541 |
+
"externalised": "externalized",
|
542 |
+
"externalises": "externalizes",
|
543 |
+
"externalising": "externalizing",
|
544 |
+
"factorise": "factorize",
|
545 |
+
"factorised": "factorized",
|
546 |
+
"factorises": "factorizes",
|
547 |
+
"factorising": "factorizing",
|
548 |
+
"faecal": "fecal",
|
549 |
+
"faeces": "feces",
|
550 |
+
"familiarisation": "familiarization",
|
551 |
+
"familiarise": "familiarize",
|
552 |
+
"familiarised": "familiarized",
|
553 |
+
"familiarises": "familiarizes",
|
554 |
+
"familiarising": "familiarizing",
|
555 |
+
"fantasise": "fantasize",
|
556 |
+
"fantasised": "fantasized",
|
557 |
+
"fantasises": "fantasizes",
|
558 |
+
"fantasising": "fantasizing",
|
559 |
+
"favour": "favor",
|
560 |
+
"favourable": "favorable",
|
561 |
+
"favourably": "favorably",
|
562 |
+
"favoured": "favored",
|
563 |
+
"favouring": "favoring",
|
564 |
+
"favourite": "favorite",
|
565 |
+
"favourites": "favorites",
|
566 |
+
"favouritism": "favoritism",
|
567 |
+
"favours": "favors",
|
568 |
+
"feminise": "feminize",
|
569 |
+
"feminised": "feminized",
|
570 |
+
"feminises": "feminizes",
|
571 |
+
"feminising": "feminizing",
|
572 |
+
"fertilisation": "fertilization",
|
573 |
+
"fertilise": "fertilize",
|
574 |
+
"fertilised": "fertilized",
|
575 |
+
"fertiliser": "fertilizer",
|
576 |
+
"fertilisers": "fertilizers",
|
577 |
+
"fertilises": "fertilizes",
|
578 |
+
"fertilising": "fertilizing",
|
579 |
+
"fervour": "fervor",
|
580 |
+
"fibre": "fiber",
|
581 |
+
"fibreglass": "fiberglass",
|
582 |
+
"fibres": "fibers",
|
583 |
+
"fictionalisation": "fictionalization",
|
584 |
+
"fictionalisations": "fictionalizations",
|
585 |
+
"fictionalise": "fictionalize",
|
586 |
+
"fictionalised": "fictionalized",
|
587 |
+
"fictionalises": "fictionalizes",
|
588 |
+
"fictionalising": "fictionalizing",
|
589 |
+
"fillet": "filet",
|
590 |
+
"filleted": "fileted",
|
591 |
+
"filleting": "fileting",
|
592 |
+
"fillets": "filets",
|
593 |
+
"finalisation": "finalization",
|
594 |
+
"finalise": "finalize",
|
595 |
+
"finalised": "finalized",
|
596 |
+
"finalises": "finalizes",
|
597 |
+
"finalising": "finalizing",
|
598 |
+
"flautist": "flutist",
|
599 |
+
"flautists": "flutists",
|
600 |
+
"flavour": "flavor",
|
601 |
+
"flavoured": "flavored",
|
602 |
+
"flavouring": "flavoring",
|
603 |
+
"flavourings": "flavorings",
|
604 |
+
"flavourless": "flavorless",
|
605 |
+
"flavours": "flavors",
|
606 |
+
"flavoursome": "flavorsome",
|
607 |
+
"flyer / flier": "flier / flyer",
|
608 |
+
"foetal": "fetal",
|
609 |
+
"foetid": "fetid",
|
610 |
+
"foetus": "fetus",
|
611 |
+
"foetuses": "fetuses",
|
612 |
+
"formalisation": "formalization",
|
613 |
+
"formalise": "formalize",
|
614 |
+
"formalised": "formalized",
|
615 |
+
"formalises": "formalizes",
|
616 |
+
"formalising": "formalizing",
|
617 |
+
"fossilisation": "fossilization",
|
618 |
+
"fossilise": "fossilize",
|
619 |
+
"fossilised": "fossilized",
|
620 |
+
"fossilises": "fossilizes",
|
621 |
+
"fossilising": "fossilizing",
|
622 |
+
"fraternisation": "fraternization",
|
623 |
+
"fraternise": "fraternize",
|
624 |
+
"fraternised": "fraternized",
|
625 |
+
"fraternises": "fraternizes",
|
626 |
+
"fraternising": "fraternizing",
|
627 |
+
"fulfil": "fulfill",
|
628 |
+
"fulfilment": "fulfillment",
|
629 |
+
"fulfils": "fulfills",
|
630 |
+
"funnelled": "funneled",
|
631 |
+
"funnelling": "funneling",
|
632 |
+
"galvanise": "galvanize",
|
633 |
+
"galvanised": "galvanized",
|
634 |
+
"galvanises": "galvanizes",
|
635 |
+
"galvanising": "galvanizing",
|
636 |
+
"gambolled": "gamboled",
|
637 |
+
"gambolling": "gamboling",
|
638 |
+
"gaol": "jail",
|
639 |
+
"gaolbird": "jailbird",
|
640 |
+
"gaolbirds": "jailbirds",
|
641 |
+
"gaolbreak": "jailbreak",
|
642 |
+
"gaolbreaks": "jailbreaks",
|
643 |
+
"gaoled": "jailed",
|
644 |
+
"gaoler": "jailer",
|
645 |
+
"gaolers": "jailers",
|
646 |
+
"gaoling": "jailing",
|
647 |
+
"gaols": "jails",
|
648 |
+
"gasses": "gases",
|
649 |
+
"gage": "gauge",
|
650 |
+
"gaged": "gauged",
|
651 |
+
"gages": "gauges",
|
652 |
+
"gaging": "gauging",
|
653 |
+
"generalisation": "generalization",
|
654 |
+
"generalisations": "generalizations",
|
655 |
+
"generalise": "generalize",
|
656 |
+
"generalised": "generalized",
|
657 |
+
"generalises": "generalizes",
|
658 |
+
"generalising": "generalizing",
|
659 |
+
"ghettoise": "ghettoize",
|
660 |
+
"ghettoised": "ghettoized",
|
661 |
+
"ghettoises": "ghettoizes",
|
662 |
+
"ghettoising": "ghettoizing",
|
663 |
+
"gipsies": "gypsies",
|
664 |
+
"glamorise": "glamorize",
|
665 |
+
"glamorised": "glamorized",
|
666 |
+
"glamorises": "glamorizes",
|
667 |
+
"glamorising": "glamorizing",
|
668 |
+
"glamor": "glamour",
|
669 |
+
"globalisation": "globalization",
|
670 |
+
"globalise": "globalize",
|
671 |
+
"globalised": "globalized",
|
672 |
+
"globalises": "globalizes",
|
673 |
+
"globalising": "globalizing",
|
674 |
+
"glueing": "gluing",
|
675 |
+
"goitre": "goiter",
|
676 |
+
"goitres": "goiters",
|
677 |
+
"gonorrhoea": "gonorrhea",
|
678 |
+
"gramme": "gram",
|
679 |
+
"grammes": "grams",
|
680 |
+
"gravelled": "graveled",
|
681 |
+
"grey": "gray",
|
682 |
+
"greyed": "grayed",
|
683 |
+
"greying": "graying",
|
684 |
+
"greyish": "grayish",
|
685 |
+
"greyness": "grayness",
|
686 |
+
"greys": "grays",
|
687 |
+
"grovelled": "groveled",
|
688 |
+
"grovelling": "groveling",
|
689 |
+
"groyne": "groin",
|
690 |
+
"groynes": "groins",
|
691 |
+
"gruelling": "grueling",
|
692 |
+
"gruellingly": "gruelingly",
|
693 |
+
"gryphon": "griffin",
|
694 |
+
"gryphons": "griffins",
|
695 |
+
"gynaecological": "gynecological",
|
696 |
+
"gynaecologist": "gynecologist",
|
697 |
+
"gynaecologists": "gynecologists",
|
698 |
+
"gynaecology": "gynecology",
|
699 |
+
"haematological": "hematological",
|
700 |
+
"haematologist": "hematologist",
|
701 |
+
"haematologists": "hematologists",
|
702 |
+
"haematology": "hematology",
|
703 |
+
"haemoglobin": "hemoglobin",
|
704 |
+
"haemophilia": "hemophilia",
|
705 |
+
"haemophiliac": "hemophiliac",
|
706 |
+
"haemophiliacs": "hemophiliacs",
|
707 |
+
"haemorrhage": "hemorrhage",
|
708 |
+
"haemorrhaged": "hemorrhaged",
|
709 |
+
"haemorrhages": "hemorrhages",
|
710 |
+
"haemorrhaging": "hemorrhaging",
|
711 |
+
"haemorrhoids": "hemorrhoids",
|
712 |
+
"harbour": "harbor",
|
713 |
+
"harboured": "harbored",
|
714 |
+
"harbouring": "harboring",
|
715 |
+
"harbours": "harbors",
|
716 |
+
"harmonisation": "harmonization",
|
717 |
+
"harmonise": "harmonize",
|
718 |
+
"harmonised": "harmonized",
|
719 |
+
"harmonises": "harmonizes",
|
720 |
+
"harmonising": "harmonizing",
|
721 |
+
"homoeopath": "homeopath",
|
722 |
+
"homoeopathic": "homeopathic",
|
723 |
+
"homoeopaths": "homeopaths",
|
724 |
+
"homoeopathy": "homeopathy",
|
725 |
+
"homogenise": "homogenize",
|
726 |
+
"homogenised": "homogenized",
|
727 |
+
"homogenises": "homogenizes",
|
728 |
+
"homogenising": "homogenizing",
|
729 |
+
"honour": "honor",
|
730 |
+
"honourable": "honorable",
|
731 |
+
"honourably": "honorably",
|
732 |
+
"honoured": "honored",
|
733 |
+
"honouring": "honoring",
|
734 |
+
"honours": "honors",
|
735 |
+
"hospitalisation": "hospitalization",
|
736 |
+
"hospitalise": "hospitalize",
|
737 |
+
"hospitalised": "hospitalized",
|
738 |
+
"hospitalises": "hospitalizes",
|
739 |
+
"hospitalising": "hospitalizing",
|
740 |
+
"humanise": "humanize",
|
741 |
+
"humanised": "humanized",
|
742 |
+
"humanises": "humanizes",
|
743 |
+
"humanising": "humanizing",
|
744 |
+
"humour": "humor",
|
745 |
+
"humoured": "humored",
|
746 |
+
"humouring": "humoring",
|
747 |
+
"humourless": "humorless",
|
748 |
+
"humours": "humors",
|
749 |
+
"hybridise": "hybridize",
|
750 |
+
"hybridised": "hybridized",
|
751 |
+
"hybridises": "hybridizes",
|
752 |
+
"hybridising": "hybridizing",
|
753 |
+
"hypnotise": "hypnotize",
|
754 |
+
"hypnotised": "hypnotized",
|
755 |
+
"hypnotises": "hypnotizes",
|
756 |
+
"hypnotising": "hypnotizing",
|
757 |
+
"hypothesise": "hypothesize",
|
758 |
+
"hypothesised": "hypothesized",
|
759 |
+
"hypothesises": "hypothesizes",
|
760 |
+
"hypothesising": "hypothesizing",
|
761 |
+
"idealisation": "idealization",
|
762 |
+
"idealise": "idealize",
|
763 |
+
"idealised": "idealized",
|
764 |
+
"idealises": "idealizes",
|
765 |
+
"idealising": "idealizing",
|
766 |
+
"idolise": "idolize",
|
767 |
+
"idolised": "idolized",
|
768 |
+
"idolises": "idolizes",
|
769 |
+
"idolising": "idolizing",
|
770 |
+
"immobilisation": "immobilization",
|
771 |
+
"immobilise": "immobilize",
|
772 |
+
"immobilised": "immobilized",
|
773 |
+
"immobiliser": "immobilizer",
|
774 |
+
"immobilisers": "immobilizers",
|
775 |
+
"immobilises": "immobilizes",
|
776 |
+
"immobilising": "immobilizing",
|
777 |
+
"immortalise": "immortalize",
|
778 |
+
"immortalised": "immortalized",
|
779 |
+
"immortalises": "immortalizes",
|
780 |
+
"immortalising": "immortalizing",
|
781 |
+
"immunisation": "immunization",
|
782 |
+
"immunise": "immunize",
|
783 |
+
"immunised": "immunized",
|
784 |
+
"immunises": "immunizes",
|
785 |
+
"immunising": "immunizing",
|
786 |
+
"impanelled": "impaneled",
|
787 |
+
"impanelling": "impaneling",
|
788 |
+
"imperilled": "imperiled",
|
789 |
+
"imperilling": "imperiling",
|
790 |
+
"individualise": "individualize",
|
791 |
+
"individualised": "individualized",
|
792 |
+
"individualises": "individualizes",
|
793 |
+
"individualising": "individualizing",
|
794 |
+
"industrialise": "industrialize",
|
795 |
+
"industrialised": "industrialized",
|
796 |
+
"industrialises": "industrializes",
|
797 |
+
"industrialising": "industrializing",
|
798 |
+
"inflexion": "inflection",
|
799 |
+
"inflexions": "inflections",
|
800 |
+
"initialise": "initialize",
|
801 |
+
"initialised": "initialized",
|
802 |
+
"initialises": "initializes",
|
803 |
+
"initialising": "initializing",
|
804 |
+
"initialled": "initialed",
|
805 |
+
"initialling": "initialing",
|
806 |
+
"instal": "install",
|
807 |
+
"instalment": "installment",
|
808 |
+
"instalments": "installments",
|
809 |
+
"instals": "installs",
|
810 |
+
"instil": "instill",
|
811 |
+
"instils": "instills",
|
812 |
+
"institutionalisation": "institutionalization",
|
813 |
+
"institutionalise": "institutionalize",
|
814 |
+
"institutionalised": "institutionalized",
|
815 |
+
"institutionalises": "institutionalizes",
|
816 |
+
"institutionalising": "institutionalizing",
|
817 |
+
"intellectualise": "intellectualize",
|
818 |
+
"intellectualised": "intellectualized",
|
819 |
+
"intellectualises": "intellectualizes",
|
820 |
+
"intellectualising": "intellectualizing",
|
821 |
+
"internalisation": "internalization",
|
822 |
+
"internalise": "internalize",
|
823 |
+
"internalised": "internalized",
|
824 |
+
"internalises": "internalizes",
|
825 |
+
"internalising": "internalizing",
|
826 |
+
"internationalisation": "internationalization",
|
827 |
+
"internationalise": "internationalize",
|
828 |
+
"internationalised": "internationalized",
|
829 |
+
"internationalises": "internationalizes",
|
830 |
+
"internationalising": "internationalizing",
|
831 |
+
"ionisation": "ionization",
|
832 |
+
"ionise": "ionize",
|
833 |
+
"ionised": "ionized",
|
834 |
+
"ioniser": "ionizer",
|
835 |
+
"ionisers": "ionizers",
|
836 |
+
"ionises": "ionizes",
|
837 |
+
"ionising": "ionizing",
|
838 |
+
"italicise": "italicize",
|
839 |
+
"italicised": "italicized",
|
840 |
+
"italicises": "italicizes",
|
841 |
+
"italicising": "italicizing",
|
842 |
+
"itemise": "itemize",
|
843 |
+
"itemised": "itemized",
|
844 |
+
"itemises": "itemizes",
|
845 |
+
"itemising": "itemizing",
|
846 |
+
"jeopardise": "jeopardize",
|
847 |
+
"jeopardised": "jeopardized",
|
848 |
+
"jeopardises": "jeopardizes",
|
849 |
+
"jeopardising": "jeopardizing",
|
850 |
+
"jewelled": "jeweled",
|
851 |
+
"jeweller": "jeweler",
|
852 |
+
"jewellers": "jewelers",
|
853 |
+
"jewellery": "jewelry",
|
854 |
+
"judgement": "judgment",
|
855 |
+
"kilogramme": "kilogram",
|
856 |
+
"kilogrammes": "kilograms",
|
857 |
+
"kilometre": "kilometer",
|
858 |
+
"kilometres": "kilometers",
|
859 |
+
"labelled": "labeled",
|
860 |
+
"labelling": "labeling",
|
861 |
+
"labour": "labor",
|
862 |
+
"laboured": "labored",
|
863 |
+
"labourer": "laborer",
|
864 |
+
"labourers": "laborers",
|
865 |
+
"labouring": "laboring",
|
866 |
+
"labours": "labors",
|
867 |
+
"lacklustre": "lackluster",
|
868 |
+
"legalisation": "legalization",
|
869 |
+
"legalise": "legalize",
|
870 |
+
"legalised": "legalized",
|
871 |
+
"legalises": "legalizes",
|
872 |
+
"legalising": "legalizing",
|
873 |
+
"legitimise": "legitimize",
|
874 |
+
"legitimised": "legitimized",
|
875 |
+
"legitimises": "legitimizes",
|
876 |
+
"legitimising": "legitimizing",
|
877 |
+
"leukaemia": "leukemia",
|
878 |
+
"levelled": "leveled",
|
879 |
+
"leveller": "leveler",
|
880 |
+
"levellers": "levelers",
|
881 |
+
"levelling": "leveling",
|
882 |
+
"libelled": "libeled",
|
883 |
+
"libelling": "libeling",
|
884 |
+
"libellous": "libelous",
|
885 |
+
"liberalisation": "liberalization",
|
886 |
+
"liberalise": "liberalize",
|
887 |
+
"liberalised": "liberalized",
|
888 |
+
"liberalises": "liberalizes",
|
889 |
+
"liberalising": "liberalizing",
|
890 |
+
"licence": "license",
|
891 |
+
"licenced": "licensed",
|
892 |
+
"licences": "licenses",
|
893 |
+
"licencing": "licensing",
|
894 |
+
"likeable": "likable",
|
895 |
+
"lionisation": "lionization",
|
896 |
+
"lionise": "lionize",
|
897 |
+
"lionised": "lionized",
|
898 |
+
"lionises": "lionizes",
|
899 |
+
"lionising": "lionizing",
|
900 |
+
"liquidise": "liquidize",
|
901 |
+
"liquidised": "liquidized",
|
902 |
+
"liquidiser": "liquidizer",
|
903 |
+
"liquidisers": "liquidizers",
|
904 |
+
"liquidises": "liquidizes",
|
905 |
+
"liquidising": "liquidizing",
|
906 |
+
"litre": "liter",
|
907 |
+
"litres": "liters",
|
908 |
+
"localise": "localize",
|
909 |
+
"localised": "localized",
|
910 |
+
"localises": "localizes",
|
911 |
+
"localising": "localizing",
|
912 |
+
"louvre": "louver",
|
913 |
+
"louvred": "louvered",
|
914 |
+
"louvres": "louvers",
|
915 |
+
"lustre": "luster",
|
916 |
+
"magnetise": "magnetize",
|
917 |
+
"magnetised": "magnetized",
|
918 |
+
"magnetises": "magnetizes",
|
919 |
+
"magnetising": "magnetizing",
|
920 |
+
"manoeuvrability": "maneuverability",
|
921 |
+
"manoeuvrable": "maneuverable",
|
922 |
+
"manoeuvre": "maneuver",
|
923 |
+
"manoeuvred": "maneuvered",
|
924 |
+
"manoeuvres": "maneuvers",
|
925 |
+
"manoeuvring": "maneuvering",
|
926 |
+
"manoeuvrings": "maneuverings",
|
927 |
+
"marginalisation": "marginalization",
|
928 |
+
"marginalise": "marginalize",
|
929 |
+
"marginalised": "marginalized",
|
930 |
+
"marginalises": "marginalizes",
|
931 |
+
"marginalising": "marginalizing",
|
932 |
+
"marshalled": "marshaled",
|
933 |
+
"marshalling": "marshaling",
|
934 |
+
"marvelled": "marveled",
|
935 |
+
"marvelling": "marveling",
|
936 |
+
"marvellous": "marvelous",
|
937 |
+
"marvellously": "marvelously",
|
938 |
+
"materialisation": "materialization",
|
939 |
+
"materialise": "materialize",
|
940 |
+
"materialised": "materialized",
|
941 |
+
"materialises": "materializes",
|
942 |
+
"materialising": "materializing",
|
943 |
+
"maximisation": "maximization",
|
944 |
+
"maximise": "maximize",
|
945 |
+
"maximised": "maximized",
|
946 |
+
"maximises": "maximizes",
|
947 |
+
"maximising": "maximizing",
|
948 |
+
"meagre": "meager",
|
949 |
+
"mechanisation": "mechanization",
|
950 |
+
"mechanise": "mechanize",
|
951 |
+
"mechanised": "mechanized",
|
952 |
+
"mechanises": "mechanizes",
|
953 |
+
"mechanising": "mechanizing",
|
954 |
+
"mediaeval": "medieval",
|
955 |
+
"memorialise": "memorialize",
|
956 |
+
"memorialised": "memorialized",
|
957 |
+
"memorialises": "memorializes",
|
958 |
+
"memorialising": "memorializing",
|
959 |
+
"memorise": "memorize",
|
960 |
+
"memorised": "memorized",
|
961 |
+
"memorises": "memorizes",
|
962 |
+
"memorising": "memorizing",
|
963 |
+
"mesmerise": "mesmerize",
|
964 |
+
"mesmerised": "mesmerized",
|
965 |
+
"mesmerises": "mesmerizes",
|
966 |
+
"mesmerising": "mesmerizing",
|
967 |
+
"metabolise": "metabolize",
|
968 |
+
"metabolised": "metabolized",
|
969 |
+
"metabolises": "metabolizes",
|
970 |
+
"metabolising": "metabolizing",
|
971 |
+
"metre": "meter",
|
972 |
+
"metres": "meters",
|
973 |
+
"micrometre": "micrometer",
|
974 |
+
"micrometres": "micrometers",
|
975 |
+
"militarise": "militarize",
|
976 |
+
"militarised": "militarized",
|
977 |
+
"militarises": "militarizes",
|
978 |
+
"militarising": "militarizing",
|
979 |
+
"milligramme": "milligram",
|
980 |
+
"milligrammes": "milligrams",
|
981 |
+
"millilitre": "milliliter",
|
982 |
+
"millilitres": "milliliters",
|
983 |
+
"millimetre": "millimeter",
|
984 |
+
"millimetres": "millimeters",
|
985 |
+
"miniaturisation": "miniaturization",
|
986 |
+
"miniaturise": "miniaturize",
|
987 |
+
"miniaturised": "miniaturized",
|
988 |
+
"miniaturises": "miniaturizes",
|
989 |
+
"miniaturising": "miniaturizing",
|
990 |
+
"minibusses": "minibuses",
|
991 |
+
"minimise": "minimize",
|
992 |
+
"minimised": "minimized",
|
993 |
+
"minimises": "minimizes",
|
994 |
+
"minimising": "minimizing",
|
995 |
+
"misbehaviour": "misbehavior",
|
996 |
+
"misdemeanour": "misdemeanor",
|
997 |
+
"misdemeanours": "misdemeanors",
|
998 |
+
"misspelt": "misspelled",
|
999 |
+
"mitre": "miter",
|
1000 |
+
"mitres": "miters",
|
1001 |
+
"mobilisation": "mobilization",
|
1002 |
+
"mobilise": "mobilize",
|
1003 |
+
"mobilised": "mobilized",
|
1004 |
+
"mobilises": "mobilizes",
|
1005 |
+
"mobilising": "mobilizing",
|
1006 |
+
"modelled": "modeled",
|
1007 |
+
"modeller": "modeler",
|
1008 |
+
"modellers": "modelers",
|
1009 |
+
"modelling": "modeling",
|
1010 |
+
"modernise": "modernize",
|
1011 |
+
"modernised": "modernized",
|
1012 |
+
"modernises": "modernizes",
|
1013 |
+
"modernising": "modernizing",
|
1014 |
+
"moisturise": "moisturize",
|
1015 |
+
"moisturised": "moisturized",
|
1016 |
+
"moisturiser": "moisturizer",
|
1017 |
+
"moisturisers": "moisturizers",
|
1018 |
+
"moisturises": "moisturizes",
|
1019 |
+
"moisturising": "moisturizing",
|
1020 |
+
"monologue": "monolog",
|
1021 |
+
"monologues": "monologs",
|
1022 |
+
"monopolisation": "monopolization",
|
1023 |
+
"monopolise": "monopolize",
|
1024 |
+
"monopolised": "monopolized",
|
1025 |
+
"monopolises": "monopolizes",
|
1026 |
+
"monopolising": "monopolizing",
|
1027 |
+
"moralise": "moralize",
|
1028 |
+
"moralised": "moralized",
|
1029 |
+
"moralises": "moralizes",
|
1030 |
+
"moralising": "moralizing",
|
1031 |
+
"motorised": "motorized",
|
1032 |
+
"mould": "mold",
|
1033 |
+
"moulded": "molded",
|
1034 |
+
"moulder": "molder",
|
1035 |
+
"mouldered": "moldered",
|
1036 |
+
"mouldering": "moldering",
|
1037 |
+
"moulders": "molders",
|
1038 |
+
"mouldier": "moldier",
|
1039 |
+
"mouldiest": "moldiest",
|
1040 |
+
"moulding": "molding",
|
1041 |
+
"mouldings": "moldings",
|
1042 |
+
"moulds": "molds",
|
1043 |
+
"mouldy": "moldy",
|
1044 |
+
"moult": "molt",
|
1045 |
+
"moulted": "molted",
|
1046 |
+
"moulting": "molting",
|
1047 |
+
"moults": "molts",
|
1048 |
+
"moustache": "mustache",
|
1049 |
+
"moustached": "mustached",
|
1050 |
+
"moustaches": "mustaches",
|
1051 |
+
"moustachioed": "mustachioed",
|
1052 |
+
"multicoloured": "multicolored",
|
1053 |
+
"nationalisation": "nationalization",
|
1054 |
+
"nationalisations": "nationalizations",
|
1055 |
+
"nationalise": "nationalize",
|
1056 |
+
"nationalised": "nationalized",
|
1057 |
+
"nationalises": "nationalizes",
|
1058 |
+
"nationalising": "nationalizing",
|
1059 |
+
"naturalisation": "naturalization",
|
1060 |
+
"naturalise": "naturalize",
|
1061 |
+
"naturalised": "naturalized",
|
1062 |
+
"naturalises": "naturalizes",
|
1063 |
+
"naturalising": "naturalizing",
|
1064 |
+
"neighbour": "neighbor",
|
1065 |
+
"neighbourhood": "neighborhood",
|
1066 |
+
"neighbourhoods": "neighborhoods",
|
1067 |
+
"neighbouring": "neighboring",
|
1068 |
+
"neighbourliness": "neighborliness",
|
1069 |
+
"neighbourly": "neighborly",
|
1070 |
+
"neighbours": "neighbors",
|
1071 |
+
"neutralisation": "neutralization",
|
1072 |
+
"neutralise": "neutralize",
|
1073 |
+
"neutralised": "neutralized",
|
1074 |
+
"neutralises": "neutralizes",
|
1075 |
+
"neutralising": "neutralizing",
|
1076 |
+
"normalisation": "normalization",
|
1077 |
+
"normalise": "normalize",
|
1078 |
+
"normalised": "normalized",
|
1079 |
+
"normalises": "normalizes",
|
1080 |
+
"normalising": "normalizing",
|
1081 |
+
"odour": "odor",
|
1082 |
+
"odourless": "odorless",
|
1083 |
+
"odours": "odors",
|
1084 |
+
"oesophagus": "esophagus",
|
1085 |
+
"oesophaguses": "esophaguses",
|
1086 |
+
"oestrogen": "estrogen",
|
1087 |
+
"offence": "offense",
|
1088 |
+
"offences": "offenses",
|
1089 |
+
"omelette": "omelet",
|
1090 |
+
"omelettes": "omelets",
|
1091 |
+
"optimise": "optimize",
|
1092 |
+
"optimised": "optimized",
|
1093 |
+
"optimises": "optimizes",
|
1094 |
+
"optimising": "optimizing",
|
1095 |
+
"organisation": "organization",
|
1096 |
+
"organisational": "organizational",
|
1097 |
+
"organisations": "organizations",
|
1098 |
+
"organise": "organize",
|
1099 |
+
"organised": "organized",
|
1100 |
+
"organiser": "organizer",
|
1101 |
+
"organisers": "organizers",
|
1102 |
+
"organises": "organizes",
|
1103 |
+
"organising": "organizing",
|
1104 |
+
"orthopaedic": "orthopedic",
|
1105 |
+
"orthopaedics": "orthopedics",
|
1106 |
+
"ostracise": "ostracize",
|
1107 |
+
"ostracised": "ostracized",
|
1108 |
+
"ostracises": "ostracizes",
|
1109 |
+
"ostracising": "ostracizing",
|
1110 |
+
"outmanoeuvre": "outmaneuver",
|
1111 |
+
"outmanoeuvred": "outmaneuvered",
|
1112 |
+
"outmanoeuvres": "outmaneuvers",
|
1113 |
+
"outmanoeuvring": "outmaneuvering",
|
1114 |
+
"overemphasise": "overemphasize",
|
1115 |
+
"overemphasised": "overemphasized",
|
1116 |
+
"overemphasises": "overemphasizes",
|
1117 |
+
"overemphasising": "overemphasizing",
|
1118 |
+
"oxidisation": "oxidization",
|
1119 |
+
"oxidise": "oxidize",
|
1120 |
+
"oxidised": "oxidized",
|
1121 |
+
"oxidises": "oxidizes",
|
1122 |
+
"oxidising": "oxidizing",
|
1123 |
+
"paederast": "pederast",
|
1124 |
+
"paederasts": "pederasts",
|
1125 |
+
"paediatric": "pediatric",
|
1126 |
+
"paediatrician": "pediatrician",
|
1127 |
+
"paediatricians": "pediatricians",
|
1128 |
+
"paediatrics": "pediatrics",
|
1129 |
+
"paedophile": "pedophile",
|
1130 |
+
"paedophiles": "pedophiles",
|
1131 |
+
"paedophilia": "pedophilia",
|
1132 |
+
"palaeolithic": "paleolithic",
|
1133 |
+
"palaeontologist": "paleontologist",
|
1134 |
+
"palaeontologists": "paleontologists",
|
1135 |
+
"palaeontology": "paleontology",
|
1136 |
+
"panelled": "paneled",
|
1137 |
+
"panelling": "paneling",
|
1138 |
+
"panellist": "panelist",
|
1139 |
+
"panellists": "panelists",
|
1140 |
+
"paralyse": "paralyze",
|
1141 |
+
"paralysed": "paralyzed",
|
1142 |
+
"paralyses": "paralyzes",
|
1143 |
+
"paralysing": "paralyzing",
|
1144 |
+
"parcelled": "parceled",
|
1145 |
+
"parcelling": "parceling",
|
1146 |
+
"parlour": "parlor",
|
1147 |
+
"parlours": "parlors",
|
1148 |
+
"particularise": "particularize",
|
1149 |
+
"particularised": "particularized",
|
1150 |
+
"particularises": "particularizes",
|
1151 |
+
"particularising": "particularizing",
|
1152 |
+
"passivisation": "passivization",
|
1153 |
+
"passivise": "passivize",
|
1154 |
+
"passivised": "passivized",
|
1155 |
+
"passivises": "passivizes",
|
1156 |
+
"passivising": "passivizing",
|
1157 |
+
"pasteurisation": "pasteurization",
|
1158 |
+
"pasteurise": "pasteurize",
|
1159 |
+
"pasteurised": "pasteurized",
|
1160 |
+
"pasteurises": "pasteurizes",
|
1161 |
+
"pasteurising": "pasteurizing",
|
1162 |
+
"patronise": "patronize",
|
1163 |
+
"patronised": "patronized",
|
1164 |
+
"patronises": "patronizes",
|
1165 |
+
"patronising": "patronizing",
|
1166 |
+
"patronisingly": "patronizingly",
|
1167 |
+
"pedalled": "pedaled",
|
1168 |
+
"pedalling": "pedaling",
|
1169 |
+
"pedestrianisation": "pedestrianization",
|
1170 |
+
"pedestrianise": "pedestrianize",
|
1171 |
+
"pedestrianised": "pedestrianized",
|
1172 |
+
"pedestrianises": "pedestrianizes",
|
1173 |
+
"pedestrianising": "pedestrianizing",
|
1174 |
+
"penalise": "penalize",
|
1175 |
+
"penalised": "penalized",
|
1176 |
+
"penalises": "penalizes",
|
1177 |
+
"penalising": "penalizing",
|
1178 |
+
"pencilled": "penciled",
|
1179 |
+
"pencilling": "penciling",
|
1180 |
+
"personalise": "personalize",
|
1181 |
+
"personalised": "personalized",
|
1182 |
+
"personalises": "personalizes",
|
1183 |
+
"personalising": "personalizing",
|
1184 |
+
"pharmacopoeia": "pharmacopeia",
|
1185 |
+
"pharmacopoeias": "pharmacopeias",
|
1186 |
+
"philosophise": "philosophize",
|
1187 |
+
"philosophised": "philosophized",
|
1188 |
+
"philosophises": "philosophizes",
|
1189 |
+
"philosophising": "philosophizing",
|
1190 |
+
"philtre": "filter",
|
1191 |
+
"philtres": "filters",
|
1192 |
+
"phoney": "phony",
|
1193 |
+
"plagiarise": "plagiarize",
|
1194 |
+
"plagiarised": "plagiarized",
|
1195 |
+
"plagiarises": "plagiarizes",
|
1196 |
+
"plagiarising": "plagiarizing",
|
1197 |
+
"plough": "plow",
|
1198 |
+
"ploughed": "plowed",
|
1199 |
+
"ploughing": "plowing",
|
1200 |
+
"ploughman": "plowman",
|
1201 |
+
"ploughmen": "plowmen",
|
1202 |
+
"ploughs": "plows",
|
1203 |
+
"ploughshare": "plowshare",
|
1204 |
+
"ploughshares": "plowshares",
|
1205 |
+
"polarisation": "polarization",
|
1206 |
+
"polarise": "polarize",
|
1207 |
+
"polarised": "polarized",
|
1208 |
+
"polarises": "polarizes",
|
1209 |
+
"polarising": "polarizing",
|
1210 |
+
"politicisation": "politicization",
|
1211 |
+
"politicise": "politicize",
|
1212 |
+
"politicised": "politicized",
|
1213 |
+
"politicises": "politicizes",
|
1214 |
+
"politicising": "politicizing",
|
1215 |
+
"popularisation": "popularization",
|
1216 |
+
"popularise": "popularize",
|
1217 |
+
"popularised": "popularized",
|
1218 |
+
"popularises": "popularizes",
|
1219 |
+
"popularising": "popularizing",
|
1220 |
+
"pouffe": "pouf",
|
1221 |
+
"pouffes": "poufs",
|
1222 |
+
"practise": "practice",
|
1223 |
+
"practised": "practiced",
|
1224 |
+
"practises": "practices",
|
1225 |
+
"practising": "practicing",
|
1226 |
+
"praesidium": "presidium",
|
1227 |
+
"praesidiums": "presidiums",
|
1228 |
+
"pressurisation": "pressurization",
|
1229 |
+
"pressurise": "pressurize",
|
1230 |
+
"pressurised": "pressurized",
|
1231 |
+
"pressurises": "pressurizes",
|
1232 |
+
"pressurising": "pressurizing",
|
1233 |
+
"pretence": "pretense",
|
1234 |
+
"pretences": "pretenses",
|
1235 |
+
"primaeval": "primeval",
|
1236 |
+
"prioritisation": "prioritization",
|
1237 |
+
"prioritise": "prioritize",
|
1238 |
+
"prioritised": "prioritized",
|
1239 |
+
"prioritises": "prioritizes",
|
1240 |
+
"prioritising": "prioritizing",
|
1241 |
+
"privatisation": "privatization",
|
1242 |
+
"privatisations": "privatizations",
|
1243 |
+
"privatise": "privatize",
|
1244 |
+
"privatised": "privatized",
|
1245 |
+
"privatises": "privatizes",
|
1246 |
+
"privatising": "privatizing",
|
1247 |
+
"professionalisation": "professionalization",
|
1248 |
+
"professionalise": "professionalize",
|
1249 |
+
"professionalised": "professionalized",
|
1250 |
+
"professionalises": "professionalizes",
|
1251 |
+
"professionalising": "professionalizing",
|
1252 |
+
"programme": "program",
|
1253 |
+
"programmes": "programs",
|
1254 |
+
"prologue": "prolog",
|
1255 |
+
"prologues": "prologs",
|
1256 |
+
"propagandise": "propagandize",
|
1257 |
+
"propagandised": "propagandized",
|
1258 |
+
"propagandises": "propagandizes",
|
1259 |
+
"propagandising": "propagandizing",
|
1260 |
+
"proselytise": "proselytize",
|
1261 |
+
"proselytised": "proselytized",
|
1262 |
+
"proselytiser": "proselytizer",
|
1263 |
+
"proselytisers": "proselytizers",
|
1264 |
+
"proselytises": "proselytizes",
|
1265 |
+
"proselytising": "proselytizing",
|
1266 |
+
"psychoanalyse": "psychoanalyze",
|
1267 |
+
"psychoanalysed": "psychoanalyzed",
|
1268 |
+
"psychoanalyses": "psychoanalyzes",
|
1269 |
+
"psychoanalysing": "psychoanalyzing",
|
1270 |
+
"publicise": "publicize",
|
1271 |
+
"publicised": "publicized",
|
1272 |
+
"publicises": "publicizes",
|
1273 |
+
"publicising": "publicizing",
|
1274 |
+
"pulverisation": "pulverization",
|
1275 |
+
"pulverise": "pulverize",
|
1276 |
+
"pulverised": "pulverized",
|
1277 |
+
"pulverises": "pulverizes",
|
1278 |
+
"pulverising": "pulverizing",
|
1279 |
+
"pummelled": "pummel",
|
1280 |
+
"pummelling": "pummeled",
|
1281 |
+
"pyjama": "pajama",
|
1282 |
+
"pyjamas": "pajamas",
|
1283 |
+
"pzazz": "pizzazz",
|
1284 |
+
"quarrelled": "quarreled",
|
1285 |
+
"quarrelling": "quarreling",
|
1286 |
+
"radicalise": "radicalize",
|
1287 |
+
"radicalised": "radicalized",
|
1288 |
+
"radicalises": "radicalizes",
|
1289 |
+
"radicalising": "radicalizing",
|
1290 |
+
"rancour": "rancor",
|
1291 |
+
"randomise": "randomize",
|
1292 |
+
"randomised": "randomized",
|
1293 |
+
"randomises": "randomizes",
|
1294 |
+
"randomising": "randomizing",
|
1295 |
+
"rationalisation": "rationalization",
|
1296 |
+
"rationalisations": "rationalizations",
|
1297 |
+
"rationalise": "rationalize",
|
1298 |
+
"rationalised": "rationalized",
|
1299 |
+
"rationalises": "rationalizes",
|
1300 |
+
"rationalising": "rationalizing",
|
1301 |
+
"ravelled": "raveled",
|
1302 |
+
"ravelling": "raveling",
|
1303 |
+
"realisable": "realizable",
|
1304 |
+
"realisation": "realization",
|
1305 |
+
"realisations": "realizations",
|
1306 |
+
"realise": "realize",
|
1307 |
+
"realised": "realized",
|
1308 |
+
"realises": "realizes",
|
1309 |
+
"realising": "realizing",
|
1310 |
+
"recognisable": "recognizable",
|
1311 |
+
"recognisably": "recognizably",
|
1312 |
+
"recognisance": "recognizance",
|
1313 |
+
"recognise": "recognize",
|
1314 |
+
"recognised": "recognized",
|
1315 |
+
"recognises": "recognizes",
|
1316 |
+
"recognising": "recognizing",
|
1317 |
+
"reconnoitre": "reconnoiter",
|
1318 |
+
"reconnoitred": "reconnoitered",
|
1319 |
+
"reconnoitres": "reconnoiters",
|
1320 |
+
"reconnoitring": "reconnoitering",
|
1321 |
+
"refuelled": "refueled",
|
1322 |
+
"refuelling": "refueling",
|
1323 |
+
"regularisation": "regularization",
|
1324 |
+
"regularise": "regularize",
|
1325 |
+
"regularised": "regularized",
|
1326 |
+
"regularises": "regularizes",
|
1327 |
+
"regularising": "regularizing",
|
1328 |
+
"remodelled": "remodeled",
|
1329 |
+
"remodelling": "remodeling",
|
1330 |
+
"remould": "remold",
|
1331 |
+
"remoulded": "remolded",
|
1332 |
+
"remoulding": "remolding",
|
1333 |
+
"remoulds": "remolds",
|
1334 |
+
"reorganisation": "reorganization",
|
1335 |
+
"reorganisations": "reorganizations",
|
1336 |
+
"reorganise": "reorganize",
|
1337 |
+
"reorganised": "reorganized",
|
1338 |
+
"reorganises": "reorganizes",
|
1339 |
+
"reorganising": "reorganizing",
|
1340 |
+
"revelled": "reveled",
|
1341 |
+
"reveller": "reveler",
|
1342 |
+
"revellers": "revelers",
|
1343 |
+
"revelling": "reveling",
|
1344 |
+
"revitalise": "revitalize",
|
1345 |
+
"revitalised": "revitalized",
|
1346 |
+
"revitalises": "revitalizes",
|
1347 |
+
"revitalising": "revitalizing",
|
1348 |
+
"revolutionise": "revolutionize",
|
1349 |
+
"revolutionised": "revolutionized",
|
1350 |
+
"revolutionises": "revolutionizes",
|
1351 |
+
"revolutionising": "revolutionizing",
|
1352 |
+
"rhapsodise": "rhapsodize",
|
1353 |
+
"rhapsodised": "rhapsodized",
|
1354 |
+
"rhapsodises": "rhapsodizes",
|
1355 |
+
"rhapsodising": "rhapsodizing",
|
1356 |
+
"rigour": "rigor",
|
1357 |
+
"rigours": "rigors",
|
1358 |
+
"ritualised": "ritualized",
|
1359 |
+
"rivalled": "rivaled",
|
1360 |
+
"rivalling": "rivaling",
|
1361 |
+
"romanticise": "romanticize",
|
1362 |
+
"romanticised": "romanticized",
|
1363 |
+
"romanticises": "romanticizes",
|
1364 |
+
"romanticising": "romanticizing",
|
1365 |
+
"rumour": "rumor",
|
1366 |
+
"rumoured": "rumored",
|
1367 |
+
"rumours": "rumors",
|
1368 |
+
"sabre": "saber",
|
1369 |
+
"sabres": "sabers",
|
1370 |
+
"saltpetre": "saltpeter",
|
1371 |
+
"sanitise": "sanitize",
|
1372 |
+
"sanitised": "sanitized",
|
1373 |
+
"sanitises": "sanitizes",
|
1374 |
+
"sanitising": "sanitizing",
|
1375 |
+
"satirise": "satirize",
|
1376 |
+
"satirised": "satirized",
|
1377 |
+
"satirises": "satirizes",
|
1378 |
+
"satirising": "satirizing",
|
1379 |
+
"saviour": "savior",
|
1380 |
+
"saviours": "saviors",
|
1381 |
+
"savour": "savor",
|
1382 |
+
"savoured": "savored",
|
1383 |
+
"savouries": "savories",
|
1384 |
+
"savouring": "savoring",
|
1385 |
+
"savours": "savors",
|
1386 |
+
"savoury": "savory",
|
1387 |
+
"scandalise": "scandalize",
|
1388 |
+
"scandalised": "scandalized",
|
1389 |
+
"scandalises": "scandalizes",
|
1390 |
+
"scandalising": "scandalizing",
|
1391 |
+
"sceptic": "skeptic",
|
1392 |
+
"sceptical": "skeptical",
|
1393 |
+
"sceptically": "skeptically",
|
1394 |
+
"scepticism": "skepticism",
|
1395 |
+
"sceptics": "skeptics",
|
1396 |
+
"sceptre": "scepter",
|
1397 |
+
"sceptres": "scepters",
|
1398 |
+
"scrutinise": "scrutinize",
|
1399 |
+
"scrutinised": "scrutinized",
|
1400 |
+
"scrutinises": "scrutinizes",
|
1401 |
+
"scrutinising": "scrutinizing",
|
1402 |
+
"secularisation": "secularization",
|
1403 |
+
"secularise": "secularize",
|
1404 |
+
"secularised": "secularized",
|
1405 |
+
"secularises": "secularizes",
|
1406 |
+
"secularising": "secularizing",
|
1407 |
+
"sensationalise": "sensationalize",
|
1408 |
+
"sensationalised": "sensationalized",
|
1409 |
+
"sensationalises": "sensationalizes",
|
1410 |
+
"sensationalising": "sensationalizing",
|
1411 |
+
"sensitise": "sensitize",
|
1412 |
+
"sensitised": "sensitized",
|
1413 |
+
"sensitises": "sensitizes",
|
1414 |
+
"sensitising": "sensitizing",
|
1415 |
+
"sentimentalise": "sentimentalize",
|
1416 |
+
"sentimentalised": "sentimentalized",
|
1417 |
+
"sentimentalises": "sentimentalizes",
|
1418 |
+
"sentimentalising": "sentimentalizing",
|
1419 |
+
"sepulchre": "sepulcher",
|
1420 |
+
"sepulchres": "sepulchers",
|
1421 |
+
"serialisation": "serialization",
|
1422 |
+
"serialisations": "serializations",
|
1423 |
+
"serialise": "serialize",
|
1424 |
+
"serialised": "serialized",
|
1425 |
+
"serialises": "serializes",
|
1426 |
+
"serialising": "serializing",
|
1427 |
+
"sermonise": "sermonize",
|
1428 |
+
"sermonised": "sermonized",
|
1429 |
+
"sermonises": "sermonizes",
|
1430 |
+
"sermonising": "sermonizing",
|
1431 |
+
"sheikh": "sheik",
|
1432 |
+
"shovelled": "shoveled",
|
1433 |
+
"shovelling": "shoveling",
|
1434 |
+
"shrivelled": "shriveled",
|
1435 |
+
"shrivelling": "shriveling",
|
1436 |
+
"signalise": "signalize",
|
1437 |
+
"signalised": "signalized",
|
1438 |
+
"signalises": "signalizes",
|
1439 |
+
"signalising": "signalizing",
|
1440 |
+
"signalled": "signaled",
|
1441 |
+
"signalling": "signaling",
|
1442 |
+
"smoulder": "smolder",
|
1443 |
+
"smouldered": "smoldered",
|
1444 |
+
"smouldering": "smoldering",
|
1445 |
+
"smoulders": "smolders",
|
1446 |
+
"snivelled": "sniveled",
|
1447 |
+
"snivelling": "sniveling",
|
1448 |
+
"snorkelled": "snorkeled",
|
1449 |
+
"snorkelling": "snorkeling",
|
1450 |
+
"snowplough": "snowplow",
|
1451 |
+
"snowploughs": "snowplow",
|
1452 |
+
"socialisation": "socialization",
|
1453 |
+
"socialise": "socialize",
|
1454 |
+
"socialised": "socialized",
|
1455 |
+
"socialises": "socializes",
|
1456 |
+
"socialising": "socializing",
|
1457 |
+
"sodomise": "sodomize",
|
1458 |
+
"sodomised": "sodomized",
|
1459 |
+
"sodomises": "sodomizes",
|
1460 |
+
"sodomising": "sodomizing",
|
1461 |
+
"solemnise": "solemnize",
|
1462 |
+
"solemnised": "solemnized",
|
1463 |
+
"solemnises": "solemnizes",
|
1464 |
+
"solemnising": "solemnizing",
|
1465 |
+
"sombre": "somber",
|
1466 |
+
"specialisation": "specialization",
|
1467 |
+
"specialisations": "specializations",
|
1468 |
+
"specialise": "specialize",
|
1469 |
+
"specialised": "specialized",
|
1470 |
+
"specialises": "specializes",
|
1471 |
+
"specialising": "specializing",
|
1472 |
+
"spectre": "specter",
|
1473 |
+
"spectres": "specters",
|
1474 |
+
"spiralled": "spiraled",
|
1475 |
+
"spiralling": "spiraling",
|
1476 |
+
"splendour": "splendor",
|
1477 |
+
"splendours": "splendors",
|
1478 |
+
"squirrelled": "squirreled",
|
1479 |
+
"squirrelling": "squirreling",
|
1480 |
+
"stabilisation": "stabilization",
|
1481 |
+
"stabilise": "stabilize",
|
1482 |
+
"stabilised": "stabilized",
|
1483 |
+
"stabiliser": "stabilizer",
|
1484 |
+
"stabilisers": "stabilizers",
|
1485 |
+
"stabilises": "stabilizes",
|
1486 |
+
"stabilising": "stabilizing",
|
1487 |
+
"standardisation": "standardization",
|
1488 |
+
"standardise": "standardize",
|
1489 |
+
"standardised": "standardized",
|
1490 |
+
"standardises": "standardizes",
|
1491 |
+
"standardising": "standardizing",
|
1492 |
+
"stencilled": "stenciled",
|
1493 |
+
"stencilling": "stenciling",
|
1494 |
+
"sterilisation": "sterilization",
|
1495 |
+
"sterilisations": "sterilizations",
|
1496 |
+
"sterilise": "sterilize",
|
1497 |
+
"sterilised": "sterilized",
|
1498 |
+
"steriliser": "sterilizer",
|
1499 |
+
"sterilisers": "sterilizers",
|
1500 |
+
"sterilises": "sterilizes",
|
1501 |
+
"sterilising": "sterilizing",
|
1502 |
+
"stigmatisation": "stigmatization",
|
1503 |
+
"stigmatise": "stigmatize",
|
1504 |
+
"stigmatised": "stigmatized",
|
1505 |
+
"stigmatises": "stigmatizes",
|
1506 |
+
"stigmatising": "stigmatizing",
|
1507 |
+
"storey": "story",
|
1508 |
+
"storeys": "stories",
|
1509 |
+
"subsidisation": "subsidization",
|
1510 |
+
"subsidise": "subsidize",
|
1511 |
+
"subsidised": "subsidized",
|
1512 |
+
"subsidiser": "subsidizer",
|
1513 |
+
"subsidisers": "subsidizers",
|
1514 |
+
"subsidises": "subsidizes",
|
1515 |
+
"subsidising": "subsidizing",
|
1516 |
+
"succour": "succor",
|
1517 |
+
"succoured": "succored",
|
1518 |
+
"succouring": "succoring",
|
1519 |
+
"succours": "succors",
|
1520 |
+
"sulphate": "sulfate",
|
1521 |
+
"sulphates": "sulfates",
|
1522 |
+
"sulphide": "sulfide",
|
1523 |
+
"sulphides": "sulfides",
|
1524 |
+
"sulphur": "sulfur",
|
1525 |
+
"sulphurous": "sulfurous",
|
1526 |
+
"summarise": "summarize",
|
1527 |
+
"summarised": "summarized",
|
1528 |
+
"summarises": "summarizes",
|
1529 |
+
"summarising": "summarizing",
|
1530 |
+
"swivelled": "swiveled",
|
1531 |
+
"swivelling": "swiveling",
|
1532 |
+
"symbolise": "symbolize",
|
1533 |
+
"symbolised": "symbolized",
|
1534 |
+
"symbolises": "symbolizes",
|
1535 |
+
"symbolising": "symbolizing",
|
1536 |
+
"sympathise": "sympathize",
|
1537 |
+
"sympathised": "sympathized",
|
1538 |
+
"sympathiser": "sympathizer",
|
1539 |
+
"sympathisers": "sympathizers",
|
1540 |
+
"sympathises": "sympathizes",
|
1541 |
+
"sympathising": "sympathizing",
|
1542 |
+
"synchronisation": "synchronization",
|
1543 |
+
"synchronise": "synchronize",
|
1544 |
+
"synchronised": "synchronized",
|
1545 |
+
"synchronises": "synchronizes",
|
1546 |
+
"synchronising": "synchronizing",
|
1547 |
+
"synthesise": "synthesize",
|
1548 |
+
"synthesised": "synthesized",
|
1549 |
+
"synthesiser": "synthesizer",
|
1550 |
+
"synthesisers": "synthesizers",
|
1551 |
+
"synthesises": "synthesizes",
|
1552 |
+
"synthesising": "synthesizing",
|
1553 |
+
"syphon": "siphon",
|
1554 |
+
"syphoned": "siphoned",
|
1555 |
+
"syphoning": "siphoning",
|
1556 |
+
"syphons": "siphons",
|
1557 |
+
"systematisation": "systematization",
|
1558 |
+
"systematise": "systematize",
|
1559 |
+
"systematised": "systematized",
|
1560 |
+
"systematises": "systematizes",
|
1561 |
+
"systematising": "systematizing",
|
1562 |
+
"tantalise": "tantalize",
|
1563 |
+
"tantalised": "tantalized",
|
1564 |
+
"tantalises": "tantalizes",
|
1565 |
+
"tantalising": "tantalizing",
|
1566 |
+
"tantalisingly": "tantalizingly",
|
1567 |
+
"tasselled": "tasseled",
|
1568 |
+
"technicolour": "technicolor",
|
1569 |
+
"temporise": "temporize",
|
1570 |
+
"temporised": "temporized",
|
1571 |
+
"temporises": "temporizes",
|
1572 |
+
"temporising": "temporizing",
|
1573 |
+
"tenderise": "tenderize",
|
1574 |
+
"tenderised": "tenderized",
|
1575 |
+
"tenderises": "tenderizes",
|
1576 |
+
"tenderising": "tenderizing",
|
1577 |
+
"terrorise": "terrorize",
|
1578 |
+
"terrorised": "terrorized",
|
1579 |
+
"terrorises": "terrorizes",
|
1580 |
+
"terrorising": "terrorizing",
|
1581 |
+
"theatre": "theater",
|
1582 |
+
"theatregoer": "theatergoer",
|
1583 |
+
"theatregoers": "theatergoers",
|
1584 |
+
"theatres": "theaters",
|
1585 |
+
"theorise": "theorize",
|
1586 |
+
"theorised": "theorized",
|
1587 |
+
"theorises": "theorizes",
|
1588 |
+
"theorising": "theorizing",
|
1589 |
+
"tonne": "ton",
|
1590 |
+
"tonnes": "tons",
|
1591 |
+
"towelled": "toweled",
|
1592 |
+
"towelling": "toweling",
|
1593 |
+
"toxaemia": "toxemia",
|
1594 |
+
"tranquillise": "tranquilize",
|
1595 |
+
"tranquillised": "tranquilized",
|
1596 |
+
"tranquilliser": "tranquilizer",
|
1597 |
+
"tranquillisers": "tranquilizers",
|
1598 |
+
"tranquillises": "tranquilizes",
|
1599 |
+
"tranquillising": "tranquilizing",
|
1600 |
+
"tranquillity": "tranquility",
|
1601 |
+
"tranquillize": "tranquilize",
|
1602 |
+
"tranquillized": "tranquilized",
|
1603 |
+
"tranquillizer": "tranquilizer",
|
1604 |
+
"tranquillizers": "tranquilizers",
|
1605 |
+
"tranquillizes": "tranquilizes",
|
1606 |
+
"tranquillizing": "tranquilizing",
|
1607 |
+
"tranquilly": "tranquility",
|
1608 |
+
"transistorised": "transistorized",
|
1609 |
+
"traumatise": "traumatize",
|
1610 |
+
"traumatised": "traumatized",
|
1611 |
+
"traumatises": "traumatizes",
|
1612 |
+
"traumatising": "traumatizing",
|
1613 |
+
"travelled": "traveled",
|
1614 |
+
"traveller": "traveler",
|
1615 |
+
"travellers": "travelers",
|
1616 |
+
"travelling": "traveling",
|
1617 |
+
"travelog": "travelogue",
|
1618 |
+
"travelogs": "travelogues",
|
1619 |
+
"trialled": "trialed",
|
1620 |
+
"trialling": "trialing",
|
1621 |
+
"tricolour": "tricolor",
|
1622 |
+
"tricolours": "tricolors",
|
1623 |
+
"trivialise": "trivialize",
|
1624 |
+
"trivialised": "trivialized",
|
1625 |
+
"trivialises": "trivializes",
|
1626 |
+
"trivialising": "trivializing",
|
1627 |
+
"tumour": "tumor",
|
1628 |
+
"tumours": "tumors",
|
1629 |
+
"tunnelled": "tunneled",
|
1630 |
+
"tunnelling": "tunneling",
|
1631 |
+
"tyrannise": "tyrannize",
|
1632 |
+
"tyrannised": "tyrannized",
|
1633 |
+
"tyrannises": "tyrannizes",
|
1634 |
+
"tyrannising": "tyrannizing",
|
1635 |
+
"tyre": "tire",
|
1636 |
+
"tyres": "tires",
|
1637 |
+
"unauthorised": "unauthorized",
|
1638 |
+
"uncivilised": "uncivilized",
|
1639 |
+
"underutilised": "underutilized",
|
1640 |
+
"unequalled": "unequaled",
|
1641 |
+
"unfavourable": "unfavorable",
|
1642 |
+
"unfavourably": "unfavorably",
|
1643 |
+
"unionisation": "unionization",
|
1644 |
+
"unionise": "unionize",
|
1645 |
+
"unionised": "unionized",
|
1646 |
+
"unionises": "unionizes",
|
1647 |
+
"unionising": "unionizing",
|
1648 |
+
"unorganised": "unorganized",
|
1649 |
+
"unravelled": "unraveled",
|
1650 |
+
"unravelling": "unraveling",
|
1651 |
+
"unrecognisable": "unrecognizable",
|
1652 |
+
"unrecognised": "unrecognized",
|
1653 |
+
"unrivalled": "unrivaled",
|
1654 |
+
"unsavoury": "unsavory",
|
1655 |
+
"untrammelled": "untrammeled",
|
1656 |
+
"urbanisation": "urbanization",
|
1657 |
+
"urbanise": "urbanize",
|
1658 |
+
"urbanised": "urbanized",
|
1659 |
+
"urbanises": "urbanizes",
|
1660 |
+
"urbanising": "urbanizing",
|
1661 |
+
"utilisable": "utilizable",
|
1662 |
+
"utilisation": "utilization",
|
1663 |
+
"utilise": "utilize",
|
1664 |
+
"utilised": "utilized",
|
1665 |
+
"utilises": "utilizes",
|
1666 |
+
"utilising": "utilizing",
|
1667 |
+
"valour": "valor",
|
1668 |
+
"vandalise": "vandalize",
|
1669 |
+
"vandalised": "vandalized",
|
1670 |
+
"vandalises": "vandalizes",
|
1671 |
+
"vandalising": "vandalizing",
|
1672 |
+
"vaporisation": "vaporization",
|
1673 |
+
"vaporise": "vaporize",
|
1674 |
+
"vaporised": "vaporized",
|
1675 |
+
"vaporises": "vaporizes",
|
1676 |
+
"vaporising": "vaporizing",
|
1677 |
+
"vapour": "vapor",
|
1678 |
+
"vapours": "vapors",
|
1679 |
+
"verbalise": "verbalize",
|
1680 |
+
"verbalised": "verbalized",
|
1681 |
+
"verbalises": "verbalizes",
|
1682 |
+
"verbalising": "verbalizing",
|
1683 |
+
"victimisation": "victimization",
|
1684 |
+
"victimise": "victimize",
|
1685 |
+
"victimised": "victimized",
|
1686 |
+
"victimises": "victimizes",
|
1687 |
+
"victimising": "victimizing",
|
1688 |
+
"videodisc": "videodisk",
|
1689 |
+
"videodiscs": "videodisks",
|
1690 |
+
"vigour": "vigor",
|
1691 |
+
"visualisation": "visualization",
|
1692 |
+
"visualisations": "visualizations",
|
1693 |
+
"visualise": "visualize",
|
1694 |
+
"visualised": "visualized",
|
1695 |
+
"visualises": "visualizes",
|
1696 |
+
"visualising": "visualizing",
|
1697 |
+
"vocalisation": "vocalization",
|
1698 |
+
"vocalisations": "vocalizations",
|
1699 |
+
"vocalise": "vocalize",
|
1700 |
+
"vocalised": "vocalized",
|
1701 |
+
"vocalises": "vocalizes",
|
1702 |
+
"vocalising": "vocalizing",
|
1703 |
+
"vulcanised": "vulcanized",
|
1704 |
+
"vulgarisation": "vulgarization",
|
1705 |
+
"vulgarise": "vulgarize",
|
1706 |
+
"vulgarised": "vulgarized",
|
1707 |
+
"vulgarises": "vulgarizes",
|
1708 |
+
"vulgarising": "vulgarizing",
|
1709 |
+
"waggon": "wagon",
|
1710 |
+
"waggons": "wagons",
|
1711 |
+
"watercolour": "watercolor",
|
1712 |
+
"watercolours": "watercolors",
|
1713 |
+
"weaselled": "weaseled",
|
1714 |
+
"weaselling": "weaseling",
|
1715 |
+
"westernisation": "westernization",
|
1716 |
+
"westernise": "westernize",
|
1717 |
+
"westernised": "westernized",
|
1718 |
+
"westernises": "westernizes",
|
1719 |
+
"westernising": "westernizing",
|
1720 |
+
"womanise": "womanize",
|
1721 |
+
"womanised": "womanized",
|
1722 |
+
"womaniser": "womanizer",
|
1723 |
+
"womanisers": "womanizers",
|
1724 |
+
"womanises": "womanizes",
|
1725 |
+
"womanising": "womanizing",
|
1726 |
+
"woollen": "woolen",
|
1727 |
+
"woollens": "woolens",
|
1728 |
+
"woollies": "woolies",
|
1729 |
+
"woolly": "wooly",
|
1730 |
+
"worshipped": "worshiped",
|
1731 |
+
"worshipping": "worshiping",
|
1732 |
+
"worshipper": "worshiper",
|
1733 |
+
"yodelled": "yodeled",
|
1734 |
+
"yodelling": "yodeling",
|
1735 |
+
"yoghourt": "yogurt",
|
1736 |
+
"yoghourts": "yogurts",
|
1737 |
+
"yoghurt": "yogurt",
|
1738 |
+
"yoghurts": "yogurts",
|
1739 |
+
"mhm": "hmm",
|
1740 |
+
"mmm": "hmm"
|
1741 |
+
}
|
whisper_stream/normalizers/english.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
from fractions import Fraction
|
5 |
+
from typing import Iterator, List, Match, Optional, Union
|
6 |
+
|
7 |
+
from more_itertools import windowed
|
8 |
+
|
9 |
+
from .basic import remove_symbols_and_diacritics
|
10 |
+
|
11 |
+
|
12 |
+
class EnglishNumberNormalizer:
|
13 |
+
"""
|
14 |
+
Convert any spelled-out numbers into arabic numbers, while handling:
|
15 |
+
|
16 |
+
- remove any commas
|
17 |
+
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
18 |
+
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
19 |
+
- spell out `one` and `ones`
|
20 |
+
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self):
|
24 |
+
super().__init__()
|
25 |
+
|
26 |
+
self.zeros = {"o", "oh", "zero"}
|
27 |
+
self.ones = {
|
28 |
+
name: i
|
29 |
+
for i, name in enumerate(
|
30 |
+
[
|
31 |
+
"one",
|
32 |
+
"two",
|
33 |
+
"three",
|
34 |
+
"four",
|
35 |
+
"five",
|
36 |
+
"six",
|
37 |
+
"seven",
|
38 |
+
"eight",
|
39 |
+
"nine",
|
40 |
+
"ten",
|
41 |
+
"eleven",
|
42 |
+
"twelve",
|
43 |
+
"thirteen",
|
44 |
+
"fourteen",
|
45 |
+
"fifteen",
|
46 |
+
"sixteen",
|
47 |
+
"seventeen",
|
48 |
+
"eighteen",
|
49 |
+
"nineteen",
|
50 |
+
],
|
51 |
+
start=1,
|
52 |
+
)
|
53 |
+
}
|
54 |
+
self.ones_plural = {
|
55 |
+
"sixes" if name == "six" else name + "s": (value, "s")
|
56 |
+
for name, value in self.ones.items()
|
57 |
+
}
|
58 |
+
self.ones_ordinal = {
|
59 |
+
"zeroth": (0, "th"),
|
60 |
+
"first": (1, "st"),
|
61 |
+
"second": (2, "nd"),
|
62 |
+
"third": (3, "rd"),
|
63 |
+
"fifth": (5, "th"),
|
64 |
+
"twelfth": (12, "th"),
|
65 |
+
**{
|
66 |
+
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
67 |
+
for name, value in self.ones.items()
|
68 |
+
if value > 3 and value != 5 and value != 12
|
69 |
+
},
|
70 |
+
}
|
71 |
+
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
72 |
+
|
73 |
+
self.tens = {
|
74 |
+
"twenty": 20,
|
75 |
+
"thirty": 30,
|
76 |
+
"forty": 40,
|
77 |
+
"fifty": 50,
|
78 |
+
"sixty": 60,
|
79 |
+
"seventy": 70,
|
80 |
+
"eighty": 80,
|
81 |
+
"ninety": 90,
|
82 |
+
}
|
83 |
+
self.tens_plural = {
|
84 |
+
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
85 |
+
}
|
86 |
+
self.tens_ordinal = {
|
87 |
+
name.replace("y", "ieth"): (value, "th")
|
88 |
+
for name, value in self.tens.items()
|
89 |
+
}
|
90 |
+
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
91 |
+
|
92 |
+
self.multipliers = {
|
93 |
+
"hundred": 100,
|
94 |
+
"thousand": 1_000,
|
95 |
+
"million": 1_000_000,
|
96 |
+
"billion": 1_000_000_000,
|
97 |
+
"trillion": 1_000_000_000_000,
|
98 |
+
"quadrillion": 1_000_000_000_000_000,
|
99 |
+
"quintillion": 1_000_000_000_000_000_000,
|
100 |
+
"sextillion": 1_000_000_000_000_000_000_000,
|
101 |
+
"septillion": 1_000_000_000_000_000_000_000_000,
|
102 |
+
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
103 |
+
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
104 |
+
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
105 |
+
}
|
106 |
+
self.multipliers_plural = {
|
107 |
+
name + "s": (value, "s") for name, value in self.multipliers.items()
|
108 |
+
}
|
109 |
+
self.multipliers_ordinal = {
|
110 |
+
name + "th": (value, "th") for name, value in self.multipliers.items()
|
111 |
+
}
|
112 |
+
self.multipliers_suffixed = {
|
113 |
+
**self.multipliers_plural,
|
114 |
+
**self.multipliers_ordinal,
|
115 |
+
}
|
116 |
+
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
117 |
+
|
118 |
+
self.preceding_prefixers = {
|
119 |
+
"minus": "-",
|
120 |
+
"negative": "-",
|
121 |
+
"plus": "+",
|
122 |
+
"positive": "+",
|
123 |
+
}
|
124 |
+
self.following_prefixers = {
|
125 |
+
"pound": "£",
|
126 |
+
"pounds": "£",
|
127 |
+
"euro": "€",
|
128 |
+
"euros": "€",
|
129 |
+
"dollar": "$",
|
130 |
+
"dollars": "$",
|
131 |
+
"cent": "¢",
|
132 |
+
"cents": "¢",
|
133 |
+
}
|
134 |
+
self.prefixes = set(
|
135 |
+
list(self.preceding_prefixers.values())
|
136 |
+
+ list(self.following_prefixers.values())
|
137 |
+
)
|
138 |
+
self.suffixers = {
|
139 |
+
"per": {"cent": "%"},
|
140 |
+
"percent": "%",
|
141 |
+
}
|
142 |
+
self.specials = {"and", "double", "triple", "point"}
|
143 |
+
|
144 |
+
self.words = set(
|
145 |
+
[
|
146 |
+
key
|
147 |
+
for mapping in [
|
148 |
+
self.zeros,
|
149 |
+
self.ones,
|
150 |
+
self.ones_suffixed,
|
151 |
+
self.tens,
|
152 |
+
self.tens_suffixed,
|
153 |
+
self.multipliers,
|
154 |
+
self.multipliers_suffixed,
|
155 |
+
self.preceding_prefixers,
|
156 |
+
self.following_prefixers,
|
157 |
+
self.suffixers,
|
158 |
+
self.specials,
|
159 |
+
]
|
160 |
+
for key in mapping
|
161 |
+
]
|
162 |
+
)
|
163 |
+
self.literal_words = {"one", "ones"}
|
164 |
+
|
165 |
+
def process_words(self, words: List[str]) -> Iterator[str]:
|
166 |
+
prefix: Optional[str] = None
|
167 |
+
value: Optional[Union[str, int]] = None
|
168 |
+
skip = False
|
169 |
+
|
170 |
+
def to_fraction(s: str):
|
171 |
+
try:
|
172 |
+
return Fraction(s)
|
173 |
+
except ValueError:
|
174 |
+
return None
|
175 |
+
|
176 |
+
def output(result: Union[str, int]):
|
177 |
+
nonlocal prefix, value
|
178 |
+
result = str(result)
|
179 |
+
if prefix is not None:
|
180 |
+
result = prefix + result
|
181 |
+
value = None
|
182 |
+
prefix = None
|
183 |
+
return result
|
184 |
+
|
185 |
+
if len(words) == 0:
|
186 |
+
return
|
187 |
+
|
188 |
+
for prev, current, next in windowed([None] + words + [None], 3):
|
189 |
+
if skip:
|
190 |
+
skip = False
|
191 |
+
continue
|
192 |
+
|
193 |
+
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
194 |
+
has_prefix = current[0] in self.prefixes
|
195 |
+
current_without_prefix = current[1:] if has_prefix else current
|
196 |
+
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
197 |
+
# arabic numbers (potentially with signs and fractions)
|
198 |
+
f = to_fraction(current_without_prefix)
|
199 |
+
assert f is not None
|
200 |
+
if value is not None:
|
201 |
+
if isinstance(value, str) and value.endswith("."):
|
202 |
+
# concatenate decimals / ip address components
|
203 |
+
value = str(value) + str(current)
|
204 |
+
continue
|
205 |
+
else:
|
206 |
+
yield output(value)
|
207 |
+
|
208 |
+
prefix = current[0] if has_prefix else prefix
|
209 |
+
if f.denominator == 1:
|
210 |
+
value = f.numerator # store integers as int
|
211 |
+
else:
|
212 |
+
value = current_without_prefix
|
213 |
+
elif current not in self.words:
|
214 |
+
# non-numeric words
|
215 |
+
if value is not None:
|
216 |
+
yield output(value)
|
217 |
+
yield output(current)
|
218 |
+
elif current in self.zeros:
|
219 |
+
value = str(value or "") + "0"
|
220 |
+
elif current in self.ones:
|
221 |
+
ones = self.ones[current]
|
222 |
+
|
223 |
+
if value is None:
|
224 |
+
value = ones
|
225 |
+
elif isinstance(value, str) or prev in self.ones:
|
226 |
+
if (
|
227 |
+
prev in self.tens and ones < 10
|
228 |
+
): # replace the last zero with the digit
|
229 |
+
assert value[-1] == "0"
|
230 |
+
value = value[:-1] + str(ones)
|
231 |
+
else:
|
232 |
+
value = str(value) + str(ones)
|
233 |
+
elif ones < 10:
|
234 |
+
if value % 10 == 0:
|
235 |
+
value += ones
|
236 |
+
else:
|
237 |
+
value = str(value) + str(ones)
|
238 |
+
else: # eleven to nineteen
|
239 |
+
if value % 100 == 0:
|
240 |
+
value += ones
|
241 |
+
else:
|
242 |
+
value = str(value) + str(ones)
|
243 |
+
elif current in self.ones_suffixed:
|
244 |
+
# ordinal or cardinal; yield the number right away
|
245 |
+
ones, suffix = self.ones_suffixed[current]
|
246 |
+
if value is None:
|
247 |
+
yield output(str(ones) + suffix)
|
248 |
+
elif isinstance(value, str) or prev in self.ones:
|
249 |
+
if prev in self.tens and ones < 10:
|
250 |
+
assert value[-1] == "0"
|
251 |
+
yield output(value[:-1] + str(ones) + suffix)
|
252 |
+
else:
|
253 |
+
yield output(str(value) + str(ones) + suffix)
|
254 |
+
elif ones < 10:
|
255 |
+
if value % 10 == 0:
|
256 |
+
yield output(str(value + ones) + suffix)
|
257 |
+
else:
|
258 |
+
yield output(str(value) + str(ones) + suffix)
|
259 |
+
else: # eleven to nineteen
|
260 |
+
if value % 100 == 0:
|
261 |
+
yield output(str(value + ones) + suffix)
|
262 |
+
else:
|
263 |
+
yield output(str(value) + str(ones) + suffix)
|
264 |
+
value = None
|
265 |
+
elif current in self.tens:
|
266 |
+
tens = self.tens[current]
|
267 |
+
if value is None:
|
268 |
+
value = tens
|
269 |
+
elif isinstance(value, str):
|
270 |
+
value = str(value) + str(tens)
|
271 |
+
else:
|
272 |
+
if value % 100 == 0:
|
273 |
+
value += tens
|
274 |
+
else:
|
275 |
+
value = str(value) + str(tens)
|
276 |
+
elif current in self.tens_suffixed:
|
277 |
+
# ordinal or cardinal; yield the number right away
|
278 |
+
tens, suffix = self.tens_suffixed[current]
|
279 |
+
if value is None:
|
280 |
+
yield output(str(tens) + suffix)
|
281 |
+
elif isinstance(value, str):
|
282 |
+
yield output(str(value) + str(tens) + suffix)
|
283 |
+
else:
|
284 |
+
if value % 100 == 0:
|
285 |
+
yield output(str(value + tens) + suffix)
|
286 |
+
else:
|
287 |
+
yield output(str(value) + str(tens) + suffix)
|
288 |
+
elif current in self.multipliers:
|
289 |
+
multiplier = self.multipliers[current]
|
290 |
+
if value is None:
|
291 |
+
value = multiplier
|
292 |
+
elif isinstance(value, str) or value == 0:
|
293 |
+
f = to_fraction(value)
|
294 |
+
p = f * multiplier if f is not None else None
|
295 |
+
if f is not None and p.denominator == 1:
|
296 |
+
value = p.numerator
|
297 |
+
else:
|
298 |
+
yield output(value)
|
299 |
+
value = multiplier
|
300 |
+
else:
|
301 |
+
before = value // 1000 * 1000
|
302 |
+
residual = value % 1000
|
303 |
+
value = before + residual * multiplier
|
304 |
+
elif current in self.multipliers_suffixed:
|
305 |
+
multiplier, suffix = self.multipliers_suffixed[current]
|
306 |
+
if value is None:
|
307 |
+
yield output(str(multiplier) + suffix)
|
308 |
+
elif isinstance(value, str):
|
309 |
+
f = to_fraction(value)
|
310 |
+
p = f * multiplier if f is not None else None
|
311 |
+
if f is not None and p.denominator == 1:
|
312 |
+
yield output(str(p.numerator) + suffix)
|
313 |
+
else:
|
314 |
+
yield output(value)
|
315 |
+
yield output(str(multiplier) + suffix)
|
316 |
+
else: # int
|
317 |
+
before = value // 1000 * 1000
|
318 |
+
residual = value % 1000
|
319 |
+
value = before + residual * multiplier
|
320 |
+
yield output(str(value) + suffix)
|
321 |
+
value = None
|
322 |
+
elif current in self.preceding_prefixers:
|
323 |
+
# apply prefix (positive, minus, etc.) if it precedes a number
|
324 |
+
if value is not None:
|
325 |
+
yield output(value)
|
326 |
+
|
327 |
+
if next in self.words or next_is_numeric:
|
328 |
+
prefix = self.preceding_prefixers[current]
|
329 |
+
else:
|
330 |
+
yield output(current)
|
331 |
+
elif current in self.following_prefixers:
|
332 |
+
# apply prefix (dollars, cents, etc.) only after a number
|
333 |
+
if value is not None:
|
334 |
+
prefix = self.following_prefixers[current]
|
335 |
+
yield output(value)
|
336 |
+
else:
|
337 |
+
yield output(current)
|
338 |
+
elif current in self.suffixers:
|
339 |
+
# apply suffix symbols (percent -> '%')
|
340 |
+
if value is not None:
|
341 |
+
suffix = self.suffixers[current]
|
342 |
+
if isinstance(suffix, dict):
|
343 |
+
if next in suffix:
|
344 |
+
yield output(str(value) + suffix[next])
|
345 |
+
skip = True
|
346 |
+
else:
|
347 |
+
yield output(value)
|
348 |
+
yield output(current)
|
349 |
+
else:
|
350 |
+
yield output(str(value) + suffix)
|
351 |
+
else:
|
352 |
+
yield output(current)
|
353 |
+
elif current in self.specials:
|
354 |
+
if next not in self.words and not next_is_numeric:
|
355 |
+
# apply special handling only if the next word can be numeric
|
356 |
+
if value is not None:
|
357 |
+
yield output(value)
|
358 |
+
yield output(current)
|
359 |
+
elif current == "and":
|
360 |
+
# ignore "and" after hundreds, thousands, etc.
|
361 |
+
if prev not in self.multipliers:
|
362 |
+
if value is not None:
|
363 |
+
yield output(value)
|
364 |
+
yield output(current)
|
365 |
+
elif current == "double" or current == "triple":
|
366 |
+
if next in self.ones or next in self.zeros:
|
367 |
+
repeats = 2 if current == "double" else 3
|
368 |
+
ones = self.ones.get(next, 0)
|
369 |
+
value = str(value or "") + str(ones) * repeats
|
370 |
+
skip = True
|
371 |
+
else:
|
372 |
+
if value is not None:
|
373 |
+
yield output(value)
|
374 |
+
yield output(current)
|
375 |
+
elif current == "point":
|
376 |
+
if next in self.decimals or next_is_numeric:
|
377 |
+
value = str(value or "") + "."
|
378 |
+
else:
|
379 |
+
# should all have been covered at this point
|
380 |
+
raise ValueError(f"Unexpected token: {current}")
|
381 |
+
else:
|
382 |
+
# all should have been covered at this point
|
383 |
+
raise ValueError(f"Unexpected token: {current}")
|
384 |
+
|
385 |
+
if value is not None:
|
386 |
+
yield output(value)
|
387 |
+
|
388 |
+
def preprocess(self, s: str):
|
389 |
+
# replace "<number> and a half" with "<number> point five"
|
390 |
+
results = []
|
391 |
+
|
392 |
+
segments = re.split(r"\band\s+a\s+half\b", s)
|
393 |
+
for i, segment in enumerate(segments):
|
394 |
+
if len(segment.strip()) == 0:
|
395 |
+
continue
|
396 |
+
if i == len(segments) - 1:
|
397 |
+
results.append(segment)
|
398 |
+
else:
|
399 |
+
results.append(segment)
|
400 |
+
last_word = segment.rsplit(maxsplit=2)[-1]
|
401 |
+
if last_word in self.decimals or last_word in self.multipliers:
|
402 |
+
results.append("point five")
|
403 |
+
else:
|
404 |
+
results.append("and a half")
|
405 |
+
|
406 |
+
s = " ".join(results)
|
407 |
+
|
408 |
+
# put a space at number/letter boundary
|
409 |
+
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
410 |
+
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
411 |
+
|
412 |
+
# but remove spaces which could be a suffix
|
413 |
+
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
414 |
+
|
415 |
+
return s
|
416 |
+
|
417 |
+
def postprocess(self, s: str):
|
418 |
+
def combine_cents(m: Match):
|
419 |
+
try:
|
420 |
+
currency = m.group(1)
|
421 |
+
integer = m.group(2)
|
422 |
+
cents = int(m.group(3))
|
423 |
+
return f"{currency}{integer}.{cents:02d}"
|
424 |
+
except ValueError:
|
425 |
+
return m.string
|
426 |
+
|
427 |
+
def extract_cents(m: Match):
|
428 |
+
try:
|
429 |
+
return f"¢{int(m.group(1))}"
|
430 |
+
except ValueError:
|
431 |
+
return m.string
|
432 |
+
|
433 |
+
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
434 |
+
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
435 |
+
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
436 |
+
|
437 |
+
# write "one(s)" instead of "1(s)", just for the readability
|
438 |
+
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
439 |
+
|
440 |
+
return s
|
441 |
+
|
442 |
+
def __call__(self, s: str):
|
443 |
+
s = self.preprocess(s)
|
444 |
+
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
445 |
+
s = self.postprocess(s)
|
446 |
+
|
447 |
+
return s
|
448 |
+
|
449 |
+
|
450 |
+
class EnglishSpellingNormalizer:
|
451 |
+
"""
|
452 |
+
Applies British-American spelling mappings as listed in [1].
|
453 |
+
|
454 |
+
[1] https://www.tysto.com/uk-us-spelling-list.html
|
455 |
+
"""
|
456 |
+
|
457 |
+
def __init__(self):
|
458 |
+
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
459 |
+
self.mapping = json.load(open(mapping_path))
|
460 |
+
|
461 |
+
def __call__(self, s: str):
|
462 |
+
return " ".join(self.mapping.get(word, word) for word in s.split())
|
463 |
+
|
464 |
+
|
465 |
+
class EnglishTextNormalizer:
|
466 |
+
def __init__(self):
|
467 |
+
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
468 |
+
self.replacers = {
|
469 |
+
# common contractions
|
470 |
+
r"\bwon't\b": "will not",
|
471 |
+
r"\bcan't\b": "can not",
|
472 |
+
r"\blet's\b": "let us",
|
473 |
+
r"\bain't\b": "aint",
|
474 |
+
r"\by'all\b": "you all",
|
475 |
+
r"\bwanna\b": "want to",
|
476 |
+
r"\bgotta\b": "got to",
|
477 |
+
r"\bgonna\b": "going to",
|
478 |
+
r"\bi'ma\b": "i am going to",
|
479 |
+
r"\bimma\b": "i am going to",
|
480 |
+
r"\bwoulda\b": "would have",
|
481 |
+
r"\bcoulda\b": "could have",
|
482 |
+
r"\bshoulda\b": "should have",
|
483 |
+
r"\bma'am\b": "madam",
|
484 |
+
# contractions in titles/prefixes
|
485 |
+
r"\bmr\b": "mister ",
|
486 |
+
r"\bmrs\b": "missus ",
|
487 |
+
r"\bst\b": "saint ",
|
488 |
+
r"\bdr\b": "doctor ",
|
489 |
+
r"\bprof\b": "professor ",
|
490 |
+
r"\bcapt\b": "captain ",
|
491 |
+
r"\bgov\b": "governor ",
|
492 |
+
r"\bald\b": "alderman ",
|
493 |
+
r"\bgen\b": "general ",
|
494 |
+
r"\bsen\b": "senator ",
|
495 |
+
r"\brep\b": "representative ",
|
496 |
+
r"\bpres\b": "president ",
|
497 |
+
r"\brev\b": "reverend ",
|
498 |
+
r"\bhon\b": "honorable ",
|
499 |
+
r"\basst\b": "assistant ",
|
500 |
+
r"\bassoc\b": "associate ",
|
501 |
+
r"\blt\b": "lieutenant ",
|
502 |
+
r"\bcol\b": "colonel ",
|
503 |
+
r"\bjr\b": "junior ",
|
504 |
+
r"\bsr\b": "senior ",
|
505 |
+
r"\besq\b": "esquire ",
|
506 |
+
# prefect tenses, ideally it should be any past participles, but it's harder..
|
507 |
+
r"'d been\b": " had been",
|
508 |
+
r"'s been\b": " has been",
|
509 |
+
r"'d gone\b": " had gone",
|
510 |
+
r"'s gone\b": " has gone",
|
511 |
+
r"'d done\b": " had done", # "'s done" is ambiguous
|
512 |
+
r"'s got\b": " has got",
|
513 |
+
# general contractions
|
514 |
+
r"n't\b": " not",
|
515 |
+
r"'re\b": " are",
|
516 |
+
r"'s\b": " is",
|
517 |
+
r"'d\b": " would",
|
518 |
+
r"'ll\b": " will",
|
519 |
+
r"'t\b": " not",
|
520 |
+
r"'ve\b": " have",
|
521 |
+
r"'m\b": " am",
|
522 |
+
}
|
523 |
+
self.standardize_numbers = EnglishNumberNormalizer()
|
524 |
+
self.standardize_spellings = EnglishSpellingNormalizer()
|
525 |
+
|
526 |
+
def __call__(self, s: str):
|
527 |
+
s = s.lower()
|
528 |
+
|
529 |
+
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
530 |
+
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
531 |
+
s = re.sub(self.ignore_patterns, "", s)
|
532 |
+
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
|
533 |
+
|
534 |
+
for pattern, replacement in self.replacers.items():
|
535 |
+
s = re.sub(pattern, replacement, s)
|
536 |
+
|
537 |
+
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
538 |
+
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
539 |
+
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
|
540 |
+
|
541 |
+
s = self.standardize_numbers(s)
|
542 |
+
s = self.standardize_spellings(s)
|
543 |
+
|
544 |
+
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
545 |
+
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
546 |
+
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
547 |
+
|
548 |
+
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
|
549 |
+
|
550 |
+
return s
|
whisper_stream/streaming_decoding.py
ADDED
@@ -0,0 +1,1198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field, replace
|
2 |
+
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch import Tensor
|
8 |
+
from torch.distributions import Categorical
|
9 |
+
|
10 |
+
from .audio import CHUNK_LENGTH
|
11 |
+
from .tokenizer import Tokenizer, get_tokenizer
|
12 |
+
from .utils import compression_ratio
|
13 |
+
|
14 |
+
if TYPE_CHECKING:
|
15 |
+
from .streaming_model import StreamingWhisper
|
16 |
+
from .model import Whisper
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass(frozen=False)
|
20 |
+
class DecodingOptions:
|
21 |
+
# whether to perform X->X "transcribe" or X->English "translate"
|
22 |
+
task: str = "transcribe"
|
23 |
+
|
24 |
+
# language that the audio is in; uses detected language if None
|
25 |
+
language: Optional[str] = None
|
26 |
+
|
27 |
+
# sampling-related options
|
28 |
+
temperature: float = 0.0
|
29 |
+
sample_len: Optional[int] = None # maximum number of tokens to sample
|
30 |
+
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
31 |
+
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
32 |
+
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
33 |
+
|
34 |
+
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
35 |
+
# to select which to return among the beams or best-of-N samples
|
36 |
+
length_penalty: Optional[float] = None
|
37 |
+
|
38 |
+
# text or tokens to feed as the prompt or the prefix; for more info:
|
39 |
+
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
40 |
+
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
41 |
+
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
42 |
+
|
43 |
+
# list of tokens ids (or comma-separated token ids) to suppress
|
44 |
+
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
45 |
+
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
46 |
+
suppress_blank: bool = True # this will suppress blank outputs
|
47 |
+
|
48 |
+
# timestamp sampling options
|
49 |
+
without_timestamps: bool = True # use <|notimestamps|> to sample text tokens only
|
50 |
+
max_initial_timestamp: Optional[float] = 1.0
|
51 |
+
|
52 |
+
# implementation details
|
53 |
+
fp16: bool = False # use fp16 for most of the calculation
|
54 |
+
|
55 |
+
# Advisor & Streaming params
|
56 |
+
advised: bool = False
|
57 |
+
attentive_advisor: bool = False
|
58 |
+
use_sa: bool = False
|
59 |
+
ctx: int = 3 # CA / SA ctx given to advisor.
|
60 |
+
gran: int = 15 # granularity of encoder embeddings (!) 15 = 300 msec (each frame equals 20msec)
|
61 |
+
pad_audio_features: bool = True
|
62 |
+
single_frame_mel: bool = True
|
63 |
+
pad_input: bool = False
|
64 |
+
|
65 |
+
look_ahead_blocks: int = 0
|
66 |
+
maximal_seconds_context: int = 30 # after 30 seconds, reset the mel.
|
67 |
+
|
68 |
+
use_kv_cache: bool = False
|
69 |
+
use_ca_kv_cache: bool = False
|
70 |
+
|
71 |
+
# streaming decoding args
|
72 |
+
stream_decode: bool = True
|
73 |
+
tokens_per_frame: int = 2
|
74 |
+
n_tokens_look_back: int = 2
|
75 |
+
streaming_timestamps: bool = True
|
76 |
+
wait_for_all: bool = False
|
77 |
+
force_first_tokens_timestamps: bool = False
|
78 |
+
|
79 |
+
|
80 |
+
@dataclass(frozen=False)
|
81 |
+
class DecodingResult:
|
82 |
+
audio_features: Tensor
|
83 |
+
language: str
|
84 |
+
language_probs: Optional[Dict[str, float]] = None
|
85 |
+
tokens: List[int] = field(default_factory=list)
|
86 |
+
text: str = ""
|
87 |
+
avg_logprob: float = np.nan
|
88 |
+
no_speech_prob: float = np.nan
|
89 |
+
temperature: float = np.nan
|
90 |
+
compression_ratio: float = np.nan
|
91 |
+
timestamps: Dict[Tuple, Tuple] = field(default_factory=tuple)
|
92 |
+
timed_tokens: List[int] = field(default_factory=list)
|
93 |
+
timed_text: str = ""
|
94 |
+
|
95 |
+
|
96 |
+
class Inference:
|
97 |
+
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
98 |
+
"""Perform a forward pass on the decoder and return per-token logits"""
|
99 |
+
raise NotImplementedError
|
100 |
+
|
101 |
+
def rearrange_kv_cache(self, source_indices) -> None:
|
102 |
+
"""Update the key-value cache according to the updated beams"""
|
103 |
+
raise NotImplementedError
|
104 |
+
|
105 |
+
def cleanup_caching(self) -> None:
|
106 |
+
"""Clean up any resources or hooks after decoding is finished"""
|
107 |
+
pass
|
108 |
+
|
109 |
+
def flush_tokens_from_cache(self) -> None:
|
110 |
+
"""flush irrelevant tokens from cache during streaming"""
|
111 |
+
pass
|
112 |
+
|
113 |
+
|
114 |
+
class PyTorchInference(Inference):
|
115 |
+
def __init__(self, model: "StreamingWhisper", initial_token_length: int, use_kv_cache: bool = False, dump_type: str = None, n_tokens_look_back: int = 2, use_beam: bool = False):
|
116 |
+
self.model: "StreamingWhisper" = model
|
117 |
+
self.initial_token_length = initial_token_length
|
118 |
+
self.hooks = []
|
119 |
+
# custom
|
120 |
+
self.use_kv_cache = use_kv_cache
|
121 |
+
self.kv_cache = {} if use_kv_cache else None
|
122 |
+
self.dump_type = dump_type
|
123 |
+
self.n_tokens_look_back = n_tokens_look_back
|
124 |
+
self.cached_logits = None
|
125 |
+
self.use_beam = use_beam
|
126 |
+
|
127 |
+
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
128 |
+
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
129 |
+
self.kv_modules = key_modules + value_modules
|
130 |
+
|
131 |
+
def logits(self, tokens: Tensor, audio_features: Tensor, first_prediction: bool = False, beam_indices: list[list] = None) -> Tensor:
|
132 |
+
|
133 |
+
if not self.kv_cache and self.use_kv_cache:
|
134 |
+
self.kv_cache, self.hooks = self.model.install_decoder_kv_cache_hooks()
|
135 |
+
|
136 |
+
if tokens.shape[-1] > self.initial_token_length and self.use_kv_cache:
|
137 |
+
# only need to use the last token except in the first forward pass
|
138 |
+
if not self.use_beam:
|
139 |
+
tokens = tokens[:, -self.n_tokens_look_back:] if first_prediction else tokens[:, -1:]
|
140 |
+
else:
|
141 |
+
n_beams = tokens.shape[0]
|
142 |
+
self.kv_cache["beam_indices"] = beam_indices # an elegant way to send it to decoder ?
|
143 |
+
tokens = tokens[beam_indices[0], beam_indices[1]].view(n_beams, -1)
|
144 |
+
|
145 |
+
|
146 |
+
return self._concat_logits_if_needed(self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache))
|
147 |
+
|
148 |
+
def _concat_logits_if_needed(self, logits: Tensor):
|
149 |
+
if not self.use_kv_cache: return logits
|
150 |
+
|
151 |
+
if self.cached_logits is None:
|
152 |
+
self.cached_logits = logits
|
153 |
+
|
154 |
+
return logits
|
155 |
+
|
156 |
+
if not self.use_beam: # greedy
|
157 |
+
self.cached_logits = torch.cat([self.cached_logits, logits], dim=1)
|
158 |
+
return self.cached_logits
|
159 |
+
|
160 |
+
# beam kv_cache
|
161 |
+
n_beams, n_ctx, n_vocab = logits.shape
|
162 |
+
|
163 |
+
for i, (beam, index, output_index) in enumerate(zip(*self.kv_cache["beam_indices"])):
|
164 |
+
if index < self.cached_logits.shape[1]:
|
165 |
+
self.cached_logits[beam, index] = logits[beam, output_index]
|
166 |
+
else:
|
167 |
+
self.cached_logits = torch.cat([self.cached_logits, logits[:, (output_index):]], dim=1)
|
168 |
+
|
169 |
+
return self.cached_logits
|
170 |
+
|
171 |
+
def cleanup_caching(self):
|
172 |
+
if not self.use_kv_cache:
|
173 |
+
return
|
174 |
+
|
175 |
+
for hook in self.hooks:
|
176 |
+
hook.remove()
|
177 |
+
|
178 |
+
del self.kv_cache
|
179 |
+
del self.hooks
|
180 |
+
self.kv_cache = {}
|
181 |
+
self.hooks = []
|
182 |
+
|
183 |
+
def flush_tokens_from_cache(self):
|
184 |
+
for key in self.kv_cache.keys():
|
185 |
+
if key == "beam_indices": continue
|
186 |
+
self.kv_cache[key] = self.kv_cache[key][:, :-self.n_tokens_look_back].detach()
|
187 |
+
|
188 |
+
self.cached_logits = self.cached_logits[:, :-self.n_tokens_look_back].detach()
|
189 |
+
|
190 |
+
def rearrange_kv_cache(self, source_indices):
|
191 |
+
if not self.use_kv_cache: return
|
192 |
+
|
193 |
+
if source_indices != list(range(len(source_indices))):
|
194 |
+
|
195 |
+
for module in self.kv_modules:
|
196 |
+
# update the key/value cache to contain the selected sequences
|
197 |
+
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
|
198 |
+
|
199 |
+
self.cached_logits = self.cached_logits[source_indices].detach()
|
200 |
+
|
201 |
+
|
202 |
+
class SequenceRanker:
|
203 |
+
def rank(
|
204 |
+
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
|
205 |
+
) -> List[int]:
|
206 |
+
"""
|
207 |
+
Given a list of groups of samples and their cumulative log probabilities,
|
208 |
+
return the indices of the samples in each group to select as the final result
|
209 |
+
"""
|
210 |
+
raise NotImplementedError
|
211 |
+
|
212 |
+
|
213 |
+
class MaximumLikelihoodRanker(SequenceRanker):
|
214 |
+
"""
|
215 |
+
Select the sample with the highest log probabilities, penalized using either
|
216 |
+
a simple length normalization or Google NMT paper's length penalty
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self, length_penalty: Optional[float]):
|
220 |
+
self.length_penalty = length_penalty
|
221 |
+
|
222 |
+
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
223 |
+
def scores(logprobs, lengths):
|
224 |
+
result = []
|
225 |
+
for logprob, length in zip(logprobs, lengths):
|
226 |
+
if self.length_penalty is None:
|
227 |
+
penalty = length
|
228 |
+
else:
|
229 |
+
# from the Google NMT paper
|
230 |
+
penalty = ((5 + length) / 6) ** self.length_penalty
|
231 |
+
result.append(logprob / penalty)
|
232 |
+
return result
|
233 |
+
|
234 |
+
# get the sequence with the highest score
|
235 |
+
lengths = [[len(t) for t in s] for s in tokens]
|
236 |
+
|
237 |
+
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
238 |
+
|
239 |
+
|
240 |
+
class TokenDecoder:
|
241 |
+
def reset(self):
|
242 |
+
"""Initialize any stateful variables for decoding a new sequence"""
|
243 |
+
|
244 |
+
def update(
|
245 |
+
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
246 |
+
) -> Tuple[Tensor, bool]:
|
247 |
+
"""Specify how to select the next token, based on the current trace and logits
|
248 |
+
|
249 |
+
Parameters
|
250 |
+
----------
|
251 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
252 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
253 |
+
|
254 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
255 |
+
per-token logits of the probability distribution at the current step
|
256 |
+
|
257 |
+
sum_logprobs : Tensor, shape = (n_batch)
|
258 |
+
cumulative log probabilities for each sequence
|
259 |
+
|
260 |
+
Returns
|
261 |
+
-------
|
262 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
263 |
+
the tokens, appended with the selected next token
|
264 |
+
|
265 |
+
completed : bool
|
266 |
+
True if all sequences has reached the end of text
|
267 |
+
|
268 |
+
"""
|
269 |
+
raise NotImplementedError
|
270 |
+
|
271 |
+
def finalize(
|
272 |
+
self, tokens: Tensor, sum_logprobs: Tensor
|
273 |
+
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
274 |
+
"""Finalize search and return the final candidate sequences
|
275 |
+
|
276 |
+
Parameters
|
277 |
+
----------
|
278 |
+
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
279 |
+
all tokens in the context so far, including the prefix and sot_sequence
|
280 |
+
|
281 |
+
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
282 |
+
cumulative log probabilities for each sequence
|
283 |
+
|
284 |
+
Returns
|
285 |
+
-------
|
286 |
+
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
287 |
+
sequence of Tensors containing candidate token sequences, for each audio input
|
288 |
+
|
289 |
+
sum_logprobs : List[List[float]], length = n_audio
|
290 |
+
sequence of cumulative log probabilities corresponding to the above
|
291 |
+
|
292 |
+
"""
|
293 |
+
raise NotImplementedError
|
294 |
+
|
295 |
+
|
296 |
+
class GreedyDecoder(TokenDecoder):
|
297 |
+
def __init__(self, temperature: float, eot: int):
|
298 |
+
self.temperature = temperature
|
299 |
+
self.eot = eot
|
300 |
+
|
301 |
+
def update(
|
302 |
+
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
303 |
+
) -> Tuple[Tensor, bool]:
|
304 |
+
if self.temperature == 0:
|
305 |
+
next_tokens = logits.argmax(dim=-1)
|
306 |
+
else:
|
307 |
+
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
308 |
+
|
309 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
310 |
+
|
311 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
312 |
+
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
313 |
+
|
314 |
+
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
315 |
+
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
316 |
+
|
317 |
+
completed = (tokens[:, -1] == self.eot).all()
|
318 |
+
return tokens, completed
|
319 |
+
|
320 |
+
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
321 |
+
# make sure each sequence has at least one EOT token at the end
|
322 |
+
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
323 |
+
return tokens, sum_logprobs.tolist()
|
324 |
+
|
325 |
+
|
326 |
+
class BeamSearchDecoder(TokenDecoder):
|
327 |
+
def __init__(
|
328 |
+
self,
|
329 |
+
beam_size: int,
|
330 |
+
eot: int,
|
331 |
+
inference: Inference,
|
332 |
+
patience: Optional[float] = None,
|
333 |
+
):
|
334 |
+
self.beam_size = beam_size
|
335 |
+
self.eot = eot
|
336 |
+
self.inference = inference
|
337 |
+
self.patience = patience or 1.0
|
338 |
+
self.max_candidates: int = round(beam_size * self.patience)
|
339 |
+
self.finished_sequences = None
|
340 |
+
|
341 |
+
assert (
|
342 |
+
self.max_candidates > 0
|
343 |
+
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
344 |
+
|
345 |
+
def reset(self):
|
346 |
+
self.finished_sequences = None
|
347 |
+
|
348 |
+
def update(
|
349 |
+
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
350 |
+
) -> Tuple[Tensor, bool]:
|
351 |
+
if tokens.shape[0] % self.beam_size != 0:
|
352 |
+
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
353 |
+
|
354 |
+
n_audio = tokens.shape[0] // self.beam_size
|
355 |
+
if self.finished_sequences is None: # for the first update
|
356 |
+
self.finished_sequences = [{} for _ in range(n_audio)]
|
357 |
+
|
358 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
359 |
+
next_tokens, source_indices, finished_sequences = [], [], []
|
360 |
+
for i in range(n_audio): # in our case n_audio = 1
|
361 |
+
scores, sources, finished = {}, {}, {}
|
362 |
+
|
363 |
+
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
364 |
+
for j in range(self.beam_size): # go over each trajectory of beam
|
365 |
+
idx = i * self.beam_size + j
|
366 |
+
prefix = tokens[idx].tolist()
|
367 |
+
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
368 |
+
new_logprob = (sum_logprobs[idx] + logprob).item()
|
369 |
+
sequence = tuple(prefix + [token.item()])
|
370 |
+
scores[sequence] = new_logprob
|
371 |
+
sources[sequence] = idx
|
372 |
+
|
373 |
+
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
374 |
+
saved = 0
|
375 |
+
for sequence in sorted(scores, key=scores.get, reverse=True):
|
376 |
+
if sequence[-1] == self.eot:
|
377 |
+
finished[sequence] = scores[sequence]
|
378 |
+
else:
|
379 |
+
sum_logprobs[len(next_tokens)] = scores[sequence]
|
380 |
+
next_tokens.append(sequence)
|
381 |
+
source_indices.append(sources[sequence])
|
382 |
+
|
383 |
+
saved += 1
|
384 |
+
if saved == self.beam_size:
|
385 |
+
break
|
386 |
+
|
387 |
+
finished_sequences.append(finished)
|
388 |
+
|
389 |
+
tokens = torch.tensor(next_tokens, device=tokens.device)
|
390 |
+
self.inference.rearrange_kv_cache(source_indices)
|
391 |
+
|
392 |
+
# add newly finished sequences to self.finished_sequences
|
393 |
+
assert len(self.finished_sequences) == len(finished_sequences)
|
394 |
+
for previously_finished, newly_finished in zip(
|
395 |
+
self.finished_sequences, finished_sequences
|
396 |
+
):
|
397 |
+
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
398 |
+
if len(previously_finished) >= self.max_candidates:
|
399 |
+
break # the candidate list is full
|
400 |
+
previously_finished[seq] = newly_finished[seq]
|
401 |
+
|
402 |
+
# mark as completed if all audio has enough number of samples
|
403 |
+
completed = all(
|
404 |
+
len(sequences) >= self.max_candidates
|
405 |
+
for sequences in self.finished_sequences
|
406 |
+
)
|
407 |
+
return tokens, completed
|
408 |
+
|
409 |
+
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
410 |
+
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
411 |
+
sum_logprobs = sum_logprobs.cpu()
|
412 |
+
for i, sequences in enumerate(self.finished_sequences):
|
413 |
+
if (
|
414 |
+
len(sequences) < self.beam_size
|
415 |
+
): # when not enough sequences are finished
|
416 |
+
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
417 |
+
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
418 |
+
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
419 |
+
if len(sequences) >= self.beam_size:
|
420 |
+
break
|
421 |
+
|
422 |
+
tokens: List[List[Tensor]] = [
|
423 |
+
[torch.tensor(seq) for seq in sequences.keys()]
|
424 |
+
for sequences in self.finished_sequences
|
425 |
+
]
|
426 |
+
sum_logprobs: List[List[float]] = [
|
427 |
+
list(sequences.values()) for sequences in self.finished_sequences
|
428 |
+
]
|
429 |
+
return tokens, sum_logprobs
|
430 |
+
|
431 |
+
|
432 |
+
class StreamingDecoder(TokenDecoder):
|
433 |
+
def __init__(self,
|
434 |
+
temperature: float,
|
435 |
+
tokens_per_frame: int,
|
436 |
+
eot: int,
|
437 |
+
inference: Inference,
|
438 |
+
n_tokens_look_back: int = 2,
|
439 |
+
streaming_timestamps: bool = False
|
440 |
+
):
|
441 |
+
self.tokens_per_frame = tokens_per_frame
|
442 |
+
self.eot = eot
|
443 |
+
self.inference = inference
|
444 |
+
self.last_logits: Tensor = None
|
445 |
+
self.temperature = temperature
|
446 |
+
self.tokens_look_back = n_tokens_look_back
|
447 |
+
self.check_token_index = -n_tokens_look_back - 1
|
448 |
+
self.streaming_timestamps = streaming_timestamps
|
449 |
+
self.timestamps_map = {-1: 0}
|
450 |
+
self._reset_timestamps()
|
451 |
+
|
452 |
+
def _insert_timestamps(self, audio_features: Tensor, tokens: Tensor, enc_emb_gran: int):
|
453 |
+
"""
|
454 |
+
Given the tokens and stable lists.
|
455 |
+
If we send complete, it means that we should put a timestamp on the tokens.
|
456 |
+
"""
|
457 |
+
# Given T, and a gran g, forward pass T/g times.
|
458 |
+
# Observe each vector, and try to find swap points for each token.
|
459 |
+
# Overall, extra T/g forward passes.
|
460 |
+
examined_token_index = 3 # this index is the first non initial token index to be predcited
|
461 |
+
for i in range(enc_emb_gran, audio_features.shape[-1] + 1, enc_emb_gran):
|
462 |
+
i_logits = self.inference.logits(tokens, audio_features[:, :i])
|
463 |
+
|
464 |
+
if tokens[0, examined_token_index] in i_logits.argmax(dim=-1): # stamp of this token.
|
465 |
+
self.timestamps_map[examined_token_index - 3] = i * 20
|
466 |
+
examined_token_index += 1
|
467 |
+
|
468 |
+
if examined_token_index >= tokens.shape[-1]: break
|
469 |
+
|
470 |
+
def _check_last_tokens(self, logits: Tensor, tokens: Tensor, next_tokens: Tensor, check_tokens: bool):
|
471 |
+
stable = []
|
472 |
+
|
473 |
+
if not check_tokens:
|
474 |
+
return stable, tokens
|
475 |
+
|
476 |
+
examine_tokens_indices = range(tokens.shape[1] + self.check_token_index, tokens.shape[1] - 1)
|
477 |
+
for i in examine_tokens_indices:
|
478 |
+
|
479 |
+
token_index = i + 1
|
480 |
+
token_prob_index = i
|
481 |
+
|
482 |
+
examined_token = tokens[:, token_index] # This is the predicted token for this index.
|
483 |
+
|
484 |
+
# It means that the prediction is stabilizing. We can move on, check next token.
|
485 |
+
# or - Prob is down, but token has still highest prob. Continue.
|
486 |
+
if self.last_logits[:, token_prob_index, examined_token] <= logits[:, token_prob_index, examined_token] \
|
487 |
+
or \
|
488 |
+
(self.last_logits[:, token_prob_index, examined_token] > logits[:, token_prob_index, examined_token] and next_tokens[:, token_prob_index] == examined_token):
|
489 |
+
stable.append(True)
|
490 |
+
continue
|
491 |
+
|
492 |
+
else:
|
493 |
+
tokens = tokens[:, :token_index] # Crop next tokens - Irrelevant after flush.
|
494 |
+
stable.append(False)
|
495 |
+
break
|
496 |
+
|
497 |
+
return stable, tokens
|
498 |
+
|
499 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, check_tokens: bool = False, index: int = 0, first_frame: bool = False) -> Tuple[Tensor, bool]:
|
500 |
+
# calc next_tokens, logprob, completed
|
501 |
+
if self.temperature == 0:
|
502 |
+
next_tokens = logits.argmax(dim=-1)
|
503 |
+
else:
|
504 |
+
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
505 |
+
|
506 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
507 |
+
completed = next_tokens[0, -1] == self.eot
|
508 |
+
|
509 |
+
# means we are at the beginning, append tokens greedily.
|
510 |
+
if first_frame:
|
511 |
+
if next_tokens[0, -1] != self.eot:
|
512 |
+
tokens = torch.cat([tokens, next_tokens[None, :, -1]], dim=-1)
|
513 |
+
self.last_logits = logits
|
514 |
+
return tokens, completed
|
515 |
+
|
516 |
+
# Otherwise, check if last tokens are stable.
|
517 |
+
stable, tokens = self._check_last_tokens(logits, tokens, next_tokens, check_tokens)
|
518 |
+
|
519 |
+
if all(stable) and next_tokens[0, -1] != self.eot:
|
520 |
+
tokens = torch.cat([tokens, next_tokens[None, :, -1]], dim=-1)
|
521 |
+
|
522 |
+
if tokens.shape[-1] - 1 - 4 not in self.timestamps_map.keys(): # mark start of timestamp
|
523 |
+
self.timestamps_map[tokens.shape[-1] - 1 - 4] = index - 1
|
524 |
+
|
525 |
+
sum_logprobs += logprobs[:, -1, next_tokens[0, -1]]
|
526 |
+
|
527 |
+
# take last tokens logits to compare on next decode step
|
528 |
+
self.last_logits = logits
|
529 |
+
|
530 |
+
return tokens, completed
|
531 |
+
|
532 |
+
def _reset_timestamps(self):
|
533 |
+
self.timestamp_tokens_indices = [None, None]
|
534 |
+
self.timestamps_indices = [None, None]
|
535 |
+
|
536 |
+
def reset(self):
|
537 |
+
self.check_token_index = -self.tokens_look_back - 1
|
538 |
+
|
539 |
+
def finalize(self, tokens, sum_logprobs):
|
540 |
+
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
541 |
+
return tokens, sum_logprobs.tolist()
|
542 |
+
|
543 |
+
|
544 |
+
class BeamStreamingDecoder(TokenDecoder):
|
545 |
+
def __init__(self,
|
546 |
+
temperature: float,
|
547 |
+
tokens_per_frame: int,
|
548 |
+
eot: int,
|
549 |
+
inference: Inference,
|
550 |
+
n_tokens_look_back: int = 2,
|
551 |
+
n_beams: int = 1,
|
552 |
+
pad_token: int = None,
|
553 |
+
wait_for_all: bool = False,
|
554 |
+
):
|
555 |
+
self.tokens_per_frame = tokens_per_frame
|
556 |
+
self.eot = eot
|
557 |
+
self.pad_token = pad_token # sotlm token
|
558 |
+
self.inference = inference
|
559 |
+
self.last_logits: Tensor = None
|
560 |
+
self.temperature = temperature
|
561 |
+
self.tokens_look_back = n_tokens_look_back
|
562 |
+
self.check_token_index = [4 + 1 for _ in range(n_beams)] # Here I'll save the last index that is relevant for checking.
|
563 |
+
self.n_beams = n_beams
|
564 |
+
self.sum_logprobs = torch.zeros(n_beams)
|
565 |
+
self.timestamps_map = {}
|
566 |
+
self.wait_for_all = wait_for_all
|
567 |
+
self.finished_sequences = {}
|
568 |
+
|
569 |
+
def _get_last_valid_token_index(self, prefix: Tensor):
|
570 |
+
indices = torch.where((prefix == self.eot) | (prefix == self.pad_token))[0]
|
571 |
+
last_valid_token_index = (prefix.shape[-1] - 1) if indices.shape == torch.Size([0]) else (indices.min().item() - 1)
|
572 |
+
last_valid_token_index = max(last_valid_token_index, 3)
|
573 |
+
|
574 |
+
return last_valid_token_index
|
575 |
+
|
576 |
+
def _check_last_tokens(self, prefix: Tensor, logits: Tensor, check_tokens: bool):
|
577 |
+
"""
|
578 |
+
given the last k logits, and the tokens, determine where we should proceed from.
|
579 |
+
prefix - a list with the tokens, including padding tokens and eot. Start from eot / pad token index. of shape (l)
|
580 |
+
logits - of shape (l, n_vocab). The logits of the specific beam.
|
581 |
+
"""
|
582 |
+
last_valid_token_index = self._get_last_valid_token_index(prefix)
|
583 |
+
|
584 |
+
if not check_tokens:
|
585 |
+
return last_valid_token_index, True
|
586 |
+
|
587 |
+
examined_prob_indices = range(max(last_valid_token_index - self.tokens_look_back, 3), last_valid_token_index)
|
588 |
+
for examined_prob_index in examined_prob_indices:
|
589 |
+
|
590 |
+
examined_token_index = examined_prob_index + 1
|
591 |
+
examined_token = prefix[examined_token_index]
|
592 |
+
examined_candidates = logits[examined_prob_index].topk(self.n_beams).indices
|
593 |
+
|
594 |
+
if examined_token not in examined_candidates: # Means the last predicted token is out of out topk.
|
595 |
+
# pad with irrelevant tokens
|
596 |
+
prefix[examined_token_index:] = self.pad_token
|
597 |
+
return examined_prob_index, False
|
598 |
+
|
599 |
+
# If nothing was returned during the check, it means that all tokens were stable.
|
600 |
+
# Continue decoding from the last valid token index.
|
601 |
+
return last_valid_token_index, True
|
602 |
+
|
603 |
+
def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, first_frame: bool = False, check_tokens: bool = False) -> Tuple[Tensor, bool]:
|
604 |
+
if tokens.shape[0] % self.n_beams != 0:
|
605 |
+
raise ValueError(f"{tokens.shape}[0] % {self.n_beams} != 0")
|
606 |
+
|
607 |
+
scores, sources, finished = {}, {}, {}
|
608 |
+
next_tokens, source_indices, finished_sequences = [], [], []
|
609 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
610 |
+
|
611 |
+
for beam in range(self.n_beams):
|
612 |
+
|
613 |
+
prefix = tokens[beam] # tokens in this beam.
|
614 |
+
sampling_index, stable = self._check_last_tokens(prefix, logits[beam], check_tokens) if not first_frame else ((tokens.shape[-1] - 1), False)
|
615 |
+
prefix = prefix.tolist() # using list, easier to append to.
|
616 |
+
# Calculate candidates from the last token index we should check.
|
617 |
+
for logprob, token in zip(*logprobs[beam, sampling_index].topk(self.n_beams + 1)):
|
618 |
+
new_logprob = (logprobs[beam, range(3, sampling_index - 1), tokens[beam, 4:sampling_index]].sum() + logprob).item()
|
619 |
+
|
620 |
+
token_index = sampling_index + 1
|
621 |
+
if token_index == len(prefix):
|
622 |
+
prefix.append(token.item())
|
623 |
+
else:
|
624 |
+
prefix[token_index] = token.item()
|
625 |
+
|
626 |
+
sequence = prefix[:prefix.index(self.pad_token)] if self.pad_token in prefix else prefix
|
627 |
+
sequence = tuple(sequence)
|
628 |
+
scores[sequence] = new_logprob
|
629 |
+
sources[sequence] = beam
|
630 |
+
|
631 |
+
# After all beams were checked, and tokens were calculated
|
632 |
+
# Get top n_beams sequences.
|
633 |
+
saved = 0
|
634 |
+
for sequence in sorted(scores, key=scores.get, reverse=True):
|
635 |
+
if self.wait_for_all and sequence[-1] == self.eot:
|
636 |
+
finished[sequence] = scores[sequence]
|
637 |
+
else:
|
638 |
+
sum_logprobs[len(next_tokens)] = scores[sequence]
|
639 |
+
next_tokens.append(Tensor(sequence).long())
|
640 |
+
source_indices.append(sources[sequence])
|
641 |
+
|
642 |
+
saved += 1
|
643 |
+
if saved == self.n_beams:
|
644 |
+
break
|
645 |
+
|
646 |
+
tokens = torch.nn.utils.rnn.pad_sequence(next_tokens, batch_first=True, padding_value=self.pad_token).to(tokens.device)
|
647 |
+
self.inference.rearrange_kv_cache(source_indices)
|
648 |
+
|
649 |
+
if not self.wait_for_all: # greedy stop mode.
|
650 |
+
completed = any([self.eot in s for s in next_tokens]) # Greedy stop - Believe any beam that says enough.
|
651 |
+
return tokens, completed
|
652 |
+
|
653 |
+
for sequence in sorted(finished, key=finished.get, reverse=True):
|
654 |
+
if len(self.finished_sequences) >= self.n_beams: break
|
655 |
+
self.finished_sequences[sequence] = finished[sequence]
|
656 |
+
|
657 |
+
# we have enough trajectories that reached EOT, in this specific frame.
|
658 |
+
completed = len(self.finished_sequences) >= self.n_beams
|
659 |
+
return tokens, completed
|
660 |
+
|
661 |
+
def reset(self):
|
662 |
+
self.check_token_index = -self.tokens_look_back
|
663 |
+
|
664 |
+
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
665 |
+
if not self.wait_for_all:
|
666 |
+
preceding_tokens = F.pad(preceding_tokens, (0, 1), value=self.eot)
|
667 |
+
return preceding_tokens, sum_logprobs.tolist()
|
668 |
+
|
669 |
+
sum_logprobs = sum_logprobs.cpu()
|
670 |
+
if len(self.finished_sequences) < self.n_beams:
|
671 |
+
for j in list(np.argsort(sum_logprobs[0]))[::-1]:
|
672 |
+
sequence = preceding_tokens[0][j].tolist() + [self.eot]
|
673 |
+
self.finished_sequences[tuple(sequence)] = sum_logprobs[0][j].item()
|
674 |
+
if len(self.finished_sequences) >= self.n_beams: break
|
675 |
+
|
676 |
+
tokens: List[List[Tensor]] = [
|
677 |
+
[torch.tensor(seq) for seq in self.finished_sequences.keys()]
|
678 |
+
]
|
679 |
+
|
680 |
+
sum_logprobs: List[List[float]] = [
|
681 |
+
list(self.finished_sequences.values())
|
682 |
+
]
|
683 |
+
|
684 |
+
return tokens, sum_logprobs
|
685 |
+
|
686 |
+
|
687 |
+
class LogitFilter:
|
688 |
+
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
689 |
+
"""Apply any filtering or masking to logits in-place
|
690 |
+
|
691 |
+
Parameters
|
692 |
+
----------
|
693 |
+
logits : Tensor, shape = (n_batch, vocab_size)
|
694 |
+
per-token logits of the probability distribution at the current step
|
695 |
+
|
696 |
+
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
697 |
+
all tokens in the context so far, including the prefix and sot_sequence tokens
|
698 |
+
|
699 |
+
"""
|
700 |
+
raise NotImplementedError
|
701 |
+
|
702 |
+
|
703 |
+
class SuppressBlank(LogitFilter):
|
704 |
+
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
705 |
+
self.tokenizer = tokenizer
|
706 |
+
self.sample_begin = sample_begin
|
707 |
+
|
708 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
709 |
+
if tokens.shape[1] == self.sample_begin:
|
710 |
+
logits[..., self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
711 |
+
|
712 |
+
|
713 |
+
class SuppressTokens(LogitFilter):
|
714 |
+
def __init__(self, suppress_tokens: Sequence[int]):
|
715 |
+
self.suppress_tokens = list(suppress_tokens)
|
716 |
+
|
717 |
+
def apply(self, logits: Tensor, tokens: Tensor):
|
718 |
+
logits[..., self.suppress_tokens] = -np.inf
|
719 |
+
|
720 |
+
|
721 |
+
class DecodingTask:
|
722 |
+
inference: Inference
|
723 |
+
sequence_ranker: SequenceRanker
|
724 |
+
decoder: TokenDecoder
|
725 |
+
logit_filters: List[LogitFilter]
|
726 |
+
|
727 |
+
def __init__(self, model: "StreamingWhisper", options: DecodingOptions):
|
728 |
+
self.model = model
|
729 |
+
|
730 |
+
language = options.language or "en"
|
731 |
+
tokenizer = get_tokenizer(
|
732 |
+
model.is_multilingual,
|
733 |
+
num_languages=model.num_languages,
|
734 |
+
language=language,
|
735 |
+
task=options.task,
|
736 |
+
)
|
737 |
+
self.tokenizer: Tokenizer = tokenizer
|
738 |
+
self.options: DecodingOptions = self._verify_options(options)
|
739 |
+
|
740 |
+
self.n_group: int = options.beam_size or options.best_of or 1
|
741 |
+
self.n_ctx: int = model.dims.n_text_ctx
|
742 |
+
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
743 |
+
|
744 |
+
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
745 |
+
if self.options.without_timestamps:
|
746 |
+
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
747 |
+
|
748 |
+
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
749 |
+
self.sample_begin: int = len(self.initial_tokens)
|
750 |
+
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
751 |
+
|
752 |
+
# inference: implements the forward pass through the decoder, including kv caching
|
753 |
+
dump_type = "ATT" if self.options.attentive_advisor else None
|
754 |
+
self.inference = PyTorchInference(model, len(self.initial_tokens), self.options.use_kv_cache, dump_type=dump_type, n_tokens_look_back=options.n_tokens_look_back, use_beam=options.beam_size > 0)
|
755 |
+
|
756 |
+
# sequence ranker: implements how to rank a group of sampled sequences
|
757 |
+
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
758 |
+
|
759 |
+
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
760 |
+
if options.stream_decode and options.beam_size in [None, 0]:
|
761 |
+
self.decoder = StreamingDecoder(
|
762 |
+
options.temperature, options.tokens_per_frame, tokenizer.eot, self.inference, options.n_tokens_look_back, options.streaming_timestamps
|
763 |
+
)
|
764 |
+
elif options.stream_decode and options.beam_size > 0:
|
765 |
+
self.decoder = BeamStreamingDecoder(
|
766 |
+
options.temperature, options.tokens_per_frame, tokenizer.eot, self.inference, options.n_tokens_look_back, options.beam_size, tokenizer.sot_lm, options.wait_for_all
|
767 |
+
)
|
768 |
+
elif options.beam_size is not None:
|
769 |
+
self.decoder = BeamSearchDecoder(
|
770 |
+
options.beam_size, tokenizer.eot, self.inference, options.patience
|
771 |
+
)
|
772 |
+
else:
|
773 |
+
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
774 |
+
|
775 |
+
# logit filters: applies various rules to suppress or penalize certain tokens
|
776 |
+
self.logit_filters = []
|
777 |
+
if self.options.suppress_blank:
|
778 |
+
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
779 |
+
if self.options.suppress_tokens:
|
780 |
+
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
781 |
+
|
782 |
+
self.mel = None
|
783 |
+
self.index = 0
|
784 |
+
|
785 |
+
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
786 |
+
self.tokens: Tensor = torch.tensor([self.initial_tokens]) # no need to use repeat, batch is meaningless in stream.
|
787 |
+
self.tokens = self.tokens.repeat_interleave(self.n_group, dim=0).to(model.device)
|
788 |
+
|
789 |
+
n_batch = self.n_group # streaming, batch_size larger than 1 is not allowed. only beam size is relevant
|
790 |
+
self.sum_logprobs: Tensor = torch.zeros(n_batch, device=model.device)
|
791 |
+
self.no_speech_probs = [np.nan] * n_batch
|
792 |
+
|
793 |
+
# Causal encoder inference
|
794 |
+
self._init_enc_kv_caching()
|
795 |
+
|
796 |
+
if options.use_ca_kv_cache:
|
797 |
+
self._init_ca_kv_caching()
|
798 |
+
|
799 |
+
self.audio_features = torch.zeros((1, model.dims.n_audio_ctx, model.dims.n_audio_state)).to(model.device)
|
800 |
+
self.frame_counter = 0
|
801 |
+
|
802 |
+
def _init_enc_kv_caching(self):
|
803 |
+
self.enc_kv_cache, self.enc_hooks = self.model.install_encoder_kv_cache_hooks()
|
804 |
+
|
805 |
+
def _init_ca_kv_caching(self):
|
806 |
+
self.ca_kv_cache, self.dec_ca_hooks = self.model.install_cross_attn_kv_cache_hooks()
|
807 |
+
|
808 |
+
def _cleanup_encoder_caching(self):
|
809 |
+
for hook in self.enc_hooks:
|
810 |
+
hook.remove()
|
811 |
+
|
812 |
+
del self.enc_kv_cache
|
813 |
+
del self.enc_hooks
|
814 |
+
self.enc_kv_cache = {}
|
815 |
+
self.enc_hooks = []
|
816 |
+
|
817 |
+
def _cleanup_ca_caching(self):
|
818 |
+
if not self.options.use_ca_kv_cache:
|
819 |
+
return
|
820 |
+
|
821 |
+
for hook in self.dec_ca_hooks:
|
822 |
+
hook.remove()
|
823 |
+
|
824 |
+
del self.ca_kv_cache
|
825 |
+
del self.dec_ca_hooks
|
826 |
+
self.ca_kv_cache = {}
|
827 |
+
self.dec_ca_hooks = []
|
828 |
+
|
829 |
+
def __del__(self):
|
830 |
+
self._cleanup_encoder_caching()
|
831 |
+
self._cleanup_ca_caching()
|
832 |
+
self.inference.cleanup_caching()
|
833 |
+
|
834 |
+
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
835 |
+
if options.beam_size is not None and options.best_of is not None:
|
836 |
+
raise ValueError("beam_size and best_of can't be given together")
|
837 |
+
if options.temperature == 0:
|
838 |
+
if options.best_of is not None:
|
839 |
+
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
840 |
+
if options.patience is not None and options.beam_size is None:
|
841 |
+
raise ValueError("patience requires beam_size to be given")
|
842 |
+
if options.length_penalty is not None and not (
|
843 |
+
0 <= options.length_penalty <= 1
|
844 |
+
):
|
845 |
+
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
846 |
+
|
847 |
+
return options
|
848 |
+
|
849 |
+
def _get_initial_tokens(self) -> Tuple[int]:
|
850 |
+
tokens = list(self.sot_sequence)
|
851 |
+
|
852 |
+
if prefix := self.options.prefix:
|
853 |
+
prefix_tokens = (
|
854 |
+
self.tokenizer.encode(" " + prefix.strip())
|
855 |
+
if isinstance(prefix, str)
|
856 |
+
else prefix
|
857 |
+
)
|
858 |
+
if self.sample_len is not None:
|
859 |
+
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
860 |
+
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
861 |
+
tokens = tokens + prefix_tokens
|
862 |
+
|
863 |
+
if prompt := self.options.prompt:
|
864 |
+
prompt_tokens = (
|
865 |
+
self.tokenizer.encode(" " + prompt.strip())
|
866 |
+
if isinstance(prompt, str)
|
867 |
+
else prompt
|
868 |
+
)
|
869 |
+
tokens = (
|
870 |
+
[self.tokenizer.sot_prev]
|
871 |
+
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
872 |
+
+ tokens
|
873 |
+
)
|
874 |
+
|
875 |
+
return tuple(tokens)
|
876 |
+
|
877 |
+
def _get_suppress_tokens(self) -> Tuple[int]:
|
878 |
+
suppress_tokens = self.options.suppress_tokens
|
879 |
+
|
880 |
+
if isinstance(suppress_tokens, str):
|
881 |
+
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
882 |
+
|
883 |
+
if -1 in suppress_tokens:
|
884 |
+
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
885 |
+
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
886 |
+
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
887 |
+
suppress_tokens = [] # interpret empty string as an empty list
|
888 |
+
else:
|
889 |
+
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
890 |
+
|
891 |
+
suppress_tokens.extend(
|
892 |
+
[
|
893 |
+
self.tokenizer.transcribe,
|
894 |
+
self.tokenizer.translate,
|
895 |
+
self.tokenizer.sot,
|
896 |
+
self.tokenizer.sot_prev,
|
897 |
+
self.tokenizer.sot_lm,
|
898 |
+
]
|
899 |
+
)
|
900 |
+
if self.tokenizer.no_speech is not None:
|
901 |
+
# no-speech probability is collected separately
|
902 |
+
suppress_tokens.append(self.tokenizer.no_speech)
|
903 |
+
|
904 |
+
return tuple(sorted(set(suppress_tokens)))
|
905 |
+
|
906 |
+
def _get_audio_features(self, mel: Tensor, index: list = None):
|
907 |
+
|
908 |
+
if self.options.fp16:
|
909 |
+
mel = mel.half()
|
910 |
+
|
911 |
+
audio_features: Tensor = self.model.encoder(mel, kv_cache=self.enc_kv_cache, mask=None)
|
912 |
+
|
913 |
+
if audio_features.dtype != (
|
914 |
+
torch.float16 if self.options.fp16 else torch.float32
|
915 |
+
):
|
916 |
+
return TypeError(
|
917 |
+
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
918 |
+
)
|
919 |
+
|
920 |
+
# Usually will run with this config
|
921 |
+
if self.options.use_ca_kv_cache:
|
922 |
+
return audio_features
|
923 |
+
|
924 |
+
# update audio_features
|
925 |
+
end_index = (self.mel.shape[-1] // 2) - 1
|
926 |
+
|
927 |
+
if end_index == (self.model.encoder.gran * (1 + self.model.encoder.extra_gran_blocks)):
|
928 |
+
start_index = 0
|
929 |
+
else:
|
930 |
+
start_index = end_index - self.model.encoder.gran
|
931 |
+
|
932 |
+
if start_index % self.model.encoder.gran != 0:
|
933 |
+
modolu_res = start_index % self.model.encoder.gran
|
934 |
+
steps = self.model.encoder.gran - modolu_res
|
935 |
+
start_index += steps
|
936 |
+
end_index = start_index + self.model.encoder.gran
|
937 |
+
|
938 |
+
self.audio_features[:, start_index:end_index] = audio_features
|
939 |
+
|
940 |
+
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
941 |
+
languages = [self.options.language] * audio_features.shape[0]
|
942 |
+
lang_probs = None
|
943 |
+
|
944 |
+
if self.options.language is None or self.options.task == "lang_id":
|
945 |
+
lang_tokens, lang_probs = self.model.detect_language(
|
946 |
+
audio_features, self.tokenizer
|
947 |
+
)
|
948 |
+
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
949 |
+
if self.options.language is None:
|
950 |
+
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
951 |
+
|
952 |
+
return languages, lang_probs
|
953 |
+
|
954 |
+
def _set_ca_kv_cache(self, value: bool):
|
955 |
+
self.model.use_ca_cache_hook = value
|
956 |
+
|
957 |
+
def _main_loop(self, audio_features: Tensor):
|
958 |
+
"""
|
959 |
+
in streaming we use self.tokens, since we need to keep in context the last tokens we got.
|
960 |
+
"""
|
961 |
+
is_first_frame = self.index == (self.options.gran * (1 + self.options.look_ahead_blocks))
|
962 |
+
self._set_ca_kv_cache(True)
|
963 |
+
beam_indices = None
|
964 |
+
|
965 |
+
try:
|
966 |
+
for i in range(self.sample_len // 8):
|
967 |
+
|
968 |
+
if self.tokens.shape[-1] > self.n_ctx: break
|
969 |
+
|
970 |
+
if self.options.stream_decode and self.options.beam_size > 0 and self.options.use_kv_cache: # We are in beam search. kv cache
|
971 |
+
last_valid_token_indices = [self.decoder._get_last_valid_token_index(self.tokens[i]) for i in range(self.tokens.shape[0])]
|
972 |
+
beam_indices_cols = last_valid_token_indices if i > 0 else sum([list(range(index - self.options.n_tokens_look_back, index + 1)) for index in last_valid_token_indices], [])
|
973 |
+
beam_indices_rows = [i // (self.options.n_tokens_look_back + 1) for i in range(self.options.beam_size * (self.options.n_tokens_look_back + 1))] if i==0 else list(range(self.options.beam_size))
|
974 |
+
beam_indices_cols_input = [item - beam_indices_cols[(j // (len(beam_indices_cols) // self.options.beam_size)) * (len(beam_indices_cols) // self.options.beam_size)] for j, item in enumerate(beam_indices_cols)]
|
975 |
+
beam_indices = [beam_indices_rows, beam_indices_cols, beam_indices_cols_input]
|
976 |
+
|
977 |
+
logits = self.inference.logits(self.tokens, audio_features, i==0, beam_indices) # run inference through decoder
|
978 |
+
|
979 |
+
# after the first decoder forward pass, no need to cache
|
980 |
+
if i == 0: self._set_ca_kv_cache(False)
|
981 |
+
|
982 |
+
if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
|
983 |
+
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
984 |
+
self.no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
985 |
+
|
986 |
+
if isinstance(self.decoder, StreamingDecoder) or isinstance(self.decoder, BeamStreamingDecoder): # We need tokens for context
|
987 |
+
logits = logits[:, :]
|
988 |
+
else: # we need to consider the logits at the last token only
|
989 |
+
logits = logits[:, -1]
|
990 |
+
|
991 |
+
# apply the logit filters, e.g. for suppressing or applying penalty to
|
992 |
+
for logit_filter in self.logit_filters:
|
993 |
+
logit_filter.apply(logits, self.tokens)
|
994 |
+
|
995 |
+
# expand the tokens tensor with the selected next tokens
|
996 |
+
if isinstance(self.decoder, BeamStreamingDecoder):
|
997 |
+
self.tokens, completed = self.decoder.update(self.tokens, logits, self.sum_logprobs, first_frame=is_first_frame, check_tokens=(i == 0))
|
998 |
+
elif isinstance(self.decoder, StreamingDecoder):
|
999 |
+
self.tokens, completed = self.decoder.update(self.tokens, logits, self.sum_logprobs, first_frame=is_first_frame, check_tokens=(i == 0), index=(self.index * 20))
|
1000 |
+
else:
|
1001 |
+
self.tokens, completed = self.decoder.update(self.tokens, logits, self.sum_logprobs)
|
1002 |
+
|
1003 |
+
if completed: # ctx is unlimited, we are limited only by i
|
1004 |
+
break
|
1005 |
+
finally:
|
1006 |
+
if self.options.use_kv_cache:
|
1007 |
+
self.inference.flush_tokens_from_cache()
|
1008 |
+
|
1009 |
+
if isinstance(self.decoder, StreamingDecoder) or isinstance(self.decoder, BeamStreamingDecoder):
|
1010 |
+
|
1011 |
+
if is_first_frame and self.options.streaming_timestamps and self.options.force_first_tokens_timestamps:
|
1012 |
+
self.decoder._insert_timestamps(audio_features, self.tokens, self.options.gran)
|
1013 |
+
|
1014 |
+
if self.tokens.shape[1] > logits.shape[1]:
|
1015 |
+
self.tokens = self.tokens[:, :logits.shape[1]]
|
1016 |
+
|
1017 |
+
self.decoder.reset()
|
1018 |
+
|
1019 |
+
return self.sum_logprobs, self.no_speech_probs
|
1020 |
+
|
1021 |
+
def _pad_audio_features(self, audio_features: Tensor):
|
1022 |
+
to_pad = 1500 - audio_features.shape[1]
|
1023 |
+
padding = torch.zeros((1, to_pad, audio_features.shape[2])).to(audio_features.device)
|
1024 |
+
audio_features = torch.cat([audio_features, padding], dim=1).to(audio_features.device)
|
1025 |
+
return audio_features
|
1026 |
+
|
1027 |
+
def _empty_results(self):
|
1028 |
+
return DecodingResult(
|
1029 |
+
audio_features=None,
|
1030 |
+
language="en",
|
1031 |
+
tokens=[],
|
1032 |
+
text="",
|
1033 |
+
avg_logprob=None,
|
1034 |
+
no_speech_prob=None,
|
1035 |
+
temperature=self.options.temperature,
|
1036 |
+
compression_ratio=0,
|
1037 |
+
)
|
1038 |
+
|
1039 |
+
def _caching_inner_reset(self):
|
1040 |
+
# Encoder kv cache clean
|
1041 |
+
self._cleanup_encoder_caching()
|
1042 |
+
self._init_enc_kv_caching()
|
1043 |
+
|
1044 |
+
# Decoder CA kv-cache
|
1045 |
+
if self.options.use_ca_kv_cache:
|
1046 |
+
self._cleanup_ca_caching()
|
1047 |
+
self._init_ca_kv_caching()
|
1048 |
+
|
1049 |
+
# Decoder SA kv-cache
|
1050 |
+
if self.options.use_kv_cache:
|
1051 |
+
self.inference.cleanup_caching()
|
1052 |
+
# Caching will be triggered on the logits function
|
1053 |
+
|
1054 |
+
def _reset_after_maximal_context(self, new_mel_frame: Tensor):
|
1055 |
+
print("Reset context...")
|
1056 |
+
num_old_mels = (self.options.gran * (self.options.look_ahead_blocks) * 2) + 2 if self.options.look_ahead_blocks > 0 else (self.options.gran * 2) + 2
|
1057 |
+
self.mel = torch.cat([self.mel[..., -num_old_mels:], new_mel_frame], dim=-1)
|
1058 |
+
self.options.prefix = self.tokens[:, len(self.sot_sequence):].tolist()[0][-self.options.n_tokens_look_back-1:]
|
1059 |
+
print(f"Modifying tokens! {self.options.prefix=}")
|
1060 |
+
self.initial_tokens = self._get_initial_tokens()
|
1061 |
+
print("Modified tokens!")
|
1062 |
+
self.tokens = torch.tensor([list(self.initial_tokens)]) # no need to use repeat, batch is meaningless in stream.
|
1063 |
+
self.tokens = self.tokens.repeat_interleave(self.n_group, dim=0).to(self.model.device)
|
1064 |
+
self._caching_inner_reset()
|
1065 |
+
self.audio_features = torch.zeros((1, self.model.dims.n_audio_ctx, self.model.dims.n_audio_state)).to(self.model.device)
|
1066 |
+
print("Finished reset...")
|
1067 |
+
|
1068 |
+
@torch.no_grad()
|
1069 |
+
def run(self, mel_frame: Tensor) -> List[DecodingResult]:
|
1070 |
+
"""
|
1071 |
+
mel_frame - a tensor containing the last frame we got from the stream.
|
1072 |
+
|
1073 |
+
The function needs to:
|
1074 |
+
1. Take the mel frame.
|
1075 |
+
2. Extract features from the given frame using kv caching in the encoder.
|
1076 |
+
3. Note that the caching in the encoder will do it automatically.
|
1077 |
+
"""
|
1078 |
+
# concat mel
|
1079 |
+
if self.options.single_frame_mel: # each time we get a single frame of mel
|
1080 |
+
self.mel = torch.cat([self.mel, mel_frame], dim=-1) if self.mel is not None else mel_frame
|
1081 |
+
else: # we are getting a whole frame [0, t_curr]
|
1082 |
+
self.mel = mel_frame
|
1083 |
+
tokenizer: Tokenizer = self.tokenizer
|
1084 |
+
n_audio: int = mel_frame.shape[0]
|
1085 |
+
|
1086 |
+
self.index += self.options.gran # on each decoding call we add more context
|
1087 |
+
self.frame_counter += 1
|
1088 |
+
print(f"Collected {self.frame_counter} frames...")
|
1089 |
+
|
1090 |
+
if self.mel.shape[-1] >= self.options.maximal_seconds_context * 100:
|
1091 |
+
self._reset_after_maximal_context(mel_frame)
|
1092 |
+
|
1093 |
+
if self.mel.shape[-1] < (self.options.gran * 2 * (self.options.look_ahead_blocks + 1)):
|
1094 |
+
print("Decoding Task: skipping first frames...")
|
1095 |
+
return self._empty_results()
|
1096 |
+
|
1097 |
+
# call the main sampling loop
|
1098 |
+
if not self.options.use_ca_kv_cache:
|
1099 |
+
self._get_audio_features(self.mel) # encoder forward pass, updates self.audio_features
|
1100 |
+
audio_features = self.audio_features
|
1101 |
+
sum_logprobs, no_speech_probs = self._main_loop(audio_features[:, :self.index])
|
1102 |
+
else:
|
1103 |
+
audio_features = self._get_audio_features(self.mel)
|
1104 |
+
sum_logprobs, no_speech_probs = self._main_loop(audio_features)
|
1105 |
+
|
1106 |
+
|
1107 |
+
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
1108 |
+
audio_features = audio_features[:: self.n_group]
|
1109 |
+
no_speech_probs = no_speech_probs[:: self.n_group]
|
1110 |
+
|
1111 |
+
self.tokens = self.tokens.reshape(n_audio, self.n_group, -1)
|
1112 |
+
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
1113 |
+
|
1114 |
+
# get the final candidates for each group, and slice between the first sampled token and EOT
|
1115 |
+
tokens, sum_logprobs = self.decoder.finalize(self.tokens, sum_logprobs)
|
1116 |
+
tokens: List[List[Tensor]] = [[t[self.sample_begin: ((t == tokenizer.eot) | (t == tokenizer.sot_lm)).nonzero()[0, 0]] for t in s] for s in tokens]
|
1117 |
+
|
1118 |
+
|
1119 |
+
# if any of the suggested beams is empty (no token), it means that the predicted token was EOT. Add it, so ML Ranker won't fail.
|
1120 |
+
for i in range(len(tokens[0])):
|
1121 |
+
if tokens[0][i].shape[0] > 0: continue
|
1122 |
+
else: tokens[0][i] = torch.tensor([self.tokenizer.eot]).to(tokens[0][i].device)
|
1123 |
+
|
1124 |
+
# select the top-ranked sample in each group
|
1125 |
+
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
1126 |
+
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
1127 |
+
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
1128 |
+
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
1129 |
+
avg_logprobs: List[float] = [
|
1130 |
+
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
1131 |
+
]
|
1132 |
+
|
1133 |
+
fields = (texts, tokens, audio_features, avg_logprobs, no_speech_probs)
|
1134 |
+
|
1135 |
+
self.tokens = self.tokens.squeeze(0)
|
1136 |
+
|
1137 |
+
if len(set(map(len, fields))) != 1:
|
1138 |
+
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
1139 |
+
|
1140 |
+
# apply timestamps
|
1141 |
+
if self.options.beam_size == 0 or self.options.beam_size is None:
|
1142 |
+
timed_tokens = tokens.copy()[0]
|
1143 |
+
for i, index in enumerate(sorted(self.decoder.timestamps_map.keys())):
|
1144 |
+
timed_tokens.insert(index + i + 1, self.tokenizer.timestamp_begin + (self.decoder.timestamps_map[index] // 20))
|
1145 |
+
else:
|
1146 |
+
timed_tokens = [50257]
|
1147 |
+
|
1148 |
+
return DecodingResult(
|
1149 |
+
audio_features=audio_features,
|
1150 |
+
language="en",
|
1151 |
+
tokens=tokens,
|
1152 |
+
text=texts[0],
|
1153 |
+
avg_logprob=avg_logprobs,
|
1154 |
+
no_speech_prob=no_speech_probs,
|
1155 |
+
temperature=self.options.temperature,
|
1156 |
+
compression_ratio=compression_ratio(texts[0]),
|
1157 |
+
timestamps=self.decoder.timestamps_map,
|
1158 |
+
timed_tokens=timed_tokens,
|
1159 |
+
timed_text=self.tokenizer.decode_with_timestamps(timed_tokens)
|
1160 |
+
)
|
1161 |
+
|
1162 |
+
|
1163 |
+
@torch.no_grad()
|
1164 |
+
def decode(
|
1165 |
+
model: "Whisper",
|
1166 |
+
mel: Tensor,
|
1167 |
+
task: DecodingTask,
|
1168 |
+
options: DecodingOptions = DecodingOptions(),
|
1169 |
+
**kwargs,
|
1170 |
+
) -> Union[DecodingResult, List[DecodingResult]]:
|
1171 |
+
"""
|
1172 |
+
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
1173 |
+
|
1174 |
+
Parameters
|
1175 |
+
----------
|
1176 |
+
model: Whisper
|
1177 |
+
the Whisper model instance
|
1178 |
+
|
1179 |
+
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
1180 |
+
A tensor containing the Mel spectrogram(s)
|
1181 |
+
|
1182 |
+
options: DecodingOptions
|
1183 |
+
A dataclass that contains all necessary options for decoding 30-second segments
|
1184 |
+
|
1185 |
+
Returns
|
1186 |
+
-------
|
1187 |
+
result: Union[DecodingResult, List[DecodingResult]]
|
1188 |
+
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
1189 |
+
"""
|
1190 |
+
if mel.ndim == 2:
|
1191 |
+
mel = mel.unsqueeze(0)
|
1192 |
+
|
1193 |
+
if kwargs:
|
1194 |
+
options = replace(options, **kwargs)
|
1195 |
+
|
1196 |
+
result = DecodingTask(model, options).run(mel)
|
1197 |
+
return result
|
1198 |
+
|
whisper_stream/streaming_model.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from torch import Tensor
|
7 |
+
from typing import Optional
|
8 |
+
from .model import AudioEncoder, TextDecoder, Whisper
|
9 |
+
|
10 |
+
from .streaming_decoding import DecodingTask, DecodingOptions, DecodingResult
|
11 |
+
from .streaming_transcribe import transcribe as transcribe_function
|
12 |
+
from .decoding import decode as non_causal_decode_function
|
13 |
+
from .audio import SpectrogramStream
|
14 |
+
|
15 |
+
from dataclasses import replace
|
16 |
+
|
17 |
+
from pytorch_lightning import LightningModule
|
18 |
+
|
19 |
+
|
20 |
+
class LoraLayer(LightningModule):
|
21 |
+
def __init__(self, input_dim, output_dim, rank=8, alpha=None, *args, **kwargs):
|
22 |
+
super().__init__(*args, **kwargs)
|
23 |
+
|
24 |
+
self.lora_A = nn.Parameter(torch.zeros(input_dim, rank))
|
25 |
+
self.lora_B = nn.Parameter(torch.zeros(rank, output_dim))
|
26 |
+
|
27 |
+
self.alpha = rank if alpha is None else alpha
|
28 |
+
self.rank = rank
|
29 |
+
self.scale = self.alpha / self.rank
|
30 |
+
|
31 |
+
self._init_weights()
|
32 |
+
|
33 |
+
def _init_weights(self):
|
34 |
+
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
35 |
+
nn.init.zeros_(self.lora_B)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return x @ (self.lora_A @ self.lora_B) * self.scale
|
39 |
+
|
40 |
+
|
41 |
+
class LoraLinearLayer(LightningModule):
|
42 |
+
def __init__(self, base_layer: nn.Linear, rank: int = 8, bias: int = True, *args, **kwargs):
|
43 |
+
super().__init__(*args, **kwargs)
|
44 |
+
|
45 |
+
self.base_layer = base_layer
|
46 |
+
|
47 |
+
self.lora_layer = LoraLayer(base_layer.in_features, base_layer.out_features, rank=rank)
|
48 |
+
self.aggregate_lora = True
|
49 |
+
|
50 |
+
def turn_on_lora(self):
|
51 |
+
self.aggregate_lora = True
|
52 |
+
|
53 |
+
def turn_off_lora(self):
|
54 |
+
self.aggregate_lora = False
|
55 |
+
|
56 |
+
def forward(self, x: Tensor):
|
57 |
+
if not self.aggregate_lora:
|
58 |
+
return self.base_layer(x)
|
59 |
+
|
60 |
+
return self.base_layer(x) + self.lora_layer(x)
|
61 |
+
|
62 |
+
|
63 |
+
class LoRAMultiHeadAttention(LightningModule):
|
64 |
+
def __init__(self, n_head, query, key, value, out, rank, *args, **kwargs):
|
65 |
+
super().__init__(*args, **kwargs)
|
66 |
+
|
67 |
+
self.n_head = n_head
|
68 |
+
self.query = LoraLinearLayer(query, rank)
|
69 |
+
self.key = LoraLinearLayer(key, rank, bias=False)
|
70 |
+
self.value = LoraLinearLayer(value, rank)
|
71 |
+
self.out = LoraLinearLayer(out, rank)
|
72 |
+
|
73 |
+
def forward(
|
74 |
+
self,
|
75 |
+
x: Tensor,
|
76 |
+
xa: Tensor = None,
|
77 |
+
mask: Tensor = None,
|
78 |
+
kv_cache: dict = None,
|
79 |
+
*args, **kwargs
|
80 |
+
):
|
81 |
+
q = self.query(x)
|
82 |
+
|
83 |
+
if kv_cache is None or xa is None or self.key not in kv_cache:
|
84 |
+
k = self.key(x if xa is None else xa)
|
85 |
+
v = self.value(x if xa is None else xa)
|
86 |
+
|
87 |
+
else:
|
88 |
+
k = kv_cache[self.key]
|
89 |
+
v = kv_cache[self.value]
|
90 |
+
|
91 |
+
wv, qk = self.qkv_attention(q, k, v, mask, kv_cache)
|
92 |
+
|
93 |
+
return self.out(wv), qk
|
94 |
+
|
95 |
+
def qkv_attention(
|
96 |
+
self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor = None, kv_cache: dict = None
|
97 |
+
):
|
98 |
+
n_batch, n_ctx, n_state = q.shape
|
99 |
+
_, k_ctx, _ = k.shape
|
100 |
+
scale = (n_state // self.n_head) ** -0.25
|
101 |
+
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
102 |
+
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
103 |
+
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
104 |
+
qk = q @ k
|
105 |
+
|
106 |
+
# apply causal mask
|
107 |
+
if mask is not None:
|
108 |
+
|
109 |
+
# kv_cache for beam search decoding case
|
110 |
+
if kv_cache is not None and "beam_indices" in kv_cache.keys():
|
111 |
+
for i in range(n_batch):
|
112 |
+
qk[i] = qk[i] + mask[kv_cache["beam_indices"][1][i * n_ctx]:kv_cache["beam_indices"][1][i * n_ctx] + n_ctx, :k_ctx]
|
113 |
+
|
114 |
+
# For training, encoder/decoder causal masks
|
115 |
+
elif k_ctx == n_ctx:
|
116 |
+
qk = qk + mask[:n_ctx, :n_ctx]
|
117 |
+
|
118 |
+
# kv_cache in the greedy decoding case
|
119 |
+
else:
|
120 |
+
qk = qk + mask[k_ctx - n_ctx:k_ctx, :k_ctx]
|
121 |
+
|
122 |
+
qk = qk.float()
|
123 |
+
|
124 |
+
w = F.softmax(qk, dim=-1).to(q.dtype)
|
125 |
+
|
126 |
+
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
127 |
+
|
128 |
+
|
129 |
+
class StreamingAudioEncoder(AudioEncoder):
|
130 |
+
|
131 |
+
def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer, cache_gran, gran, rank, extra_gran_blocks):
|
132 |
+
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer)
|
133 |
+
|
134 |
+
self.gran = gran
|
135 |
+
self.extra_gran_blocks = extra_gran_blocks
|
136 |
+
|
137 |
+
for block in self.blocks:
|
138 |
+
block.attn = LoRAMultiHeadAttention(self.n_head,
|
139 |
+
block.attn.query,
|
140 |
+
block.attn.key,
|
141 |
+
block.attn.value,
|
142 |
+
block.attn.out,
|
143 |
+
rank)
|
144 |
+
|
145 |
+
self.use_stream = False
|
146 |
+
|
147 |
+
# mask for training
|
148 |
+
matrix_size = n_ctx
|
149 |
+
block_size = gran
|
150 |
+
extra_blocks = extra_gran_blocks
|
151 |
+
mask = torch.full((matrix_size, matrix_size), float('-inf'))
|
152 |
+
|
153 |
+
for i in range(0, matrix_size, block_size):
|
154 |
+
if (i // block_size) <= extra_blocks:
|
155 |
+
zero_cols = (block_size * (extra_blocks + 1))
|
156 |
+
else:
|
157 |
+
zero_cols = (block_size * (extra_blocks + 1)) + ((i // block_size) - extra_blocks) * block_size
|
158 |
+
|
159 |
+
mask[i:i + block_size, :zero_cols] = 1
|
160 |
+
|
161 |
+
self.register_buffer("mask", mask, persistent=False)
|
162 |
+
|
163 |
+
def _use_stream(self, use_stream: bool):
|
164 |
+
self.use_stream = use_stream
|
165 |
+
|
166 |
+
def forward(self, x: Tensor, index: list = [0, 1500], kv_cache = None, mask = True):
|
167 |
+
"""
|
168 |
+
simulate streaming forward using qk cache self attn.
|
169 |
+
"""
|
170 |
+
x = F.gelu(self.conv1(x))
|
171 |
+
x = F.gelu(self.conv2(x))
|
172 |
+
x = x.permute(0, 2, 1)
|
173 |
+
|
174 |
+
if self.use_stream:
|
175 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
176 |
+
x = x[:, offset:offset + self.gran + (int(offset == 0) * (self.extra_gran_blocks * self.gran))] # offset
|
177 |
+
x = (x + self.positional_embedding[offset:offset + self.gran + (int(offset == 0) * (self.extra_gran_blocks * self.gran))]).to(x.dtype)
|
178 |
+
else: # use during training
|
179 |
+
x = x[:, index[0]:index[1]] # offset
|
180 |
+
x = (x + self.positional_embedding[index[0]:index[1]]).to(x.dtype)
|
181 |
+
|
182 |
+
for block in self.blocks:
|
183 |
+
chosen_mask = mask[..., :index[1], :index[1]] if isinstance(mask, Tensor) else self.mask if (mask is not None) and (self.use_stream) else None
|
184 |
+
x = block(x, mask=chosen_mask, kv_cache=kv_cache)
|
185 |
+
|
186 |
+
x = self.ln_post(x)
|
187 |
+
|
188 |
+
return x
|
189 |
+
|
190 |
+
|
191 |
+
class StreamingTextDecoder(TextDecoder):
|
192 |
+
def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer, rank):
|
193 |
+
super().__init__(n_vocab, n_ctx, n_state, n_head, n_layer)
|
194 |
+
|
195 |
+
self.n_ctx = n_ctx
|
196 |
+
self.n_state = n_state
|
197 |
+
|
198 |
+
for block in self.blocks:
|
199 |
+
block.attn = LoRAMultiHeadAttention(n_head,
|
200 |
+
block.attn.query,
|
201 |
+
block.attn.key,
|
202 |
+
block.attn.value,
|
203 |
+
block.attn.out,
|
204 |
+
rank)
|
205 |
+
|
206 |
+
block.cross_attn = LoRAMultiHeadAttention(n_head,
|
207 |
+
block.cross_attn.query,
|
208 |
+
block.cross_attn.key,
|
209 |
+
block.cross_attn.value,
|
210 |
+
block.cross_attn.out,
|
211 |
+
rank)
|
212 |
+
|
213 |
+
def forward(self, x: Tensor, xa: Tensor, kv_cache: dict = None, dump_type: str = None, **kwargs):
|
214 |
+
"""
|
215 |
+
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
216 |
+
the text tokens
|
217 |
+
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
218 |
+
the encoded audio features to be attended on
|
219 |
+
"""
|
220 |
+
if kv_cache is not None and "beam_indices" in kv_cache.keys():
|
221 |
+
x = self.token_embedding(x) + self.positional_embedding.unsqueeze(0).expand(x.shape[0], self.positional_embedding.shape[0], self.positional_embedding.shape[1])[kv_cache["beam_indices"][0], kv_cache["beam_indices"][1]].view(x.shape[0], -1, self.n_state)
|
222 |
+
else:
|
223 |
+
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
|
224 |
+
x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
|
225 |
+
|
226 |
+
x = x.to(xa.dtype)
|
227 |
+
|
228 |
+
for block in self.blocks:
|
229 |
+
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
|
230 |
+
|
231 |
+
x = self.ln(x)
|
232 |
+
logits = (
|
233 |
+
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
234 |
+
).float()
|
235 |
+
|
236 |
+
return logits
|
237 |
+
|
238 |
+
|
239 |
+
class StreamingWhisper(Whisper):
|
240 |
+
def __init__(self, dims, cache_gran: bool = True, gran: int = 16, rank: int = 0, extra_gran_blocks: int = 0):
|
241 |
+
super().__init__(dims)
|
242 |
+
|
243 |
+
self.cache_gran = cache_gran
|
244 |
+
self.gran = gran
|
245 |
+
self.rank = rank
|
246 |
+
self.extra_gran_blocks = extra_gran_blocks
|
247 |
+
|
248 |
+
print(f"Running a streaming whisper model, using chunk size: {gran * 20}[msec] and {extra_gran_blocks} extra chunks for initialization.")
|
249 |
+
|
250 |
+
# The only difference is a streaming encoder
|
251 |
+
self.encoder = StreamingAudioEncoder(
|
252 |
+
self.dims.n_mels,
|
253 |
+
self.dims.n_audio_ctx,
|
254 |
+
self.dims.n_audio_state,
|
255 |
+
self.dims.n_audio_head,
|
256 |
+
self.dims.n_audio_layer,
|
257 |
+
cache_gran=cache_gran,
|
258 |
+
gran=gran,
|
259 |
+
rank=rank,
|
260 |
+
extra_gran_blocks=extra_gran_blocks
|
261 |
+
)
|
262 |
+
|
263 |
+
self.decoder = StreamingTextDecoder(
|
264 |
+
self.dims.n_vocab,
|
265 |
+
self.dims.n_text_ctx,
|
266 |
+
self.dims.n_text_state,
|
267 |
+
self.dims.n_text_head,
|
268 |
+
self.dims.n_text_layer,
|
269 |
+
rank=rank
|
270 |
+
)
|
271 |
+
|
272 |
+
# Advisor params - Dropped.
|
273 |
+
self.advisor_type = None
|
274 |
+
self.n_advisor_class = 0
|
275 |
+
self.decoder_advisor = None
|
276 |
+
|
277 |
+
self.decoding_task = None
|
278 |
+
self.spec_streamer = SpectrogramStream()
|
279 |
+
|
280 |
+
self.use_ca_cache_hook = True # relevant only when ca_kv_cache is installed
|
281 |
+
|
282 |
+
def reset(self, use_stream: bool = True, clean_task: bool = True):
|
283 |
+
self.encoder._use_stream(use_stream)
|
284 |
+
del self.decoding_task # trigger clean encoder kv caching
|
285 |
+
self.decoding_task = None
|
286 |
+
self.spec_streamer.reset()
|
287 |
+
|
288 |
+
@torch.no_grad()
|
289 |
+
def decode(self, mel: Tensor, options: DecodingOptions = DecodingOptions(), use_frames: bool = False, **kwargs) -> DecodingResult:
|
290 |
+
if kwargs: options = replace(options, **kwargs)
|
291 |
+
|
292 |
+
if use_frames: # mel is frames of audio, need to calc mel
|
293 |
+
mel = self.spec_streamer.calc_mel_with_new_frame(mel).squeeze(0)
|
294 |
+
|
295 |
+
if self.encoder.gran != options.gran:
|
296 |
+
print(f"Encoder gran & options gran differ. forcing options to be on encoder's gran: {self.encoder.gran}")
|
297 |
+
options.gran = self.encoder.gran
|
298 |
+
|
299 |
+
if not self.decoding_task:
|
300 |
+
self.decoding_task = DecodingTask(self, options)
|
301 |
+
|
302 |
+
return self.decoding_task.run(mel.unsqueeze(0))
|
303 |
+
|
304 |
+
def _turn_off_lora(self):
|
305 |
+
for _, layer in self.encoder.named_modules():
|
306 |
+
if isinstance(layer, LoraLinearLayer):
|
307 |
+
layer.turn_off_lora()
|
308 |
+
|
309 |
+
def _turn_on_lora(self):
|
310 |
+
for _, layer in self.encoder.named_modules():
|
311 |
+
if isinstance(layer, LoraLinearLayer):
|
312 |
+
layer.turn_on_lora()
|
313 |
+
|
314 |
+
def _cancel_streaming_mode(self):
|
315 |
+
self._turn_off_lora()
|
316 |
+
self.encoder._use_stream(False)
|
317 |
+
|
318 |
+
def _revert_streaming_mode(self):
|
319 |
+
self._turn_on_lora()
|
320 |
+
self.encoder._use_stream(True)
|
321 |
+
|
322 |
+
@torch.no_grad()
|
323 |
+
def non_causal_decode(self, mel: Tensor, options: DecodingOptions = DecodingOptions(), **kwargs) -> DecodingResult:
|
324 |
+
self._cancel_streaming_mode()
|
325 |
+
results = non_causal_decode_function(self, mel, options, **kwargs)
|
326 |
+
self._revert_streaming_mode()
|
327 |
+
return results
|
328 |
+
|
329 |
+
def remove_encoder_kv_cache_hooks(self):
|
330 |
+
for hook in self.encoder._forward_hooks.values():
|
331 |
+
hook.remove()
|
332 |
+
|
333 |
+
def install_encoder_kv_cache_hooks(self, cache = None):
|
334 |
+
cache = {**cache} if cache is not None else {}
|
335 |
+
hooks = []
|
336 |
+
|
337 |
+
def save_to_cache(module, _, output):
|
338 |
+
if module not in cache or output.shape[1] > self.dims.n_audio_ctx:
|
339 |
+
# save as-is, for the first token or cross attention
|
340 |
+
cache[module] = output
|
341 |
+
else:
|
342 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
343 |
+
return cache[module]
|
344 |
+
|
345 |
+
def install_hooks(layer: nn.Module):
|
346 |
+
if isinstance(layer, LoRAMultiHeadAttention):
|
347 |
+
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
348 |
+
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
349 |
+
|
350 |
+
self.encoder.apply(install_hooks)
|
351 |
+
return cache, hooks
|
352 |
+
|
353 |
+
def install_decoder_kv_cache_hooks(self, cache = None):
|
354 |
+
cache = {**cache} if cache is not None else {}
|
355 |
+
hooks = []
|
356 |
+
|
357 |
+
def save_to_cache(module, _, output):
|
358 |
+
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
359 |
+
cache[module] = output
|
360 |
+
else:
|
361 |
+
if "beam_indices" not in cache.keys() or all([index == (cache[module].shape[1]) for index in cache["beam_indices"][1]]):
|
362 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
363 |
+
else:
|
364 |
+
for _, (beam, index, output_index) in enumerate(zip(*cache["beam_indices"])):
|
365 |
+
if index < cache[module].shape[1]:
|
366 |
+
cache[module][beam, index] = output[beam, output_index]
|
367 |
+
else:
|
368 |
+
cache[module] = torch.cat([cache[module], output[:, (output_index):]], dim=1).detach()
|
369 |
+
|
370 |
+
return cache[module]
|
371 |
+
|
372 |
+
for name, module in self.decoder.named_modules():
|
373 |
+
if isinstance(module, LoRAMultiHeadAttention) and "attn" in name and "cross" not in name:
|
374 |
+
hooks.append(module.key.register_forward_hook(save_to_cache))
|
375 |
+
hooks.append(module.value.register_forward_hook(save_to_cache))
|
376 |
+
|
377 |
+
return cache, hooks
|
378 |
+
|
379 |
+
def install_cross_attn_kv_cache_hooks(self, cache=None):
|
380 |
+
cache = {**cache} if cache is not None else {}
|
381 |
+
hooks = []
|
382 |
+
|
383 |
+
def save_to_cache(module, _, output):
|
384 |
+
if not self.use_ca_cache_hook:
|
385 |
+
return cache[module]
|
386 |
+
|
387 |
+
if module not in cache or output.shape[1] > self.dims.n_audio_ctx:
|
388 |
+
# save as-is, for the first token or cross attention
|
389 |
+
cache[module] = output
|
390 |
+
else:
|
391 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
392 |
+
|
393 |
+
return cache[module]
|
394 |
+
|
395 |
+
def check_if_calculation_is_needed(module, _):
|
396 |
+
if not self.use_ca_cache_hook:
|
397 |
+
return cache[module]
|
398 |
+
|
399 |
+
for name, module in self.decoder.named_modules():
|
400 |
+
if isinstance(module, LoRAMultiHeadAttention) and "cross_attn" in name:
|
401 |
+
hooks.append(module.key.register_forward_hook(save_to_cache))
|
402 |
+
hooks.append(module.key.register_forward_pre_hook(check_if_calculation_is_needed))
|
403 |
+
hooks.append(module.value.register_forward_hook(save_to_cache))
|
404 |
+
hooks.append(module.value.register_forward_pre_hook(check_if_calculation_is_needed))
|
405 |
+
|
406 |
+
return cache, hooks
|
407 |
+
|
408 |
+
# For non-causal decoding compatibility
|
409 |
+
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
410 |
+
"""
|
411 |
+
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
412 |
+
tensors calculated for the previous positions. This method returns a dictionary that stores
|
413 |
+
all caches, and the necessary hooks for the key and value projection modules that save the
|
414 |
+
intermediate tensors to be reused during later calculations.
|
415 |
+
|
416 |
+
Returns
|
417 |
+
-------
|
418 |
+
cache : Dict[nn.Module, torch.Tensor]
|
419 |
+
A dictionary object mapping the key/value projection modules to its cache
|
420 |
+
hooks : List[RemovableHandle]
|
421 |
+
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
422 |
+
"""
|
423 |
+
cache = {**cache} if cache is not None else {}
|
424 |
+
hooks = []
|
425 |
+
|
426 |
+
def save_to_cache(module, _, output):
|
427 |
+
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
428 |
+
# save as-is, for the first token or cross attention
|
429 |
+
cache[module] = output
|
430 |
+
else:
|
431 |
+
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
432 |
+
return cache[module]
|
433 |
+
|
434 |
+
def install_hooks(layer: nn.Module):
|
435 |
+
if isinstance(layer, LoRAMultiHeadAttention):
|
436 |
+
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
437 |
+
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
438 |
+
|
439 |
+
self.decoder.apply(install_hooks)
|
440 |
+
return cache, hooks
|
441 |
+
|
442 |
+
# refers to function from streaming_decoding, streaming_transcribe library
|
443 |
+
transcribe = transcribe_function
|
444 |
+
|
whisper_stream/streaming_transcribe.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
import traceback
|
5 |
+
import warnings
|
6 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
import librosa
|
12 |
+
|
13 |
+
from .audio import (
|
14 |
+
FRAMES_PER_SECOND,
|
15 |
+
HOP_LENGTH,
|
16 |
+
N_FRAMES,
|
17 |
+
N_SAMPLES,
|
18 |
+
SAMPLE_RATE,
|
19 |
+
log_mel_spectrogram,
|
20 |
+
pad_or_trim,
|
21 |
+
load_audio,
|
22 |
+
SpectrogramStream,
|
23 |
+
MyStream
|
24 |
+
)
|
25 |
+
from .streaming_decoding import DecodingOptions, DecodingResult, DecodingTask
|
26 |
+
from .timing import add_word_timestamps
|
27 |
+
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
28 |
+
from .utils import (
|
29 |
+
exact_div,
|
30 |
+
format_timestamp,
|
31 |
+
get_end,
|
32 |
+
get_writer,
|
33 |
+
make_safe,
|
34 |
+
optional_float,
|
35 |
+
optional_int,
|
36 |
+
str2bool,
|
37 |
+
)
|
38 |
+
|
39 |
+
if TYPE_CHECKING:
|
40 |
+
from .streaming_model import StreamingWhisper
|
41 |
+
|
42 |
+
|
43 |
+
def transcribe(
|
44 |
+
model: "StreamingWhisper" = None,
|
45 |
+
filename: str = None,
|
46 |
+
channels: int = 2,
|
47 |
+
simulate_stream: bool = False,
|
48 |
+
wav_file: str = None,
|
49 |
+
single_frame_mel: bool = True,
|
50 |
+
advised: bool = False,
|
51 |
+
attentive_advisor: bool = False,
|
52 |
+
get_mel: bool = False,
|
53 |
+
temperature: float = 0,
|
54 |
+
beam_size: int = 5,
|
55 |
+
stream_decode: bool = True,
|
56 |
+
use_ca_kv_cache: bool = False,
|
57 |
+
use_sa_kv_cache: bool = False,
|
58 |
+
language: str = "en",
|
59 |
+
wait_for_all: bool = False,
|
60 |
+
use_latency: bool = False,
|
61 |
+
get_times: bool = False,
|
62 |
+
pad_trim: bool = True,
|
63 |
+
max_sec_context: int = 30,
|
64 |
+
streaming_timestamps: bool = False,
|
65 |
+
force_first_tokens_timestamps: bool = False,
|
66 |
+
use_remote_stream: bool = False,
|
67 |
+
) -> List[str]:
|
68 |
+
"""
|
69 |
+
Open a stream and transcribe it using streaming whisper model
|
70 |
+
|
71 |
+
Parameters
|
72 |
+
----------
|
73 |
+
model: Whisper
|
74 |
+
The Whisper model instance
|
75 |
+
|
76 |
+
Returns -
|
77 |
+
-------
|
78 |
+
A dict with a text, tokens field with all of the text that was transcribed till the stream stopped.
|
79 |
+
"""
|
80 |
+
model.reset(use_stream=True) # we first reset the model before starting a stream, cleaning any cache.
|
81 |
+
model.eval()
|
82 |
+
|
83 |
+
# Instantiate streaming instance and open a stream
|
84 |
+
ms_gran = model.encoder.gran * 20
|
85 |
+
stream_instance = MyStream(ms_gran,
|
86 |
+
channels=channels,
|
87 |
+
filename=filename,
|
88 |
+
simulate_stream=simulate_stream,
|
89 |
+
wav_file=wav_file,
|
90 |
+
use_latency=use_latency,
|
91 |
+
pad_trim=pad_trim,
|
92 |
+
use_remote_machine=use_remote_stream)
|
93 |
+
|
94 |
+
stream_instance.open_stream()
|
95 |
+
|
96 |
+
# frames - used only when filename is given, in order to save a long wav at the end of the conversation.
|
97 |
+
frames = []
|
98 |
+
|
99 |
+
# first we'll use
|
100 |
+
decoding_options = DecodingOptions(
|
101 |
+
language=language,
|
102 |
+
gran=(ms_gran // 20),
|
103 |
+
single_frame_mel=single_frame_mel,
|
104 |
+
without_timestamps=True,
|
105 |
+
beam_size=beam_size if temperature == 0 else None,
|
106 |
+
temperature=temperature,
|
107 |
+
length_penalty=None,
|
108 |
+
advised=advised,
|
109 |
+
attentive_advisor=attentive_advisor,
|
110 |
+
look_ahead_blocks=model.encoder.extra_gran_blocks,
|
111 |
+
patience=None,
|
112 |
+
stream_decode=stream_decode,
|
113 |
+
use_kv_cache=use_sa_kv_cache,
|
114 |
+
use_ca_kv_cache=use_ca_kv_cache,
|
115 |
+
wait_for_all=wait_for_all,
|
116 |
+
streaming_timestamps=streaming_timestamps,
|
117 |
+
force_first_tokens_timestamps=force_first_tokens_timestamps
|
118 |
+
)
|
119 |
+
|
120 |
+
streamed_spectrogram = SpectrogramStream(n_mels=model.dims.n_mels) # default values are whisper default values
|
121 |
+
|
122 |
+
texts = []
|
123 |
+
times = []
|
124 |
+
mel = None
|
125 |
+
reset_len = (max_sec_context * SAMPLE_RATE) + 360 # 360 is for the mel padding
|
126 |
+
try:
|
127 |
+
for frame in stream_instance.read():
|
128 |
+
# save frames for optional save
|
129 |
+
frames.extend(frame)
|
130 |
+
|
131 |
+
if len(frames) > reset_len: # When we surpass the max_sec_context - reset model (positional embeddings constrain us)
|
132 |
+
frame = np.concatenate((frames[-360:], frame))
|
133 |
+
frames = []
|
134 |
+
frames.extend(frame.tolist())
|
135 |
+
model.reset(use_stream=True)
|
136 |
+
streamed_spectrogram.reset()
|
137 |
+
|
138 |
+
if get_times:
|
139 |
+
torch.cuda.synchronize()
|
140 |
+
start = time.time()
|
141 |
+
|
142 |
+
frame_tensor = torch.from_numpy(frame).pin_memory()
|
143 |
+
mel_frame = streamed_spectrogram.calc_mel_with_new_frame(frame_tensor.to(model.device, non_blocking=True))
|
144 |
+
|
145 |
+
# decode given the new mel frame and print results
|
146 |
+
result = model.decode(mel_frame.squeeze(0), decoding_options)
|
147 |
+
|
148 |
+
if get_times:
|
149 |
+
torch.cuda.synchronize()
|
150 |
+
end = time.time()
|
151 |
+
|
152 |
+
print(result.text)
|
153 |
+
|
154 |
+
# append metrics
|
155 |
+
if get_times:
|
156 |
+
times.append(end - start)
|
157 |
+
texts.append(result)
|
158 |
+
|
159 |
+
except KeyboardInterrupt:
|
160 |
+
stream_instance.close_stream(frames)
|
161 |
+
|
162 |
+
print("Finished capturing audio.")
|
163 |
+
|
164 |
+
if get_mel: return texts, mel
|
165 |
+
if get_times: return texts, times
|
166 |
+
|
167 |
+
return texts, []
|
168 |
+
|
169 |
+
|
170 |
+
def cli():
|
171 |
+
parser = argparse.ArgumentParser(description="Transcribe streaming audio with customizable options")
|
172 |
+
|
173 |
+
# Model choices
|
174 |
+
parser.add_argument("--model", type=str, default="small", help="Model size to transcribe with")
|
175 |
+
parser.add_argument("--chunk_size", type=int, default=300, help="Chunk size for streaming")
|
176 |
+
parser.add_argument("--multilingual", action="store_true", help="Use a multilingual checkpoint if exists.")
|
177 |
+
|
178 |
+
# Required/optional file inputs
|
179 |
+
parser.add_argument("--filename", type=str, help="Path to the input audio file")
|
180 |
+
parser.add_argument("--wav_file", type=str, help="Optional WAV file path to stream")
|
181 |
+
|
182 |
+
# Audio configuration
|
183 |
+
parser.add_argument("--channels", type=int, default=2, help="Number of audio channels")
|
184 |
+
|
185 |
+
# Streaming behavior
|
186 |
+
parser.add_argument("--simulate_stream", action="store_true", help="Simulate a stream from a file")
|
187 |
+
parser.add_argument("--single_frame_mel", action="store_true", default=True, help="Use single frame MELs")
|
188 |
+
parser.add_argument("--stream_decode", action="store_true", default=True, help="Use streaming decode")
|
189 |
+
parser.add_argument("--use_ca_kv_cache", action="store_true", help="Use cross-attention key-value cache")
|
190 |
+
parser.add_argument("--use_sa_kv_cache", action="store_true", help="Use self-attention key-value cache")
|
191 |
+
parser.add_argument("--wait_for_all", action="store_true", help="Wait for all results before outputting")
|
192 |
+
parser.add_argument("--use_latency", action="store_true", help="Track latency for metrics")
|
193 |
+
parser.add_argument("--pad_trim", action="store_true", default=True, help="Enable padding and trimming")
|
194 |
+
parser.add_argument("--streaming_timestamps", action="store_true", help="Use timestamps in streaming")
|
195 |
+
parser.add_argument("--force_first_tokens_timestamps", action="store_true", help="Force timestamps on first tokens")
|
196 |
+
parser.add_argument("--use_remote_stream", action="store_true", help="Using remote stream")
|
197 |
+
|
198 |
+
# Model behavior
|
199 |
+
parser.add_argument("--advised", action="store_true", help="Enable advised decoding")
|
200 |
+
parser.add_argument("--attentive_advisor", action="store_true", help="Use attentive advisor logic")
|
201 |
+
parser.add_argument("--get_mel", action="store_true", help="Return MEL spectrogram")
|
202 |
+
parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
|
203 |
+
parser.add_argument("--beam_size", type=int, default=5, help="Beam size for beam search decoding")
|
204 |
+
parser.add_argument("--language", type=str, default="en", help="Language of transcription")
|
205 |
+
parser.add_argument("--get_times", action="store_true", help="Return word-level timing information")
|
206 |
+
parser.add_argument("--max_sec_context", type=int, default=30, help="Max context window size in seconds")
|
207 |
+
|
208 |
+
return parser.parse_args()
|
whisper_stream/timing.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import subprocess
|
3 |
+
import warnings
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import TYPE_CHECKING, List
|
6 |
+
|
7 |
+
import numba
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
13 |
+
from .tokenizer import Tokenizer
|
14 |
+
|
15 |
+
if TYPE_CHECKING:
|
16 |
+
from .model import Whisper
|
17 |
+
|
18 |
+
|
19 |
+
def median_filter(x: torch.Tensor, filter_width: int):
|
20 |
+
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
21 |
+
pad_width = filter_width // 2
|
22 |
+
if x.shape[-1] <= pad_width:
|
23 |
+
# F.pad requires the padding width to be smaller than the input dimension
|
24 |
+
return x
|
25 |
+
|
26 |
+
if (ndim := x.ndim) <= 2:
|
27 |
+
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
28 |
+
x = x[None, None, :]
|
29 |
+
|
30 |
+
assert (
|
31 |
+
filter_width > 0 and filter_width % 2 == 1
|
32 |
+
), "`filter_width` should be an odd number"
|
33 |
+
|
34 |
+
result = None
|
35 |
+
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
36 |
+
if x.is_cuda:
|
37 |
+
try:
|
38 |
+
from .triton_ops import median_filter_cuda
|
39 |
+
|
40 |
+
result = median_filter_cuda(x, filter_width)
|
41 |
+
except (RuntimeError, subprocess.CalledProcessError):
|
42 |
+
warnings.warn(
|
43 |
+
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
44 |
+
"falling back to a slower median kernel implementation..."
|
45 |
+
)
|
46 |
+
|
47 |
+
if result is None:
|
48 |
+
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
49 |
+
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
50 |
+
|
51 |
+
if ndim <= 2:
|
52 |
+
result = result[0, 0]
|
53 |
+
|
54 |
+
return result
|
55 |
+
|
56 |
+
|
57 |
+
@numba.jit(nopython=True)
|
58 |
+
def backtrace(trace: np.ndarray):
|
59 |
+
i = trace.shape[0] - 1
|
60 |
+
j = trace.shape[1] - 1
|
61 |
+
trace[0, :] = 2
|
62 |
+
trace[:, 0] = 1
|
63 |
+
|
64 |
+
result = []
|
65 |
+
while i > 0 or j > 0:
|
66 |
+
result.append((i - 1, j - 1))
|
67 |
+
|
68 |
+
if trace[i, j] == 0:
|
69 |
+
i -= 1
|
70 |
+
j -= 1
|
71 |
+
elif trace[i, j] == 1:
|
72 |
+
i -= 1
|
73 |
+
elif trace[i, j] == 2:
|
74 |
+
j -= 1
|
75 |
+
else:
|
76 |
+
raise ValueError("Unexpected trace[i, j]")
|
77 |
+
|
78 |
+
result = np.array(result)
|
79 |
+
return result[::-1, :].T
|
80 |
+
|
81 |
+
|
82 |
+
@numba.jit(nopython=True, parallel=True)
|
83 |
+
def dtw_cpu(x: np.ndarray):
|
84 |
+
N, M = x.shape
|
85 |
+
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
86 |
+
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
87 |
+
|
88 |
+
cost[0, 0] = 0
|
89 |
+
for j in range(1, M + 1):
|
90 |
+
for i in range(1, N + 1):
|
91 |
+
c0 = cost[i - 1, j - 1]
|
92 |
+
c1 = cost[i - 1, j]
|
93 |
+
c2 = cost[i, j - 1]
|
94 |
+
|
95 |
+
if c0 < c1 and c0 < c2:
|
96 |
+
c, t = c0, 0
|
97 |
+
elif c1 < c0 and c1 < c2:
|
98 |
+
c, t = c1, 1
|
99 |
+
else:
|
100 |
+
c, t = c2, 2
|
101 |
+
|
102 |
+
cost[i, j] = x[i - 1, j - 1] + c
|
103 |
+
trace[i, j] = t
|
104 |
+
|
105 |
+
return backtrace(trace)
|
106 |
+
|
107 |
+
|
108 |
+
def dtw_cuda(x, BLOCK_SIZE=1024):
|
109 |
+
from .triton_ops import dtw_kernel
|
110 |
+
|
111 |
+
M, N = x.shape
|
112 |
+
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
113 |
+
|
114 |
+
x_skew = (
|
115 |
+
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
116 |
+
)
|
117 |
+
x_skew = x_skew.T.contiguous()
|
118 |
+
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
119 |
+
cost[0, 0] = 0
|
120 |
+
cost = cost.cuda()
|
121 |
+
trace = torch.zeros_like(cost, dtype=torch.int32)
|
122 |
+
|
123 |
+
dtw_kernel[(1,)](
|
124 |
+
cost,
|
125 |
+
trace,
|
126 |
+
x_skew,
|
127 |
+
x_skew.stride(0),
|
128 |
+
cost.stride(0),
|
129 |
+
trace.stride(0),
|
130 |
+
N,
|
131 |
+
M,
|
132 |
+
BLOCK_SIZE=BLOCK_SIZE,
|
133 |
+
)
|
134 |
+
|
135 |
+
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
136 |
+
:, : N + 1
|
137 |
+
]
|
138 |
+
return backtrace(trace.cpu().numpy())
|
139 |
+
|
140 |
+
|
141 |
+
def dtw(x: torch.Tensor) -> np.ndarray:
|
142 |
+
if x.is_cuda:
|
143 |
+
try:
|
144 |
+
return dtw_cuda(x)
|
145 |
+
except (RuntimeError, subprocess.CalledProcessError):
|
146 |
+
warnings.warn(
|
147 |
+
"Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
|
148 |
+
"falling back to a slower DTW implementation..."
|
149 |
+
)
|
150 |
+
|
151 |
+
return dtw_cpu(x.double().cpu().numpy())
|
152 |
+
|
153 |
+
|
154 |
+
@dataclass
|
155 |
+
class WordTiming:
|
156 |
+
word: str
|
157 |
+
tokens: List[int]
|
158 |
+
start: float
|
159 |
+
end: float
|
160 |
+
probability: float
|
161 |
+
|
162 |
+
|
163 |
+
def find_alignment(
|
164 |
+
model: "Whisper",
|
165 |
+
tokenizer: Tokenizer,
|
166 |
+
text_tokens: List[int],
|
167 |
+
mel: torch.Tensor,
|
168 |
+
num_frames: int,
|
169 |
+
*,
|
170 |
+
medfilt_width: int = 7,
|
171 |
+
qk_scale: float = 1.0,
|
172 |
+
) -> List[WordTiming]:
|
173 |
+
if len(text_tokens) == 0:
|
174 |
+
return []
|
175 |
+
|
176 |
+
tokens = torch.tensor(
|
177 |
+
[
|
178 |
+
*tokenizer.sot_sequence,
|
179 |
+
tokenizer.no_timestamps,
|
180 |
+
*text_tokens,
|
181 |
+
tokenizer.eot,
|
182 |
+
]
|
183 |
+
).to(model.device)
|
184 |
+
|
185 |
+
# install hooks on the cross attention layers to retrieve the attention weights
|
186 |
+
QKs = [None] * model.dims.n_text_layer
|
187 |
+
hooks = [
|
188 |
+
block.cross_attn.register_forward_hook(
|
189 |
+
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
190 |
+
)
|
191 |
+
for i, block in enumerate(model.decoder.blocks)
|
192 |
+
]
|
193 |
+
|
194 |
+
with torch.no_grad():
|
195 |
+
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
196 |
+
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
197 |
+
token_probs = sampled_logits.softmax(dim=-1)
|
198 |
+
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
199 |
+
text_token_probs = text_token_probs.tolist()
|
200 |
+
|
201 |
+
for hook in hooks:
|
202 |
+
hook.remove()
|
203 |
+
|
204 |
+
# heads * tokens * frames
|
205 |
+
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
206 |
+
weights = weights[:, :, : num_frames // 2]
|
207 |
+
weights = (weights * qk_scale).softmax(dim=-1)
|
208 |
+
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
209 |
+
weights = (weights - mean) / std
|
210 |
+
weights = median_filter(weights, medfilt_width)
|
211 |
+
|
212 |
+
matrix = weights.mean(axis=0)
|
213 |
+
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
214 |
+
text_indices, time_indices = dtw(-matrix)
|
215 |
+
|
216 |
+
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
217 |
+
if len(word_tokens) <= 1:
|
218 |
+
# return on eot only
|
219 |
+
# >>> np.pad([], (1, 0))
|
220 |
+
# array([0.])
|
221 |
+
# This results in crashes when we lookup jump_times with float, like
|
222 |
+
# IndexError: arrays used as indices must be of integer (or boolean) type
|
223 |
+
return []
|
224 |
+
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
225 |
+
|
226 |
+
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
227 |
+
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
228 |
+
start_times = jump_times[word_boundaries[:-1]]
|
229 |
+
end_times = jump_times[word_boundaries[1:]]
|
230 |
+
word_probabilities = [
|
231 |
+
np.mean(text_token_probs[i:j])
|
232 |
+
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
233 |
+
]
|
234 |
+
|
235 |
+
return [
|
236 |
+
WordTiming(word, tokens, start, end, probability)
|
237 |
+
for word, tokens, start, end, probability in zip(
|
238 |
+
words, word_tokens, start_times, end_times, word_probabilities
|
239 |
+
)
|
240 |
+
]
|
241 |
+
|
242 |
+
|
243 |
+
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
244 |
+
# merge prepended punctuations
|
245 |
+
i = len(alignment) - 2
|
246 |
+
j = len(alignment) - 1
|
247 |
+
while i >= 0:
|
248 |
+
previous = alignment[i]
|
249 |
+
following = alignment[j]
|
250 |
+
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
251 |
+
# prepend it to the following word
|
252 |
+
following.word = previous.word + following.word
|
253 |
+
following.tokens = previous.tokens + following.tokens
|
254 |
+
previous.word = ""
|
255 |
+
previous.tokens = []
|
256 |
+
else:
|
257 |
+
j = i
|
258 |
+
i -= 1
|
259 |
+
|
260 |
+
# merge appended punctuations
|
261 |
+
i = 0
|
262 |
+
j = 1
|
263 |
+
while j < len(alignment):
|
264 |
+
previous = alignment[i]
|
265 |
+
following = alignment[j]
|
266 |
+
if not previous.word.endswith(" ") and following.word in appended:
|
267 |
+
# append it to the previous word
|
268 |
+
previous.word = previous.word + following.word
|
269 |
+
previous.tokens = previous.tokens + following.tokens
|
270 |
+
following.word = ""
|
271 |
+
following.tokens = []
|
272 |
+
else:
|
273 |
+
i = j
|
274 |
+
j += 1
|
275 |
+
|
276 |
+
|
277 |
+
def add_word_timestamps(
|
278 |
+
*,
|
279 |
+
segments: List[dict],
|
280 |
+
model: "Whisper",
|
281 |
+
tokenizer: Tokenizer,
|
282 |
+
mel: torch.Tensor,
|
283 |
+
num_frames: int,
|
284 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
285 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
286 |
+
last_speech_timestamp: float,
|
287 |
+
**kwargs,
|
288 |
+
):
|
289 |
+
if len(segments) == 0:
|
290 |
+
return
|
291 |
+
|
292 |
+
text_tokens_per_segment = [
|
293 |
+
[token for token in segment["tokens"] if token < tokenizer.eot]
|
294 |
+
for segment in segments
|
295 |
+
]
|
296 |
+
|
297 |
+
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
298 |
+
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
299 |
+
word_durations = np.array([t.end - t.start for t in alignment])
|
300 |
+
word_durations = word_durations[word_durations.nonzero()]
|
301 |
+
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
302 |
+
median_duration = min(0.7, float(median_duration))
|
303 |
+
max_duration = median_duration * 2
|
304 |
+
|
305 |
+
# hack: truncate long words at sentence boundaries.
|
306 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
307 |
+
if len(word_durations) > 0:
|
308 |
+
sentence_end_marks = ".。!!??"
|
309 |
+
# ensure words at sentence boundaries are not longer than twice the median word duration.
|
310 |
+
for i in range(1, len(alignment)):
|
311 |
+
if alignment[i].end - alignment[i].start > max_duration:
|
312 |
+
if alignment[i].word in sentence_end_marks:
|
313 |
+
alignment[i].end = alignment[i].start + max_duration
|
314 |
+
elif alignment[i - 1].word in sentence_end_marks:
|
315 |
+
alignment[i].start = alignment[i].end - max_duration
|
316 |
+
|
317 |
+
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
318 |
+
|
319 |
+
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
320 |
+
word_index = 0
|
321 |
+
|
322 |
+
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
323 |
+
saved_tokens = 0
|
324 |
+
words = []
|
325 |
+
|
326 |
+
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
327 |
+
timing = alignment[word_index]
|
328 |
+
|
329 |
+
if timing.word:
|
330 |
+
words.append(
|
331 |
+
dict(
|
332 |
+
word=timing.word,
|
333 |
+
start=round(time_offset + timing.start, 2),
|
334 |
+
end=round(time_offset + timing.end, 2),
|
335 |
+
probability=timing.probability,
|
336 |
+
)
|
337 |
+
)
|
338 |
+
|
339 |
+
saved_tokens += len(timing.tokens)
|
340 |
+
word_index += 1
|
341 |
+
|
342 |
+
# hack: truncate long words at segment boundaries.
|
343 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
344 |
+
if len(words) > 0:
|
345 |
+
# ensure the first and second word after a pause is not longer than
|
346 |
+
# twice the median word duration.
|
347 |
+
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
348 |
+
words[0]["end"] - words[0]["start"] > max_duration
|
349 |
+
or (
|
350 |
+
len(words) > 1
|
351 |
+
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
352 |
+
)
|
353 |
+
):
|
354 |
+
if (
|
355 |
+
len(words) > 1
|
356 |
+
and words[1]["end"] - words[1]["start"] > max_duration
|
357 |
+
):
|
358 |
+
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
359 |
+
words[0]["end"] = words[1]["start"] = boundary
|
360 |
+
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
361 |
+
|
362 |
+
# prefer the segment-level start timestamp if the first word is too long.
|
363 |
+
if (
|
364 |
+
segment["start"] < words[0]["end"]
|
365 |
+
and segment["start"] - 0.5 > words[0]["start"]
|
366 |
+
):
|
367 |
+
words[0]["start"] = max(
|
368 |
+
0, min(words[0]["end"] - median_duration, segment["start"])
|
369 |
+
)
|
370 |
+
else:
|
371 |
+
segment["start"] = words[0]["start"]
|
372 |
+
|
373 |
+
# prefer the segment-level end timestamp if the last word is too long.
|
374 |
+
if (
|
375 |
+
segment["end"] > words[-1]["start"]
|
376 |
+
and segment["end"] + 0.5 < words[-1]["end"]
|
377 |
+
):
|
378 |
+
words[-1]["end"] = max(
|
379 |
+
words[-1]["start"] + median_duration, segment["end"]
|
380 |
+
)
|
381 |
+
else:
|
382 |
+
segment["end"] = words[-1]["end"]
|
383 |
+
|
384 |
+
last_speech_timestamp = segment["end"]
|
385 |
+
|
386 |
+
segment["words"] = words
|
whisper_stream/tokenizer.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import os
|
3 |
+
import string
|
4 |
+
from dataclasses import dataclass, field
|
5 |
+
from functools import cached_property, lru_cache
|
6 |
+
from typing import Dict, List, Optional, Tuple
|
7 |
+
|
8 |
+
import tiktoken
|
9 |
+
|
10 |
+
LANGUAGES = {
|
11 |
+
"en": "english",
|
12 |
+
"zh": "chinese",
|
13 |
+
"de": "german",
|
14 |
+
"es": "spanish",
|
15 |
+
"ru": "russian",
|
16 |
+
"ko": "korean",
|
17 |
+
"fr": "french",
|
18 |
+
"ja": "japanese",
|
19 |
+
"pt": "portuguese",
|
20 |
+
"tr": "turkish",
|
21 |
+
"pl": "polish",
|
22 |
+
"ca": "catalan",
|
23 |
+
"nl": "dutch",
|
24 |
+
"ar": "arabic",
|
25 |
+
"sv": "swedish",
|
26 |
+
"it": "italian",
|
27 |
+
"id": "indonesian",
|
28 |
+
"hi": "hindi",
|
29 |
+
"fi": "finnish",
|
30 |
+
"vi": "vietnamese",
|
31 |
+
"he": "hebrew",
|
32 |
+
"uk": "ukrainian",
|
33 |
+
"el": "greek",
|
34 |
+
"ms": "malay",
|
35 |
+
"cs": "czech",
|
36 |
+
"ro": "romanian",
|
37 |
+
"da": "danish",
|
38 |
+
"hu": "hungarian",
|
39 |
+
"ta": "tamil",
|
40 |
+
"no": "norwegian",
|
41 |
+
"th": "thai",
|
42 |
+
"ur": "urdu",
|
43 |
+
"hr": "croatian",
|
44 |
+
"bg": "bulgarian",
|
45 |
+
"lt": "lithuanian",
|
46 |
+
"la": "latin",
|
47 |
+
"mi": "maori",
|
48 |
+
"ml": "malayalam",
|
49 |
+
"cy": "welsh",
|
50 |
+
"sk": "slovak",
|
51 |
+
"te": "telugu",
|
52 |
+
"fa": "persian",
|
53 |
+
"lv": "latvian",
|
54 |
+
"bn": "bengali",
|
55 |
+
"sr": "serbian",
|
56 |
+
"az": "azerbaijani",
|
57 |
+
"sl": "slovenian",
|
58 |
+
"kn": "kannada",
|
59 |
+
"et": "estonian",
|
60 |
+
"mk": "macedonian",
|
61 |
+
"br": "breton",
|
62 |
+
"eu": "basque",
|
63 |
+
"is": "icelandic",
|
64 |
+
"hy": "armenian",
|
65 |
+
"ne": "nepali",
|
66 |
+
"mn": "mongolian",
|
67 |
+
"bs": "bosnian",
|
68 |
+
"kk": "kazakh",
|
69 |
+
"sq": "albanian",
|
70 |
+
"sw": "swahili",
|
71 |
+
"gl": "galician",
|
72 |
+
"mr": "marathi",
|
73 |
+
"pa": "punjabi",
|
74 |
+
"si": "sinhala",
|
75 |
+
"km": "khmer",
|
76 |
+
"sn": "shona",
|
77 |
+
"yo": "yoruba",
|
78 |
+
"so": "somali",
|
79 |
+
"af": "afrikaans",
|
80 |
+
"oc": "occitan",
|
81 |
+
"ka": "georgian",
|
82 |
+
"be": "belarusian",
|
83 |
+
"tg": "tajik",
|
84 |
+
"sd": "sindhi",
|
85 |
+
"gu": "gujarati",
|
86 |
+
"am": "amharic",
|
87 |
+
"yi": "yiddish",
|
88 |
+
"lo": "lao",
|
89 |
+
"uz": "uzbek",
|
90 |
+
"fo": "faroese",
|
91 |
+
"ht": "haitian creole",
|
92 |
+
"ps": "pashto",
|
93 |
+
"tk": "turkmen",
|
94 |
+
"nn": "nynorsk",
|
95 |
+
"mt": "maltese",
|
96 |
+
"sa": "sanskrit",
|
97 |
+
"lb": "luxembourgish",
|
98 |
+
"my": "myanmar",
|
99 |
+
"bo": "tibetan",
|
100 |
+
"tl": "tagalog",
|
101 |
+
"mg": "malagasy",
|
102 |
+
"as": "assamese",
|
103 |
+
"tt": "tatar",
|
104 |
+
"haw": "hawaiian",
|
105 |
+
"ln": "lingala",
|
106 |
+
"ha": "hausa",
|
107 |
+
"ba": "bashkir",
|
108 |
+
"jw": "javanese",
|
109 |
+
"su": "sundanese",
|
110 |
+
"yue": "cantonese",
|
111 |
+
}
|
112 |
+
|
113 |
+
# language code lookup by name, with a few language aliases
|
114 |
+
TO_LANGUAGE_CODE = {
|
115 |
+
**{language: code for code, language in LANGUAGES.items()},
|
116 |
+
"burmese": "my",
|
117 |
+
"valencian": "ca",
|
118 |
+
"flemish": "nl",
|
119 |
+
"haitian": "ht",
|
120 |
+
"letzeburgesch": "lb",
|
121 |
+
"pushto": "ps",
|
122 |
+
"panjabi": "pa",
|
123 |
+
"moldavian": "ro",
|
124 |
+
"moldovan": "ro",
|
125 |
+
"sinhalese": "si",
|
126 |
+
"castilian": "es",
|
127 |
+
"mandarin": "zh",
|
128 |
+
}
|
129 |
+
|
130 |
+
|
131 |
+
@dataclass
|
132 |
+
class Tokenizer:
|
133 |
+
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
134 |
+
|
135 |
+
encoding: tiktoken.Encoding
|
136 |
+
num_languages: int
|
137 |
+
language: Optional[str] = None
|
138 |
+
task: Optional[str] = None
|
139 |
+
sot_sequence: Tuple[int] = ()
|
140 |
+
special_tokens: Dict[str, int] = field(default_factory=dict)
|
141 |
+
|
142 |
+
def __post_init__(self):
|
143 |
+
for special in self.encoding.special_tokens_set:
|
144 |
+
special_token = self.encoding.encode_single_token(special)
|
145 |
+
self.special_tokens[special] = special_token
|
146 |
+
|
147 |
+
sot: int = self.special_tokens["<|startoftranscript|>"]
|
148 |
+
translate: int = self.special_tokens["<|translate|>"]
|
149 |
+
transcribe: int = self.special_tokens["<|transcribe|>"]
|
150 |
+
|
151 |
+
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
152 |
+
sot_sequence = [sot]
|
153 |
+
if self.language is not None:
|
154 |
+
sot_sequence.append(sot + 1 + langs.index(self.language))
|
155 |
+
if self.task is not None:
|
156 |
+
task_token: int = transcribe if self.task == "transcribe" else translate
|
157 |
+
sot_sequence.append(task_token)
|
158 |
+
|
159 |
+
self.sot_sequence = tuple(sot_sequence)
|
160 |
+
|
161 |
+
def encode(self, text, **kwargs):
|
162 |
+
return self.encoding.encode(text, **kwargs)
|
163 |
+
|
164 |
+
def decode(self, token_ids: List[int], **kwargs) -> str:
|
165 |
+
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
166 |
+
return self.encoding.decode(token_ids, **kwargs)
|
167 |
+
|
168 |
+
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
169 |
+
"""
|
170 |
+
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
171 |
+
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
172 |
+
"""
|
173 |
+
return self.encoding.decode(token_ids, **kwargs)
|
174 |
+
|
175 |
+
@cached_property
|
176 |
+
def eot(self) -> int:
|
177 |
+
return self.encoding.eot_token
|
178 |
+
|
179 |
+
@cached_property
|
180 |
+
def transcribe(self) -> int:
|
181 |
+
return self.special_tokens["<|transcribe|>"]
|
182 |
+
|
183 |
+
@cached_property
|
184 |
+
def translate(self) -> int:
|
185 |
+
return self.special_tokens["<|translate|>"]
|
186 |
+
|
187 |
+
@cached_property
|
188 |
+
def sot(self) -> int:
|
189 |
+
return self.special_tokens["<|startoftranscript|>"]
|
190 |
+
|
191 |
+
@cached_property
|
192 |
+
def sot_lm(self) -> int:
|
193 |
+
return self.special_tokens["<|startoflm|>"]
|
194 |
+
|
195 |
+
@cached_property
|
196 |
+
def sot_prev(self) -> int:
|
197 |
+
return self.special_tokens["<|startofprev|>"]
|
198 |
+
|
199 |
+
@cached_property
|
200 |
+
def no_speech(self) -> int:
|
201 |
+
return self.special_tokens["<|nospeech|>"]
|
202 |
+
|
203 |
+
@cached_property
|
204 |
+
def no_timestamps(self) -> int:
|
205 |
+
return self.special_tokens["<|notimestamps|>"]
|
206 |
+
|
207 |
+
@cached_property
|
208 |
+
def timestamp_begin(self) -> int:
|
209 |
+
return self.special_tokens["<|0.00|>"]
|
210 |
+
|
211 |
+
@cached_property
|
212 |
+
def language_token(self) -> int:
|
213 |
+
"""Returns the token id corresponding to the value of the `language` field"""
|
214 |
+
if self.language is None:
|
215 |
+
raise ValueError("This tokenizer does not have language token configured")
|
216 |
+
|
217 |
+
return self.to_language_token(self.language)
|
218 |
+
|
219 |
+
def to_language_token(self, language):
|
220 |
+
if token := self.special_tokens.get(f"<|{language}|>", None):
|
221 |
+
return token
|
222 |
+
|
223 |
+
raise KeyError(f"Language {language} not found in tokenizer.")
|
224 |
+
|
225 |
+
@cached_property
|
226 |
+
def all_language_tokens(self) -> Tuple[int]:
|
227 |
+
result = []
|
228 |
+
for token, token_id in self.special_tokens.items():
|
229 |
+
if token.strip("<|>") in LANGUAGES:
|
230 |
+
result.append(token_id)
|
231 |
+
return tuple(result)[: self.num_languages]
|
232 |
+
|
233 |
+
@cached_property
|
234 |
+
def all_language_codes(self) -> Tuple[str]:
|
235 |
+
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
236 |
+
|
237 |
+
@cached_property
|
238 |
+
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
239 |
+
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
240 |
+
|
241 |
+
@cached_property
|
242 |
+
def non_speech_tokens(self) -> Tuple[int]:
|
243 |
+
"""
|
244 |
+
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
245 |
+
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
246 |
+
|
247 |
+
- ♪♪♪
|
248 |
+
- ( SPEAKING FOREIGN LANGUAGE )
|
249 |
+
- [DAVID] Hey there,
|
250 |
+
|
251 |
+
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
252 |
+
"""
|
253 |
+
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
254 |
+
symbols += (
|
255 |
+
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
256 |
+
)
|
257 |
+
|
258 |
+
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
259 |
+
# In case they're multiple tokens, suppress the first token, which is safe because:
|
260 |
+
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
261 |
+
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
262 |
+
miscellaneous = set("♩♪♫♬♭♮♯")
|
263 |
+
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
264 |
+
|
265 |
+
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
266 |
+
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
267 |
+
for symbol in symbols + list(miscellaneous):
|
268 |
+
for tokens in [
|
269 |
+
self.encoding.encode(symbol),
|
270 |
+
self.encoding.encode(" " + symbol),
|
271 |
+
]:
|
272 |
+
if len(tokens) == 1 or symbol in miscellaneous:
|
273 |
+
result.add(tokens[0])
|
274 |
+
|
275 |
+
return tuple(sorted(result))
|
276 |
+
|
277 |
+
def split_to_word_tokens(self, tokens: List[int]):
|
278 |
+
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
279 |
+
# These languages don't typically use spaces, so it is difficult to split words
|
280 |
+
# without morpheme analysis. Here, we instead split words at any
|
281 |
+
# position where the tokens are decoded as valid unicode points
|
282 |
+
return self.split_tokens_on_unicode(tokens)
|
283 |
+
|
284 |
+
return self.split_tokens_on_spaces(tokens)
|
285 |
+
|
286 |
+
def split_tokens_on_unicode(self, tokens: List[int]):
|
287 |
+
decoded_full = self.decode_with_timestamps(tokens)
|
288 |
+
replacement_char = "\ufffd"
|
289 |
+
|
290 |
+
words = []
|
291 |
+
word_tokens = []
|
292 |
+
current_tokens = []
|
293 |
+
unicode_offset = 0
|
294 |
+
|
295 |
+
for token in tokens:
|
296 |
+
current_tokens.append(token)
|
297 |
+
decoded = self.decode_with_timestamps(current_tokens)
|
298 |
+
|
299 |
+
if (
|
300 |
+
replacement_char not in decoded
|
301 |
+
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
302 |
+
== replacement_char
|
303 |
+
):
|
304 |
+
words.append(decoded)
|
305 |
+
word_tokens.append(current_tokens)
|
306 |
+
current_tokens = []
|
307 |
+
unicode_offset += len(decoded)
|
308 |
+
|
309 |
+
return words, word_tokens
|
310 |
+
|
311 |
+
def split_tokens_on_spaces(self, tokens: List[int]):
|
312 |
+
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
313 |
+
words = []
|
314 |
+
word_tokens = []
|
315 |
+
|
316 |
+
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
317 |
+
special = subword_tokens[0] >= self.eot
|
318 |
+
with_space = subword.startswith(" ")
|
319 |
+
punctuation = subword.strip() in string.punctuation
|
320 |
+
if special or with_space or punctuation or len(words) == 0:
|
321 |
+
words.append(subword)
|
322 |
+
word_tokens.append(subword_tokens)
|
323 |
+
else:
|
324 |
+
words[-1] = words[-1] + subword
|
325 |
+
word_tokens[-1].extend(subword_tokens)
|
326 |
+
|
327 |
+
return words, word_tokens
|
328 |
+
|
329 |
+
|
330 |
+
@lru_cache(maxsize=None)
|
331 |
+
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
332 |
+
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
333 |
+
ranks = {
|
334 |
+
base64.b64decode(token): int(rank)
|
335 |
+
for token, rank in (line.split() for line in open(vocab_path) if line)
|
336 |
+
}
|
337 |
+
n_vocab = len(ranks)
|
338 |
+
special_tokens = {}
|
339 |
+
|
340 |
+
specials = [
|
341 |
+
"<|endoftext|>",
|
342 |
+
"<|startoftranscript|>",
|
343 |
+
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
344 |
+
"<|translate|>",
|
345 |
+
"<|transcribe|>",
|
346 |
+
"<|startoflm|>",
|
347 |
+
"<|startofprev|>",
|
348 |
+
"<|nospeech|>",
|
349 |
+
"<|notimestamps|>",
|
350 |
+
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
351 |
+
]
|
352 |
+
|
353 |
+
for token in specials:
|
354 |
+
special_tokens[token] = n_vocab
|
355 |
+
n_vocab += 1
|
356 |
+
|
357 |
+
return tiktoken.Encoding(
|
358 |
+
name=os.path.basename(vocab_path),
|
359 |
+
explicit_n_vocab=n_vocab,
|
360 |
+
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
361 |
+
mergeable_ranks=ranks,
|
362 |
+
special_tokens=special_tokens,
|
363 |
+
)
|
364 |
+
|
365 |
+
|
366 |
+
@lru_cache(maxsize=None)
|
367 |
+
def get_tokenizer(
|
368 |
+
multilingual: bool,
|
369 |
+
*,
|
370 |
+
num_languages: int = 99,
|
371 |
+
language: Optional[str] = None,
|
372 |
+
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
373 |
+
) -> Tokenizer:
|
374 |
+
if language is not None:
|
375 |
+
language = language.lower()
|
376 |
+
if language not in LANGUAGES:
|
377 |
+
if language in TO_LANGUAGE_CODE:
|
378 |
+
language = TO_LANGUAGE_CODE[language]
|
379 |
+
else:
|
380 |
+
raise ValueError(f"Unsupported language: {language}")
|
381 |
+
|
382 |
+
if multilingual:
|
383 |
+
encoding_name = "multilingual"
|
384 |
+
language = language or "en"
|
385 |
+
task = task or "transcribe"
|
386 |
+
else:
|
387 |
+
encoding_name = "gpt2"
|
388 |
+
language = None
|
389 |
+
task = None
|
390 |
+
|
391 |
+
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
392 |
+
|
393 |
+
return Tokenizer(
|
394 |
+
encoding=encoding, num_languages=num_languages, language=language, task=task
|
395 |
+
)
|
whisper_stream/transcribe.py
ADDED
@@ -0,0 +1,605 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import traceback
|
4 |
+
import warnings
|
5 |
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import tqdm
|
10 |
+
|
11 |
+
from .audio import (
|
12 |
+
FRAMES_PER_SECOND,
|
13 |
+
HOP_LENGTH,
|
14 |
+
N_FRAMES,
|
15 |
+
N_SAMPLES,
|
16 |
+
SAMPLE_RATE,
|
17 |
+
log_mel_spectrogram,
|
18 |
+
pad_or_trim,
|
19 |
+
)
|
20 |
+
from .decoding import DecodingOptions, DecodingResult
|
21 |
+
from .timing import add_word_timestamps
|
22 |
+
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
23 |
+
from .utils import (
|
24 |
+
exact_div,
|
25 |
+
format_timestamp,
|
26 |
+
get_end,
|
27 |
+
get_writer,
|
28 |
+
make_safe,
|
29 |
+
optional_float,
|
30 |
+
optional_int,
|
31 |
+
str2bool,
|
32 |
+
)
|
33 |
+
|
34 |
+
if TYPE_CHECKING:
|
35 |
+
from .model import Whisper
|
36 |
+
|
37 |
+
|
38 |
+
def transcribe(
|
39 |
+
model: "Whisper",
|
40 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
41 |
+
*,
|
42 |
+
verbose: Optional[bool] = None,
|
43 |
+
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
44 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
45 |
+
logprob_threshold: Optional[float] = -1.0,
|
46 |
+
no_speech_threshold: Optional[float] = 0.6,
|
47 |
+
condition_on_previous_text: bool = True,
|
48 |
+
initial_prompt: Optional[str] = None,
|
49 |
+
word_timestamps: bool = False,
|
50 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
51 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
52 |
+
clip_timestamps: Union[str, List[float]] = "0",
|
53 |
+
hallucination_silence_threshold: Optional[float] = None,
|
54 |
+
**decode_options,
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
Transcribe an audio file using Whisper
|
58 |
+
|
59 |
+
Parameters
|
60 |
+
----------
|
61 |
+
model: Whisper
|
62 |
+
The Whisper model instance
|
63 |
+
|
64 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
65 |
+
The path to the audio file to open, or the audio waveform
|
66 |
+
|
67 |
+
verbose: bool
|
68 |
+
Whether to display the text being decoded to the console. If True, displays all the details,
|
69 |
+
If False, displays minimal details. If None, does not display anything
|
70 |
+
|
71 |
+
temperature: Union[float, Tuple[float, ...]]
|
72 |
+
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
73 |
+
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
74 |
+
|
75 |
+
compression_ratio_threshold: float
|
76 |
+
If the gzip compression ratio is above this value, treat as failed
|
77 |
+
|
78 |
+
logprob_threshold: float
|
79 |
+
If the average log probability over sampled tokens is below this value, treat as failed
|
80 |
+
|
81 |
+
no_speech_threshold: float
|
82 |
+
If the no_speech probability is higher than this value AND the average log probability
|
83 |
+
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
84 |
+
|
85 |
+
condition_on_previous_text: bool
|
86 |
+
if True, the previous output of the model is provided as a prompt for the next window;
|
87 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
88 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
89 |
+
|
90 |
+
word_timestamps: bool
|
91 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
92 |
+
and include the timestamps for each word in each segment.
|
93 |
+
|
94 |
+
prepend_punctuations: str
|
95 |
+
If word_timestamps is True, merge these punctuation symbols with the next word
|
96 |
+
|
97 |
+
append_punctuations: str
|
98 |
+
If word_timestamps is True, merge these punctuation symbols with the previous word
|
99 |
+
|
100 |
+
initial_prompt: Optional[str]
|
101 |
+
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
102 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
103 |
+
to make it more likely to predict those word correctly.
|
104 |
+
|
105 |
+
decode_options: dict
|
106 |
+
Keyword arguments to construct `DecodingOptions` instances
|
107 |
+
|
108 |
+
clip_timestamps: Union[str, List[float]]
|
109 |
+
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
110 |
+
The last end timestamp defaults to the end of the file.
|
111 |
+
|
112 |
+
hallucination_silence_threshold: Optional[float]
|
113 |
+
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
114 |
+
when a possible hallucination is detected
|
115 |
+
|
116 |
+
Returns
|
117 |
+
-------
|
118 |
+
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
119 |
+
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
120 |
+
"""
|
121 |
+
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
122 |
+
if model.device == torch.device("cpu"):
|
123 |
+
if torch.cuda.is_available():
|
124 |
+
warnings.warn("Performing inference on CPU when CUDA is available")
|
125 |
+
if dtype == torch.float16:
|
126 |
+
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
127 |
+
dtype = torch.float32
|
128 |
+
|
129 |
+
if dtype == torch.float32:
|
130 |
+
decode_options["fp16"] = False
|
131 |
+
|
132 |
+
# Pad 30-seconds of silence to the input audio, for slicing
|
133 |
+
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
134 |
+
content_frames = mel.shape[-1] - N_FRAMES
|
135 |
+
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
136 |
+
|
137 |
+
if decode_options.get("language", None) is None:
|
138 |
+
if not model.is_multilingual:
|
139 |
+
decode_options["language"] = "en"
|
140 |
+
else:
|
141 |
+
if verbose:
|
142 |
+
print(
|
143 |
+
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
144 |
+
)
|
145 |
+
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
146 |
+
_, probs = model.detect_language(mel_segment)
|
147 |
+
decode_options["language"] = max(probs, key=probs.get)
|
148 |
+
if verbose is not None:
|
149 |
+
print(
|
150 |
+
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
151 |
+
)
|
152 |
+
|
153 |
+
language: str = decode_options["language"]
|
154 |
+
task: str = decode_options.get("task", "transcribe")
|
155 |
+
tokenizer = get_tokenizer(
|
156 |
+
model.is_multilingual,
|
157 |
+
num_languages=model.num_languages,
|
158 |
+
language=language,
|
159 |
+
task=task,
|
160 |
+
)
|
161 |
+
|
162 |
+
if isinstance(clip_timestamps, str):
|
163 |
+
clip_timestamps = [
|
164 |
+
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
165 |
+
]
|
166 |
+
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
167 |
+
if len(seek_points) == 0:
|
168 |
+
seek_points.append(0)
|
169 |
+
if len(seek_points) % 2 == 1:
|
170 |
+
seek_points.append(content_frames)
|
171 |
+
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
172 |
+
|
173 |
+
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
174 |
+
|
175 |
+
if word_timestamps and task == "translate":
|
176 |
+
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
177 |
+
|
178 |
+
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
179 |
+
temperatures = (
|
180 |
+
[temperature] if isinstance(temperature, (int, float)) else temperature
|
181 |
+
)
|
182 |
+
decode_result = None
|
183 |
+
|
184 |
+
for t in temperatures:
|
185 |
+
kwargs = {**decode_options}
|
186 |
+
if t > 0:
|
187 |
+
# disable beam_size and patience when t > 0
|
188 |
+
kwargs.pop("beam_size", None)
|
189 |
+
kwargs.pop("patience", None)
|
190 |
+
else:
|
191 |
+
# disable best_of when t == 0
|
192 |
+
kwargs.pop("best_of", None)
|
193 |
+
|
194 |
+
options = DecodingOptions(**kwargs, temperature=t)
|
195 |
+
decode_result = model.decode(segment, options)
|
196 |
+
|
197 |
+
needs_fallback = False
|
198 |
+
if (
|
199 |
+
compression_ratio_threshold is not None
|
200 |
+
and decode_result.compression_ratio > compression_ratio_threshold
|
201 |
+
):
|
202 |
+
needs_fallback = True # too repetitive
|
203 |
+
if (
|
204 |
+
logprob_threshold is not None
|
205 |
+
and decode_result.avg_logprob < logprob_threshold
|
206 |
+
):
|
207 |
+
needs_fallback = True # average log probability is too low
|
208 |
+
if (
|
209 |
+
no_speech_threshold is not None
|
210 |
+
and decode_result.no_speech_prob > no_speech_threshold
|
211 |
+
):
|
212 |
+
needs_fallback = False # silence
|
213 |
+
if not needs_fallback:
|
214 |
+
break
|
215 |
+
|
216 |
+
return decode_result
|
217 |
+
|
218 |
+
clip_idx = 0
|
219 |
+
seek = seek_clips[clip_idx][0]
|
220 |
+
input_stride = exact_div(
|
221 |
+
N_FRAMES, model.dims.n_audio_ctx
|
222 |
+
) # mel frames per output token: 2
|
223 |
+
time_precision = (
|
224 |
+
input_stride * HOP_LENGTH / SAMPLE_RATE
|
225 |
+
) # time per output token: 0.02 (seconds)
|
226 |
+
all_tokens = []
|
227 |
+
all_segments = []
|
228 |
+
prompt_reset_since = 0
|
229 |
+
|
230 |
+
if initial_prompt is not None:
|
231 |
+
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
232 |
+
all_tokens.extend(initial_prompt_tokens)
|
233 |
+
else:
|
234 |
+
initial_prompt_tokens = []
|
235 |
+
|
236 |
+
def new_segment(
|
237 |
+
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
238 |
+
):
|
239 |
+
tokens = tokens.tolist()
|
240 |
+
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
241 |
+
return {
|
242 |
+
"seek": seek,
|
243 |
+
"start": start,
|
244 |
+
"end": end,
|
245 |
+
"text": tokenizer.decode(text_tokens),
|
246 |
+
"tokens": tokens,
|
247 |
+
"temperature": result.temperature,
|
248 |
+
"avg_logprob": result.avg_logprob,
|
249 |
+
"compression_ratio": result.compression_ratio,
|
250 |
+
"no_speech_prob": result.no_speech_prob,
|
251 |
+
}
|
252 |
+
|
253 |
+
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
254 |
+
with tqdm.tqdm(
|
255 |
+
total=content_frames, unit="frames", disable=verbose is not False
|
256 |
+
) as pbar:
|
257 |
+
last_speech_timestamp = 0.0
|
258 |
+
# NOTE: This loop is obscurely flattened to make the diff readable.
|
259 |
+
# A later commit should turn this into a simpler nested loop.
|
260 |
+
# for seek_clip_start, seek_clip_end in seek_clips:
|
261 |
+
# while seek < seek_clip_end
|
262 |
+
while clip_idx < len(seek_clips):
|
263 |
+
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
264 |
+
if seek < seek_clip_start:
|
265 |
+
seek = seek_clip_start
|
266 |
+
if seek >= seek_clip_end:
|
267 |
+
clip_idx += 1
|
268 |
+
if clip_idx < len(seek_clips):
|
269 |
+
seek = seek_clips[clip_idx][0]
|
270 |
+
continue
|
271 |
+
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
272 |
+
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
273 |
+
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
274 |
+
mel_segment = mel[:, seek : seek + segment_size]
|
275 |
+
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
276 |
+
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
277 |
+
|
278 |
+
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
279 |
+
result: DecodingResult = decode_with_fallback(mel_segment)
|
280 |
+
tokens = torch.tensor(result.tokens)
|
281 |
+
|
282 |
+
if no_speech_threshold is not None:
|
283 |
+
# no voice activity check
|
284 |
+
should_skip = result.no_speech_prob > no_speech_threshold
|
285 |
+
if (
|
286 |
+
logprob_threshold is not None
|
287 |
+
and result.avg_logprob > logprob_threshold
|
288 |
+
):
|
289 |
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
290 |
+
should_skip = False
|
291 |
+
|
292 |
+
if should_skip:
|
293 |
+
seek += segment_size # fast-forward to the next segment boundary
|
294 |
+
continue
|
295 |
+
|
296 |
+
previous_seek = seek
|
297 |
+
current_segments = []
|
298 |
+
|
299 |
+
# anomalous words are very long/short/improbable
|
300 |
+
def word_anomaly_score(word: dict) -> float:
|
301 |
+
probability = word.get("probability", 0.0)
|
302 |
+
duration = word["end"] - word["start"]
|
303 |
+
score = 0.0
|
304 |
+
if probability < 0.15:
|
305 |
+
score += 1.0
|
306 |
+
if duration < 0.133:
|
307 |
+
score += (0.133 - duration) * 15
|
308 |
+
if duration > 2.0:
|
309 |
+
score += duration - 2.0
|
310 |
+
return score
|
311 |
+
|
312 |
+
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
313 |
+
if segment is None or not segment["words"]:
|
314 |
+
return False
|
315 |
+
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
316 |
+
words = words[:8]
|
317 |
+
score = sum(word_anomaly_score(w) for w in words)
|
318 |
+
return score >= 3 or score + 0.01 >= len(words)
|
319 |
+
|
320 |
+
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
321 |
+
return next((s for s in segments if s["words"]), None)
|
322 |
+
|
323 |
+
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
324 |
+
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
325 |
+
|
326 |
+
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
327 |
+
consecutive.add_(1)
|
328 |
+
if len(consecutive) > 0:
|
329 |
+
# if the output contains two consecutive timestamp tokens
|
330 |
+
slices = consecutive.tolist()
|
331 |
+
if single_timestamp_ending:
|
332 |
+
slices.append(len(tokens))
|
333 |
+
|
334 |
+
last_slice = 0
|
335 |
+
for current_slice in slices:
|
336 |
+
sliced_tokens = tokens[last_slice:current_slice]
|
337 |
+
start_timestamp_pos = (
|
338 |
+
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
339 |
+
)
|
340 |
+
end_timestamp_pos = (
|
341 |
+
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
342 |
+
)
|
343 |
+
current_segments.append(
|
344 |
+
new_segment(
|
345 |
+
start=time_offset + start_timestamp_pos * time_precision,
|
346 |
+
end=time_offset + end_timestamp_pos * time_precision,
|
347 |
+
tokens=sliced_tokens,
|
348 |
+
result=result,
|
349 |
+
)
|
350 |
+
)
|
351 |
+
last_slice = current_slice
|
352 |
+
|
353 |
+
if single_timestamp_ending:
|
354 |
+
# single timestamp at the end means no speech after the last timestamp.
|
355 |
+
seek += segment_size
|
356 |
+
else:
|
357 |
+
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
358 |
+
last_timestamp_pos = (
|
359 |
+
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
360 |
+
)
|
361 |
+
seek += last_timestamp_pos * input_stride
|
362 |
+
else:
|
363 |
+
duration = segment_duration
|
364 |
+
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
365 |
+
if (
|
366 |
+
len(timestamps) > 0
|
367 |
+
and timestamps[-1].item() != tokenizer.timestamp_begin
|
368 |
+
):
|
369 |
+
# no consecutive timestamps but it has a timestamp; use the last one.
|
370 |
+
last_timestamp_pos = (
|
371 |
+
timestamps[-1].item() - tokenizer.timestamp_begin
|
372 |
+
)
|
373 |
+
duration = last_timestamp_pos * time_precision
|
374 |
+
|
375 |
+
current_segments.append(
|
376 |
+
new_segment(
|
377 |
+
start=time_offset,
|
378 |
+
end=time_offset + duration,
|
379 |
+
tokens=tokens,
|
380 |
+
result=result,
|
381 |
+
)
|
382 |
+
)
|
383 |
+
seek += segment_size
|
384 |
+
|
385 |
+
if word_timestamps:
|
386 |
+
add_word_timestamps(
|
387 |
+
segments=current_segments,
|
388 |
+
model=model,
|
389 |
+
tokenizer=tokenizer,
|
390 |
+
mel=mel_segment,
|
391 |
+
num_frames=segment_size,
|
392 |
+
prepend_punctuations=prepend_punctuations,
|
393 |
+
append_punctuations=append_punctuations,
|
394 |
+
last_speech_timestamp=last_speech_timestamp,
|
395 |
+
)
|
396 |
+
|
397 |
+
if not single_timestamp_ending:
|
398 |
+
last_word_end = get_end(current_segments)
|
399 |
+
if last_word_end is not None and last_word_end > time_offset:
|
400 |
+
seek = round(last_word_end * FRAMES_PER_SECOND)
|
401 |
+
|
402 |
+
# skip silence before possible hallucinations
|
403 |
+
if hallucination_silence_threshold is not None:
|
404 |
+
threshold = hallucination_silence_threshold
|
405 |
+
if not single_timestamp_ending:
|
406 |
+
last_word_end = get_end(current_segments)
|
407 |
+
if last_word_end is not None and last_word_end > time_offset:
|
408 |
+
remaining_duration = window_end_time - last_word_end
|
409 |
+
if remaining_duration > threshold:
|
410 |
+
seek = round(last_word_end * FRAMES_PER_SECOND)
|
411 |
+
else:
|
412 |
+
seek = previous_seek + segment_size
|
413 |
+
|
414 |
+
# if first segment might be a hallucination, skip leading silence
|
415 |
+
first_segment = next_words_segment(current_segments)
|
416 |
+
if first_segment is not None and is_segment_anomaly(first_segment):
|
417 |
+
gap = first_segment["start"] - time_offset
|
418 |
+
if gap > threshold:
|
419 |
+
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
420 |
+
continue
|
421 |
+
|
422 |
+
# skip silence before any possible hallucination that is surrounded
|
423 |
+
# by silence or more hallucinations
|
424 |
+
hal_last_end = last_speech_timestamp
|
425 |
+
for si in range(len(current_segments)):
|
426 |
+
segment = current_segments[si]
|
427 |
+
if not segment["words"]:
|
428 |
+
continue
|
429 |
+
if is_segment_anomaly(segment):
|
430 |
+
next_segment = next_words_segment(
|
431 |
+
current_segments[si + 1 :]
|
432 |
+
)
|
433 |
+
if next_segment is not None:
|
434 |
+
hal_next_start = next_segment["words"][0]["start"]
|
435 |
+
else:
|
436 |
+
hal_next_start = time_offset + segment_duration
|
437 |
+
silence_before = (
|
438 |
+
segment["start"] - hal_last_end > threshold
|
439 |
+
or segment["start"] < threshold
|
440 |
+
or segment["start"] - time_offset < 2.0
|
441 |
+
)
|
442 |
+
silence_after = (
|
443 |
+
hal_next_start - segment["end"] > threshold
|
444 |
+
or is_segment_anomaly(next_segment)
|
445 |
+
or window_end_time - segment["end"] < 2.0
|
446 |
+
)
|
447 |
+
if silence_before and silence_after:
|
448 |
+
seek = round(
|
449 |
+
max(time_offset + 1, segment["start"])
|
450 |
+
* FRAMES_PER_SECOND
|
451 |
+
)
|
452 |
+
if content_duration - segment["end"] < threshold:
|
453 |
+
seek = content_frames
|
454 |
+
current_segments[si:] = []
|
455 |
+
break
|
456 |
+
hal_last_end = segment["end"]
|
457 |
+
|
458 |
+
last_word_end = get_end(current_segments)
|
459 |
+
if last_word_end is not None:
|
460 |
+
last_speech_timestamp = last_word_end
|
461 |
+
|
462 |
+
if verbose:
|
463 |
+
for segment in current_segments:
|
464 |
+
start, end, text = segment["start"], segment["end"], segment["text"]
|
465 |
+
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
466 |
+
print(make_safe(line))
|
467 |
+
|
468 |
+
# if a segment is instantaneous or does not contain text, clear it
|
469 |
+
for i, segment in enumerate(current_segments):
|
470 |
+
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
471 |
+
segment["text"] = ""
|
472 |
+
segment["tokens"] = []
|
473 |
+
segment["words"] = []
|
474 |
+
|
475 |
+
all_segments.extend(
|
476 |
+
[
|
477 |
+
{"id": i, **segment}
|
478 |
+
for i, segment in enumerate(
|
479 |
+
current_segments, start=len(all_segments)
|
480 |
+
)
|
481 |
+
]
|
482 |
+
)
|
483 |
+
all_tokens.extend(
|
484 |
+
[token for segment in current_segments for token in segment["tokens"]]
|
485 |
+
)
|
486 |
+
|
487 |
+
if not condition_on_previous_text or result.temperature > 0.5:
|
488 |
+
# do not feed the prompt tokens if a high temperature was used
|
489 |
+
prompt_reset_since = len(all_tokens)
|
490 |
+
|
491 |
+
# update progress bar
|
492 |
+
pbar.update(min(content_frames, seek) - previous_seek)
|
493 |
+
|
494 |
+
return dict(
|
495 |
+
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
496 |
+
segments=all_segments,
|
497 |
+
language=language,
|
498 |
+
)
|
499 |
+
|
500 |
+
|
501 |
+
def cli():
|
502 |
+
from . import available_models
|
503 |
+
|
504 |
+
def valid_model_name(name):
|
505 |
+
if name in available_models() or os.path.exists(name):
|
506 |
+
return name
|
507 |
+
raise ValueError(
|
508 |
+
f"model should be one of {available_models()} or path to a model checkpoint"
|
509 |
+
)
|
510 |
+
|
511 |
+
# fmt: off
|
512 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
513 |
+
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
514 |
+
parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
|
515 |
+
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
516 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
517 |
+
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
518 |
+
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
519 |
+
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
520 |
+
|
521 |
+
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
522 |
+
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
523 |
+
|
524 |
+
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
525 |
+
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
526 |
+
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
527 |
+
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
528 |
+
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
529 |
+
|
530 |
+
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
531 |
+
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
532 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
533 |
+
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
534 |
+
|
535 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
536 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
537 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
538 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
539 |
+
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
540 |
+
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
541 |
+
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
542 |
+
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
543 |
+
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
544 |
+
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
545 |
+
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
546 |
+
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
547 |
+
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
548 |
+
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
549 |
+
# fmt: on
|
550 |
+
|
551 |
+
args = parser.parse_args().__dict__
|
552 |
+
model_name: str = args.pop("model")
|
553 |
+
model_dir: str = args.pop("model_dir")
|
554 |
+
output_dir: str = args.pop("output_dir")
|
555 |
+
output_format: str = args.pop("output_format")
|
556 |
+
device: str = args.pop("device")
|
557 |
+
os.makedirs(output_dir, exist_ok=True)
|
558 |
+
|
559 |
+
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
560 |
+
if args["language"] is not None:
|
561 |
+
warnings.warn(
|
562 |
+
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
563 |
+
)
|
564 |
+
args["language"] = "en"
|
565 |
+
|
566 |
+
temperature = args.pop("temperature")
|
567 |
+
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
568 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
569 |
+
else:
|
570 |
+
temperature = [temperature]
|
571 |
+
|
572 |
+
if (threads := args.pop("threads")) > 0:
|
573 |
+
torch.set_num_threads(threads)
|
574 |
+
|
575 |
+
from . import load_model
|
576 |
+
|
577 |
+
model = load_model(model_name, device=device, download_root=model_dir)
|
578 |
+
|
579 |
+
writer = get_writer(output_format, output_dir)
|
580 |
+
word_options = [
|
581 |
+
"highlight_words",
|
582 |
+
"max_line_count",
|
583 |
+
"max_line_width",
|
584 |
+
"max_words_per_line",
|
585 |
+
]
|
586 |
+
if not args["word_timestamps"]:
|
587 |
+
for option in word_options:
|
588 |
+
if args[option]:
|
589 |
+
parser.error(f"--{option} requires --word_timestamps True")
|
590 |
+
if args["max_line_count"] and not args["max_line_width"]:
|
591 |
+
warnings.warn("--max_line_count has no effect without --max_line_width")
|
592 |
+
if args["max_words_per_line"] and args["max_line_width"]:
|
593 |
+
warnings.warn("--max_words_per_line has no effect with --max_line_width")
|
594 |
+
writer_args = {arg: args.pop(arg) for arg in word_options}
|
595 |
+
for audio_path in args.pop("audio"):
|
596 |
+
try:
|
597 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
598 |
+
writer(result, audio_path, **writer_args)
|
599 |
+
except Exception as e:
|
600 |
+
traceback.print_exc()
|
601 |
+
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
602 |
+
|
603 |
+
|
604 |
+
if __name__ == "__main__":
|
605 |
+
cli()
|
whisper_stream/triton_ops.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import lru_cache
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
try:
|
7 |
+
import triton
|
8 |
+
import triton.language as tl
|
9 |
+
except ImportError:
|
10 |
+
raise RuntimeError("triton import failed; try `pip install --pre triton`")
|
11 |
+
|
12 |
+
|
13 |
+
@triton.jit
|
14 |
+
def dtw_kernel(
|
15 |
+
cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr
|
16 |
+
):
|
17 |
+
offsets = tl.arange(0, BLOCK_SIZE)
|
18 |
+
mask = offsets < M
|
19 |
+
|
20 |
+
for k in range(1, N + M + 1): # k = i + j
|
21 |
+
tl.debug_barrier()
|
22 |
+
|
23 |
+
p0 = cost + (k - 1) * cost_stride
|
24 |
+
p1 = cost + k * cost_stride
|
25 |
+
p2 = cost + k * cost_stride + 1
|
26 |
+
|
27 |
+
c0 = tl.load(p0 + offsets, mask=mask)
|
28 |
+
c1 = tl.load(p1 + offsets, mask=mask)
|
29 |
+
c2 = tl.load(p2 + offsets, mask=mask)
|
30 |
+
|
31 |
+
x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
|
32 |
+
cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
|
33 |
+
|
34 |
+
cost_ptr = cost + (k + 1) * cost_stride + 1
|
35 |
+
tl.store(cost_ptr + offsets, cost_row, mask=mask)
|
36 |
+
|
37 |
+
trace_ptr = trace + (k + 1) * trace_stride + 1
|
38 |
+
tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
|
39 |
+
tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
|
40 |
+
tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
|
41 |
+
|
42 |
+
|
43 |
+
@lru_cache(maxsize=None)
|
44 |
+
def median_kernel(filter_width: int):
|
45 |
+
@triton.jit
|
46 |
+
def kernel(
|
47 |
+
y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr
|
48 |
+
): # x.shape[-1] == filter_width
|
49 |
+
row_idx = tl.program_id(0)
|
50 |
+
offsets = tl.arange(0, BLOCK_SIZE)
|
51 |
+
mask = offsets < y_stride
|
52 |
+
|
53 |
+
x_ptr = x + row_idx * x_stride # noqa: F841
|
54 |
+
y_ptr = y + row_idx * y_stride
|
55 |
+
|
56 |
+
LOAD_ALL_ROWS_HERE # noqa: F821
|
57 |
+
|
58 |
+
BUBBLESORT_HERE # noqa: F821
|
59 |
+
|
60 |
+
tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
|
61 |
+
|
62 |
+
kernel = triton.JITFunction(kernel.fn)
|
63 |
+
kernel.src = kernel.src.replace(
|
64 |
+
" LOAD_ALL_ROWS_HERE",
|
65 |
+
"\n".join(
|
66 |
+
[
|
67 |
+
f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)"
|
68 |
+
for i in range(filter_width)
|
69 |
+
]
|
70 |
+
),
|
71 |
+
)
|
72 |
+
kernel.src = kernel.src.replace(
|
73 |
+
" BUBBLESORT_HERE",
|
74 |
+
"\n\n".join(
|
75 |
+
[
|
76 |
+
"\n\n".join(
|
77 |
+
[
|
78 |
+
"\n".join(
|
79 |
+
[
|
80 |
+
f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
|
81 |
+
f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
|
82 |
+
f" row{j} = smaller",
|
83 |
+
f" row{j + 1} = larger",
|
84 |
+
]
|
85 |
+
)
|
86 |
+
for j in range(filter_width - i - 1)
|
87 |
+
]
|
88 |
+
)
|
89 |
+
for i in range(filter_width // 2 + 1)
|
90 |
+
]
|
91 |
+
),
|
92 |
+
)
|
93 |
+
kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
|
94 |
+
|
95 |
+
return kernel
|
96 |
+
|
97 |
+
|
98 |
+
def median_filter_cuda(x: torch.Tensor, filter_width: int):
|
99 |
+
"""Apply a median filter of given width along the last dimension of x"""
|
100 |
+
slices = x.contiguous().unfold(-1, filter_width, 1)
|
101 |
+
grid = np.prod(slices.shape[:-2])
|
102 |
+
|
103 |
+
kernel = median_kernel(filter_width)
|
104 |
+
y = torch.empty_like(slices[..., 0])
|
105 |
+
|
106 |
+
BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
|
107 |
+
kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
|
108 |
+
|
109 |
+
return y
|
whisper_stream/utils.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
import zlib
|
6 |
+
from typing import Callable, List, Optional, TextIO
|
7 |
+
|
8 |
+
system_encoding = sys.getdefaultencoding()
|
9 |
+
|
10 |
+
if system_encoding != "utf-8":
|
11 |
+
|
12 |
+
def make_safe(string):
|
13 |
+
# replaces any character not representable using the system default encoding with an '?',
|
14 |
+
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
|
15 |
+
return string.encode(system_encoding, errors="replace").decode(system_encoding)
|
16 |
+
|
17 |
+
else:
|
18 |
+
|
19 |
+
def make_safe(string):
|
20 |
+
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
|
21 |
+
return string
|
22 |
+
|
23 |
+
|
24 |
+
def exact_div(x, y):
|
25 |
+
assert x % y == 0
|
26 |
+
return x // y
|
27 |
+
|
28 |
+
|
29 |
+
def str2bool(string):
|
30 |
+
str2val = {"True": True, "False": False}
|
31 |
+
if string in str2val:
|
32 |
+
return str2val[string]
|
33 |
+
else:
|
34 |
+
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
35 |
+
|
36 |
+
|
37 |
+
def optional_int(string):
|
38 |
+
return None if string == "None" else int(string)
|
39 |
+
|
40 |
+
|
41 |
+
def optional_float(string):
|
42 |
+
return None if string == "None" else float(string)
|
43 |
+
|
44 |
+
|
45 |
+
def compression_ratio(text) -> float:
|
46 |
+
text_bytes = text.encode("utf-8")
|
47 |
+
return len(text_bytes) / len(zlib.compress(text_bytes))
|
48 |
+
|
49 |
+
|
50 |
+
def format_timestamp(
|
51 |
+
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
52 |
+
):
|
53 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
54 |
+
milliseconds = round(seconds * 1000.0)
|
55 |
+
|
56 |
+
hours = milliseconds // 3_600_000
|
57 |
+
milliseconds -= hours * 3_600_000
|
58 |
+
|
59 |
+
minutes = milliseconds // 60_000
|
60 |
+
milliseconds -= minutes * 60_000
|
61 |
+
|
62 |
+
seconds = milliseconds // 1_000
|
63 |
+
milliseconds -= seconds * 1_000
|
64 |
+
|
65 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
66 |
+
return (
|
67 |
+
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
68 |
+
)
|
69 |
+
|
70 |
+
|
71 |
+
def get_start(segments: List[dict]) -> Optional[float]:
|
72 |
+
return next(
|
73 |
+
(w["start"] for s in segments for w in s["words"]),
|
74 |
+
segments[0]["start"] if segments else None,
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def get_end(segments: List[dict]) -> Optional[float]:
|
79 |
+
return next(
|
80 |
+
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
81 |
+
segments[-1]["end"] if segments else None,
|
82 |
+
)
|
83 |
+
|
84 |
+
|
85 |
+
class ResultWriter:
|
86 |
+
extension: str
|
87 |
+
|
88 |
+
def __init__(self, output_dir: str):
|
89 |
+
self.output_dir = output_dir
|
90 |
+
|
91 |
+
def __call__(
|
92 |
+
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
93 |
+
):
|
94 |
+
audio_basename = os.path.basename(audio_path)
|
95 |
+
audio_basename = os.path.splitext(audio_basename)[0]
|
96 |
+
output_path = os.path.join(
|
97 |
+
self.output_dir, audio_basename + "." + self.extension
|
98 |
+
)
|
99 |
+
|
100 |
+
with open(output_path, "w", encoding="utf-8") as f:
|
101 |
+
self.write_result(result, file=f, options=options, **kwargs)
|
102 |
+
|
103 |
+
def write_result(
|
104 |
+
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
105 |
+
):
|
106 |
+
raise NotImplementedError
|
107 |
+
|
108 |
+
|
109 |
+
class WriteTXT(ResultWriter):
|
110 |
+
extension: str = "txt"
|
111 |
+
|
112 |
+
def write_result(
|
113 |
+
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
114 |
+
):
|
115 |
+
for segment in result["segments"]:
|
116 |
+
print(segment["text"].strip(), file=file, flush=True)
|
117 |
+
|
118 |
+
|
119 |
+
class SubtitlesWriter(ResultWriter):
|
120 |
+
always_include_hours: bool
|
121 |
+
decimal_marker: str
|
122 |
+
|
123 |
+
def iterate_result(
|
124 |
+
self,
|
125 |
+
result: dict,
|
126 |
+
options: Optional[dict] = None,
|
127 |
+
*,
|
128 |
+
max_line_width: Optional[int] = None,
|
129 |
+
max_line_count: Optional[int] = None,
|
130 |
+
highlight_words: bool = False,
|
131 |
+
max_words_per_line: Optional[int] = None,
|
132 |
+
):
|
133 |
+
options = options or {}
|
134 |
+
max_line_width = max_line_width or options.get("max_line_width")
|
135 |
+
max_line_count = max_line_count or options.get("max_line_count")
|
136 |
+
highlight_words = highlight_words or options.get("highlight_words", False)
|
137 |
+
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
138 |
+
preserve_segments = max_line_count is None or max_line_width is None
|
139 |
+
max_line_width = max_line_width or 1000
|
140 |
+
max_words_per_line = max_words_per_line or 1000
|
141 |
+
|
142 |
+
def iterate_subtitles():
|
143 |
+
line_len = 0
|
144 |
+
line_count = 1
|
145 |
+
# the next subtitle to yield (a list of word timings with whitespace)
|
146 |
+
subtitle: List[dict] = []
|
147 |
+
last: float = get_start(result["segments"]) or 0.0
|
148 |
+
for segment in result["segments"]:
|
149 |
+
chunk_index = 0
|
150 |
+
words_count = max_words_per_line
|
151 |
+
while chunk_index < len(segment["words"]):
|
152 |
+
remaining_words = len(segment["words"]) - chunk_index
|
153 |
+
if max_words_per_line > len(segment["words"]) - chunk_index:
|
154 |
+
words_count = remaining_words
|
155 |
+
for i, original_timing in enumerate(
|
156 |
+
segment["words"][chunk_index : chunk_index + words_count]
|
157 |
+
):
|
158 |
+
timing = original_timing.copy()
|
159 |
+
long_pause = (
|
160 |
+
not preserve_segments and timing["start"] - last > 3.0
|
161 |
+
)
|
162 |
+
has_room = line_len + len(timing["word"]) <= max_line_width
|
163 |
+
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
164 |
+
if (
|
165 |
+
line_len > 0
|
166 |
+
and has_room
|
167 |
+
and not long_pause
|
168 |
+
and not seg_break
|
169 |
+
):
|
170 |
+
# line continuation
|
171 |
+
line_len += len(timing["word"])
|
172 |
+
else:
|
173 |
+
# new line
|
174 |
+
timing["word"] = timing["word"].strip()
|
175 |
+
if (
|
176 |
+
len(subtitle) > 0
|
177 |
+
and max_line_count is not None
|
178 |
+
and (long_pause or line_count >= max_line_count)
|
179 |
+
or seg_break
|
180 |
+
):
|
181 |
+
# subtitle break
|
182 |
+
yield subtitle
|
183 |
+
subtitle = []
|
184 |
+
line_count = 1
|
185 |
+
elif line_len > 0:
|
186 |
+
# line break
|
187 |
+
line_count += 1
|
188 |
+
timing["word"] = "\n" + timing["word"]
|
189 |
+
line_len = len(timing["word"].strip())
|
190 |
+
subtitle.append(timing)
|
191 |
+
last = timing["start"]
|
192 |
+
chunk_index += max_words_per_line
|
193 |
+
if len(subtitle) > 0:
|
194 |
+
yield subtitle
|
195 |
+
|
196 |
+
if len(result["segments"]) > 0 and "words" in result["segments"][0]:
|
197 |
+
for subtitle in iterate_subtitles():
|
198 |
+
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
199 |
+
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
200 |
+
subtitle_text = "".join([word["word"] for word in subtitle])
|
201 |
+
if highlight_words:
|
202 |
+
last = subtitle_start
|
203 |
+
all_words = [timing["word"] for timing in subtitle]
|
204 |
+
for i, this_word in enumerate(subtitle):
|
205 |
+
start = self.format_timestamp(this_word["start"])
|
206 |
+
end = self.format_timestamp(this_word["end"])
|
207 |
+
if last != start:
|
208 |
+
yield last, start, subtitle_text
|
209 |
+
|
210 |
+
yield start, end, "".join(
|
211 |
+
[
|
212 |
+
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
213 |
+
if j == i
|
214 |
+
else word
|
215 |
+
for j, word in enumerate(all_words)
|
216 |
+
]
|
217 |
+
)
|
218 |
+
last = end
|
219 |
+
else:
|
220 |
+
yield subtitle_start, subtitle_end, subtitle_text
|
221 |
+
else:
|
222 |
+
for segment in result["segments"]:
|
223 |
+
segment_start = self.format_timestamp(segment["start"])
|
224 |
+
segment_end = self.format_timestamp(segment["end"])
|
225 |
+
segment_text = segment["text"].strip().replace("-->", "->")
|
226 |
+
yield segment_start, segment_end, segment_text
|
227 |
+
|
228 |
+
def format_timestamp(self, seconds: float):
|
229 |
+
return format_timestamp(
|
230 |
+
seconds=seconds,
|
231 |
+
always_include_hours=self.always_include_hours,
|
232 |
+
decimal_marker=self.decimal_marker,
|
233 |
+
)
|
234 |
+
|
235 |
+
|
236 |
+
class WriteVTT(SubtitlesWriter):
|
237 |
+
extension: str = "vtt"
|
238 |
+
always_include_hours: bool = False
|
239 |
+
decimal_marker: str = "."
|
240 |
+
|
241 |
+
def write_result(
|
242 |
+
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
243 |
+
):
|
244 |
+
print("WEBVTT\n", file=file)
|
245 |
+
for start, end, text in self.iterate_result(result, options, **kwargs):
|
246 |
+
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
247 |
+
|
248 |
+
|
249 |
+
class WriteSRT(SubtitlesWriter):
|
250 |
+
extension: str = "srt"
|
251 |
+
always_include_hours: bool = True
|
252 |
+
decimal_marker: str = ","
|
253 |
+
|
254 |
+
def write_result(
|
255 |
+
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
256 |
+
):
|
257 |
+
for i, (start, end, text) in enumerate(
|
258 |
+
self.iterate_result(result, options, **kwargs), start=1
|
259 |
+
):
|
260 |
+
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
261 |
+
|
262 |
+
|
263 |
+
class WriteTSV(ResultWriter):
|
264 |
+
"""
|
265 |
+
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
266 |
+
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
267 |
+
|
268 |
+
Using integer milliseconds as start and end times means there's no chance of interference from
|
269 |
+
an environment setting a language encoding that causes the decimal in a floating point number
|
270 |
+
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
271 |
+
"""
|
272 |
+
|
273 |
+
extension: str = "tsv"
|
274 |
+
|
275 |
+
def write_result(
|
276 |
+
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
277 |
+
):
|
278 |
+
print("start", "end", "text", sep="\t", file=file)
|
279 |
+
for segment in result["segments"]:
|
280 |
+
print(round(1000 * segment["start"]), file=file, end="\t")
|
281 |
+
print(round(1000 * segment["end"]), file=file, end="\t")
|
282 |
+
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
283 |
+
|
284 |
+
|
285 |
+
class WriteJSON(ResultWriter):
|
286 |
+
extension: str = "json"
|
287 |
+
|
288 |
+
def write_result(
|
289 |
+
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
290 |
+
):
|
291 |
+
json.dump(result, file)
|
292 |
+
|
293 |
+
|
294 |
+
def get_writer(
|
295 |
+
output_format: str, output_dir: str
|
296 |
+
) -> Callable[[dict, TextIO, dict], None]:
|
297 |
+
writers = {
|
298 |
+
"txt": WriteTXT,
|
299 |
+
"vtt": WriteVTT,
|
300 |
+
"srt": WriteSRT,
|
301 |
+
"tsv": WriteTSV,
|
302 |
+
"json": WriteJSON,
|
303 |
+
}
|
304 |
+
|
305 |
+
if output_format == "all":
|
306 |
+
all_writers = [writer(output_dir) for writer in writers.values()]
|
307 |
+
|
308 |
+
def write_all(
|
309 |
+
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
310 |
+
):
|
311 |
+
for writer in all_writers:
|
312 |
+
writer(result, file, options, **kwargs)
|
313 |
+
|
314 |
+
return write_all
|
315 |
+
|
316 |
+
return writers[output_format](output_dir)
|
whisper_stream/version.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = "20231117"
|