ras0k commited on
Commit
a36e6e8
·
1 Parent(s): c464395

Revert "whisperX merged in"

Browse files

This reverts commit 9c2d684d1a778e559b6b13bb90112712dbc20568.

.gitignore DELETED
@@ -1,2 +0,0 @@
1
- whisperx.egg-info/
2
- **/__pycache__/
 
 
 
EXAMPLES.md DELETED
@@ -1,37 +0,0 @@
1
- # More Examples
2
-
3
- ## Other Languages
4
-
5
- For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18).
6
-
7
- Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}
8
-
9
-
10
- If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
11
-
12
- ### French
13
- whisperx --model large --language fr examples/sample_fr_01.wav
14
-
15
-
16
- https://user-images.githubusercontent.com/36994049/208298804-31c49d6f-6787-444e-a53f-e93c52706752.mov
17
-
18
-
19
- ### German
20
- whisperx --model large --language de examples/sample_de_01.wav
21
-
22
-
23
- https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
24
-
25
-
26
- ### Italian
27
- whisperx --model large --language de examples/sample_it_01.wav
28
-
29
-
30
- https://user-images.githubusercontent.com/36994049/208298819-6f462b2c-8cae-4c54-b8e1-90855794efc7.mov
31
-
32
-
33
- ### Japanese
34
- whisperx --model large --language ja examples/sample_ja_01.wav
35
-
36
-
37
- https://user-images.githubusercontent.com/19920981/208731743-311f2360-b73b-4c60-809d-aaf3cd7e06f4.mov
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE DELETED
@@ -1,27 +0,0 @@
1
- Copyright (c) 2022, Max Bain
2
- All rights reserved.
3
-
4
- Redistribution and use in source and binary forms, with or without
5
- modification, are permitted provided that the following conditions are met:
6
- 1. Redistributions of source code must retain the above copyright
7
- notice, this list of conditions and the following disclaimer.
8
- 2. Redistributions in binary form must reproduce the above copyright
9
- notice, this list of conditions and the following disclaimer in the
10
- documentation and/or other materials provided with the distribution.
11
- 3. All advertising materials mentioning features or use of this software
12
- must display the following acknowledgement:
13
- This product includes software developed by Max Bain.
14
- 4. Neither the name of Max Bain nor the
15
- names of its contributors may be used to endorse or promote products
16
- derived from this software without specific prior written permission.
17
-
18
- THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER ''AS IS'' AND ANY
19
- EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20
- WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21
- DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22
- FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23
- DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24
- SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25
- CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26
- OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
27
- USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
MANIFEST.in DELETED
@@ -1,4 +0,0 @@
1
- include whisperx/assets/*
2
- include whisperx/assets/gpt2/*
3
- include whisperx/assets/multilingual/*
4
- include whisperx/normalizers/english.json
 
 
 
 
 
figures/pipeline.png DELETED
Binary file (123 kB)
 
requirements.txt DELETED
@@ -1,10 +0,0 @@
1
- numpy
2
- pandas
3
- torch >=1.9
4
- torchaudio >=0.10,<1.0
5
- tqdm
6
- more-itertools
7
- transformers>=4.19.0
8
- ffmpeg-python==0.2.0
9
- pyannote.audio
10
- openai-whisper==20230314
 
 
 
 
 
 
 
 
 
 
 
setup.py DELETED
@@ -1,28 +0,0 @@
1
- import os
2
-
3
- import pkg_resources
4
- from setuptools import setup, find_packages
5
-
6
- setup(
7
- name="whisperx",
8
- py_modules=["whisperx"],
9
- version="2.0",
10
- description="Time-Accurate Automatic Speech Recognition using Whisper.",
11
- readme="README.md",
12
- python_requires=">=3.8",
13
- author="Max Bain",
14
- url="https://github.com/m-bain/whisperx",
15
- license="MIT",
16
- packages=find_packages(exclude=["tests*"]),
17
- install_requires=[
18
- str(r)
19
- for r in pkg_resources.parse_requirements(
20
- open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
21
- )
22
- ],
23
- entry_points = {
24
- 'console_scripts': ['whisperx=whisperx.transcribe:cli'],
25
- },
26
- include_package_data=True,
27
- extras_require={'dev': ['pytest']},
28
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperx/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .transcribe import transcribe, transcribe_with_vad
2
- from .alignment import load_align_model, align
3
- from .vad import load_vad_model
 
 
 
 
whisperx/__main__.py DELETED
@@ -1,4 +0,0 @@
1
- from .transcribe import cli
2
-
3
-
4
- cli()
 
 
 
 
 
whisperx/alignment.py DELETED
@@ -1,548 +0,0 @@
1
- """"
2
- Forced Alignment with Whisper
3
- C. Max Bain
4
- """
5
- import numpy as np
6
- import pandas as pd
7
- from typing import List, Union, Iterator, TYPE_CHECKING
8
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
9
- import torchaudio
10
- import torch
11
- from dataclasses import dataclass
12
- from whisper.audio import SAMPLE_RATE, load_audio
13
- from .utils import interpolate_nans
14
-
15
-
16
- LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
17
-
18
- DEFAULT_ALIGN_MODELS_TORCH = {
19
- "en": "WAV2VEC2_ASR_BASE_960H",
20
- "fr": "VOXPOPULI_ASR_BASE_10K_FR",
21
- "de": "VOXPOPULI_ASR_BASE_10K_DE",
22
- "es": "VOXPOPULI_ASR_BASE_10K_ES",
23
- "it": "VOXPOPULI_ASR_BASE_10K_IT",
24
- }
25
-
26
- DEFAULT_ALIGN_MODELS_HF = {
27
- "ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
28
- "zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
29
- "nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
30
- "uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
31
- "pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
32
- "ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
33
- "ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
34
- "pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
35
- "hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
36
- "fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
37
- "fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
38
- "el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
39
- "tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
40
- }
41
-
42
-
43
- def load_align_model(language_code, device, model_name=None):
44
- if model_name is None:
45
- # use default model
46
- if language_code in DEFAULT_ALIGN_MODELS_TORCH:
47
- model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
48
- elif language_code in DEFAULT_ALIGN_MODELS_HF:
49
- model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
50
- else:
51
- print(f"There is no default alignment model set for this language ({language_code}).\
52
- Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
53
- raise ValueError(f"No default align-model for language: {language_code}")
54
-
55
- if model_name in torchaudio.pipelines.__all__:
56
- pipeline_type = "torchaudio"
57
- bundle = torchaudio.pipelines.__dict__[model_name]
58
- align_model = bundle.get_model().to(device)
59
- labels = bundle.get_labels()
60
- align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
61
- else:
62
- try:
63
- processor = Wav2Vec2Processor.from_pretrained(model_name)
64
- align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
65
- except Exception as e:
66
- print(e)
67
- print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
68
- raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
69
- pipeline_type = "huggingface"
70
- align_model = align_model.to(device)
71
- labels = processor.tokenizer.get_vocab()
72
- align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
73
-
74
- align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
75
-
76
- return align_model, align_metadata
77
-
78
-
79
- def align(
80
- transcript: Iterator[dict],
81
- model: torch.nn.Module,
82
- align_model_metadata: dict,
83
- audio: Union[str, np.ndarray, torch.Tensor],
84
- device: str,
85
- extend_duration: float = 0.0,
86
- start_from_previous: bool = True,
87
- interpolate_method: str = "nearest",
88
- ):
89
- """
90
- Force align phoneme recognition predictions to known transcription
91
-
92
- Parameters
93
- ----------
94
- transcript: Iterator[dict]
95
- The Whisper model instance
96
-
97
- model: torch.nn.Module
98
- Alignment model (wav2vec2)
99
-
100
- audio: Union[str, np.ndarray, torch.Tensor]
101
- The path to the audio file to open, or the audio waveform
102
-
103
- device: str
104
- cuda device
105
-
106
- diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
107
- diarization segments with speaker labels.
108
-
109
- extend_duration: float
110
- Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
111
-
112
- If the gzip compression ratio is above this value, treat as failed
113
-
114
- interpolate_method: str ["nearest", "linear", "ignore"]
115
- Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
116
- "nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
117
-
118
- Returns
119
- -------
120
- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
121
- the spoken language ("language"), which is detected when `decode_options["language"]` is None.
122
- """
123
- if not torch.is_tensor(audio):
124
- if isinstance(audio, str):
125
- audio = load_audio(audio)
126
- audio = torch.from_numpy(audio)
127
- if len(audio.shape) == 1:
128
- audio = audio.unsqueeze(0)
129
-
130
- MAX_DURATION = audio.shape[1] / SAMPLE_RATE
131
-
132
- model_dictionary = align_model_metadata["dictionary"]
133
- model_lang = align_model_metadata["language"]
134
- model_type = align_model_metadata["type"]
135
-
136
- aligned_segments = []
137
-
138
- prev_t2 = 0
139
-
140
- char_segments_arr = {
141
- "segment-idx": [],
142
- "subsegment-idx": [],
143
- "word-idx": [],
144
- "char": [],
145
- "start": [],
146
- "end": [],
147
- "score": [],
148
- }
149
-
150
- for sdx, segment in enumerate(transcript):
151
- while True:
152
- segment_align_success = False
153
-
154
- # strip spaces at beginning / end, but keep track of the amount.
155
- num_leading = len(segment["text"]) - len(segment["text"].lstrip())
156
- num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
157
- transcription = segment["text"]
158
-
159
- # TODO: convert number tokenizer / symbols to phonetic words for alignment.
160
- # e.g. "$300" -> "three hundred dollars"
161
- # currently "$300" is ignored since no characters present in the phonetic dictionary
162
-
163
- # split into words
164
- if model_lang not in LANGUAGES_WITHOUT_SPACES:
165
- per_word = transcription.split(" ")
166
- else:
167
- per_word = transcription
168
-
169
- # first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
170
- clean_char, clean_cdx = [], []
171
- for cdx, char in enumerate(transcription):
172
- char_ = char.lower()
173
- # wav2vec2 models use "|" character to represent spaces
174
- if model_lang not in LANGUAGES_WITHOUT_SPACES:
175
- char_ = char_.replace(" ", "|")
176
-
177
- # ignore whitespace at beginning and end of transcript
178
- if cdx < num_leading:
179
- pass
180
- elif cdx > len(transcription) - num_trailing - 1:
181
- pass
182
- elif char_ in model_dictionary.keys():
183
- clean_char.append(char_)
184
- clean_cdx.append(cdx)
185
-
186
- clean_wdx = []
187
- for wdx, wrd in enumerate(per_word):
188
- if any([c in model_dictionary.keys() for c in wrd]):
189
- clean_wdx.append(wdx)
190
-
191
- # if no characters are in the dictionary, then we skip this segment...
192
- if len(clean_char) == 0:
193
- print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
194
- break
195
-
196
- transcription_cleaned = "".join(clean_char)
197
- tokens = [model_dictionary[c] for c in transcription_cleaned]
198
-
199
- # we only pad if not using VAD filtering
200
- if "seg_text" not in segment:
201
- # pad according original timestamps
202
- t1 = max(segment["start"] - extend_duration, 0)
203
- t2 = min(segment["end"] + extend_duration, MAX_DURATION)
204
-
205
- # use prev_t2 as current t1 if it"s later
206
- if start_from_previous and t1 < prev_t2:
207
- t1 = prev_t2
208
-
209
- # check if timestamp range is still valid
210
- if t1 >= MAX_DURATION:
211
- print("Failed to align segment: original start time longer than audio duration, skipping...")
212
- break
213
- if t2 - t1 < 0.02:
214
- print("Failed to align segment: duration smaller than 0.02s time precision")
215
- break
216
-
217
- f1 = int(t1 * SAMPLE_RATE)
218
- f2 = int(t2 * SAMPLE_RATE)
219
-
220
- waveform_segment = audio[:, f1:f2]
221
-
222
- with torch.inference_mode():
223
- if model_type == "torchaudio":
224
- emissions, _ = model(waveform_segment.to(device))
225
- elif model_type == "huggingface":
226
- emissions = model(waveform_segment.to(device)).logits
227
- else:
228
- raise NotImplementedError(f"Align model of type {model_type} not supported.")
229
- emissions = torch.log_softmax(emissions, dim=-1)
230
-
231
- emission = emissions[0].cpu().detach()
232
-
233
- trellis = get_trellis(emission, tokens)
234
- path = backtrack(trellis, emission, tokens)
235
- if path is None:
236
- print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
237
- break
238
- char_segments = merge_repeats(path, transcription_cleaned)
239
- # word_segments = merge_words(char_segments)
240
-
241
-
242
- # sub-segments
243
- if "seg-text" not in segment:
244
- segment["seg-text"] = [transcription]
245
-
246
- seg_lens = [0] + [len(x) for x in segment["seg-text"]]
247
- seg_lens_cumsum = list(np.cumsum(seg_lens))
248
- sub_seg_idx = 0
249
-
250
- wdx = 0
251
- duration = t2 - t1
252
- ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
253
- for cdx, char in enumerate(transcription + " "):
254
- is_last = False
255
- if cdx == len(transcription):
256
- break
257
- elif cdx+1 == len(transcription):
258
- is_last = True
259
-
260
-
261
- start, end, score = None, None, None
262
- if cdx in clean_cdx:
263
- char_seg = char_segments[clean_cdx.index(cdx)]
264
- start = char_seg.start * ratio + t1
265
- end = char_seg.end * ratio + t1
266
- score = char_seg.score
267
-
268
- char_segments_arr["char"].append(char)
269
- char_segments_arr["start"].append(start)
270
- char_segments_arr["end"].append(end)
271
- char_segments_arr["score"].append(score)
272
- char_segments_arr["word-idx"].append(wdx)
273
- char_segments_arr["segment-idx"].append(sdx)
274
- char_segments_arr["subsegment-idx"].append(sub_seg_idx)
275
-
276
- # word-level info
277
- if model_lang in LANGUAGES_WITHOUT_SPACES:
278
- # character == word
279
- wdx += 1
280
- elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
281
- wdx += 1
282
-
283
- if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
284
- wdx = 0
285
- sub_seg_idx += 1
286
-
287
- prev_t2 = segment["end"]
288
-
289
- segment_align_success = True
290
- # end while True loop
291
- break
292
-
293
- # reset prev_t2 due to drifting issues
294
- if not segment_align_success:
295
- prev_t2 = 0
296
-
297
- char_segments_arr = pd.DataFrame(char_segments_arr)
298
- not_space = char_segments_arr["char"] != " "
299
-
300
- per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
301
- char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
302
- per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
303
- per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
304
- per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
305
- char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
306
- per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
307
-
308
- word_segments_arr = {}
309
-
310
- # start of word is first char with a timestamp
311
- word_segments_arr["start"] = per_word_grp["start"].min().values
312
- # end of word is last char with a timestamp
313
- word_segments_arr["end"] = per_word_grp["end"].max().values
314
- # score of word is mean (excluding nan)
315
- word_segments_arr["score"] = per_word_grp["score"].mean().values
316
-
317
- word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
318
- word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
319
- word_segments_arr = pd.DataFrame(word_segments_arr)
320
-
321
- word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
322
- segments_arr = {}
323
- segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
324
- segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
325
- segments_arr = pd.DataFrame(segments_arr)
326
- segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
327
- segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
328
-
329
- # interpolate missing words / sub-segments
330
- if interpolate_method != "ignore":
331
- wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
332
- wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
333
- # we still know which word timestamps are interpolated because their score == nan
334
- word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
335
- word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
336
-
337
- word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
338
- word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
339
-
340
- sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
341
- segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
342
- segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
343
-
344
- # merge words & subsegments which are missing times
345
- word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
346
-
347
- word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
348
- word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
349
- word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
350
-
351
- seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
352
- segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
353
- segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
354
- segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
355
- else:
356
- word_segments_arr.dropna(inplace=True)
357
- segments_arr.dropna(inplace=True)
358
-
359
- # if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
360
- segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
361
- segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
362
- segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
363
- segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
364
-
365
-
366
- aligned_segments = []
367
- aligned_segments_word = []
368
-
369
- word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
370
- char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
371
-
372
- for sdx, srow in segments_arr.iterrows():
373
-
374
- seg_idx = int(srow["segment-idx"])
375
- sub_start = int(srow["subsegment-idx-start"])
376
- sub_end = int(srow["subsegment-idx-end"])
377
-
378
- seg = transcript[seg_idx]
379
- text = "".join(seg["seg-text"][sub_start:sub_end])
380
-
381
- wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
382
- wseg["start"].fillna(srow["start"], inplace=True)
383
- wseg["end"].fillna(srow["end"], inplace=True)
384
- wseg["segment-text-start"].fillna(0, inplace=True)
385
- wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
386
-
387
- cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
388
- # fixes bug for single segment in transcript
389
- cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
390
- cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
391
- if 'level_1' in cseg: del cseg['level_1']
392
- if 'level_0' in cseg: del cseg['level_0']
393
- cseg.reset_index(inplace=True)
394
- aligned_segments.append(
395
- {
396
- "start": srow["start"],
397
- "end": srow["end"],
398
- "text": text,
399
- "word-segments": wseg,
400
- "char-segments": cseg
401
- }
402
- )
403
-
404
- def get_raw_text(word_row):
405
- return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
406
-
407
- wdx = 0
408
- curr_text = get_raw_text(wseg.iloc[wdx])
409
- if len(wseg) > 1:
410
- for _, wrow in wseg.iloc[1:].iterrows():
411
- if wrow['start'] != wseg.iloc[wdx]['start']:
412
- aligned_segments_word.append(
413
- {
414
- "text": curr_text.strip(),
415
- "start": wseg.iloc[wdx]["start"],
416
- "end": wseg.iloc[wdx]["end"],
417
- }
418
- )
419
- curr_text = ""
420
- curr_text += " " + get_raw_text(wrow)
421
- wdx += 1
422
- aligned_segments_word.append(
423
- {
424
- "text": curr_text.strip(),
425
- "start": wseg.iloc[wdx]["start"],
426
- "end": wseg.iloc[wdx]["end"]
427
- }
428
- )
429
-
430
-
431
- return {"segments": aligned_segments, "word_segments": aligned_segments_word}
432
-
433
-
434
- """
435
- source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
436
- """
437
- def get_trellis(emission, tokens, blank_id=0):
438
- num_frame = emission.size(0)
439
- num_tokens = len(tokens)
440
-
441
- # Trellis has extra diemsions for both time axis and tokens.
442
- # The extra dim for tokens represents <SoS> (start-of-sentence)
443
- # The extra dim for time axis is for simplification of the code.
444
- trellis = torch.empty((num_frame + 1, num_tokens + 1))
445
- trellis[0, 0] = 0
446
- trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
447
- trellis[0, -num_tokens:] = -float("inf")
448
- trellis[-num_tokens:, 0] = float("inf")
449
-
450
- for t in range(num_frame):
451
- trellis[t + 1, 1:] = torch.maximum(
452
- # Score for staying at the same token
453
- trellis[t, 1:] + emission[t, blank_id],
454
- # Score for changing to the next token
455
- trellis[t, :-1] + emission[t, tokens],
456
- )
457
- return trellis
458
-
459
- @dataclass
460
- class Point:
461
- token_index: int
462
- time_index: int
463
- score: float
464
-
465
- def backtrack(trellis, emission, tokens, blank_id=0):
466
- # Note:
467
- # j and t are indices for trellis, which has extra dimensions
468
- # for time and tokens at the beginning.
469
- # When referring to time frame index `T` in trellis,
470
- # the corresponding index in emission is `T-1`.
471
- # Similarly, when referring to token index `J` in trellis,
472
- # the corresponding index in transcript is `J-1`.
473
- j = trellis.size(1) - 1
474
- t_start = torch.argmax(trellis[:, j]).item()
475
-
476
- path = []
477
- for t in range(t_start, 0, -1):
478
- # 1. Figure out if the current position was stay or change
479
- # Note (again):
480
- # `emission[J-1]` is the emission at time frame `J` of trellis dimension.
481
- # Score for token staying the same from time frame J-1 to T.
482
- stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
483
- # Score for token changing from C-1 at T-1 to J at T.
484
- changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
485
-
486
- # 2. Store the path with frame-wise probability.
487
- prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
488
- # Return token index and time index in non-trellis coordinate.
489
- path.append(Point(j - 1, t - 1, prob))
490
-
491
- # 3. Update the token
492
- if changed > stayed:
493
- j -= 1
494
- if j == 0:
495
- break
496
- else:
497
- # failed
498
- return None
499
- return path[::-1]
500
-
501
- # Merge the labels
502
- @dataclass
503
- class Segment:
504
- label: str
505
- start: int
506
- end: int
507
- score: float
508
-
509
- def __repr__(self):
510
- return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
511
-
512
- @property
513
- def length(self):
514
- return self.end - self.start
515
-
516
- def merge_repeats(path, transcript):
517
- i1, i2 = 0, 0
518
- segments = []
519
- while i1 < len(path):
520
- while i2 < len(path) and path[i1].token_index == path[i2].token_index:
521
- i2 += 1
522
- score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
523
- segments.append(
524
- Segment(
525
- transcript[path[i1].token_index],
526
- path[i1].time_index,
527
- path[i2 - 1].time_index + 1,
528
- score,
529
- )
530
- )
531
- i1 = i2
532
- return segments
533
-
534
- def merge_words(segments, separator="|"):
535
- words = []
536
- i1, i2 = 0, 0
537
- while i1 < len(segments):
538
- if i2 >= len(segments) or segments[i2].label == separator:
539
- if i1 != i2:
540
- segs = segments[i1:i2]
541
- word = "".join([seg.label for seg in segs])
542
- score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
543
- words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
544
- i1 = i2 + 1
545
- i2 = i1
546
- else:
547
- i2 += 1
548
- return words
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperx/asr.py DELETED
@@ -1,429 +0,0 @@
1
- import warnings
2
- from typing import TYPE_CHECKING, Optional, Tuple, Union
3
- import numpy as np
4
- import torch
5
- import tqdm
6
- import ffmpeg
7
- from whisper.audio import (
8
- FRAMES_PER_SECOND,
9
- HOP_LENGTH,
10
- N_FRAMES,
11
- N_SAMPLES,
12
- SAMPLE_RATE,
13
- CHUNK_LENGTH,
14
- log_mel_spectrogram,
15
- pad_or_trim,
16
- load_audio
17
- )
18
- from whisper.decoding import DecodingOptions, DecodingResult
19
- from whisper.timing import add_word_timestamps
20
- from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
21
- from whisper.utils import (
22
- exact_div,
23
- format_timestamp,
24
- make_safe,
25
- )
26
-
27
- if TYPE_CHECKING:
28
- from whisper.model import Whisper
29
-
30
- from .vad import merge_chunks
31
-
32
- def transcribe(
33
- model: "Whisper",
34
- audio: Union[str, np.ndarray, torch.Tensor] = None,
35
- mel: np.ndarray = None,
36
- verbose: Optional[bool] = None,
37
- temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
38
- compression_ratio_threshold: Optional[float] = 2.4,
39
- logprob_threshold: Optional[float] = -1.0,
40
- no_speech_threshold: Optional[float] = 0.6,
41
- condition_on_previous_text: bool = True,
42
- initial_prompt: Optional[str] = None,
43
- word_timestamps: bool = False,
44
- prepend_punctuations: str = "\"'“¿([{-",
45
- append_punctuations: str = "\"'.。,,!!??::”)]}、",
46
- **decode_options,
47
- ):
48
- """
49
- Transcribe an audio file using Whisper.
50
- We redefine the Whisper transcribe function to allow mel input (for sequential slicing of audio)
51
-
52
- Parameters
53
- ----------
54
- model: Whisper
55
- The Whisper model instance
56
-
57
- audio: Union[str, np.ndarray, torch.Tensor]
58
- The path to the audio file to open, or the audio waveform
59
-
60
- mel: np.ndarray
61
- Mel spectrogram of audio segment.
62
-
63
- verbose: bool
64
- Whether to display the text being decoded to the console. If True, displays all the details,
65
- If False, displays minimal details. If None, does not display anything
66
-
67
- temperature: Union[float, Tuple[float, ...]]
68
- Temperature for sampling. It can be a tuple of temperatures, which will be successively used
69
- upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
70
-
71
- compression_ratio_threshold: float
72
- If the gzip compression ratio is above this value, treat as failed
73
-
74
- logprob_threshold: float
75
- If the average log probability over sampled tokens is below this value, treat as failed
76
-
77
- no_speech_threshold: float
78
- If the no_speech probability is higher than this value AND the average log probability
79
- over sampled tokens is below `logprob_threshold`, consider the segment as silent
80
-
81
- condition_on_previous_text: bool
82
- if True, the previous output of the model is provided as a prompt for the next window;
83
- disabling may make the text inconsistent across windows, but the model becomes less prone to
84
- getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
85
-
86
- word_timestamps: bool
87
- Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
88
- and include the timestamps for each word in each segment.
89
-
90
- prepend_punctuations: str
91
- If word_timestamps is True, merge these punctuation symbols with the next word
92
-
93
- append_punctuations: str
94
- If word_timestamps is True, merge these punctuation symbols with the previous word
95
-
96
- initial_prompt: Optional[str]
97
- Optional text to provide as a prompt for the first window. This can be used to provide, or
98
- "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
99
- to make it more likely to predict those word correctly.
100
-
101
- decode_options: dict
102
- Keyword arguments to construct `DecodingOptions` instances
103
-
104
- Returns
105
- -------
106
- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
107
- the spoken language ("language"), which is detected when `decode_options["language"]` is None.
108
- """
109
- dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
110
- if model.device == torch.device("cpu"):
111
- if torch.cuda.is_available():
112
- warnings.warn("Performing inference on CPU when CUDA is available")
113
- if dtype == torch.float16:
114
- warnings.warn("FP16 is not supported on CPU; using FP32 instead")
115
- dtype = torch.float32
116
-
117
- if dtype == torch.float32:
118
- decode_options["fp16"] = False
119
-
120
- # Pad 30-seconds of silence to the input audio, for slicing
121
- if mel is None:
122
- if audio is None:
123
- raise ValueError("Transcribe needs either audio or mel as input, currently both are none.")
124
- mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
125
- content_frames = mel.shape[-1] - N_FRAMES
126
-
127
- if decode_options.get("language", None) is None:
128
- if not model.is_multilingual:
129
- decode_options["language"] = "en"
130
- else:
131
- if verbose:
132
- print(
133
- "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
134
- )
135
- mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
136
- _, probs = model.detect_language(mel_segment)
137
- decode_options["language"] = max(probs, key=probs.get)
138
- if verbose is not None:
139
- print(
140
- f"Detected language: {LANGUAGES[decode_options['language']].title()}"
141
- )
142
-
143
- language: str = decode_options["language"]
144
- task: str = decode_options.get("task", "transcribe")
145
- tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
146
-
147
- if word_timestamps and task == "translate":
148
- warnings.warn("Word-level timestamps on translations may not be reliable.")
149
-
150
- def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
151
- temperatures = (
152
- [temperature] if isinstance(temperature, (int, float)) else temperature
153
- )
154
- decode_result = None
155
-
156
- for t in temperatures:
157
- kwargs = {**decode_options}
158
- if t > 0:
159
- # disable beam_size and patience when t > 0
160
- kwargs.pop("beam_size", None)
161
- kwargs.pop("patience", None)
162
- else:
163
- # disable best_of when t == 0
164
- kwargs.pop("best_of", None)
165
-
166
- options = DecodingOptions(**kwargs, temperature=t)
167
- decode_result = model.decode(segment, options)
168
-
169
- needs_fallback = False
170
- if (
171
- compression_ratio_threshold is not None
172
- and decode_result.compression_ratio > compression_ratio_threshold
173
- ):
174
- needs_fallback = True # too repetitive
175
- if (
176
- logprob_threshold is not None
177
- and decode_result.avg_logprob < logprob_threshold
178
- ):
179
- needs_fallback = True # average log probability is too low
180
-
181
- if not needs_fallback:
182
- break
183
-
184
- return decode_result
185
-
186
- seek = 0
187
- input_stride = exact_div(
188
- N_FRAMES, model.dims.n_audio_ctx
189
- ) # mel frames per output token: 2
190
- time_precision = (
191
- input_stride * HOP_LENGTH / SAMPLE_RATE
192
- ) # time per output token: 0.02 (seconds)
193
- all_tokens = []
194
- all_segments = []
195
- prompt_reset_since = 0
196
-
197
- if initial_prompt is not None:
198
- initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
199
- all_tokens.extend(initial_prompt_tokens)
200
- else:
201
- initial_prompt_tokens = []
202
-
203
- def new_segment(
204
- *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
205
- ):
206
- tokens = tokens.tolist()
207
- text_tokens = [token for token in tokens if token < tokenizer.eot]
208
- return {
209
- "seek": seek,
210
- "start": start,
211
- "end": end,
212
- "text": tokenizer.decode(text_tokens),
213
- "tokens": tokens,
214
- "temperature": result.temperature,
215
- "avg_logprob": result.avg_logprob,
216
- "compression_ratio": result.compression_ratio,
217
- "no_speech_prob": result.no_speech_prob,
218
- }
219
-
220
-
221
- # show the progress bar when verbose is False (if True, transcribed text will be printed)
222
- with tqdm.tqdm(
223
- total=content_frames, unit="frames", disable=verbose is not False
224
- ) as pbar:
225
- while seek < content_frames:
226
- time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
227
- mel_segment = mel[:, seek : seek + N_FRAMES]
228
- segment_size = min(N_FRAMES, content_frames - seek)
229
- segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
230
- mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
231
-
232
- decode_options["prompt"] = all_tokens[prompt_reset_since:]
233
- result: DecodingResult = decode_with_fallback(mel_segment)
234
- tokens = torch.tensor(result.tokens)
235
- if no_speech_threshold is not None:
236
- # no voice activity check
237
- should_skip = result.no_speech_prob > no_speech_threshold
238
- if (
239
- logprob_threshold is not None
240
- and result.avg_logprob > logprob_threshold
241
- ):
242
- # don't skip if the logprob is high enough, despite the no_speech_prob
243
- should_skip = False
244
-
245
- if should_skip:
246
- seek += segment_size # fast-forward to the next segment boundary
247
- continue
248
-
249
- previous_seek = seek
250
- current_segments = []
251
-
252
- timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
253
- single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
254
-
255
- consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
256
- consecutive.add_(1)
257
- if len(consecutive) > 0:
258
- # if the output contains two consecutive timestamp tokens
259
- slices = consecutive.tolist()
260
- if single_timestamp_ending:
261
- slices.append(len(tokens))
262
-
263
- last_slice = 0
264
- for current_slice in slices:
265
- sliced_tokens = tokens[last_slice:current_slice]
266
- start_timestamp_pos = (
267
- sliced_tokens[0].item() - tokenizer.timestamp_begin
268
- )
269
- end_timestamp_pos = (
270
- sliced_tokens[-1].item() - tokenizer.timestamp_begin
271
- )
272
- current_segments.append(
273
- new_segment(
274
- start=time_offset + start_timestamp_pos * time_precision,
275
- end=time_offset + end_timestamp_pos * time_precision,
276
- tokens=sliced_tokens,
277
- result=result,
278
- )
279
- )
280
- last_slice = current_slice
281
-
282
- if single_timestamp_ending:
283
- # single timestamp at the end means no speech after the last timestamp.
284
- seek += segment_size
285
- else:
286
- # otherwise, ignore the unfinished segment and seek to the last timestamp
287
- last_timestamp_pos = (
288
- tokens[last_slice - 1].item() - tokenizer.timestamp_begin
289
- )
290
- seek += last_timestamp_pos * input_stride
291
- else:
292
- duration = segment_duration
293
- timestamps = tokens[timestamp_tokens.nonzero().flatten()]
294
- if (
295
- len(timestamps) > 0
296
- and timestamps[-1].item() != tokenizer.timestamp_begin
297
- ):
298
- # no consecutive timestamps but it has a timestamp; use the last one.
299
- last_timestamp_pos = (
300
- timestamps[-1].item() - tokenizer.timestamp_begin
301
- )
302
- duration = last_timestamp_pos * time_precision
303
-
304
- current_segments.append(
305
- new_segment(
306
- start=time_offset,
307
- end=time_offset + duration,
308
- tokens=tokens,
309
- result=result,
310
- )
311
- )
312
- seek += segment_size
313
-
314
- if not condition_on_previous_text or result.temperature > 0.5:
315
- # do not feed the prompt tokens if a high temperature was used
316
- prompt_reset_since = len(all_tokens)
317
-
318
- if word_timestamps:
319
- add_word_timestamps(
320
- segments=current_segments,
321
- model=model,
322
- tokenizer=tokenizer,
323
- mel=mel_segment,
324
- num_frames=segment_size,
325
- prepend_punctuations=prepend_punctuations,
326
- append_punctuations=append_punctuations,
327
- )
328
- word_end_timestamps = [
329
- w["end"] for s in current_segments for w in s["words"]
330
- ]
331
- if not single_timestamp_ending and len(word_end_timestamps) > 0:
332
- seek_shift = round(
333
- (word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
334
- )
335
- if seek_shift > 0:
336
- seek = previous_seek + seek_shift
337
-
338
- if verbose:
339
- for segment in current_segments:
340
- start, end, text = segment["start"], segment["end"], segment["text"]
341
- line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
342
- print(make_safe(line))
343
-
344
- # if a segment is instantaneous or does not contain text, clear it
345
- for i, segment in enumerate(current_segments):
346
- if segment["start"] == segment["end"] or segment["text"].strip() == "":
347
- segment["text"] = ""
348
- segment["tokens"] = []
349
- segment["words"] = []
350
-
351
- all_segments.extend(
352
- [
353
- {"id": i, **segment}
354
- for i, segment in enumerate(
355
- current_segments, start=len(all_segments)
356
- )
357
- ]
358
- )
359
- all_tokens.extend(
360
- [token for segment in current_segments for token in segment["tokens"]]
361
- )
362
-
363
- # update progress bar
364
- pbar.update(min(content_frames, seek) - previous_seek)
365
-
366
-
367
- return dict(
368
- text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
369
- segments=all_segments,
370
- language=language,
371
- )
372
-
373
-
374
- def transcribe_with_vad(
375
- model: "Whisper",
376
- audio: str,
377
- vad_pipeline,
378
- mel = None,
379
- verbose: Optional[bool] = None,
380
- **kwargs
381
- ):
382
- """
383
- Transcribe per VAD segment
384
- """
385
-
386
- vad_segments = vad_pipeline(audio)
387
-
388
- # if not torch.is_tensor(audio):
389
- # if isinstance(audio, str):
390
- audio = load_audio(audio)
391
- audio = torch.from_numpy(audio)
392
-
393
- prev = 0
394
- output = {"segments": []}
395
-
396
- # merge segments to approx 30s inputs to make whisper most appropraite
397
- vad_segments = merge_chunks(vad_segments, chunk_size=CHUNK_LENGTH)
398
- if len(vad_segments) == 0:
399
- return output
400
-
401
- print(">>Performing transcription...")
402
- for sdx, seg_t in enumerate(vad_segments):
403
- if verbose:
404
- print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
405
- seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE), int(seg_t["end"] * SAMPLE_RATE)
406
- local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
407
- audio = audio[local_f_start:] # seek forward
408
- seg_audio = audio[:local_f_end-local_f_start] # seek forward
409
- prev = seg_f_start
410
- local_mel = log_mel_spectrogram(seg_audio, padding=N_SAMPLES)
411
- # need to pad
412
-
413
- result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
414
- seg_t["text"] = result["text"]
415
- output["segments"].append(
416
- {
417
- "start": seg_t["start"],
418
- "end": seg_t["end"],
419
- "language": result["language"],
420
- "text": result["text"],
421
- "seg-text": [x["text"] for x in result["segments"]],
422
- "seg-start": [x["start"] for x in result["segments"]],
423
- "seg-end": [x["end"] for x in result["segments"]],
424
- }
425
- )
426
-
427
- output["language"] = output["segments"][0]["language"]
428
-
429
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperx/diarize.py DELETED
@@ -1,76 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- from pyannote.audio import Pipeline
4
-
5
- class DiarizationPipeline:
6
- def __init__(
7
- self,
8
- model_name="pyannote/[email protected]",
9
- use_auth_token=None,
10
- ):
11
- self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
12
-
13
- def __call__(self, audio, min_speakers=None, max_speakers=None):
14
- segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
15
- diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
16
- diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
17
- diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
18
- return diarize_df
19
-
20
- def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
21
- for seg in result_segments:
22
- wdf = seg['word-segments']
23
- if len(wdf['start'].dropna()) == 0:
24
- wdf['start'] = seg['start']
25
- wdf['end'] = seg['end']
26
- speakers = []
27
- for wdx, wrow in wdf.iterrows():
28
- if not np.isnan(wrow['start']):
29
- diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
30
- diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
31
- # remove no hit
32
- if not fill_nearest:
33
- dia_tmp = diarize_df[diarize_df['intersection'] > 0]
34
- else:
35
- dia_tmp = diarize_df
36
- if len(dia_tmp) == 0:
37
- speaker = None
38
- else:
39
- speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
40
- else:
41
- speaker = None
42
- speakers.append(speaker)
43
- seg['word-segments']['speaker'] = speakers
44
-
45
- speaker_count = pd.Series(speakers).value_counts()
46
- if len(speaker_count) == 0:
47
- seg["speaker"]= "UNKNOWN"
48
- else:
49
- seg["speaker"] = speaker_count.index[0]
50
-
51
- # create word level segments for .srt
52
- word_seg = []
53
- for seg in result_segments:
54
- wseg = pd.DataFrame(seg["word-segments"])
55
- for wdx, wrow in wseg.iterrows():
56
- if wrow["start"] is not None:
57
- speaker = wrow['speaker']
58
- if speaker is None or speaker == np.nan:
59
- speaker = "UNKNOWN"
60
- word_seg.append(
61
- {
62
- "start": wrow["start"],
63
- "end": wrow["end"],
64
- "text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
65
- }
66
- )
67
-
68
- # TODO: create segments but split words on new speaker
69
-
70
- return result_segments, word_seg
71
-
72
- class Segment:
73
- def __init__(self, start, end, speaker=None):
74
- self.start = start
75
- self.end = end
76
- self.speaker = speaker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperx/transcribe.py DELETED
@@ -1,220 +0,0 @@
1
- import argparse
2
- import os
3
- import gc
4
- import warnings
5
- from typing import TYPE_CHECKING, Optional, Tuple, Union
6
- import numpy as np
7
- import torch
8
- import tempfile
9
- import ffmpeg
10
- from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
11
- from whisper.audio import SAMPLE_RATE
12
- from whisper.utils import (
13
- optional_float,
14
- optional_int,
15
- str2bool,
16
- )
17
-
18
- from .alignment import load_align_model, align
19
- from .asr import transcribe, transcribe_with_vad
20
- from .diarize import DiarizationPipeline, assign_word_speakers
21
- from .utils import get_writer
22
- from .vad import load_vad_model
23
-
24
- def cli():
25
- from whisper import available_models
26
-
27
- # fmt: off
28
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
29
- parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
30
- parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
31
- parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
32
- parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
33
- parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
34
- parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="format of the output file; if not specified, all available formats will be produced")
35
- parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
36
-
37
- parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
38
- parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
39
-
40
- # alignment params
41
- parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
42
- parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment (if not using VAD).")
43
- parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment (if not using VAD)")
44
- parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
45
- parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
46
-
47
- # vad params
48
- parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
49
- parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
50
- parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
51
-
52
- # diarization params
53
- parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
54
- parser.add_argument("--min_speakers", default=None, type=int)
55
- parser.add_argument("--max_speakers", default=None, type=int)
56
-
57
- parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
58
- parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
59
- parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
60
- parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
61
- parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
62
-
63
- parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
64
- parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
65
- parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
66
- parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
67
-
68
- parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
69
- parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
70
- parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
71
- parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
72
- parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
73
- parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
74
- parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
75
- parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
76
-
77
- parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
78
- # parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
79
- parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
80
- # fmt: on
81
-
82
- args = parser.parse_args().__dict__
83
- model_name: str = args.pop("model")
84
- model_dir: str = args.pop("model_dir")
85
- output_dir: str = args.pop("output_dir")
86
- output_format: str = args.pop("output_format")
87
- device: str = args.pop("device")
88
- # model_flush: bool = args.pop("model_flush")
89
- os.makedirs(output_dir, exist_ok=True)
90
-
91
- tmp_dir: str = args.pop("tmp_dir")
92
- if tmp_dir is not None:
93
- os.makedirs(tmp_dir, exist_ok=True)
94
-
95
- align_model: str = args.pop("align_model")
96
- align_extend: float = args.pop("align_extend")
97
- align_from_prev: bool = args.pop("align_from_prev")
98
- interpolate_method: str = args.pop("interpolate_method")
99
- no_align: bool = args.pop("no_align")
100
-
101
- hf_token: str = args.pop("hf_token")
102
- vad_filter: bool = args.pop("vad_filter")
103
- vad_onset: float = args.pop("vad_onset")
104
- vad_offset: float = args.pop("vad_offset")
105
-
106
- diarize: bool = args.pop("diarize")
107
- min_speakers: int = args.pop("min_speakers")
108
- max_speakers: int = args.pop("max_speakers")
109
-
110
- if vad_filter:
111
- from pyannote.audio import Pipeline
112
- from pyannote.audio import Model, Pipeline
113
- vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token)
114
- else:
115
- vad_model = None
116
-
117
- # if model_flush:
118
- # print(">>Model flushing activated... Only loading model after ASR stage")
119
- # del align_model
120
- # align_model = ""
121
-
122
-
123
- if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
124
- if args["language"] is not None:
125
- warnings.warn(
126
- f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
127
- )
128
- args["language"] = "en"
129
-
130
- temperature = args.pop("temperature")
131
- if (increment := args.pop("temperature_increment_on_fallback")) is not None:
132
- temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
133
- else:
134
- temperature = [temperature]
135
-
136
- if (threads := args.pop("threads")) > 0:
137
- torch.set_num_threads(threads)
138
-
139
- from whisper import load_model
140
-
141
- writer = get_writer(output_format, output_dir)
142
-
143
- # Part 1: VAD & ASR Loop
144
- results = []
145
- tmp_results = []
146
- model = load_model(model_name, device=device, download_root=model_dir)
147
- for audio_path in args.pop("audio"):
148
- input_audio_path = audio_path
149
- tfile = None
150
-
151
- # >> VAD & ASR
152
- if vad_model is not None:
153
- if not audio_path.endswith(".wav"):
154
- print(">>VAD requires .wav format, converting to wav as a tempfile...")
155
- audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
156
- if tmp_dir is not None:
157
- input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
158
- else:
159
- input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav")
160
- ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
161
- print(">>Performing VAD...")
162
- result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args)
163
- else:
164
- print(">>Performing transcription...")
165
- result = transcribe(model, input_audio_path, temperature=temperature, **args)
166
-
167
- results.append((result, input_audio_path))
168
-
169
- # Unload Whisper and VAD
170
- del model
171
- del vad_model
172
- gc.collect()
173
- torch.cuda.empty_cache()
174
-
175
- # Part 2: Align Loop
176
- if not no_align:
177
- tmp_results = results
178
- results = []
179
- align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
180
- align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
181
- for result, input_audio_path in tmp_results:
182
- # >> Align
183
- if align_model is not None and len(result["segments"]) > 0:
184
- if result.get("language", "en") != align_metadata["language"]:
185
- # load new language
186
- print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
187
- align_model, align_metadata = load_align_model(result["language"], device)
188
- print(">>Performing alignment...")
189
- result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
190
- extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
191
- results.append((result, input_audio_path))
192
-
193
- # Unload align model
194
- del align_model
195
- gc.collect()
196
- torch.cuda.empty_cache()
197
-
198
- # >> Diarize
199
- if diarize:
200
- if hf_token is None:
201
- print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
202
- tmp_results = results
203
- results = []
204
- diarize_model = DiarizationPipeline(use_auth_token=hf_token)
205
- for result, input_audio_path in tmp_results:
206
- diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
207
- results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
208
- result = {"segments": results_segments, "word_segments": word_segments}
209
- results.append((result, input_audio_path))
210
-
211
- # >> Write
212
- for result, audio_path in results:
213
- writer(result, audio_path)
214
-
215
- # cleanup
216
- if input_audio_path != audio_path:
217
- os.remove(input_audio_path)
218
-
219
- if __name__ == "__main__":
220
- cli()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperx/utils.py DELETED
@@ -1,317 +0,0 @@
1
- import os
2
- import zlib
3
- from typing import Callable, TextIO, Iterator, Tuple
4
- import pandas as pd
5
- import numpy as np
6
-
7
- def interpolate_nans(x, method='nearest'):
8
- if x.notnull().sum() > 1:
9
- return x.interpolate(method=method).ffill().bfill()
10
- else:
11
- return x.ffill().bfill()
12
-
13
-
14
- def write_txt(transcript: Iterator[dict], file: TextIO):
15
- for segment in transcript:
16
- print(segment['text'].strip(), file=file, flush=True)
17
-
18
-
19
- def write_vtt(transcript: Iterator[dict], file: TextIO):
20
- print("WEBVTT\n", file=file)
21
- for segment in transcript:
22
- print(
23
- f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
24
- f"{segment['text'].strip().replace('-->', '->')}\n",
25
- file=file,
26
- flush=True,
27
- )
28
-
29
- def write_tsv(transcript: Iterator[dict], file: TextIO):
30
- print("start", "end", "text", sep="\t", file=file)
31
- for segment in transcript:
32
- print(segment['start'], file=file, end="\t")
33
- print(segment['end'], file=file, end="\t")
34
- print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
35
-
36
-
37
- def write_srt(transcript: Iterator[dict], file: TextIO):
38
- """
39
- Write a transcript to a file in SRT format.
40
-
41
- Example usage:
42
- from pathlib import Path
43
- from whisper.utils import write_srt
44
-
45
- result = transcribe(model, audio_path, temperature=temperature, **args)
46
-
47
- # save SRT
48
- audio_basename = Path(audio_path).stem
49
- with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
50
- write_srt(result["segments"], file=srt)
51
- """
52
- for i, segment in enumerate(transcript, start=1):
53
- # write srt lines
54
- print(
55
- f"{i}\n"
56
- f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
57
- f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
58
- f"{segment['text'].strip().replace('-->', '->')}\n",
59
- file=file,
60
- flush=True,
61
- )
62
-
63
-
64
- def write_ass(transcript: Iterator[dict],
65
- file: TextIO,
66
- resolution: str = "word",
67
- color: str = None, underline=True,
68
- prefmt: str = None, suffmt: str = None,
69
- font: str = None, font_size: int = 24,
70
- strip=True, **kwargs):
71
- """
72
- Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py
73
- Generate Advanced SubStation Alpha (ass) file from results to
74
- display both phrase-level & word-level timestamp simultaneously by:
75
- -using segment-level timestamps display phrases as usual
76
- -using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment
77
- Note: ass file is used in the same way as srt, vtt, etc.
78
- Parameters
79
- ----------
80
- transcript: dict
81
- results from modified model
82
- file: TextIO
83
- file object to write to
84
- resolution: str
85
- "word" or "char", timestamp resolution to highlight.
86
- color: str
87
- color code for a word at its corresponding timestamp
88
- <bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
89
- underline: bool
90
- whether to underline a word at its corresponding timestamp
91
- prefmt: str
92
- used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline')
93
- appears as such in the .ass file:
94
- Hi, {<prefmt>}how{<suffmt>} are you?
95
- reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
96
- suffmt: str
97
- used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline')
98
- appears as such in the .ass file:
99
- Hi, {<prefmt>}how{<suffmt>} are you?
100
- reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
101
- font: str
102
- word font (default: Arial)
103
- font_size: int
104
- word font size (default: 48)
105
- kwargs:
106
- used for format styles:
107
- 'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
108
- 'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
109
- 'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
110
-
111
- """
112
-
113
- fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff',
114
- 'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0',
115
- 'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100',
116
- 'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0',
117
- 'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'}
118
-
119
- for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()):
120
- kwargs[k] = f'&H{kwargs[k]}'
121
-
122
- fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict)
123
-
124
- if font:
125
- fmt_style_dict.update(Fontname=font)
126
- if font_size:
127
- fmt_style_dict.update(Fontsize=font_size)
128
-
129
- fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}'
130
-
131
- styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}'
132
-
133
- ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \
134
- f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
135
- f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'
136
-
137
- if prefmt or suffmt:
138
- if suffmt:
139
- assert prefmt, 'prefmt must be used along with suffmt'
140
- else:
141
- suffmt = r'\r'
142
- else:
143
- if not color:
144
- color = 'HFF00'
145
- underline_code = r'\u1' if underline else ''
146
-
147
- prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}'
148
- suffmt = r'{\r}'
149
-
150
- def secs_to_hhmmss(secs: Tuple[float, int]):
151
- mm, ss = divmod(secs, 60)
152
- hh, mm = divmod(mm, 60)
153
- return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
154
-
155
-
156
- def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
157
- if idx_0 == -1:
158
- text = chars
159
- else:
160
- text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
161
- return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
162
- f"Default,,0,0,0,,{text.strip() if strip else text}"
163
-
164
- if resolution == "word":
165
- resolution_key = "word-segments"
166
- elif resolution == "char":
167
- resolution_key = "char-segments"
168
- else:
169
- raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
170
-
171
- ass_arr = []
172
-
173
- for segment in transcript:
174
- # if "12" in segment['text']:
175
- # import pdb; pdb.set_trace()
176
- if resolution_key in segment:
177
- res_segs = pd.DataFrame(segment[resolution_key])
178
- prev = segment['start']
179
- if "speaker" in segment:
180
- speaker_str = f"[{segment['speaker']}]: "
181
- else:
182
- speaker_str = ""
183
- for cdx, crow in res_segs.iterrows():
184
- if not np.isnan(crow['start']):
185
- if resolution == "char":
186
- idx_0 = cdx
187
- idx_1 = cdx + 1
188
- elif resolution == "word":
189
- idx_0 = int(crow["segment-text-start"])
190
- idx_1 = int(crow["segment-text-end"])
191
- # fill gap
192
- if crow['start'] > prev:
193
- filler_ts = {
194
- "chars": speaker_str + segment['text'],
195
- "start": prev,
196
- "end": crow['start'],
197
- "idx_0": -1,
198
- "idx_1": -1
199
- }
200
-
201
- ass_arr.append(filler_ts)
202
- # highlight current word
203
- f_word_ts = {
204
- "chars": speaker_str + segment['text'],
205
- "start": crow['start'],
206
- "end": crow['end'],
207
- "idx_0": idx_0 + len(speaker_str),
208
- "idx_1": idx_1 + len(speaker_str)
209
- }
210
- ass_arr.append(f_word_ts)
211
- prev = crow['end']
212
-
213
- ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
214
-
215
- file.write(ass_str)
216
-
217
-
218
- from whisper.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteVTT, WriteSRT, WriteTSV, WriteJSON, format_timestamp
219
-
220
- class WriteASS(ResultWriter):
221
- extension: str = "ass"
222
-
223
- def write_result(self, result: dict, file: TextIO):
224
- write_ass(result["segments"], file, resolution="word")
225
-
226
- class WriteASSchar(ResultWriter):
227
- extension: str = "ass"
228
-
229
- def write_result(self, result: dict, file: TextIO):
230
- write_ass(result["segments"], file, resolution="char")
231
-
232
- class WritePickle(ResultWriter):
233
- extension: str = "ass"
234
-
235
- def write_result(self, result: dict, file: TextIO):
236
- pd.DataFrame(result["segments"]).to_pickle(file)
237
-
238
- class WriteSRTWord(ResultWriter):
239
- extension: str = "word.srt"
240
- always_include_hours: bool = True
241
- decimal_marker: str = ","
242
-
243
- def iterate_result(self, result: dict):
244
- for segment in result["word_segments"]:
245
- segment_start = self.format_timestamp(segment["start"])
246
- segment_end = self.format_timestamp(segment["end"])
247
- segment_text = segment["text"].strip().replace("-->", "->")
248
-
249
- if word_timings := segment.get("words", None):
250
- all_words = [timing["word"] for timing in word_timings]
251
- all_words[0] = all_words[0].strip() # remove the leading space, if any
252
- last = segment_start
253
- for i, this_word in enumerate(word_timings):
254
- start = self.format_timestamp(this_word["start"])
255
- end = self.format_timestamp(this_word["end"])
256
- if last != start:
257
- yield last, start, segment_text
258
-
259
- yield start, end, "".join(
260
- [
261
- f"<u>{word}</u>" if j == i else word
262
- for j, word in enumerate(all_words)
263
- ]
264
- )
265
- last = end
266
-
267
- if last != segment_end:
268
- yield last, segment_end, segment_text
269
- else:
270
- yield segment_start, segment_end, segment_text
271
-
272
- def write_result(self, result: dict, file: TextIO):
273
- if "word_segments" not in result:
274
- return
275
- for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
276
- print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
277
-
278
- def format_timestamp(self, seconds: float):
279
- return format_timestamp(
280
- seconds=seconds,
281
- always_include_hours=self.always_include_hours,
282
- decimal_marker=self.decimal_marker,
283
- )
284
-
285
- def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
286
- writers = {
287
- "txt": WriteTXT,
288
- "vtt": WriteVTT,
289
- "srt": WriteSRT,
290
- "tsv": WriteTSV,
291
- "ass": WriteASS,
292
- "srt-word": WriteSRTWord,
293
- # "ass-char": WriteASSchar,
294
- # "pickle": WritePickle,
295
- # "json": WriteJSON,
296
- }
297
-
298
- writers_other = {
299
- "pkl": WritePickle,
300
- "ass-char": WriteASSchar
301
- }
302
-
303
- if output_format == "all":
304
- all_writers = [writer(output_dir) for writer in writers.values()]
305
-
306
- def write_all(result: dict, file: TextIO):
307
- for writer in all_writers:
308
- writer(result, file)
309
-
310
- return write_all
311
-
312
- if output_format in writers:
313
- return writers[output_format](output_dir)
314
- elif output_format in writers_other:
315
- return writers_other[output_format](output_dir)
316
- else:
317
- raise ValueError(f"Output format '{output_format}' not supported, choose from {writers.keys()} and {writers_other.keys()}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperx/vad.py DELETED
@@ -1,305 +0,0 @@
1
- import os
2
- import urllib
3
- import pandas as pd
4
- import numpy as np
5
- import torch
6
- import hashlib
7
- from tqdm import tqdm
8
- from typing import Optional, Callable, Union, Text
9
- from pyannote.audio.core.io import AudioFile
10
- from pyannote.core import Annotation, Segment, SlidingWindowFeature
11
- from pyannote.audio.pipelines.utils import PipelineModel
12
- from pyannote.audio import Model
13
- from pyannote.audio.pipelines import VoiceActivityDetection
14
- from .diarize import Segment as SegmentX
15
- from typing import List, Tuple, Optional
16
-
17
- VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
18
-
19
- def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
20
- model_dir = torch.hub._get_torch_home()
21
- os.makedirs(model_dir, exist_ok = True)
22
- model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
23
- if os.path.exists(model_fp) and not os.path.isfile(model_fp):
24
- raise RuntimeError(f"{model_fp} exists and is not a regular file")
25
-
26
- if not os.path.isfile(model_fp):
27
- with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
28
- with tqdm(
29
- total=int(source.info().get("Content-Length")),
30
- ncols=80,
31
- unit="iB",
32
- unit_scale=True,
33
- unit_divisor=1024,
34
- ) as loop:
35
- while True:
36
- buffer = source.read(8192)
37
- if not buffer:
38
- break
39
-
40
- output.write(buffer)
41
- loop.update(len(buffer))
42
-
43
- model_bytes = open(model_fp, "rb").read()
44
- if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
45
- raise RuntimeError(
46
- "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
47
- )
48
-
49
- vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
50
- hyperparameters = {"onset": vad_onset,
51
- "offset": vad_offset,
52
- "min_duration_on": 0.1,
53
- "min_duration_off": 0.1}
54
- vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
55
- vad_pipeline.instantiate(hyperparameters)
56
-
57
- return vad_pipeline
58
-
59
- class Binarize:
60
- """Binarize detection scores using hysteresis thresholding, with min-cut operation
61
- to ensure not segments are longer than max_duration.
62
-
63
- Parameters
64
- ----------
65
- onset : float, optional
66
- Onset threshold. Defaults to 0.5.
67
- offset : float, optional
68
- Offset threshold. Defaults to `onset`.
69
- min_duration_on : float, optional
70
- Remove active regions shorter than that many seconds. Defaults to 0s.
71
- min_duration_off : float, optional
72
- Fill inactive regions shorter than that many seconds. Defaults to 0s.
73
- pad_onset : float, optional
74
- Extend active regions by moving their start time by that many seconds.
75
- Defaults to 0s.
76
- pad_offset : float, optional
77
- Extend active regions by moving their end time by that many seconds.
78
- Defaults to 0s.
79
- max_duration: float
80
- The maximum length of an active segment, divides segment at timestamp with lowest score.
81
- Reference
82
- ---------
83
- Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
84
- RNN-based Voice Activity Detection", InterSpeech 2015.
85
-
86
- Modified by Max Bain to include WhisperX's min-cut operation
87
- https://arxiv.org/abs/2303.00747
88
-
89
- Pyannote-audio
90
- """
91
-
92
- def __init__(
93
- self,
94
- onset: float = 0.5,
95
- offset: Optional[float] = None,
96
- min_duration_on: float = 0.0,
97
- min_duration_off: float = 0.0,
98
- pad_onset: float = 0.0,
99
- pad_offset: float = 0.0,
100
- max_duration: float = float('inf')
101
- ):
102
-
103
- super().__init__()
104
-
105
- self.onset = onset
106
- self.offset = offset or onset
107
-
108
- self.pad_onset = pad_onset
109
- self.pad_offset = pad_offset
110
-
111
- self.min_duration_on = min_duration_on
112
- self.min_duration_off = min_duration_off
113
-
114
- self.max_duration = max_duration
115
-
116
- def __call__(self, scores: SlidingWindowFeature) -> Annotation:
117
- """Binarize detection scores
118
- Parameters
119
- ----------
120
- scores : SlidingWindowFeature
121
- Detection scores.
122
- Returns
123
- -------
124
- active : Annotation
125
- Binarized scores.
126
- """
127
-
128
- num_frames, num_classes = scores.data.shape
129
- frames = scores.sliding_window
130
- timestamps = [frames[i].middle for i in range(num_frames)]
131
-
132
- # annotation meant to store 'active' regions
133
- active = Annotation()
134
- for k, k_scores in enumerate(scores.data.T):
135
-
136
- label = k if scores.labels is None else scores.labels[k]
137
-
138
- # initial state
139
- start = timestamps[0]
140
- is_active = k_scores[0] > self.onset
141
- curr_scores = [k_scores[0]]
142
- curr_timestamps = [start]
143
- for t, y in zip(timestamps[1:], k_scores[1:]):
144
- # currently active
145
- if is_active:
146
- curr_duration = t - start
147
- if curr_duration > self.max_duration:
148
- # if curr_duration > 15:
149
- # import pdb; pdb.set_trace()
150
- search_after = len(curr_scores) // 2
151
- # divide segment
152
- min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
153
- min_score_t = curr_timestamps[min_score_div_idx]
154
- region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
155
- active[region, k] = label
156
- start = curr_timestamps[min_score_div_idx]
157
- curr_scores = curr_scores[min_score_div_idx+1:]
158
- curr_timestamps = curr_timestamps[min_score_div_idx+1:]
159
- # switching from active to inactive
160
- elif y < self.offset:
161
- region = Segment(start - self.pad_onset, t + self.pad_offset)
162
- active[region, k] = label
163
- start = t
164
- is_active = False
165
- curr_scores = []
166
- curr_timestamps = []
167
- # currently inactive
168
- else:
169
- # switching from inactive to active
170
- if y > self.onset:
171
- start = t
172
- is_active = True
173
- curr_scores.append(y)
174
- curr_timestamps.append(t)
175
-
176
- # if active at the end, add final region
177
- if is_active:
178
- region = Segment(start - self.pad_onset, t + self.pad_offset)
179
- active[region, k] = label
180
-
181
- # because of padding, some active regions might be overlapping: merge them.
182
- # also: fill same speaker gaps shorter than min_duration_off
183
- if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
184
- if self.max_duration < float("inf"):
185
- raise NotImplementedError(f"This would break current max_duration param")
186
- active = active.support(collar=self.min_duration_off)
187
-
188
- # remove tracks shorter than min_duration_on
189
- if self.min_duration_on > 0:
190
- for segment, track in list(active.itertracks()):
191
- if segment.duration < self.min_duration_on:
192
- del active[segment, track]
193
-
194
- return active
195
-
196
-
197
- class VoiceActivitySegmentation(VoiceActivityDetection):
198
- def __init__(
199
- self,
200
- segmentation: PipelineModel = "pyannote/segmentation",
201
- fscore: bool = False,
202
- use_auth_token: Union[Text, None] = None,
203
- **inference_kwargs,
204
- ):
205
-
206
- super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
207
-
208
- def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
209
- """Apply voice activity detection
210
-
211
- Parameters
212
- ----------
213
- file : AudioFile
214
- Processed file.
215
- hook : callable, optional
216
- Hook called after each major step of the pipeline with the following
217
- signature: hook("step_name", step_artefact, file=file)
218
-
219
- Returns
220
- -------
221
- speech : Annotation
222
- Speech regions.
223
- """
224
-
225
- # setup hook (e.g. for debugging purposes)
226
- hook = self.setup_hook(file, hook=hook)
227
-
228
- # apply segmentation model (only if needed)
229
- # output shape is (num_chunks, num_frames, 1)
230
- if self.training:
231
- if self.CACHED_SEGMENTATION in file:
232
- segmentations = file[self.CACHED_SEGMENTATION]
233
- else:
234
- segmentations = self._segmentation(file)
235
- file[self.CACHED_SEGMENTATION] = segmentations
236
- else:
237
- segmentations: SlidingWindowFeature = self._segmentation(file)
238
-
239
- return segmentations
240
-
241
-
242
- def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
243
-
244
- active = Annotation()
245
- for k, vad_t in enumerate(vad_arr):
246
- region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
247
- active[region, k] = 1
248
-
249
-
250
- if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
251
- active = active.support(collar=min_duration_off)
252
-
253
- # remove tracks shorter than min_duration_on
254
- if min_duration_on > 0:
255
- for segment, track in list(active.itertracks()):
256
- if segment.duration < min_duration_on:
257
- del active[segment, track]
258
-
259
- active = active.for_json()
260
- active_segs = pd.DataFrame([x['segment'] for x in active['content']])
261
- return active_segs
262
-
263
- def merge_chunks(segments, chunk_size):
264
- """
265
- Merge operation described in paper
266
- """
267
- curr_end = 0
268
- merged_segments = []
269
- seg_idxs = []
270
- speaker_idxs = []
271
-
272
- assert chunk_size > 0
273
- binarize = Binarize(max_duration=chunk_size)
274
- segments = binarize(segments)
275
- segments_list = []
276
- for speech_turn in segments.get_timeline():
277
- segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
278
-
279
- if len(segments_list) == 0:
280
- print("No active speech found in audio")
281
- return []
282
- # assert segments_list, "segments_list is empty."
283
- # Make sur the starting point is the start of the segment.
284
- curr_start = segments_list[0].start
285
-
286
- for seg in segments_list:
287
- if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
288
- merged_segments.append({
289
- "start": curr_start,
290
- "end": curr_end,
291
- "segments": seg_idxs,
292
- })
293
- curr_start = seg.start
294
- seg_idxs = []
295
- speaker_idxs = []
296
- curr_end = seg.end
297
- seg_idxs.append((seg.start, seg.end))
298
- speaker_idxs.append(seg.speaker)
299
- # add final
300
- merged_segments.append({
301
- "start": curr_start,
302
- "end": curr_end,
303
- "segments": seg_idxs,
304
- })
305
- return merged_segments