ras0k commited on
Commit
9c2d684
·
1 Parent(s): 08384d4

whisperX merged in

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ whisperx.egg-info/
2
+ **/__pycache__/
EXAMPLES.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # More Examples
2
+
3
+ ## Other Languages
4
+
5
+ For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18).
6
+
7
+ Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}
8
+
9
+
10
+ If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
11
+
12
+ ### French
13
+ whisperx --model large --language fr examples/sample_fr_01.wav
14
+
15
+
16
+ https://user-images.githubusercontent.com/36994049/208298804-31c49d6f-6787-444e-a53f-e93c52706752.mov
17
+
18
+
19
+ ### German
20
+ whisperx --model large --language de examples/sample_de_01.wav
21
+
22
+
23
+ https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
24
+
25
+
26
+ ### Italian
27
+ whisperx --model large --language de examples/sample_it_01.wav
28
+
29
+
30
+ https://user-images.githubusercontent.com/36994049/208298819-6f462b2c-8cae-4c54-b8e1-90855794efc7.mov
31
+
32
+
33
+ ### Japanese
34
+ whisperx --model large --language ja examples/sample_ja_01.wav
35
+
36
+
37
+ https://user-images.githubusercontent.com/19920981/208731743-311f2360-b73b-4c60-809d-aaf3cd7e06f4.mov
LICENSE ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022, Max Bain
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without
5
+ modification, are permitted provided that the following conditions are met:
6
+ 1. Redistributions of source code must retain the above copyright
7
+ notice, this list of conditions and the following disclaimer.
8
+ 2. Redistributions in binary form must reproduce the above copyright
9
+ notice, this list of conditions and the following disclaimer in the
10
+ documentation and/or other materials provided with the distribution.
11
+ 3. All advertising materials mentioning features or use of this software
12
+ must display the following acknowledgement:
13
+ This product includes software developed by Max Bain.
14
+ 4. Neither the name of Max Bain nor the
15
+ names of its contributors may be used to endorse or promote products
16
+ derived from this software without specific prior written permission.
17
+
18
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER ''AS IS'' AND ANY
19
+ EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
27
+ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
MANIFEST.in ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ include whisperx/assets/*
2
+ include whisperx/assets/gpt2/*
3
+ include whisperx/assets/multilingual/*
4
+ include whisperx/normalizers/english.json
figures/pipeline.png ADDED
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pandas
3
+ torch >=1.9
4
+ torchaudio >=0.10,<1.0
5
+ tqdm
6
+ more-itertools
7
+ transformers>=4.19.0
8
+ ffmpeg-python==0.2.0
9
+ pyannote.audio
10
+ openai-whisper==20230314
setup.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import pkg_resources
4
+ from setuptools import setup, find_packages
5
+
6
+ setup(
7
+ name="whisperx",
8
+ py_modules=["whisperx"],
9
+ version="2.0",
10
+ description="Time-Accurate Automatic Speech Recognition using Whisper.",
11
+ readme="README.md",
12
+ python_requires=">=3.8",
13
+ author="Max Bain",
14
+ url="https://github.com/m-bain/whisperx",
15
+ license="MIT",
16
+ packages=find_packages(exclude=["tests*"]),
17
+ install_requires=[
18
+ str(r)
19
+ for r in pkg_resources.parse_requirements(
20
+ open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
21
+ )
22
+ ],
23
+ entry_points = {
24
+ 'console_scripts': ['whisperx=whisperx.transcribe:cli'],
25
+ },
26
+ include_package_data=True,
27
+ extras_require={'dev': ['pytest']},
28
+ )
whisperx/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .transcribe import transcribe, transcribe_with_vad
2
+ from .alignment import load_align_model, align
3
+ from .vad import load_vad_model
whisperx/__main__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .transcribe import cli
2
+
3
+
4
+ cli()
whisperx/alignment.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """"
2
+ Forced Alignment with Whisper
3
+ C. Max Bain
4
+ """
5
+ import numpy as np
6
+ import pandas as pd
7
+ from typing import List, Union, Iterator, TYPE_CHECKING
8
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
9
+ import torchaudio
10
+ import torch
11
+ from dataclasses import dataclass
12
+ from whisper.audio import SAMPLE_RATE, load_audio
13
+ from .utils import interpolate_nans
14
+
15
+
16
+ LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
17
+
18
+ DEFAULT_ALIGN_MODELS_TORCH = {
19
+ "en": "WAV2VEC2_ASR_BASE_960H",
20
+ "fr": "VOXPOPULI_ASR_BASE_10K_FR",
21
+ "de": "VOXPOPULI_ASR_BASE_10K_DE",
22
+ "es": "VOXPOPULI_ASR_BASE_10K_ES",
23
+ "it": "VOXPOPULI_ASR_BASE_10K_IT",
24
+ }
25
+
26
+ DEFAULT_ALIGN_MODELS_HF = {
27
+ "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
28
+ "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
29
+ "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
30
+ "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
31
+ "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
32
+ "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
33
+ "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
34
+ "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
35
+ "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
36
+ "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
37
+ "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
38
+ "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
39
+ "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
40
+ }
41
+
42
+
43
+ def load_align_model(language_code, device, model_name=None):
44
+ if model_name is None:
45
+ # use default model
46
+ if language_code in DEFAULT_ALIGN_MODELS_TORCH:
47
+ model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
48
+ elif language_code in DEFAULT_ALIGN_MODELS_HF:
49
+ model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
50
+ else:
51
+ print(f"There is no default alignment model set for this language ({language_code}).\
52
+ Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
53
+ raise ValueError(f"No default align-model for language: {language_code}")
54
+
55
+ if model_name in torchaudio.pipelines.__all__:
56
+ pipeline_type = "torchaudio"
57
+ bundle = torchaudio.pipelines.__dict__[model_name]
58
+ align_model = bundle.get_model().to(device)
59
+ labels = bundle.get_labels()
60
+ align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
61
+ else:
62
+ try:
63
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
64
+ align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
65
+ except Exception as e:
66
+ print(e)
67
+ print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
68
+ raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
69
+ pipeline_type = "huggingface"
70
+ align_model = align_model.to(device)
71
+ labels = processor.tokenizer.get_vocab()
72
+ align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
73
+
74
+ align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
75
+
76
+ return align_model, align_metadata
77
+
78
+
79
+ def align(
80
+ transcript: Iterator[dict],
81
+ model: torch.nn.Module,
82
+ align_model_metadata: dict,
83
+ audio: Union[str, np.ndarray, torch.Tensor],
84
+ device: str,
85
+ extend_duration: float = 0.0,
86
+ start_from_previous: bool = True,
87
+ interpolate_method: str = "nearest",
88
+ ):
89
+ """
90
+ Force align phoneme recognition predictions to known transcription
91
+
92
+ Parameters
93
+ ----------
94
+ transcript: Iterator[dict]
95
+ The Whisper model instance
96
+
97
+ model: torch.nn.Module
98
+ Alignment model (wav2vec2)
99
+
100
+ audio: Union[str, np.ndarray, torch.Tensor]
101
+ The path to the audio file to open, or the audio waveform
102
+
103
+ device: str
104
+ cuda device
105
+
106
+ diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
107
+ diarization segments with speaker labels.
108
+
109
+ extend_duration: float
110
+ Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
111
+
112
+ If the gzip compression ratio is above this value, treat as failed
113
+
114
+ interpolate_method: str ["nearest", "linear", "ignore"]
115
+ Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
116
+ "nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
117
+
118
+ Returns
119
+ -------
120
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
121
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
122
+ """
123
+ if not torch.is_tensor(audio):
124
+ if isinstance(audio, str):
125
+ audio = load_audio(audio)
126
+ audio = torch.from_numpy(audio)
127
+ if len(audio.shape) == 1:
128
+ audio = audio.unsqueeze(0)
129
+
130
+ MAX_DURATION = audio.shape[1] / SAMPLE_RATE
131
+
132
+ model_dictionary = align_model_metadata["dictionary"]
133
+ model_lang = align_model_metadata["language"]
134
+ model_type = align_model_metadata["type"]
135
+
136
+ aligned_segments = []
137
+
138
+ prev_t2 = 0
139
+
140
+ char_segments_arr = {
141
+ "segment-idx": [],
142
+ "subsegment-idx": [],
143
+ "word-idx": [],
144
+ "char": [],
145
+ "start": [],
146
+ "end": [],
147
+ "score": [],
148
+ }
149
+
150
+ for sdx, segment in enumerate(transcript):
151
+ while True:
152
+ segment_align_success = False
153
+
154
+ # strip spaces at beginning / end, but keep track of the amount.
155
+ num_leading = len(segment["text"]) - len(segment["text"].lstrip())
156
+ num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
157
+ transcription = segment["text"]
158
+
159
+ # TODO: convert number tokenizer / symbols to phonetic words for alignment.
160
+ # e.g. "$300" -> "three hundred dollars"
161
+ # currently "$300" is ignored since no characters present in the phonetic dictionary
162
+
163
+ # split into words
164
+ if model_lang not in LANGUAGES_WITHOUT_SPACES:
165
+ per_word = transcription.split(" ")
166
+ else:
167
+ per_word = transcription
168
+
169
+ # first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
170
+ clean_char, clean_cdx = [], []
171
+ for cdx, char in enumerate(transcription):
172
+ char_ = char.lower()
173
+ # wav2vec2 models use "|" character to represent spaces
174
+ if model_lang not in LANGUAGES_WITHOUT_SPACES:
175
+ char_ = char_.replace(" ", "|")
176
+
177
+ # ignore whitespace at beginning and end of transcript
178
+ if cdx < num_leading:
179
+ pass
180
+ elif cdx > len(transcription) - num_trailing - 1:
181
+ pass
182
+ elif char_ in model_dictionary.keys():
183
+ clean_char.append(char_)
184
+ clean_cdx.append(cdx)
185
+
186
+ clean_wdx = []
187
+ for wdx, wrd in enumerate(per_word):
188
+ if any([c in model_dictionary.keys() for c in wrd]):
189
+ clean_wdx.append(wdx)
190
+
191
+ # if no characters are in the dictionary, then we skip this segment...
192
+ if len(clean_char) == 0:
193
+ print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
194
+ break
195
+
196
+ transcription_cleaned = "".join(clean_char)
197
+ tokens = [model_dictionary[c] for c in transcription_cleaned]
198
+
199
+ # we only pad if not using VAD filtering
200
+ if "seg_text" not in segment:
201
+ # pad according original timestamps
202
+ t1 = max(segment["start"] - extend_duration, 0)
203
+ t2 = min(segment["end"] + extend_duration, MAX_DURATION)
204
+
205
+ # use prev_t2 as current t1 if it"s later
206
+ if start_from_previous and t1 < prev_t2:
207
+ t1 = prev_t2
208
+
209
+ # check if timestamp range is still valid
210
+ if t1 >= MAX_DURATION:
211
+ print("Failed to align segment: original start time longer than audio duration, skipping...")
212
+ break
213
+ if t2 - t1 < 0.02:
214
+ print("Failed to align segment: duration smaller than 0.02s time precision")
215
+ break
216
+
217
+ f1 = int(t1 * SAMPLE_RATE)
218
+ f2 = int(t2 * SAMPLE_RATE)
219
+
220
+ waveform_segment = audio[:, f1:f2]
221
+
222
+ with torch.inference_mode():
223
+ if model_type == "torchaudio":
224
+ emissions, _ = model(waveform_segment.to(device))
225
+ elif model_type == "huggingface":
226
+ emissions = model(waveform_segment.to(device)).logits
227
+ else:
228
+ raise NotImplementedError(f"Align model of type {model_type} not supported.")
229
+ emissions = torch.log_softmax(emissions, dim=-1)
230
+
231
+ emission = emissions[0].cpu().detach()
232
+
233
+ trellis = get_trellis(emission, tokens)
234
+ path = backtrack(trellis, emission, tokens)
235
+ if path is None:
236
+ print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
237
+ break
238
+ char_segments = merge_repeats(path, transcription_cleaned)
239
+ # word_segments = merge_words(char_segments)
240
+
241
+
242
+ # sub-segments
243
+ if "seg-text" not in segment:
244
+ segment["seg-text"] = [transcription]
245
+
246
+ seg_lens = [0] + [len(x) for x in segment["seg-text"]]
247
+ seg_lens_cumsum = list(np.cumsum(seg_lens))
248
+ sub_seg_idx = 0
249
+
250
+ wdx = 0
251
+ duration = t2 - t1
252
+ ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
253
+ for cdx, char in enumerate(transcription + " "):
254
+ is_last = False
255
+ if cdx == len(transcription):
256
+ break
257
+ elif cdx+1 == len(transcription):
258
+ is_last = True
259
+
260
+
261
+ start, end, score = None, None, None
262
+ if cdx in clean_cdx:
263
+ char_seg = char_segments[clean_cdx.index(cdx)]
264
+ start = char_seg.start * ratio + t1
265
+ end = char_seg.end * ratio + t1
266
+ score = char_seg.score
267
+
268
+ char_segments_arr["char"].append(char)
269
+ char_segments_arr["start"].append(start)
270
+ char_segments_arr["end"].append(end)
271
+ char_segments_arr["score"].append(score)
272
+ char_segments_arr["word-idx"].append(wdx)
273
+ char_segments_arr["segment-idx"].append(sdx)
274
+ char_segments_arr["subsegment-idx"].append(sub_seg_idx)
275
+
276
+ # word-level info
277
+ if model_lang in LANGUAGES_WITHOUT_SPACES:
278
+ # character == word
279
+ wdx += 1
280
+ elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
281
+ wdx += 1
282
+
283
+ if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
284
+ wdx = 0
285
+ sub_seg_idx += 1
286
+
287
+ prev_t2 = segment["end"]
288
+
289
+ segment_align_success = True
290
+ # end while True loop
291
+ break
292
+
293
+ # reset prev_t2 due to drifting issues
294
+ if not segment_align_success:
295
+ prev_t2 = 0
296
+
297
+ char_segments_arr = pd.DataFrame(char_segments_arr)
298
+ not_space = char_segments_arr["char"] != " "
299
+
300
+ per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
301
+ char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
302
+ per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
303
+ per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
304
+ per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
305
+ char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
306
+ per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
307
+
308
+ word_segments_arr = {}
309
+
310
+ # start of word is first char with a timestamp
311
+ word_segments_arr["start"] = per_word_grp["start"].min().values
312
+ # end of word is last char with a timestamp
313
+ word_segments_arr["end"] = per_word_grp["end"].max().values
314
+ # score of word is mean (excluding nan)
315
+ word_segments_arr["score"] = per_word_grp["score"].mean().values
316
+
317
+ word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
318
+ word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
319
+ word_segments_arr = pd.DataFrame(word_segments_arr)
320
+
321
+ word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
322
+ segments_arr = {}
323
+ segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
324
+ segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
325
+ segments_arr = pd.DataFrame(segments_arr)
326
+ segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
327
+ segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
328
+
329
+ # interpolate missing words / sub-segments
330
+ if interpolate_method != "ignore":
331
+ wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
332
+ wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
333
+ # we still know which word timestamps are interpolated because their score == nan
334
+ word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
335
+ word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
336
+
337
+ word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
338
+ word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
339
+
340
+ sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
341
+ segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
342
+ segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
343
+
344
+ # merge words & subsegments which are missing times
345
+ word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
346
+
347
+ word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
348
+ word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
349
+ word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
350
+
351
+ seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
352
+ segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
353
+ segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
354
+ segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
355
+ else:
356
+ word_segments_arr.dropna(inplace=True)
357
+ segments_arr.dropna(inplace=True)
358
+
359
+ # if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
360
+ segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
361
+ segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
362
+ segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
363
+ segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
364
+
365
+
366
+ aligned_segments = []
367
+ aligned_segments_word = []
368
+
369
+ word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
370
+ char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
371
+
372
+ for sdx, srow in segments_arr.iterrows():
373
+
374
+ seg_idx = int(srow["segment-idx"])
375
+ sub_start = int(srow["subsegment-idx-start"])
376
+ sub_end = int(srow["subsegment-idx-end"])
377
+
378
+ seg = transcript[seg_idx]
379
+ text = "".join(seg["seg-text"][sub_start:sub_end])
380
+
381
+ wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
382
+ wseg["start"].fillna(srow["start"], inplace=True)
383
+ wseg["end"].fillna(srow["end"], inplace=True)
384
+ wseg["segment-text-start"].fillna(0, inplace=True)
385
+ wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
386
+
387
+ cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
388
+ # fixes bug for single segment in transcript
389
+ cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
390
+ cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
391
+ if 'level_1' in cseg: del cseg['level_1']
392
+ if 'level_0' in cseg: del cseg['level_0']
393
+ cseg.reset_index(inplace=True)
394
+ aligned_segments.append(
395
+ {
396
+ "start": srow["start"],
397
+ "end": srow["end"],
398
+ "text": text,
399
+ "word-segments": wseg,
400
+ "char-segments": cseg
401
+ }
402
+ )
403
+
404
+ def get_raw_text(word_row):
405
+ return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
406
+
407
+ wdx = 0
408
+ curr_text = get_raw_text(wseg.iloc[wdx])
409
+ if len(wseg) > 1:
410
+ for _, wrow in wseg.iloc[1:].iterrows():
411
+ if wrow['start'] != wseg.iloc[wdx]['start']:
412
+ aligned_segments_word.append(
413
+ {
414
+ "text": curr_text.strip(),
415
+ "start": wseg.iloc[wdx]["start"],
416
+ "end": wseg.iloc[wdx]["end"],
417
+ }
418
+ )
419
+ curr_text = ""
420
+ curr_text += " " + get_raw_text(wrow)
421
+ wdx += 1
422
+ aligned_segments_word.append(
423
+ {
424
+ "text": curr_text.strip(),
425
+ "start": wseg.iloc[wdx]["start"],
426
+ "end": wseg.iloc[wdx]["end"]
427
+ }
428
+ )
429
+
430
+
431
+ return {"segments": aligned_segments, "word_segments": aligned_segments_word}
432
+
433
+
434
+ """
435
+ source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
436
+ """
437
+ def get_trellis(emission, tokens, blank_id=0):
438
+ num_frame = emission.size(0)
439
+ num_tokens = len(tokens)
440
+
441
+ # Trellis has extra diemsions for both time axis and tokens.
442
+ # The extra dim for tokens represents <SoS> (start-of-sentence)
443
+ # The extra dim for time axis is for simplification of the code.
444
+ trellis = torch.empty((num_frame + 1, num_tokens + 1))
445
+ trellis[0, 0] = 0
446
+ trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
447
+ trellis[0, -num_tokens:] = -float("inf")
448
+ trellis[-num_tokens:, 0] = float("inf")
449
+
450
+ for t in range(num_frame):
451
+ trellis[t + 1, 1:] = torch.maximum(
452
+ # Score for staying at the same token
453
+ trellis[t, 1:] + emission[t, blank_id],
454
+ # Score for changing to the next token
455
+ trellis[t, :-1] + emission[t, tokens],
456
+ )
457
+ return trellis
458
+
459
+ @dataclass
460
+ class Point:
461
+ token_index: int
462
+ time_index: int
463
+ score: float
464
+
465
+ def backtrack(trellis, emission, tokens, blank_id=0):
466
+ # Note:
467
+ # j and t are indices for trellis, which has extra dimensions
468
+ # for time and tokens at the beginning.
469
+ # When referring to time frame index `T` in trellis,
470
+ # the corresponding index in emission is `T-1`.
471
+ # Similarly, when referring to token index `J` in trellis,
472
+ # the corresponding index in transcript is `J-1`.
473
+ j = trellis.size(1) - 1
474
+ t_start = torch.argmax(trellis[:, j]).item()
475
+
476
+ path = []
477
+ for t in range(t_start, 0, -1):
478
+ # 1. Figure out if the current position was stay or change
479
+ # Note (again):
480
+ # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
481
+ # Score for token staying the same from time frame J-1 to T.
482
+ stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
483
+ # Score for token changing from C-1 at T-1 to J at T.
484
+ changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
485
+
486
+ # 2. Store the path with frame-wise probability.
487
+ prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
488
+ # Return token index and time index in non-trellis coordinate.
489
+ path.append(Point(j - 1, t - 1, prob))
490
+
491
+ # 3. Update the token
492
+ if changed > stayed:
493
+ j -= 1
494
+ if j == 0:
495
+ break
496
+ else:
497
+ # failed
498
+ return None
499
+ return path[::-1]
500
+
501
+ # Merge the labels
502
+ @dataclass
503
+ class Segment:
504
+ label: str
505
+ start: int
506
+ end: int
507
+ score: float
508
+
509
+ def __repr__(self):
510
+ return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
511
+
512
+ @property
513
+ def length(self):
514
+ return self.end - self.start
515
+
516
+ def merge_repeats(path, transcript):
517
+ i1, i2 = 0, 0
518
+ segments = []
519
+ while i1 < len(path):
520
+ while i2 < len(path) and path[i1].token_index == path[i2].token_index:
521
+ i2 += 1
522
+ score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
523
+ segments.append(
524
+ Segment(
525
+ transcript[path[i1].token_index],
526
+ path[i1].time_index,
527
+ path[i2 - 1].time_index + 1,
528
+ score,
529
+ )
530
+ )
531
+ i1 = i2
532
+ return segments
533
+
534
+ def merge_words(segments, separator="|"):
535
+ words = []
536
+ i1, i2 = 0, 0
537
+ while i1 < len(segments):
538
+ if i2 >= len(segments) or segments[i2].label == separator:
539
+ if i1 != i2:
540
+ segs = segments[i1:i2]
541
+ word = "".join([seg.label for seg in segs])
542
+ score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
543
+ words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
544
+ i1 = i2 + 1
545
+ i2 = i1
546
+ else:
547
+ i2 += 1
548
+ return words
whisperx/asr.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
3
+ import numpy as np
4
+ import torch
5
+ import tqdm
6
+ import ffmpeg
7
+ from whisper.audio import (
8
+ FRAMES_PER_SECOND,
9
+ HOP_LENGTH,
10
+ N_FRAMES,
11
+ N_SAMPLES,
12
+ SAMPLE_RATE,
13
+ CHUNK_LENGTH,
14
+ log_mel_spectrogram,
15
+ pad_or_trim,
16
+ load_audio
17
+ )
18
+ from whisper.decoding import DecodingOptions, DecodingResult
19
+ from whisper.timing import add_word_timestamps
20
+ from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
21
+ from whisper.utils import (
22
+ exact_div,
23
+ format_timestamp,
24
+ make_safe,
25
+ )
26
+
27
+ if TYPE_CHECKING:
28
+ from whisper.model import Whisper
29
+
30
+ from .vad import merge_chunks
31
+
32
+ def transcribe(
33
+ model: "Whisper",
34
+ audio: Union[str, np.ndarray, torch.Tensor] = None,
35
+ mel: np.ndarray = None,
36
+ verbose: Optional[bool] = None,
37
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
38
+ compression_ratio_threshold: Optional[float] = 2.4,
39
+ logprob_threshold: Optional[float] = -1.0,
40
+ no_speech_threshold: Optional[float] = 0.6,
41
+ condition_on_previous_text: bool = True,
42
+ initial_prompt: Optional[str] = None,
43
+ word_timestamps: bool = False,
44
+ prepend_punctuations: str = "\"'“¿([{-",
45
+ append_punctuations: str = "\"'.。,,!!??::”)]}、",
46
+ **decode_options,
47
+ ):
48
+ """
49
+ Transcribe an audio file using Whisper.
50
+ We redefine the Whisper transcribe function to allow mel input (for sequential slicing of audio)
51
+
52
+ Parameters
53
+ ----------
54
+ model: Whisper
55
+ The Whisper model instance
56
+
57
+ audio: Union[str, np.ndarray, torch.Tensor]
58
+ The path to the audio file to open, or the audio waveform
59
+
60
+ mel: np.ndarray
61
+ Mel spectrogram of audio segment.
62
+
63
+ verbose: bool
64
+ Whether to display the text being decoded to the console. If True, displays all the details,
65
+ If False, displays minimal details. If None, does not display anything
66
+
67
+ temperature: Union[float, Tuple[float, ...]]
68
+ Temperature for sampling. It can be a tuple of temperatures, which will be successively used
69
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
70
+
71
+ compression_ratio_threshold: float
72
+ If the gzip compression ratio is above this value, treat as failed
73
+
74
+ logprob_threshold: float
75
+ If the average log probability over sampled tokens is below this value, treat as failed
76
+
77
+ no_speech_threshold: float
78
+ If the no_speech probability is higher than this value AND the average log probability
79
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
80
+
81
+ condition_on_previous_text: bool
82
+ if True, the previous output of the model is provided as a prompt for the next window;
83
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
84
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
85
+
86
+ word_timestamps: bool
87
+ Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
88
+ and include the timestamps for each word in each segment.
89
+
90
+ prepend_punctuations: str
91
+ If word_timestamps is True, merge these punctuation symbols with the next word
92
+
93
+ append_punctuations: str
94
+ If word_timestamps is True, merge these punctuation symbols with the previous word
95
+
96
+ initial_prompt: Optional[str]
97
+ Optional text to provide as a prompt for the first window. This can be used to provide, or
98
+ "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
99
+ to make it more likely to predict those word correctly.
100
+
101
+ decode_options: dict
102
+ Keyword arguments to construct `DecodingOptions` instances
103
+
104
+ Returns
105
+ -------
106
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
107
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
108
+ """
109
+ dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
110
+ if model.device == torch.device("cpu"):
111
+ if torch.cuda.is_available():
112
+ warnings.warn("Performing inference on CPU when CUDA is available")
113
+ if dtype == torch.float16:
114
+ warnings.warn("FP16 is not supported on CPU; using FP32 instead")
115
+ dtype = torch.float32
116
+
117
+ if dtype == torch.float32:
118
+ decode_options["fp16"] = False
119
+
120
+ # Pad 30-seconds of silence to the input audio, for slicing
121
+ if mel is None:
122
+ if audio is None:
123
+ raise ValueError("Transcribe needs either audio or mel as input, currently both are none.")
124
+ mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
125
+ content_frames = mel.shape[-1] - N_FRAMES
126
+
127
+ if decode_options.get("language", None) is None:
128
+ if not model.is_multilingual:
129
+ decode_options["language"] = "en"
130
+ else:
131
+ if verbose:
132
+ print(
133
+ "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
134
+ )
135
+ mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
136
+ _, probs = model.detect_language(mel_segment)
137
+ decode_options["language"] = max(probs, key=probs.get)
138
+ if verbose is not None:
139
+ print(
140
+ f"Detected language: {LANGUAGES[decode_options['language']].title()}"
141
+ )
142
+
143
+ language: str = decode_options["language"]
144
+ task: str = decode_options.get("task", "transcribe")
145
+ tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
146
+
147
+ if word_timestamps and task == "translate":
148
+ warnings.warn("Word-level timestamps on translations may not be reliable.")
149
+
150
+ def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
151
+ temperatures = (
152
+ [temperature] if isinstance(temperature, (int, float)) else temperature
153
+ )
154
+ decode_result = None
155
+
156
+ for t in temperatures:
157
+ kwargs = {**decode_options}
158
+ if t > 0:
159
+ # disable beam_size and patience when t > 0
160
+ kwargs.pop("beam_size", None)
161
+ kwargs.pop("patience", None)
162
+ else:
163
+ # disable best_of when t == 0
164
+ kwargs.pop("best_of", None)
165
+
166
+ options = DecodingOptions(**kwargs, temperature=t)
167
+ decode_result = model.decode(segment, options)
168
+
169
+ needs_fallback = False
170
+ if (
171
+ compression_ratio_threshold is not None
172
+ and decode_result.compression_ratio > compression_ratio_threshold
173
+ ):
174
+ needs_fallback = True # too repetitive
175
+ if (
176
+ logprob_threshold is not None
177
+ and decode_result.avg_logprob < logprob_threshold
178
+ ):
179
+ needs_fallback = True # average log probability is too low
180
+
181
+ if not needs_fallback:
182
+ break
183
+
184
+ return decode_result
185
+
186
+ seek = 0
187
+ input_stride = exact_div(
188
+ N_FRAMES, model.dims.n_audio_ctx
189
+ ) # mel frames per output token: 2
190
+ time_precision = (
191
+ input_stride * HOP_LENGTH / SAMPLE_RATE
192
+ ) # time per output token: 0.02 (seconds)
193
+ all_tokens = []
194
+ all_segments = []
195
+ prompt_reset_since = 0
196
+
197
+ if initial_prompt is not None:
198
+ initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
199
+ all_tokens.extend(initial_prompt_tokens)
200
+ else:
201
+ initial_prompt_tokens = []
202
+
203
+ def new_segment(
204
+ *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
205
+ ):
206
+ tokens = tokens.tolist()
207
+ text_tokens = [token for token in tokens if token < tokenizer.eot]
208
+ return {
209
+ "seek": seek,
210
+ "start": start,
211
+ "end": end,
212
+ "text": tokenizer.decode(text_tokens),
213
+ "tokens": tokens,
214
+ "temperature": result.temperature,
215
+ "avg_logprob": result.avg_logprob,
216
+ "compression_ratio": result.compression_ratio,
217
+ "no_speech_prob": result.no_speech_prob,
218
+ }
219
+
220
+
221
+ # show the progress bar when verbose is False (if True, transcribed text will be printed)
222
+ with tqdm.tqdm(
223
+ total=content_frames, unit="frames", disable=verbose is not False
224
+ ) as pbar:
225
+ while seek < content_frames:
226
+ time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
227
+ mel_segment = mel[:, seek : seek + N_FRAMES]
228
+ segment_size = min(N_FRAMES, content_frames - seek)
229
+ segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
230
+ mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
231
+
232
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
233
+ result: DecodingResult = decode_with_fallback(mel_segment)
234
+ tokens = torch.tensor(result.tokens)
235
+ if no_speech_threshold is not None:
236
+ # no voice activity check
237
+ should_skip = result.no_speech_prob > no_speech_threshold
238
+ if (
239
+ logprob_threshold is not None
240
+ and result.avg_logprob > logprob_threshold
241
+ ):
242
+ # don't skip if the logprob is high enough, despite the no_speech_prob
243
+ should_skip = False
244
+
245
+ if should_skip:
246
+ seek += segment_size # fast-forward to the next segment boundary
247
+ continue
248
+
249
+ previous_seek = seek
250
+ current_segments = []
251
+
252
+ timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
253
+ single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
254
+
255
+ consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
256
+ consecutive.add_(1)
257
+ if len(consecutive) > 0:
258
+ # if the output contains two consecutive timestamp tokens
259
+ slices = consecutive.tolist()
260
+ if single_timestamp_ending:
261
+ slices.append(len(tokens))
262
+
263
+ last_slice = 0
264
+ for current_slice in slices:
265
+ sliced_tokens = tokens[last_slice:current_slice]
266
+ start_timestamp_pos = (
267
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
268
+ )
269
+ end_timestamp_pos = (
270
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
271
+ )
272
+ current_segments.append(
273
+ new_segment(
274
+ start=time_offset + start_timestamp_pos * time_precision,
275
+ end=time_offset + end_timestamp_pos * time_precision,
276
+ tokens=sliced_tokens,
277
+ result=result,
278
+ )
279
+ )
280
+ last_slice = current_slice
281
+
282
+ if single_timestamp_ending:
283
+ # single timestamp at the end means no speech after the last timestamp.
284
+ seek += segment_size
285
+ else:
286
+ # otherwise, ignore the unfinished segment and seek to the last timestamp
287
+ last_timestamp_pos = (
288
+ tokens[last_slice - 1].item() - tokenizer.timestamp_begin
289
+ )
290
+ seek += last_timestamp_pos * input_stride
291
+ else:
292
+ duration = segment_duration
293
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
294
+ if (
295
+ len(timestamps) > 0
296
+ and timestamps[-1].item() != tokenizer.timestamp_begin
297
+ ):
298
+ # no consecutive timestamps but it has a timestamp; use the last one.
299
+ last_timestamp_pos = (
300
+ timestamps[-1].item() - tokenizer.timestamp_begin
301
+ )
302
+ duration = last_timestamp_pos * time_precision
303
+
304
+ current_segments.append(
305
+ new_segment(
306
+ start=time_offset,
307
+ end=time_offset + duration,
308
+ tokens=tokens,
309
+ result=result,
310
+ )
311
+ )
312
+ seek += segment_size
313
+
314
+ if not condition_on_previous_text or result.temperature > 0.5:
315
+ # do not feed the prompt tokens if a high temperature was used
316
+ prompt_reset_since = len(all_tokens)
317
+
318
+ if word_timestamps:
319
+ add_word_timestamps(
320
+ segments=current_segments,
321
+ model=model,
322
+ tokenizer=tokenizer,
323
+ mel=mel_segment,
324
+ num_frames=segment_size,
325
+ prepend_punctuations=prepend_punctuations,
326
+ append_punctuations=append_punctuations,
327
+ )
328
+ word_end_timestamps = [
329
+ w["end"] for s in current_segments for w in s["words"]
330
+ ]
331
+ if not single_timestamp_ending and len(word_end_timestamps) > 0:
332
+ seek_shift = round(
333
+ (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
334
+ )
335
+ if seek_shift > 0:
336
+ seek = previous_seek + seek_shift
337
+
338
+ if verbose:
339
+ for segment in current_segments:
340
+ start, end, text = segment["start"], segment["end"], segment["text"]
341
+ line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
342
+ print(make_safe(line))
343
+
344
+ # if a segment is instantaneous or does not contain text, clear it
345
+ for i, segment in enumerate(current_segments):
346
+ if segment["start"] == segment["end"] or segment["text"].strip() == "":
347
+ segment["text"] = ""
348
+ segment["tokens"] = []
349
+ segment["words"] = []
350
+
351
+ all_segments.extend(
352
+ [
353
+ {"id": i, **segment}
354
+ for i, segment in enumerate(
355
+ current_segments, start=len(all_segments)
356
+ )
357
+ ]
358
+ )
359
+ all_tokens.extend(
360
+ [token for segment in current_segments for token in segment["tokens"]]
361
+ )
362
+
363
+ # update progress bar
364
+ pbar.update(min(content_frames, seek) - previous_seek)
365
+
366
+
367
+ return dict(
368
+ text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
369
+ segments=all_segments,
370
+ language=language,
371
+ )
372
+
373
+
374
+ def transcribe_with_vad(
375
+ model: "Whisper",
376
+ audio: str,
377
+ vad_pipeline,
378
+ mel = None,
379
+ verbose: Optional[bool] = None,
380
+ **kwargs
381
+ ):
382
+ """
383
+ Transcribe per VAD segment
384
+ """
385
+
386
+ vad_segments = vad_pipeline(audio)
387
+
388
+ # if not torch.is_tensor(audio):
389
+ # if isinstance(audio, str):
390
+ audio = load_audio(audio)
391
+ audio = torch.from_numpy(audio)
392
+
393
+ prev = 0
394
+ output = {"segments": []}
395
+
396
+ # merge segments to approx 30s inputs to make whisper most appropraite
397
+ vad_segments = merge_chunks(vad_segments, chunk_size=CHUNK_LENGTH)
398
+ if len(vad_segments) == 0:
399
+ return output
400
+
401
+ print(">>Performing transcription...")
402
+ for sdx, seg_t in enumerate(vad_segments):
403
+ if verbose:
404
+ print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
405
+ seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE), int(seg_t["end"] * SAMPLE_RATE)
406
+ local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
407
+ audio = audio[local_f_start:] # seek forward
408
+ seg_audio = audio[:local_f_end-local_f_start] # seek forward
409
+ prev = seg_f_start
410
+ local_mel = log_mel_spectrogram(seg_audio, padding=N_SAMPLES)
411
+ # need to pad
412
+
413
+ result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
414
+ seg_t["text"] = result["text"]
415
+ output["segments"].append(
416
+ {
417
+ "start": seg_t["start"],
418
+ "end": seg_t["end"],
419
+ "language": result["language"],
420
+ "text": result["text"],
421
+ "seg-text": [x["text"] for x in result["segments"]],
422
+ "seg-start": [x["start"] for x in result["segments"]],
423
+ "seg-end": [x["end"] for x in result["segments"]],
424
+ }
425
+ )
426
+
427
+ output["language"] = output["segments"][0]["language"]
428
+
429
+ return output
whisperx/diarize.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from pyannote.audio import Pipeline
4
+
5
+ class DiarizationPipeline:
6
+ def __init__(
7
+ self,
8
+ model_name="pyannote/[email protected]",
9
+ use_auth_token=None,
10
+ ):
11
+ self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
12
+
13
+ def __call__(self, audio, min_speakers=None, max_speakers=None):
14
+ segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
15
+ diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
16
+ diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
17
+ diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
18
+ return diarize_df
19
+
20
+ def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
21
+ for seg in result_segments:
22
+ wdf = seg['word-segments']
23
+ if len(wdf['start'].dropna()) == 0:
24
+ wdf['start'] = seg['start']
25
+ wdf['end'] = seg['end']
26
+ speakers = []
27
+ for wdx, wrow in wdf.iterrows():
28
+ if not np.isnan(wrow['start']):
29
+ diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
30
+ diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
31
+ # remove no hit
32
+ if not fill_nearest:
33
+ dia_tmp = diarize_df[diarize_df['intersection'] > 0]
34
+ else:
35
+ dia_tmp = diarize_df
36
+ if len(dia_tmp) == 0:
37
+ speaker = None
38
+ else:
39
+ speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
40
+ else:
41
+ speaker = None
42
+ speakers.append(speaker)
43
+ seg['word-segments']['speaker'] = speakers
44
+
45
+ speaker_count = pd.Series(speakers).value_counts()
46
+ if len(speaker_count) == 0:
47
+ seg["speaker"]= "UNKNOWN"
48
+ else:
49
+ seg["speaker"] = speaker_count.index[0]
50
+
51
+ # create word level segments for .srt
52
+ word_seg = []
53
+ for seg in result_segments:
54
+ wseg = pd.DataFrame(seg["word-segments"])
55
+ for wdx, wrow in wseg.iterrows():
56
+ if wrow["start"] is not None:
57
+ speaker = wrow['speaker']
58
+ if speaker is None or speaker == np.nan:
59
+ speaker = "UNKNOWN"
60
+ word_seg.append(
61
+ {
62
+ "start": wrow["start"],
63
+ "end": wrow["end"],
64
+ "text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
65
+ }
66
+ )
67
+
68
+ # TODO: create segments but split words on new speaker
69
+
70
+ return result_segments, word_seg
71
+
72
+ class Segment:
73
+ def __init__(self, start, end, speaker=None):
74
+ self.start = start
75
+ self.end = end
76
+ self.speaker = speaker
whisperx/transcribe.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import gc
4
+ import warnings
5
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
6
+ import numpy as np
7
+ import torch
8
+ import tempfile
9
+ import ffmpeg
10
+ from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
11
+ from whisper.audio import SAMPLE_RATE
12
+ from whisper.utils import (
13
+ optional_float,
14
+ optional_int,
15
+ str2bool,
16
+ )
17
+
18
+ from .alignment import load_align_model, align
19
+ from .asr import transcribe, transcribe_with_vad
20
+ from .diarize import DiarizationPipeline, assign_word_speakers
21
+ from .utils import get_writer
22
+ from .vad import load_vad_model
23
+
24
+ def cli():
25
+ from whisper import available_models
26
+
27
+ # fmt: off
28
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
29
+ parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
30
+ parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
31
+ parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
32
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
33
+ parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
34
+ parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="format of the output file; if not specified, all available formats will be produced")
35
+ parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
36
+
37
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
38
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
39
+
40
+ # alignment params
41
+ parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
42
+ parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment (if not using VAD).")
43
+ parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment (if not using VAD)")
44
+ parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
45
+ parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
46
+
47
+ # vad params
48
+ parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
49
+ parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
50
+ parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
51
+
52
+ # diarization params
53
+ parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
54
+ parser.add_argument("--min_speakers", default=None, type=int)
55
+ parser.add_argument("--max_speakers", default=None, type=int)
56
+
57
+ parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
58
+ parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
59
+ parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
60
+ parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
61
+ parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
62
+
63
+ parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
64
+ parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
65
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
66
+ parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
67
+
68
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
69
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
70
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
71
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
72
+ parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
73
+ parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
74
+ parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
75
+ parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
76
+
77
+ parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
78
+ # parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
79
+ parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
80
+ # fmt: on
81
+
82
+ args = parser.parse_args().__dict__
83
+ model_name: str = args.pop("model")
84
+ model_dir: str = args.pop("model_dir")
85
+ output_dir: str = args.pop("output_dir")
86
+ output_format: str = args.pop("output_format")
87
+ device: str = args.pop("device")
88
+ # model_flush: bool = args.pop("model_flush")
89
+ os.makedirs(output_dir, exist_ok=True)
90
+
91
+ tmp_dir: str = args.pop("tmp_dir")
92
+ if tmp_dir is not None:
93
+ os.makedirs(tmp_dir, exist_ok=True)
94
+
95
+ align_model: str = args.pop("align_model")
96
+ align_extend: float = args.pop("align_extend")
97
+ align_from_prev: bool = args.pop("align_from_prev")
98
+ interpolate_method: str = args.pop("interpolate_method")
99
+ no_align: bool = args.pop("no_align")
100
+
101
+ hf_token: str = args.pop("hf_token")
102
+ vad_filter: bool = args.pop("vad_filter")
103
+ vad_onset: float = args.pop("vad_onset")
104
+ vad_offset: float = args.pop("vad_offset")
105
+
106
+ diarize: bool = args.pop("diarize")
107
+ min_speakers: int = args.pop("min_speakers")
108
+ max_speakers: int = args.pop("max_speakers")
109
+
110
+ if vad_filter:
111
+ from pyannote.audio import Pipeline
112
+ from pyannote.audio import Model, Pipeline
113
+ vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token)
114
+ else:
115
+ vad_model = None
116
+
117
+ # if model_flush:
118
+ # print(">>Model flushing activated... Only loading model after ASR stage")
119
+ # del align_model
120
+ # align_model = ""
121
+
122
+
123
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
124
+ if args["language"] is not None:
125
+ warnings.warn(
126
+ f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
127
+ )
128
+ args["language"] = "en"
129
+
130
+ temperature = args.pop("temperature")
131
+ if (increment := args.pop("temperature_increment_on_fallback")) is not None:
132
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
133
+ else:
134
+ temperature = [temperature]
135
+
136
+ if (threads := args.pop("threads")) > 0:
137
+ torch.set_num_threads(threads)
138
+
139
+ from whisper import load_model
140
+
141
+ writer = get_writer(output_format, output_dir)
142
+
143
+ # Part 1: VAD & ASR Loop
144
+ results = []
145
+ tmp_results = []
146
+ model = load_model(model_name, device=device, download_root=model_dir)
147
+ for audio_path in args.pop("audio"):
148
+ input_audio_path = audio_path
149
+ tfile = None
150
+
151
+ # >> VAD & ASR
152
+ if vad_model is not None:
153
+ if not audio_path.endswith(".wav"):
154
+ print(">>VAD requires .wav format, converting to wav as a tempfile...")
155
+ audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
156
+ if tmp_dir is not None:
157
+ input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
158
+ else:
159
+ input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav")
160
+ ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
161
+ print(">>Performing VAD...")
162
+ result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args)
163
+ else:
164
+ print(">>Performing transcription...")
165
+ result = transcribe(model, input_audio_path, temperature=temperature, **args)
166
+
167
+ results.append((result, input_audio_path))
168
+
169
+ # Unload Whisper and VAD
170
+ del model
171
+ del vad_model
172
+ gc.collect()
173
+ torch.cuda.empty_cache()
174
+
175
+ # Part 2: Align Loop
176
+ if not no_align:
177
+ tmp_results = results
178
+ results = []
179
+ align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
180
+ align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
181
+ for result, input_audio_path in tmp_results:
182
+ # >> Align
183
+ if align_model is not None and len(result["segments"]) > 0:
184
+ if result.get("language", "en") != align_metadata["language"]:
185
+ # load new language
186
+ print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
187
+ align_model, align_metadata = load_align_model(result["language"], device)
188
+ print(">>Performing alignment...")
189
+ result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
190
+ extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
191
+ results.append((result, input_audio_path))
192
+
193
+ # Unload align model
194
+ del align_model
195
+ gc.collect()
196
+ torch.cuda.empty_cache()
197
+
198
+ # >> Diarize
199
+ if diarize:
200
+ if hf_token is None:
201
+ print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
202
+ tmp_results = results
203
+ results = []
204
+ diarize_model = DiarizationPipeline(use_auth_token=hf_token)
205
+ for result, input_audio_path in tmp_results:
206
+ diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
207
+ results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
208
+ result = {"segments": results_segments, "word_segments": word_segments}
209
+ results.append((result, input_audio_path))
210
+
211
+ # >> Write
212
+ for result, audio_path in results:
213
+ writer(result, audio_path)
214
+
215
+ # cleanup
216
+ if input_audio_path != audio_path:
217
+ os.remove(input_audio_path)
218
+
219
+ if __name__ == "__main__":
220
+ cli()
whisperx/utils.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zlib
3
+ from typing import Callable, TextIO, Iterator, Tuple
4
+ import pandas as pd
5
+ import numpy as np
6
+
7
+ def interpolate_nans(x, method='nearest'):
8
+ if x.notnull().sum() > 1:
9
+ return x.interpolate(method=method).ffill().bfill()
10
+ else:
11
+ return x.ffill().bfill()
12
+
13
+
14
+ def write_txt(transcript: Iterator[dict], file: TextIO):
15
+ for segment in transcript:
16
+ print(segment['text'].strip(), file=file, flush=True)
17
+
18
+
19
+ def write_vtt(transcript: Iterator[dict], file: TextIO):
20
+ print("WEBVTT\n", file=file)
21
+ for segment in transcript:
22
+ print(
23
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
24
+ f"{segment['text'].strip().replace('-->', '->')}\n",
25
+ file=file,
26
+ flush=True,
27
+ )
28
+
29
+ def write_tsv(transcript: Iterator[dict], file: TextIO):
30
+ print("start", "end", "text", sep="\t", file=file)
31
+ for segment in transcript:
32
+ print(segment['start'], file=file, end="\t")
33
+ print(segment['end'], file=file, end="\t")
34
+ print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
35
+
36
+
37
+ def write_srt(transcript: Iterator[dict], file: TextIO):
38
+ """
39
+ Write a transcript to a file in SRT format.
40
+
41
+ Example usage:
42
+ from pathlib import Path
43
+ from whisper.utils import write_srt
44
+
45
+ result = transcribe(model, audio_path, temperature=temperature, **args)
46
+
47
+ # save SRT
48
+ audio_basename = Path(audio_path).stem
49
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
50
+ write_srt(result["segments"], file=srt)
51
+ """
52
+ for i, segment in enumerate(transcript, start=1):
53
+ # write srt lines
54
+ print(
55
+ f"{i}\n"
56
+ f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
57
+ f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
58
+ f"{segment['text'].strip().replace('-->', '->')}\n",
59
+ file=file,
60
+ flush=True,
61
+ )
62
+
63
+
64
+ def write_ass(transcript: Iterator[dict],
65
+ file: TextIO,
66
+ resolution: str = "word",
67
+ color: str = None, underline=True,
68
+ prefmt: str = None, suffmt: str = None,
69
+ font: str = None, font_size: int = 24,
70
+ strip=True, **kwargs):
71
+ """
72
+ Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py
73
+ Generate Advanced SubStation Alpha (ass) file from results to
74
+ display both phrase-level & word-level timestamp simultaneously by:
75
+ -using segment-level timestamps display phrases as usual
76
+ -using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment
77
+ Note: ass file is used in the same way as srt, vtt, etc.
78
+ Parameters
79
+ ----------
80
+ transcript: dict
81
+ results from modified model
82
+ file: TextIO
83
+ file object to write to
84
+ resolution: str
85
+ "word" or "char", timestamp resolution to highlight.
86
+ color: str
87
+ color code for a word at its corresponding timestamp
88
+ <bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
89
+ underline: bool
90
+ whether to underline a word at its corresponding timestamp
91
+ prefmt: str
92
+ used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline')
93
+ appears as such in the .ass file:
94
+ Hi, {<prefmt>}how{<suffmt>} are you?
95
+ reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
96
+ suffmt: str
97
+ used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline')
98
+ appears as such in the .ass file:
99
+ Hi, {<prefmt>}how{<suffmt>} are you?
100
+ reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
101
+ font: str
102
+ word font (default: Arial)
103
+ font_size: int
104
+ word font size (default: 48)
105
+ kwargs:
106
+ used for format styles:
107
+ 'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
108
+ 'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
109
+ 'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
110
+
111
+ """
112
+
113
+ fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff',
114
+ 'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0',
115
+ 'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100',
116
+ 'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0',
117
+ 'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'}
118
+
119
+ for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()):
120
+ kwargs[k] = f'&H{kwargs[k]}'
121
+
122
+ fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict)
123
+
124
+ if font:
125
+ fmt_style_dict.update(Fontname=font)
126
+ if font_size:
127
+ fmt_style_dict.update(Fontsize=font_size)
128
+
129
+ fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}'
130
+
131
+ styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}'
132
+
133
+ ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \
134
+ f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
135
+ f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'
136
+
137
+ if prefmt or suffmt:
138
+ if suffmt:
139
+ assert prefmt, 'prefmt must be used along with suffmt'
140
+ else:
141
+ suffmt = r'\r'
142
+ else:
143
+ if not color:
144
+ color = 'HFF00'
145
+ underline_code = r'\u1' if underline else ''
146
+
147
+ prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}'
148
+ suffmt = r'{\r}'
149
+
150
+ def secs_to_hhmmss(secs: Tuple[float, int]):
151
+ mm, ss = divmod(secs, 60)
152
+ hh, mm = divmod(mm, 60)
153
+ return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
154
+
155
+
156
+ def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
157
+ if idx_0 == -1:
158
+ text = chars
159
+ else:
160
+ text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
161
+ return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
162
+ f"Default,,0,0,0,,{text.strip() if strip else text}"
163
+
164
+ if resolution == "word":
165
+ resolution_key = "word-segments"
166
+ elif resolution == "char":
167
+ resolution_key = "char-segments"
168
+ else:
169
+ raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
170
+
171
+ ass_arr = []
172
+
173
+ for segment in transcript:
174
+ # if "12" in segment['text']:
175
+ # import pdb; pdb.set_trace()
176
+ if resolution_key in segment:
177
+ res_segs = pd.DataFrame(segment[resolution_key])
178
+ prev = segment['start']
179
+ if "speaker" in segment:
180
+ speaker_str = f"[{segment['speaker']}]: "
181
+ else:
182
+ speaker_str = ""
183
+ for cdx, crow in res_segs.iterrows():
184
+ if not np.isnan(crow['start']):
185
+ if resolution == "char":
186
+ idx_0 = cdx
187
+ idx_1 = cdx + 1
188
+ elif resolution == "word":
189
+ idx_0 = int(crow["segment-text-start"])
190
+ idx_1 = int(crow["segment-text-end"])
191
+ # fill gap
192
+ if crow['start'] > prev:
193
+ filler_ts = {
194
+ "chars": speaker_str + segment['text'],
195
+ "start": prev,
196
+ "end": crow['start'],
197
+ "idx_0": -1,
198
+ "idx_1": -1
199
+ }
200
+
201
+ ass_arr.append(filler_ts)
202
+ # highlight current word
203
+ f_word_ts = {
204
+ "chars": speaker_str + segment['text'],
205
+ "start": crow['start'],
206
+ "end": crow['end'],
207
+ "idx_0": idx_0 + len(speaker_str),
208
+ "idx_1": idx_1 + len(speaker_str)
209
+ }
210
+ ass_arr.append(f_word_ts)
211
+ prev = crow['end']
212
+
213
+ ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
214
+
215
+ file.write(ass_str)
216
+
217
+
218
+ from whisper.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteVTT, WriteSRT, WriteTSV, WriteJSON, format_timestamp
219
+
220
+ class WriteASS(ResultWriter):
221
+ extension: str = "ass"
222
+
223
+ def write_result(self, result: dict, file: TextIO):
224
+ write_ass(result["segments"], file, resolution="word")
225
+
226
+ class WriteASSchar(ResultWriter):
227
+ extension: str = "ass"
228
+
229
+ def write_result(self, result: dict, file: TextIO):
230
+ write_ass(result["segments"], file, resolution="char")
231
+
232
+ class WritePickle(ResultWriter):
233
+ extension: str = "ass"
234
+
235
+ def write_result(self, result: dict, file: TextIO):
236
+ pd.DataFrame(result["segments"]).to_pickle(file)
237
+
238
+ class WriteSRTWord(ResultWriter):
239
+ extension: str = "word.srt"
240
+ always_include_hours: bool = True
241
+ decimal_marker: str = ","
242
+
243
+ def iterate_result(self, result: dict):
244
+ for segment in result["word_segments"]:
245
+ segment_start = self.format_timestamp(segment["start"])
246
+ segment_end = self.format_timestamp(segment["end"])
247
+ segment_text = segment["text"].strip().replace("-->", "->")
248
+
249
+ if word_timings := segment.get("words", None):
250
+ all_words = [timing["word"] for timing in word_timings]
251
+ all_words[0] = all_words[0].strip() # remove the leading space, if any
252
+ last = segment_start
253
+ for i, this_word in enumerate(word_timings):
254
+ start = self.format_timestamp(this_word["start"])
255
+ end = self.format_timestamp(this_word["end"])
256
+ if last != start:
257
+ yield last, start, segment_text
258
+
259
+ yield start, end, "".join(
260
+ [
261
+ f"<u>{word}</u>" if j == i else word
262
+ for j, word in enumerate(all_words)
263
+ ]
264
+ )
265
+ last = end
266
+
267
+ if last != segment_end:
268
+ yield last, segment_end, segment_text
269
+ else:
270
+ yield segment_start, segment_end, segment_text
271
+
272
+ def write_result(self, result: dict, file: TextIO):
273
+ if "word_segments" not in result:
274
+ return
275
+ for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
276
+ print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
277
+
278
+ def format_timestamp(self, seconds: float):
279
+ return format_timestamp(
280
+ seconds=seconds,
281
+ always_include_hours=self.always_include_hours,
282
+ decimal_marker=self.decimal_marker,
283
+ )
284
+
285
+ def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
286
+ writers = {
287
+ "txt": WriteTXT,
288
+ "vtt": WriteVTT,
289
+ "srt": WriteSRT,
290
+ "tsv": WriteTSV,
291
+ "ass": WriteASS,
292
+ "srt-word": WriteSRTWord,
293
+ # "ass-char": WriteASSchar,
294
+ # "pickle": WritePickle,
295
+ # "json": WriteJSON,
296
+ }
297
+
298
+ writers_other = {
299
+ "pkl": WritePickle,
300
+ "ass-char": WriteASSchar
301
+ }
302
+
303
+ if output_format == "all":
304
+ all_writers = [writer(output_dir) for writer in writers.values()]
305
+
306
+ def write_all(result: dict, file: TextIO):
307
+ for writer in all_writers:
308
+ writer(result, file)
309
+
310
+ return write_all
311
+
312
+ if output_format in writers:
313
+ return writers[output_format](output_dir)
314
+ elif output_format in writers_other:
315
+ return writers_other[output_format](output_dir)
316
+ else:
317
+ raise ValueError(f"Output format '{output_format}' not supported, choose from {writers.keys()} and {writers_other.keys()}")
whisperx/vad.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ import hashlib
7
+ from tqdm import tqdm
8
+ from typing import Optional, Callable, Union, Text
9
+ from pyannote.audio.core.io import AudioFile
10
+ from pyannote.core import Annotation, Segment, SlidingWindowFeature
11
+ from pyannote.audio.pipelines.utils import PipelineModel
12
+ from pyannote.audio import Model
13
+ from pyannote.audio.pipelines import VoiceActivityDetection
14
+ from .diarize import Segment as SegmentX
15
+ from typing import List, Tuple, Optional
16
+
17
+ VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
18
+
19
+ def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
20
+ model_dir = torch.hub._get_torch_home()
21
+ os.makedirs(model_dir, exist_ok = True)
22
+ model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
23
+ if os.path.exists(model_fp) and not os.path.isfile(model_fp):
24
+ raise RuntimeError(f"{model_fp} exists and is not a regular file")
25
+
26
+ if not os.path.isfile(model_fp):
27
+ with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
28
+ with tqdm(
29
+ total=int(source.info().get("Content-Length")),
30
+ ncols=80,
31
+ unit="iB",
32
+ unit_scale=True,
33
+ unit_divisor=1024,
34
+ ) as loop:
35
+ while True:
36
+ buffer = source.read(8192)
37
+ if not buffer:
38
+ break
39
+
40
+ output.write(buffer)
41
+ loop.update(len(buffer))
42
+
43
+ model_bytes = open(model_fp, "rb").read()
44
+ if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
45
+ raise RuntimeError(
46
+ "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
47
+ )
48
+
49
+ vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
50
+ hyperparameters = {"onset": vad_onset,
51
+ "offset": vad_offset,
52
+ "min_duration_on": 0.1,
53
+ "min_duration_off": 0.1}
54
+ vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
55
+ vad_pipeline.instantiate(hyperparameters)
56
+
57
+ return vad_pipeline
58
+
59
+ class Binarize:
60
+ """Binarize detection scores using hysteresis thresholding, with min-cut operation
61
+ to ensure not segments are longer than max_duration.
62
+
63
+ Parameters
64
+ ----------
65
+ onset : float, optional
66
+ Onset threshold. Defaults to 0.5.
67
+ offset : float, optional
68
+ Offset threshold. Defaults to `onset`.
69
+ min_duration_on : float, optional
70
+ Remove active regions shorter than that many seconds. Defaults to 0s.
71
+ min_duration_off : float, optional
72
+ Fill inactive regions shorter than that many seconds. Defaults to 0s.
73
+ pad_onset : float, optional
74
+ Extend active regions by moving their start time by that many seconds.
75
+ Defaults to 0s.
76
+ pad_offset : float, optional
77
+ Extend active regions by moving their end time by that many seconds.
78
+ Defaults to 0s.
79
+ max_duration: float
80
+ The maximum length of an active segment, divides segment at timestamp with lowest score.
81
+ Reference
82
+ ---------
83
+ Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
84
+ RNN-based Voice Activity Detection", InterSpeech 2015.
85
+
86
+ Modified by Max Bain to include WhisperX's min-cut operation
87
+ https://arxiv.org/abs/2303.00747
88
+
89
+ Pyannote-audio
90
+ """
91
+
92
+ def __init__(
93
+ self,
94
+ onset: float = 0.5,
95
+ offset: Optional[float] = None,
96
+ min_duration_on: float = 0.0,
97
+ min_duration_off: float = 0.0,
98
+ pad_onset: float = 0.0,
99
+ pad_offset: float = 0.0,
100
+ max_duration: float = float('inf')
101
+ ):
102
+
103
+ super().__init__()
104
+
105
+ self.onset = onset
106
+ self.offset = offset or onset
107
+
108
+ self.pad_onset = pad_onset
109
+ self.pad_offset = pad_offset
110
+
111
+ self.min_duration_on = min_duration_on
112
+ self.min_duration_off = min_duration_off
113
+
114
+ self.max_duration = max_duration
115
+
116
+ def __call__(self, scores: SlidingWindowFeature) -> Annotation:
117
+ """Binarize detection scores
118
+ Parameters
119
+ ----------
120
+ scores : SlidingWindowFeature
121
+ Detection scores.
122
+ Returns
123
+ -------
124
+ active : Annotation
125
+ Binarized scores.
126
+ """
127
+
128
+ num_frames, num_classes = scores.data.shape
129
+ frames = scores.sliding_window
130
+ timestamps = [frames[i].middle for i in range(num_frames)]
131
+
132
+ # annotation meant to store 'active' regions
133
+ active = Annotation()
134
+ for k, k_scores in enumerate(scores.data.T):
135
+
136
+ label = k if scores.labels is None else scores.labels[k]
137
+
138
+ # initial state
139
+ start = timestamps[0]
140
+ is_active = k_scores[0] > self.onset
141
+ curr_scores = [k_scores[0]]
142
+ curr_timestamps = [start]
143
+ for t, y in zip(timestamps[1:], k_scores[1:]):
144
+ # currently active
145
+ if is_active:
146
+ curr_duration = t - start
147
+ if curr_duration > self.max_duration:
148
+ # if curr_duration > 15:
149
+ # import pdb; pdb.set_trace()
150
+ search_after = len(curr_scores) // 2
151
+ # divide segment
152
+ min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
153
+ min_score_t = curr_timestamps[min_score_div_idx]
154
+ region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
155
+ active[region, k] = label
156
+ start = curr_timestamps[min_score_div_idx]
157
+ curr_scores = curr_scores[min_score_div_idx+1:]
158
+ curr_timestamps = curr_timestamps[min_score_div_idx+1:]
159
+ # switching from active to inactive
160
+ elif y < self.offset:
161
+ region = Segment(start - self.pad_onset, t + self.pad_offset)
162
+ active[region, k] = label
163
+ start = t
164
+ is_active = False
165
+ curr_scores = []
166
+ curr_timestamps = []
167
+ # currently inactive
168
+ else:
169
+ # switching from inactive to active
170
+ if y > self.onset:
171
+ start = t
172
+ is_active = True
173
+ curr_scores.append(y)
174
+ curr_timestamps.append(t)
175
+
176
+ # if active at the end, add final region
177
+ if is_active:
178
+ region = Segment(start - self.pad_onset, t + self.pad_offset)
179
+ active[region, k] = label
180
+
181
+ # because of padding, some active regions might be overlapping: merge them.
182
+ # also: fill same speaker gaps shorter than min_duration_off
183
+ if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
184
+ if self.max_duration < float("inf"):
185
+ raise NotImplementedError(f"This would break current max_duration param")
186
+ active = active.support(collar=self.min_duration_off)
187
+
188
+ # remove tracks shorter than min_duration_on
189
+ if self.min_duration_on > 0:
190
+ for segment, track in list(active.itertracks()):
191
+ if segment.duration < self.min_duration_on:
192
+ del active[segment, track]
193
+
194
+ return active
195
+
196
+
197
+ class VoiceActivitySegmentation(VoiceActivityDetection):
198
+ def __init__(
199
+ self,
200
+ segmentation: PipelineModel = "pyannote/segmentation",
201
+ fscore: bool = False,
202
+ use_auth_token: Union[Text, None] = None,
203
+ **inference_kwargs,
204
+ ):
205
+
206
+ super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
207
+
208
+ def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
209
+ """Apply voice activity detection
210
+
211
+ Parameters
212
+ ----------
213
+ file : AudioFile
214
+ Processed file.
215
+ hook : callable, optional
216
+ Hook called after each major step of the pipeline with the following
217
+ signature: hook("step_name", step_artefact, file=file)
218
+
219
+ Returns
220
+ -------
221
+ speech : Annotation
222
+ Speech regions.
223
+ """
224
+
225
+ # setup hook (e.g. for debugging purposes)
226
+ hook = self.setup_hook(file, hook=hook)
227
+
228
+ # apply segmentation model (only if needed)
229
+ # output shape is (num_chunks, num_frames, 1)
230
+ if self.training:
231
+ if self.CACHED_SEGMENTATION in file:
232
+ segmentations = file[self.CACHED_SEGMENTATION]
233
+ else:
234
+ segmentations = self._segmentation(file)
235
+ file[self.CACHED_SEGMENTATION] = segmentations
236
+ else:
237
+ segmentations: SlidingWindowFeature = self._segmentation(file)
238
+
239
+ return segmentations
240
+
241
+
242
+ def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
243
+
244
+ active = Annotation()
245
+ for k, vad_t in enumerate(vad_arr):
246
+ region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
247
+ active[region, k] = 1
248
+
249
+
250
+ if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
251
+ active = active.support(collar=min_duration_off)
252
+
253
+ # remove tracks shorter than min_duration_on
254
+ if min_duration_on > 0:
255
+ for segment, track in list(active.itertracks()):
256
+ if segment.duration < min_duration_on:
257
+ del active[segment, track]
258
+
259
+ active = active.for_json()
260
+ active_segs = pd.DataFrame([x['segment'] for x in active['content']])
261
+ return active_segs
262
+
263
+ def merge_chunks(segments, chunk_size):
264
+ """
265
+ Merge operation described in paper
266
+ """
267
+ curr_end = 0
268
+ merged_segments = []
269
+ seg_idxs = []
270
+ speaker_idxs = []
271
+
272
+ assert chunk_size > 0
273
+ binarize = Binarize(max_duration=chunk_size)
274
+ segments = binarize(segments)
275
+ segments_list = []
276
+ for speech_turn in segments.get_timeline():
277
+ segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
278
+
279
+ if len(segments_list) == 0:
280
+ print("No active speech found in audio")
281
+ return []
282
+ # assert segments_list, "segments_list is empty."
283
+ # Make sur the starting point is the start of the segment.
284
+ curr_start = segments_list[0].start
285
+
286
+ for seg in segments_list:
287
+ if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
288
+ merged_segments.append({
289
+ "start": curr_start,
290
+ "end": curr_end,
291
+ "segments": seg_idxs,
292
+ })
293
+ curr_start = seg.start
294
+ seg_idxs = []
295
+ speaker_idxs = []
296
+ curr_end = seg.end
297
+ seg_idxs.append((seg.start, seg.end))
298
+ speaker_idxs.append(seg.speaker)
299
+ # add final
300
+ merged_segments.append({
301
+ "start": curr_start,
302
+ "end": curr_end,
303
+ "segments": seg_idxs,
304
+ })
305
+ return merged_segments